main f5e55835df4e cached
163 files
2.0 MB
532.8k tokens
656 symbols
1 requests
Download .txt
Showing preview only (2,131K chars total). Download the full file or copy to clipboard to get everything.
Repository: muxueChen/ComfyUI_NTCosyVoice
Branch: main
Commit: f5e55835df4e
Files: 163
Total size: 2.0 MB

Directory structure:
gitextract_rz35ckig/

├── .github/
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── README.md
├── __init__.py
├── cosyvoice/
│   ├── __init__.py
│   ├── bin/
│   │   ├── __init__.py
│   │   ├── average_model.py
│   │   ├── export_jit.py
│   │   ├── export_onnx.py
│   │   ├── export_trt.sh
│   │   ├── inference.py
│   │   └── train.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── cosyvoice.py
│   │   ├── frontend.py
│   │   └── model.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── processor.py
│   ├── flow/
│   │   ├── __init__.py
│   │   ├── decoder.py
│   │   ├── flow.py
│   │   ├── flow_matching.py
│   │   └── length_regulator.py
│   ├── hifigan/
│   │   ├── __init__.py
│   │   ├── discriminator.py
│   │   ├── f0_predictor.py
│   │   ├── generator.py
│   │   └── hifigan.py
│   ├── llm/
│   │   ├── __init__.py
│   │   └── llm.py
│   ├── tokenizer/
│   │   ├── __init__.py
│   │   ├── assets/
│   │   │   └── multilingual_zh_ja_yue_char_del.tiktoken
│   │   └── tokenizer.py
│   ├── transformer/
│   │   ├── __init__.py
│   │   ├── activation.py
│   │   ├── attention.py
│   │   ├── convolution.py
│   │   ├── decoder.py
│   │   ├── decoder_layer.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   ├── encoder_layer.py
│   │   ├── label_smoothing_loss.py
│   │   ├── positionwise_feed_forward.py
│   │   ├── subsampling.py
│   │   └── upsample_encoder.py
│   └── utils/
│       ├── __init__.py
│       ├── class_utils.py
│       ├── common.py
│       ├── executor.py
│       ├── file_utils.py
│       ├── frontend_utils.py
│       ├── losses.py
│       ├── mask.py
│       ├── scheduler.py
│       └── train_utils.py
├── downloadmodel.py
├── examples/
│   ├── CrossLingual.json
│   ├── Instruct2.json
│   └── ZeroShot.json
├── pyproject.toml
├── requirements.txt
└── third_party/
    ├── Matcha-TTS/
    │   ├── LICENSE
    │   ├── MANIFEST.in
    │   ├── Makefile
    │   ├── README.md
    │   ├── __init__.py
    │   ├── configs/
    │   │   ├── __init__.py
    │   │   ├── callbacks/
    │   │   │   ├── default.yaml
    │   │   │   ├── model_checkpoint.yaml
    │   │   │   ├── model_summary.yaml
    │   │   │   ├── none.yaml
    │   │   │   └── rich_progress_bar.yaml
    │   │   ├── data/
    │   │   │   ├── hi-fi_en-US_female.yaml
    │   │   │   ├── ljspeech.yaml
    │   │   │   └── vctk.yaml
    │   │   ├── debug/
    │   │   │   ├── default.yaml
    │   │   │   ├── fdr.yaml
    │   │   │   ├── limit.yaml
    │   │   │   ├── overfit.yaml
    │   │   │   └── profiler.yaml
    │   │   ├── eval.yaml
    │   │   ├── experiment/
    │   │   │   ├── hifi_dataset_piper_phonemizer.yaml
    │   │   │   ├── ljspeech.yaml
    │   │   │   ├── ljspeech_min_memory.yaml
    │   │   │   └── multispeaker.yaml
    │   │   ├── extras/
    │   │   │   └── default.yaml
    │   │   ├── hparams_search/
    │   │   │   └── mnist_optuna.yaml
    │   │   ├── hydra/
    │   │   │   └── default.yaml
    │   │   ├── local/
    │   │   │   └── .gitkeep
    │   │   ├── logger/
    │   │   │   ├── aim.yaml
    │   │   │   ├── comet.yaml
    │   │   │   ├── csv.yaml
    │   │   │   ├── many_loggers.yaml
    │   │   │   ├── mlflow.yaml
    │   │   │   ├── neptune.yaml
    │   │   │   ├── tensorboard.yaml
    │   │   │   └── wandb.yaml
    │   │   ├── model/
    │   │   │   ├── cfm/
    │   │   │   │   └── default.yaml
    │   │   │   ├── decoder/
    │   │   │   │   └── default.yaml
    │   │   │   ├── encoder/
    │   │   │   │   └── default.yaml
    │   │   │   ├── matcha.yaml
    │   │   │   └── optimizer/
    │   │   │       └── adam.yaml
    │   │   ├── paths/
    │   │   │   └── default.yaml
    │   │   ├── train.yaml
    │   │   └── trainer/
    │   │       ├── cpu.yaml
    │   │       ├── ddp.yaml
    │   │       ├── ddp_sim.yaml
    │   │       ├── default.yaml
    │   │       ├── gpu.yaml
    │   │       └── mps.yaml
    │   ├── matcha/
    │   │   ├── VERSION
    │   │   ├── __init__.py
    │   │   ├── app.py
    │   │   ├── cli.py
    │   │   ├── data/
    │   │   │   ├── __init__.py
    │   │   │   ├── components/
    │   │   │   │   └── __init__.py
    │   │   │   └── text_mel_datamodule.py
    │   │   ├── hifigan/
    │   │   │   ├── LICENSE
    │   │   │   ├── README.md
    │   │   │   ├── __init__.py
    │   │   │   ├── config.py
    │   │   │   ├── denoiser.py
    │   │   │   ├── env.py
    │   │   │   ├── meldataset.py
    │   │   │   ├── models.py
    │   │   │   └── xutils.py
    │   │   ├── models/
    │   │   │   ├── __init__.py
    │   │   │   ├── baselightningmodule.py
    │   │   │   ├── components/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── decoder.py
    │   │   │   │   ├── flow_matching.py
    │   │   │   │   ├── text_encoder.py
    │   │   │   │   └── transformer.py
    │   │   │   └── matcha_tts.py
    │   │   ├── onnx/
    │   │   │   ├── __init__.py
    │   │   │   ├── export.py
    │   │   │   └── infer.py
    │   │   ├── text/
    │   │   │   ├── __init__.py
    │   │   │   ├── cleaners.py
    │   │   │   ├── numbers.py
    │   │   │   └── symbols.py
    │   │   ├── train.py
    │   │   └── utils/
    │   │       ├── __init__.py
    │   │       ├── audio.py
    │   │       ├── generate_data_statistics.py
    │   │       ├── instantiators.py
    │   │       ├── logging_utils.py
    │   │       ├── model.py
    │   │       ├── monotonic_align/
    │   │       │   ├── __init__.py
    │   │       │   ├── core.pyx
    │   │       │   └── setup.py
    │   │       ├── pylogger.py
    │   │       ├── rich_utils.py
    │   │       └── utils.py
    │   ├── notebooks/
    │   │   └── .gitkeep
    │   ├── pyproject.toml
    │   ├── requirements.txt
    │   ├── scripts/
    │   │   └── schedule.sh
    │   ├── setup.py
    │   └── synthesis.ipynb
    └── __init__.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to Comfy registry
on:
  workflow_dispatch:
  push:
    branches:
      - main
      - master
    paths:
      - "pyproject.toml"

jobs:
  publish-node:
    name: Publish Custom Node to registry
    runs-on: ubuntu-latest
    if: ${{ github.repository_owner == 'muxueChen' }}
    steps:
      - name: Check out code
        uses: actions/checkout@v4
        with:
          submodules: true
      - name: Publish Custom Node
        uses: Comfy-Org/publish-node-action@main
        with:
          ## Add your own personal access token to your Github Repository secrets and reference it here.
          personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}


================================================
FILE: .gitignore
================================================
/idea
/__pycache__
**/__pycache__
**/**/__pycache__
/pretrained_models
.vscode
.vs
.idea


================================================
FILE: README.md
================================================
# CosyVoice2 for ComfyUI
ComfyUI_NTCosyVoice is a plugin of ComfyUI for Cosysvoice2
## install plugin
```angular2html
git clone https://github.com/muxueChen/ComfyUI_NTCosyVoice.git
```
## Install dependency packages
```angular2html
cd ComfyUI_NTCosyVoice
pip install -r requirements.txt
```
## download models
```angular2html
python downloadmodel.py
```
## Install ttsfrd (Optional)
Notice that this step is not necessary. If you do not install ttsfrd package, we will use WeTextProcessing by default.
```angular2html
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd_dependency-0.1-py3-none-any.whl
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
```

================================================
FILE: __init__.py
================================================
import sys
import os
nor_dir = os.path.dirname(__file__)
Matcha_path = os.path.join(nor_dir, 'third_party/Matcha-TTS')
sys.path.append(nor_dir)
sys.path.append(Matcha_path)

from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
import torchaudio
import torch


def nt_load_wav(speech, sample_rate, target_sr):
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech


class NTCosyVoiceZeroShotSampler:
    def __init__(self):
        self.__cosyvoice = None

    @property
    def cosyvoice(self):
        if self.__cosyvoice is None:
            model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B')
            self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False)
        return self.__cosyvoice

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "audio": ("AUDIO",),
                "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}),
                "text": ("STRING", {"multiline": True}),
                "prompt_text": ("STRING", {"multiline": True}),
            },
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("tts_speech",)
    FUNCTION = "main_func"
    CATEGORY = "Nineton Nodes"

    def main_func(self, audio, speed, text, prompt_text):
        waveform = audio["waveform"].squeeze(0)
        sample_rate = audio["sample_rate"]
        print(f"waveform:{waveform}, sample_rate:{sample_rate}")
        prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000)
        speechs = []
        for i, j in enumerate(self.cosyvoice.inference_zero_shot(tts_text=text, prompt_text=prompt_text, prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)):
            speechs.append(j['tts_speech'])

        tts_speech = torch.cat(speechs, dim=1)
        tts_speech = tts_speech.unsqueeze(0)
        outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate}

        return (outaudio,)


class NTCosyVoiceCrossLingualSampler:
    def __init__(self):
        self.__cosyvoice = None

    @property
    def cosyvoice(self):
        if self.__cosyvoice is None:
            model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B')
            self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False)
        return self.__cosyvoice

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "audio": ("AUDIO",),
                "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}),
                "text": ("STRING", {"multiline": True}),
            },
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("tts_speech",)
    FUNCTION = "main_func"
    CATEGORY = "Nineton Nodes"

    def main_func(self, audio, speed, text):
        waveform = audio["waveform"].squeeze(0)
        sample_rate = audio["sample_rate"]

        prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000)
        speechs = []
        for i, j in enumerate(self.cosyvoice.inference_cross_lingual(tts_text=text,
                prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)):
            speechs.append(j['tts_speech'])

        tts_speech = torch.cat(speechs, dim=1)
        tts_speech = tts_speech.unsqueeze(0)
        outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate}

        return (outaudio,)


class NTCosyVoiceInstruct2Sampler:
    def __init__(self):
        self.__cosyvoice = None

    @property
    def cosyvoice(self):
        if self.__cosyvoice is None:
            model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B')
            self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False)
        return self.__cosyvoice

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "audio": ("AUDIO",),
                "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}),
                "text": ("STRING", {"multiline": True}),
                "instruct": ("STRING", {"multiline": True}),
            },
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("tts_speech",)
    FUNCTION = "main_func"
    CATEGORY = "Nineton Nodes"

    def main_func(self, audio, speed, text, instruct):
        waveform = audio["waveform"].squeeze(0)
        sample_rate = audio["sample_rate"]

        prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000)

        speechs = []
        for i, j in enumerate(self.cosyvoice.inference_instruct2(tts_text=text, instruct_text=instruct, prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)):
            speechs.append(j['tts_speech'])

        tts_speech = torch.cat(speechs, dim=1)
        tts_speech = tts_speech.unsqueeze(0)
        outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate}

        return (outaudio,)


NODE_CLASS_MAPPINGS = {
    "NTCosyVoiceZeroShotSampler": NTCosyVoiceZeroShotSampler,
    "NTCosyVoiceInstruct2Sampler": NTCosyVoiceInstruct2Sampler,
    "NTCosyVoiceCrossLingualSampler": NTCosyVoiceCrossLingualSampler
}

__all__ = ['NODE_CLASS_MAPPINGS']

================================================
FILE: cosyvoice/__init__.py
================================================


================================================
FILE: cosyvoice/bin/__init__.py
================================================


================================================
FILE: cosyvoice/bin/average_model.py
================================================
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import glob

import yaml
import torch


def get_args():
    parser = argparse.ArgumentParser(description='average model')
    parser.add_argument('--dst_model', required=True, help='averaged model')
    parser.add_argument('--src_path',
                        required=True,
                        help='src model path for average')
    parser.add_argument('--val_best',
                        action="store_true",
                        help='averaged model')
    parser.add_argument('--num',
                        default=5,
                        type=int,
                        help='nums for averaged model')

    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    val_scores = []
    if args.val_best:
        yamls = glob.glob('{}/*.yaml'.format(args.src_path))
        yamls = [
            f for f in yamls
            if not (os.path.basename(f).startswith('train')
                    or os.path.basename(f).startswith('init'))
        ]
        for y in yamls:
            with open(y, 'r') as f:
                dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
                loss = float(dic_yaml['loss_dict']['loss'])
                epoch = int(dic_yaml['epoch'])
                step = int(dic_yaml['step'])
                tag = dic_yaml['tag']
                val_scores += [[epoch, step, loss, tag]]
        sorted_val_scores = sorted(val_scores,
                                   key=lambda x: x[2],
                                   reverse=False)
        print("best val (epoch, step, loss, tag) = " +
              str(sorted_val_scores[:args.num]))
        path_list = [
            args.src_path + '/epoch_{}_whole.pt'.format(score[0])
            for score in sorted_val_scores[:args.num]
        ]
    print(path_list)
    avg = {}
    num = args.num
    assert num == len(path_list)
    for path in path_list:
        print('Processing {}'.format(path))
        states = torch.load(path, map_location=torch.device('cpu'))
        for k in states.keys():
            if k not in avg.keys():
                avg[k] = states[k].clone()
            else:
                avg[k] += states[k]
    # average
    for k in avg.keys():
        if avg[k] is not None:
            # pytorch 1.6 use true_divide instead of /=
            avg[k] = torch.true_divide(avg[k], num)
    print('Saving to {}'.format(args.dst_model))
    torch.save(avg, args.dst_model)


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/bin/export_jit.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice


def get_args():
    parser = argparse.ArgumentParser(description='export your model for deployment')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_models/CosyVoice-300M',
                        help='local path')
    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')

    torch._C._jit_set_fusion_strategy([('STATIC', 1)])
    torch._C._jit_set_profiling_mode(False)
    torch._C._jit_set_profiling_executor(False)

    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)

    # 1. export llm text_encoder
    llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
    script = torch.jit.script(llm_text_encoder)
    script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)
    script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))

    # 2. export llm llm
    llm_llm = cosyvoice.model.llm.llm.half()
    script = torch.jit.script(llm_llm)
    script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
    script = torch.jit.optimize_for_inference(script)
    script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))

    # 3. export flow encoder
    flow_encoder = cosyvoice.model.flow.encoder
    script = torch.jit.script(flow_encoder)
    script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)
    script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/bin/export_onnx.py
================================================
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import onnxruntime
import random
import torch
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice


def get_dummy_input(batch_size, seq_len, out_channels, device):
    x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
    mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    t = torch.rand((batch_size), dtype=torch.float32, device=device)
    spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
    cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    return x, mask, mu, t, spks, cond


def get_args():
    parser = argparse.ArgumentParser(description='export your model for deployment')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_models/CosyVoice-300M',
                        help='local path')
    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')

    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)

    # 1. export flow decoder estimator
    estimator = cosyvoice.model.flow.decoder.estimator

    device = cosyvoice.model.device
    batch_size, seq_len = 1, 256
    out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
    x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
    torch.onnx.export(
        estimator,
        (x, mask, mu, t, spks, cond),
        '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
        output_names=['estimator_out'],
        dynamic_axes={
            'x': {0: 'batch_size', 2: 'seq_len'},
            'mask': {0: 'batch_size', 2: 'seq_len'},
            'mu': {0: 'batch_size', 2: 'seq_len'},
            'cond': {0: 'batch_size', 2: 'seq_len'},
            't': {0: 'batch_size'},
            'spks': {0: 'batch_size'},
            'estimator_out': {0: 'batch_size', 2: 'seq_len'},
        }
    )

    # 2. test computation consistency
    option = onnxruntime.SessionOptions()
    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    option.intra_op_num_threads = 1
    providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
                                                  sess_options=option, providers=providers)

    for _ in tqdm(range(10)):
        x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
        output_pytorch = estimator(x, mask, mu, t, spks, cond)
        ort_inputs = {
            'x': x.cpu().numpy(),
            'mask': mask.cpu().numpy(),
            'mu': mu.cpu().numpy(),
            't': t.cpu().numpy(),
            'spks': spks.cpu().numpy(),
            'cond': cond.cpu().numpy()
        }
        output_onnx = estimator_onnx.run(None, ort_inputs)[0]
        torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)


if __name__ == "__main__":
    main()


================================================
FILE: cosyvoice/bin/export_trt.sh
================================================
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
TRT_DIR=<YOUR_TRT_DIR>
MODEL_DIR=<COSYVOICE2_MODEL_DIR>

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw


================================================
FILE: cosyvoice/bin/inference.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.dataset.dataset import Dataset


def get_args():
    parser = argparse.ArgumentParser(description='inference with your model')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--prompt_data', required=True, help='prompt data file')
    parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
    parser.add_argument('--tts_text', required=True, help='tts input file')
    parser.add_argument('--llm_model', required=True, help='llm model file')
    parser.add_argument('--flow_model', required=True, help='flow model file')
    parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
    parser.add_argument('--gpu',
                        type=int,
                        default=-1,
                        help='gpu id for this rank, -1 for cpu')
    parser.add_argument('--mode',
                        default='sft',
                        choices=['sft', 'zero_shot'],
                        help='inference mode')
    parser.add_argument('--result_dir', required=True, help='asr result file')
    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # Init cosyvoice models from configs
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    with open(args.config, 'r') as f:
        configs = load_hyperpyyaml(f)

    model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
    model.load(args.llm_model, args.flow_model, args.hifigan_model)

    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
                           tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
    test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

    del configs
    os.makedirs(args.result_dir, exist_ok=True)
    fn = os.path.join(args.result_dir, 'wav.scp')
    f = open(fn, 'w')
    with torch.no_grad():
        for _, batch in tqdm(enumerate(test_data_loader)):
            utts = batch["utts"]
            assert len(utts) == 1, "inference mode only support batchsize 1"
            text_token = batch["text_token"].to(device)
            text_token_len = batch["text_token_len"].to(device)
            tts_index = batch["tts_index"]
            tts_text_token = batch["tts_text_token"].to(device)
            tts_text_token_len = batch["tts_text_token_len"].to(device)
            speech_token = batch["speech_token"].to(device)
            speech_token_len = batch["speech_token_len"].to(device)
            speech_feat = batch["speech_feat"].to(device)
            speech_feat_len = batch["speech_feat_len"].to(device)
            utt_embedding = batch["utt_embedding"].to(device)
            spk_embedding = batch["spk_embedding"].to(device)
            if args.mode == 'sft':
                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                               'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
            else:
                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                               'prompt_text': text_token, 'prompt_text_len': text_token_len,
                               'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                               'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                               'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                               'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
            tts_speeches = []
            for model_output in model.tts(**model_input):
                tts_speeches.append(model_output['tts_speech'])
            tts_speeches = torch.concat(tts_speeches, dim=1)
            tts_key = '{}_{}'.format(utts[0], tts_index[0])
            tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
            torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
            f.write('{} {}\n'.format(tts_key, tts_fn))
            f.flush()
    f.close()
    logging.info('Result wav.scp saved in {}'.format(fn))


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/bin/train.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed

from hyperpyyaml import load_hyperpyyaml

from torch.distributed.elastic.multiprocessing.errors import record

from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
    init_distributed,
    init_dataset_and_dataloader,
    init_optimizer_and_scheduler,
    init_summarywriter, save_model,
    wrap_cuda_model, check_modify_and_save_config)


def get_args():
    parser = argparse.ArgumentParser(description='training your network')
    parser.add_argument('--train_engine',
                        default='torch_ddp',
                        choices=['torch_ddp', 'deepspeed'],
                        help='Engine for paralleled training')
    parser.add_argument('--model', required=True, help='model which will be trained')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--train_data', required=True, help='train data file')
    parser.add_argument('--cv_data', required=True, help='cv data file')
    parser.add_argument('--checkpoint', help='checkpoint model')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--tensorboard_dir',
                        default='tensorboard',
                        help='tensorboard log dir')
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo'],
                        help='distributed backend')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--prefetch',
                        default=100,
                        type=int,
                        help='prefetch number')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--deepspeed.save_states',
                        dest='save_states',
                        default='model_only',
                        choices=['model_only', 'model+optimizer'],
                        help='save model/optimizer states')
    parser.add_argument('--timeout',
                        default=60,
                        type=int,
                        help='timeout (in seconds) of cosyvoice_join.')
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    return args


@record
def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    # gan train has some special initialization logic
    gan = True if args.model == 'hifigan' else False

    override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
    if gan is True:
        override_dict.pop('hift')
    with open(args.config, 'r') as f:
        configs = load_hyperpyyaml(f, overrides=override_dict)
    if gan is True:
        configs['train_conf'] = configs['train_conf_gan']
    configs['train_conf'].update(vars(args))

    # Init env for ddp
    init_distributed(args)

    # Get dataset & dataloader
    train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
        init_dataset_and_dataloader(args, configs, gan)

    # Do some sanity checks and save config to arsg.model_dir
    configs = check_modify_and_save_config(args, configs)

    # Tensorboard summary
    writer = init_summarywriter(args)

    # load checkpoint
    model = configs[args.model]
    start_step, start_epoch = 0, -1
    if args.checkpoint is not None:
        if os.path.exists(args.checkpoint):
            state_dict = torch.load(args.checkpoint, map_location='cpu')
            model.load_state_dict(state_dict, strict=False)
            if 'step' in state_dict:
                start_step = state_dict['step']
            if 'epoch' in state_dict:
                start_epoch = state_dict['epoch']
        else:
            logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))

    # Dispatch model from cpu to gpu
    model = wrap_cuda_model(args, model)

    # Get optimizer & scheduler
    model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
    scheduler.set_step(start_step)
    if scheduler_d is not None:
        scheduler_d.set_step(start_step)

    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    info_dict['step'] = start_step
    info_dict['epoch'] = start_epoch
    save_model(model, 'init', info_dict)

    # Get executor
    executor = Executor(gan=gan)
    executor.step = start_step

    # Init scaler, used for pytorch amp mixed precision training
    scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
    print('start step {} start epoch {}'.format(start_step, start_epoch))
    # Start training loop
    for epoch in range(start_epoch + 1, info_dict['max_epoch']):
        executor.epoch = epoch
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        if gan is True:
            executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
                                        writer, info_dict, scaler, group_join)
        else:
            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
        dist.destroy_process_group(group_join)


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/cli/__init__.py
================================================


================================================
FILE: cosyvoice/cli/cosyvoice.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type


class CosyVoice:

    def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
        self.instruct = True if '-Instruct' in model_dir else False
        self.model_dir = model_dir
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
            configs = load_hyperpyyaml(f)
        assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v1.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
            load_jit = False
            fp16 = False
            logging.warning('cpu do not support fp16 and jit, force set to False')
        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
        self.model.load('{}/llm.pt'.format(model_dir),
                        '{}/flow.pt'.format(model_dir),
                        '{}/hift.pt'.format(model_dir))
        if load_jit:
            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                '{}/llm.llm.fp16.zip'.format(model_dir),
                                '{}/flow.encoder.fp32.zip'.format(model_dir))
        if load_onnx:
            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
        del configs

    def list_available_spks(self):
        spks = list(self.frontend.spk2info.keys())
        return spks

    def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_sft(i, spk_id)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
        prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            if len(i) < 0.5 * len(prompt_text):
                logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
        assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
        if self.instruct is False:
            raise ValueError('{} do not support instruct inference'.format(self.model_dir))
        instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
        start_time = time.time()
        for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
            speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
            logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
            yield model_output
            start_time = time.time()


class CosyVoice2(CosyVoice):

    def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
        self.instruct = True if '-Instruct' in model_dir else False
        self.model_dir = model_dir
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
        assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v2.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and load_jit is True:
            load_jit = False
            logging.warning('cpu do not support jit, force set to False')
        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
        self.model.load('{}/llm.pt'.format(model_dir),
                        '{}/flow.pt'.format(model_dir),
                        '{}/hift.pt'.format(model_dir))
        if load_jit:
            self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
        if load_trt is True and load_onnx is True:
            load_onnx = False
            logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
        if load_onnx:
            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
        if load_trt:
            self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
        del configs

    def inference_instruct(self, *args, **kwargs):
        raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')

    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
        assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()


================================================
FILE: cosyvoice/cli/frontend.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import json
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import re
import inflect
try:
    import ttsfrd
    use_ttsfrd = True
except ImportError:
    print("failed to import ttsfrd, use WeTextProcessing instead")
    from tn.chinese.normalizer import Normalizer as ZhNormalizer
    from tn.english.normalizer import Normalizer as EnNormalizer
    use_ttsfrd = False
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation


class CosyVoiceFrontEnd:

    def __init__(self,
                 get_tokenizer: Callable,
                 feat_extractor: Callable,
                 campplus_model: str,
                 speech_tokenizer_model: str,
                 spk2info: str = '',
                 allowed_special: str = 'all'):
        self.tokenizer = get_tokenizer()
        self.feat_extractor = feat_extractor
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        option = onnxruntime.SessionOptions()
        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        option.intra_op_num_threads = 1
        self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
        self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
                                                                     providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
                                                                                "CPUExecutionProvider"])
        if os.path.exists(spk2info):
            self.spk2info = torch.load(spk2info, map_location=self.device)
        else:
            self.spk2info = {}
        self.allowed_special = allowed_special
        self.use_ttsfrd = use_ttsfrd
        if self.use_ttsfrd:
            self.frd = ttsfrd.TtsFrontendEngine()
            ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
            assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
                'failed to initialize ttsfrd resource'
            self.frd.set_lang_type('pinyinvg')
        else:
            self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
            self.en_tn_model = EnNormalizer()
            self.inflect_parser = inflect.engine()

    def _extract_text_token(self, text):
        text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
        text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
        text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
        return text_token, text_token_len

    def _extract_speech_token(self, speech):
        assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
        feat = whisper.log_mel_spectrogram(speech, n_mels=128)
        speech_token = self.speech_tokenizer_session.run(None,
                                                         {self.speech_tokenizer_session.get_inputs()[0].name:
                                                          feat.detach().cpu().numpy(),
                                                          self.speech_tokenizer_session.get_inputs()[1].name:
                                                          np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
        speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
        speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
        return speech_token, speech_token_len

    def _extract_spk_embedding(self, speech):
        feat = kaldi.fbank(speech,
                           num_mel_bins=80,
                           dither=0,
                           sample_frequency=16000)
        feat = feat - feat.mean(dim=0, keepdim=True)
        embedding = self.campplus_session.run(None,
                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
        embedding = torch.tensor([embedding]).to(self.device)
        return embedding

    def _extract_speech_feat(self, speech):
        speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
        speech_feat = speech_feat.unsqueeze(dim=0)
        speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
        return speech_feat, speech_feat_len

    def text_normalize(self, text, split=True, text_frontend=True):
        if text_frontend is False:
            return [text] if split is True else text
        text = text.strip()
        if self.use_ttsfrd:
            texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
            text = ''.join(texts)
        else:
            if contains_chinese(text):
                text = self.zh_tn_model.normalize(text)
                text = text.replace("\n", "")
                text = replace_blank(text)
                text = replace_corner_mark(text)
                text = text.replace(".", "。")
                text = text.replace(" - ", ",")
                text = remove_bracket(text)
                text = re.sub(r'[,,、]+$', '。', text)
                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
                                             token_min_n=60, merge_len=20, comma_split=False))
            else:
                text = self.en_tn_model.normalize(text)
                text = spell_out_number(text, self.inflect_parser)
                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
                                             token_min_n=60, merge_len=20, comma_split=False))
        texts = [i for i in texts if not is_only_punctuation(i)]
        return texts if split is True else text

    def frontend_sft(self, tts_text, spk_id):
        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
        embedding = self.spk2info[spk_id]['embedding']
        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
        return model_input

    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
        prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
        speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
        if resample_rate == 24000:
            # cosyvoice2, force speech_feat % speech_token = 2
            token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
            speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
        embedding = self._extract_spk_embedding(prompt_speech_16k)
        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                       'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
                       'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                       'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                       'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                       'llm_embedding': embedding, 'flow_embedding': embedding}
        return model_input

    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
        # in cross lingual mode, we remove prompt in llm
        del model_input['prompt_text']
        del model_input['prompt_text_len']
        del model_input['llm_prompt_speech_token']
        del model_input['llm_prompt_speech_token_len']
        return model_input

    def frontend_instruct(self, tts_text, spk_id, instruct_text):
        model_input = self.frontend_sft(tts_text, spk_id)
        # in instruct mode, we remove spk_embedding in llm due to information leakage
        del model_input['llm_embedding']
        instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
        model_input['prompt_text'] = instruct_text_token
        model_input['prompt_text_len'] = instruct_text_token_len
        return model_input

    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
        del model_input['llm_prompt_speech_token']
        del model_input['llm_prompt_speech_token_len']
        return model_input

    def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
        embedding = self._extract_spk_embedding(prompt_speech_16k)
        source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
        model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
                       'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
                       'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
                       'flow_embedding': embedding}
        return model_input


================================================
FILE: cosyvoice/cli/model.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out


class CosyVoiceModel:

    def __init__(self,
                 llm: torch.nn.Module,
                 flow: torch.nn.Module,
                 hift: torch.nn.Module,
                 fp16: bool):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.llm = llm
        self.flow = flow
        self.hift = hift
        self.fp16 = fp16
        self.token_min_hop_len = 2 * self.flow.input_frame_rate
        self.token_max_hop_len = 4 * self.flow.input_frame_rate
        self.token_overlap_len = 20
        # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
        self.flow.decoder.estimator.static_chunk_size = 0
        # mel fade in out
        self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
        self.mel_window = np.hamming(2 * self.mel_overlap_len)
        # hift cache
        self.mel_cache_len = 20
        self.source_cache_len = int(self.mel_cache_len * 256)
        # speech fade in out
        self.speech_window = np.hamming(2 * self.source_cache_len)
        # rtf and decoding related
        self.stream_scale_factor = 1
        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
        self.lock = threading.Lock()
        # dict used to store session related variable
        self.tts_speech_token_dict = {}
        self.llm_end_dict = {}
        self.mel_overlap_dict = {}
        self.flow_cache_dict = {}
        self.hift_cache_dict = {}

    def load(self, llm_model, flow_model, hift_model):
        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
        self.llm.to(self.device).eval()
        if self.fp16 is True:
            self.llm.half()
        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
        self.flow.to(self.device).eval()
        # in case hift_model is a hifigan model
        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.to(self.device).eval()

    def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
        assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
        llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
        self.llm.text_encoder = llm_text_encoder
        llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
        self.llm.llm = llm_llm
        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
        self.flow.encoder = flow_encoder

    def load_onnx(self, flow_decoder_estimator_model):
        import onnxruntime
        option = onnxruntime.SessionOptions()
        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        option.intra_op_num_threads = 1
        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
        del self.flow.decoder.estimator
        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)

    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
        if self.fp16 is True:
            llm_embedding = llm_embedding.half()
        with self.llm_context:
            for i in self.llm.inference(text=text.to(self.device),
                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                        prompt_text=prompt_text.to(self.device),
                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                        embedding=llm_embedding.to(self.device)):
                self.tts_speech_token_dict[uuid].append(i)
        self.llm_end_dict[uuid] = True

    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
        tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_token=prompt_token.to(self.device),
                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_feat=prompt_feat.to(self.device),
                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                                  embedding=embedding.to(self.device),
                                                  flow_cache=self.flow_cache_dict[uuid])
        self.flow_cache_dict[uuid] = flow_cache

        # mel overlap fade in out
        if self.mel_overlap_dict[uuid].shape[2] != 0:
            tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
        # append hift cache
        if self.hift_cache_dict[uuid] is not None:
            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
        else:
            hift_cache_source = torch.zeros(1, 1, 0)
        # keep overlap mel and hift cache
        if finalize is False:
            self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
            tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                          'source': tts_source[:, :, -self.source_cache_len:],
                                          'speech': tts_speech[:, -self.source_cache_len:]}
            tts_speech = tts_speech[:, :-self.source_cache_len]
        else:
            if speed != 1.0:
                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
        return tts_speech

    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
            self.hift_cache_dict[this_uuid] = None
            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
        p.start()
        if stream is True:
            token_hop_len = self.token_min_hop_len
            while True:
                time.sleep(0.1)
                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
                        .unsqueeze(dim=0)
                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token,
                                                     prompt_feat=prompt_speech_feat,
                                                     embedding=flow_embedding,
                                                     uuid=this_uuid,
                                                     finalize=False)
                    yield {'tts_speech': this_tts_speech.cpu()}
                    with self.lock:
                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
                    # increase token_hop_len for better speech quality
                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                    break
            p.join()
            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True)
            yield {'tts_speech': this_tts_speech.cpu()}
        else:
            # deal with all tokens
            p.join()
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True,
                                             speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
        with self.lock:
            self.tts_speech_token_dict.pop(this_uuid)
            self.llm_end_dict.pop(this_uuid)
            self.mel_overlap_dict.pop(this_uuid)
            self.hift_cache_dict.pop(this_uuid)
            self.flow_cache_dict.pop(this_uuid)

    def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
            self.hift_cache_dict[this_uuid] = None
            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
        if stream is True:
            token_hop_len = self.token_min_hop_len
            while True:
                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
                        .unsqueeze(dim=0)
                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token,
                                                     prompt_feat=prompt_speech_feat,
                                                     embedding=flow_embedding,
                                                     uuid=this_uuid,
                                                     finalize=False)
                    yield {'tts_speech': this_tts_speech.cpu()}
                    with self.lock:
                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
                    # increase token_hop_len for better speech quality
                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                    break
            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True)
            yield {'tts_speech': this_tts_speech.cpu()}
        else:
            # deal with all tokens
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True,
                                             speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
        with self.lock:
            self.tts_speech_token_dict.pop(this_uuid)
            self.llm_end_dict.pop(this_uuid)
            self.mel_overlap_dict.pop(this_uuid)
            self.hift_cache_dict.pop(this_uuid)


class CosyVoice2Model:

    def __init__(self,
                 llm: torch.nn.Module,
                 flow: torch.nn.Module,
                 hift: torch.nn.Module):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.llm = llm
        self.flow = flow
        self.hift = hift
        self.token_hop_len = 2 * self.flow.input_frame_rate
        # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
        self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
        self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
        # hift cache
        self.mel_cache_len = 8
        self.source_cache_len = int(self.mel_cache_len * 480)
        # speech fade in out
        self.speech_window = np.hamming(2 * self.source_cache_len)
        # rtf and decoding related
        self.stream_scale_factor = 1
        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
        self.lock = threading.Lock()
        # dict used to store session related variable
        self.tts_speech_token_dict = {}
        self.llm_end_dict = {}
        self.hift_cache_dict = {}

    def load(self, llm_model, flow_model, hift_model):
        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
        self.llm.to(self.device).eval()
        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
        self.flow.to(self.device).eval()
        self.flow.decoder.fp16 = False
        # in case hift_model is a hifigan model
        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.to(self.device).eval()

    def load_jit(self, flow_encoder_model):
        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
        self.flow.encoder = flow_encoder

    def load_onnx(self, flow_decoder_estimator_model):
        import onnxruntime
        option = onnxruntime.SessionOptions()
        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        option.intra_op_num_threads = 1
        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
        del self.flow.decoder.estimator
        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)

    def load_trt(self, flow_decoder_estimator_model):
        del self.flow.decoder.estimator
        import tensorrt as trt
        with open(flow_decoder_estimator_model, 'rb') as f:
            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
        if self.flow.decoder.estimator_engine is None:
            raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
        self.flow.decoder.fp16 = True

    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
        with self.llm_context:
            for i in self.llm.inference(text=text.to(self.device),
                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                        prompt_text=prompt_text.to(self.device),
                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                        embedding=llm_embedding.to(self.device)):
                self.tts_speech_token_dict[uuid].append(i)
        self.llm_end_dict[uuid] = True

    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
        tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                         token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_token=prompt_token.to(self.device),
                                         prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_feat=prompt_feat.to(self.device),
                                         prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                         embedding=embedding.to(self.device),
                                         finalize=finalize)
        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
        # append hift cache
        if self.hift_cache_dict[uuid] is not None:
            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
        else:
            hift_cache_source = torch.zeros(1, 1, 0)
        # keep overlap mel and hift cache
        if finalize is False:
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                          'source': tts_source[:, :, -self.source_cache_len:],
                                          'speech': tts_speech[:, -self.source_cache_len:]}
            tts_speech = tts_speech[:, :-self.source_cache_len]
        else:
            if speed != 1.0:
                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
        return tts_speech

    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
            self.hift_cache_dict[this_uuid] = None
        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
        p.start()
        if stream is True:
            token_offset = 0
            while True:
                time.sleep(0.1)
                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token,
                                                     prompt_feat=prompt_speech_feat,
                                                     embedding=flow_embedding,
                                                     uuid=this_uuid,
                                                     token_offset=token_offset,
                                                     finalize=False)
                    token_offset += self.token_hop_len
                    yield {'tts_speech': this_tts_speech.cpu()}
                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
                    break
            p.join()
            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             token_offset=token_offset,
                                             finalize=True)
            yield {'tts_speech': this_tts_speech.cpu()}
        else:
            # deal with all tokens
            p.join()
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             token_offset=0,
                                             finalize=True,
                                             speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
        with self.lock:
            self.tts_speech_token_dict.pop(this_uuid)
            self.llm_end_dict.pop(this_uuid)


================================================
FILE: cosyvoice/dataset/__init__.py
================================================


================================================
FILE: cosyvoice/dataset/dataset.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#               2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import json
import math
from functools import partial

import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists, read_json_lists


class Processor(IterableDataset):

    def __init__(self, source, f, *args, **kw):
        assert callable(f)
        self.source = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self.source is not None
        assert callable(self.f)
        return self.f(iter(self.source), *self.args, **self.kw)

    def apply(self, f):
        assert callable(f)
        return Processor(self, f, *self.args, **self.kw)


class DistributedSampler:

    def __init__(self, shuffle=True, partition=True):
        self.epoch = -1
        self.update()
        self.shuffle = shuffle
        self.partition = partition

    def update(self):
        assert dist.is_available()
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            self.rank = 0
            self.world_size = 1
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            self.worker_id = 0
            self.num_workers = 1
        else:
            self.worker_id = worker_info.id
            self.num_workers = worker_info.num_workers
        return dict(rank=self.rank,
                    world_size=self.world_size,
                    worker_id=self.worker_id,
                    num_workers=self.num_workers)

    def set_epoch(self, epoch):
        self.epoch = epoch

    def sample(self, data):
        """ Sample data according to rank/world_size/num_workers

            Args:
                data(List): input data list

            Returns:
                List: data list after sample
        """
        data = list(range(len(data)))
        # force datalist even
        if self.partition:
            if self.shuffle:
                random.Random(self.epoch).shuffle(data)
            if len(data) < self.world_size:
                data = data * math.ceil(self.world_size / len(data))
                data = data[:self.world_size]
            data = data[self.rank::self.world_size]
        if len(data) < self.num_workers:
            data = data * math.ceil(self.num_workers / len(data))
            data = data[:self.num_workers]
        data = data[self.worker_id::self.num_workers]
        return data


class DataList(IterableDataset):

    def __init__(self, lists, shuffle=True, partition=True):
        self.lists = lists
        self.sampler = DistributedSampler(shuffle, partition)

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)

    def __iter__(self):
        sampler_info = self.sampler.update()
        indexes = self.sampler.sample(self.lists)
        for index in indexes:
            data = dict(src=self.lists[index])
            data.update(sampler_info)
            yield data


def Dataset(data_list_file,
            data_pipeline,
            mode='train',
            gan=False,
            shuffle=True,
            partition=True,
            tts_file='',
            prompt_utt2data=''):
    """ Construct dataset from arguments

        We have two shuffle stage in the Dataset. The first is global
        shuffle at shards tar/raw file level. The second is global shuffle
        at training samples level.

        Args:
            data_type(str): raw/shard
            tokenizer (BaseTokenizer): tokenizer to tokenize
            partition(bool): whether to do data partition in terms of rank
    """
    assert mode in ['train', 'inference']
    lists = read_lists(data_list_file)
    if mode == 'inference':
        with open(tts_file) as f:
            tts_data = json.load(f)
        utt2lists = read_json_lists(prompt_utt2data)
        # filter unnecessary file in inference mode
        lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
    dataset = DataList(lists,
                       shuffle=shuffle,
                       partition=partition)
    if mode == 'inference':
        # map partial arg to parquet_opener func in inference mode
        data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
    if gan is True:
        # map partial arg to padding func in gan mode
        data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
    for func in data_pipeline:
        dataset = Processor(dataset, func, mode=mode)
    return dataset


================================================
FILE: cosyvoice/dataset/processor.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random

import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

torchaudio.set_audio_backend('soundfile')

AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}


def parquet_opener(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        url = sample['src']
        try:
            for df in pq.ParquetFile(url).iter_batches(batch_size=64):
                df = df.to_pandas()
                for i in range(len(df)):
                    if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
                        continue
                    sample.update(dict(df.loc[i]))
                    if mode == 'train':
                        # NOTE do not return sample directly, must initialize a new dict
                        yield {**sample}
                    else:
                        for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
                            yield {**sample, 'tts_index': index, 'tts_text': text}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(url, ex))


def filter(data,
           max_length=10240,
           min_length=10,
           token_max_length=200,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=1,
           mode='train'):
    """ Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
        del sample['audio_data']
        # sample['wav'] is torch.Tensor, we have 100 frames every second
        num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
        if num_frames < min_length:
            continue
        if num_frames > max_length:
            continue
        if len(sample['text_token']) < token_min_length:
            continue
        if len(sample['text_token']) > token_max_length:
            continue
        if len(sample['speech_token']) == 0:
            continue
        if num_frames != 0:
            if len(sample['text_token']) / num_frames < min_output_input_ratio:
                continue
            if len(sample['text_token']) / num_frames > max_output_input_ratio:
                continue
        yield sample


def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
    """ Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['speech']
        if sample_rate != resample_rate:
            if sample_rate < min_sample_rate:
                continue
            sample['sample_rate'] = resample_rate
            sample['speech'] = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=resample_rate)(waveform)
        max_val = sample['speech'].abs().max()
        if max_val > 1:
            sample['speech'] /= max_val
        yield sample


def truncate(data, truncate_length=24576, mode='train'):
    """ Truncate data.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            truncate_length: truncate length

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        waveform = sample['speech']
        if waveform.shape[1] > truncate_length:
            start = random.randint(0, waveform.shape[1] - truncate_length)
            waveform = waveform[:, start: start + truncate_length]
        else:
            waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
        sample['speech'] = waveform
        yield sample


def compute_fbank(data,
                  feat_extractor,
                  mode='train'):
    """ Extract fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        assert 'utt' in sample
        assert 'text_token' in sample
        waveform = sample['speech']
        mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
        sample['speech_feat'] = mat
        yield sample


def compute_f0(data, pitch_extractor, mode='train'):
    """ Extract f0

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        assert 'utt' in sample
        assert 'text_token' in sample
        waveform = sample['speech']
        mat = pitch_extractor(waveform).transpose(1, 2)
        mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
        sample['pitch_feat'] = mat[0, 0]
        yield sample


def parse_embedding(data, normalize, mode='train'):
    """ Parse utt_embedding/spk_embedding

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
        if normalize:
            sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
            sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
        yield sample


def tokenize(data, get_tokenizer, allowed_special, mode='train'):
    """ Decode text to chars or BPE
        Inplace operation

        Args:
            data: Iterable[{key, wav, txt, sample_rate}]

        Returns:
            Iterable[{key, wav, txt, tokens, label, sample_rate}]
    """
    tokenizer = get_tokenizer()
    for sample in data:
        assert 'text' in sample
        sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
        if mode == 'inference':
            sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
        yield sample


def shuffle(data, shuffle_size=10000, mode='train'):
    """ Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= shuffle_size:
            random.shuffle(buf)
            for x in buf:
                yield x
            buf = []
    # The sample left over
    random.shuffle(buf)
    for x in buf:
        yield x


def sort(data, sort_size=500, mode='train'):
    """ Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= sort_size:
            buf.sort(key=lambda x: x['speech_feat'].size(0))
            for x in buf:
                yield x
            buf = []
    # The sample left over
    buf.sort(key=lambda x: x['speech_feat'].size(0))
    for x in buf:
        yield x


def static_batch(data, batch_size=16):
    """ Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= batch_size:
            yield buf
            buf = []
    if len(buf) > 0:
        yield buf


def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    longest_frames = 0
    for sample in data:
        assert 'speech_feat' in sample
        assert isinstance(sample['speech_feat'], torch.Tensor)
        new_sample_frames = sample['speech_feat'].size(0)
        longest_frames = max(longest_frames, new_sample_frames)
        frames_after_padding = longest_frames * (len(buf) + 1)
        if frames_after_padding > max_frames_in_batch:
            yield buf
            buf = [sample]
            longest_frames = new_sample_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
    """ Wrapper for static/dynamic batch
    """
    if mode == 'inference':
        return static_batch(data, 1)
    else:
        if batch_type == 'static':
            return static_batch(data, batch_size)
        elif batch_type == 'dynamic':
            return dynamic_batch(data, max_frames_in_batch)
        else:
            logging.fatal('Unsupported batch type {}'.format(batch_type))


def padding(data, use_spk_embedding, mode='train', gan=False):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
                                       dtype=torch.int32)
        order = torch.argsort(speech_feat_len, descending=True)

        utts = [sample[i]['utt'] for i in order]
        speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
        speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
        speech = pad_sequence(speech, batch_first=True, padding_value=0)
        speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
        speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
        speech_token = pad_sequence(speech_token,
                                    batch_first=True,
                                    padding_value=0)
        speech_feat = [sample[i]['speech_feat'] for i in order]
        speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
        speech_feat = pad_sequence(speech_feat,
                                   batch_first=True,
                                   padding_value=0)
        text = [sample[i]['text'] for i in order]
        text_token = [torch.tensor(sample[i]['text_token']) for i in order]
        text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
        text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
        utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
        spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
        batch = {
            "utts": utts,
            "speech": speech,
            "speech_len": speech_len,
            "speech_token": speech_token,
            "speech_token_len": speech_token_len,
            "speech_feat": speech_feat,
            "speech_feat_len": speech_feat_len,
            "text": text,
            "text_token": text_token,
            "text_token_len": text_token_len,
            "utt_embedding": utt_embedding,
            "spk_embedding": spk_embedding,
        }
        if gan is True:
            # in gan train, we need pitch_feat
            pitch_feat = [sample[i]['pitch_feat'] for i in order]
            pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
            pitch_feat = pad_sequence(pitch_feat,
                                      batch_first=True,
                                      padding_value=0)
            batch["pitch_feat"] = pitch_feat
            batch["pitch_feat_len"] = pitch_feat_len
        else:
            # only gan train needs speech, delete it to save memory
            del batch["speech"]
            del batch["speech_len"]
        if mode == 'inference':
            tts_text = [sample[i]['tts_text'] for i in order]
            tts_index = [sample[i]['tts_index'] for i in order]
            tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
            tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
            tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
            batch.update({'tts_text': tts_text,
                          'tts_index': tts_index,
                          'tts_text_token': tts_text_token,
                          'tts_text_token_len': tts_text_token_len})
        if use_spk_embedding is True:
            batch["embedding"] = batch["spk_embedding"]
        else:
            batch["embedding"] = batch["utt_embedding"]
        yield batch


================================================
FILE: cosyvoice/flow/__init__.py
================================================


================================================
FILE: cosyvoice/flow/decoder.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat
from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock


class Transpose(torch.nn.Module):
    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor):
        x = torch.transpose(x, self.dim0, self.dim1)
        return x


class CausalBlock1D(Block1D):
    def __init__(self, dim: int, dim_out: int):
        super(CausalBlock1D, self).__init__(dim, dim_out)
        self.block = torch.nn.Sequential(
            CausalConv1d(dim, dim_out, 3),
            Transpose(1, 2),
            nn.LayerNorm(dim_out),
            Transpose(1, 2),
            nn.Mish(),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        output = self.block(x * mask)
        return output * mask


class CausalResnetBlock1D(ResnetBlock1D):
    def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
        super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
        self.block1 = CausalBlock1D(dim, dim_out)
        self.block2 = CausalBlock1D(dim_out, dim_out)


class CausalConv1d(torch.nn.Conv1d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',
        device=None,
        dtype=None
    ) -> None:
        super(CausalConv1d, self).__init__(in_channels, out_channels,
                                           kernel_size, stride,
                                           padding=0, dilation=dilation,
                                           groups=groups, bias=bias,
                                           padding_mode=padding_mode,
                                           device=device, dtype=dtype)
        assert stride == 1
        self.causal_padding = (kernel_size - 1, 0)

    def forward(self, x: torch.Tensor):
        x = F.pad(x, self.causal_padding)
        x = super(CausalConv1d, self).forward(x)
        return x


class ConditionalDecoder(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        causal=False,
        channels=(256, 256),
        dropout=0.05,
        attention_head_dim=64,
        n_blocks=1,
        num_mid_blocks=2,
        num_heads=4,
        act_fn="snake",
    ):
        """
        This decoder requires an input with the same shape of the target. So, if your text content
        is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
        """
        super().__init__()
        channels = tuple(channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.causal = causal
        self.time_embeddings = SinusoidalPosEmb(in_channels)
        time_embed_dim = channels[0] * 4
        self.time_mlp = TimestepEmbedding(
            in_channels=in_channels,
            time_embed_dim=time_embed_dim,
            act_fn="silu",
        )
        self.down_blocks = nn.ModuleList([])
        self.mid_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

        output_channel = in_channels
        for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
            input_channel = output_channel
            output_channel = channels[i]
            is_last = i == len(channels) - 1
            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
                ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            downsample = (
                Downsample1D(output_channel) if not is_last else
                CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
            )
            self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))

        for _ in range(num_mid_blocks):
            input_channel = channels[-1]
            out_channels = channels[-1]
            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
                ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)

            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )

            self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))

        channels = channels[::-1] + (channels[0],)
        for i in range(len(channels) - 1):
            input_channel = channels[i] * 2
            output_channel = channels[i + 1]
            is_last = i == len(channels) - 2
            resnet = CausalResnetBlock1D(
                dim=input_channel,
                dim_out=output_channel,
                time_emb_dim=time_embed_dim,
            ) if self.causal else ResnetBlock1D(
                dim=input_channel,
                dim_out=output_channel,
                time_emb_dim=time_embed_dim,
            )
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            upsample = (
                Upsample1D(output_channel, use_conv_transpose=True)
                if not is_last
                else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
            )
            self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
        self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
        self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GroupNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, mask, mu, t, spks=None, cond=None):
        """Forward pass of the UNet1DConditional model.

        Args:
            x (torch.Tensor): shape (batch_size, in_channels, time)
            mask (_type_): shape (batch_size, 1, time)
            t (_type_): shape (batch_size)
            spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
            cond (_type_, optional): placeholder for future use. Defaults to None.

        Raises:
            ValueError: _description_
            ValueError: _description_

        Returns:
            _type_: _description_
        """

        t = self.time_embeddings(t).to(t.dtype)
        t = self.time_mlp(t)

        x = pack([x, mu], "b * t")[0]

        if spks is not None:
            spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
            x = pack([x, spks], "b * t")[0]
        if cond is not None:
            x = pack([x, cond], "b * t")[0]

        hiddens = []
        masks = [mask]
        for resnet, transformer_blocks, downsample in self.down_blocks:
            mask_down = masks[-1]
            x = resnet(x, mask_down, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
            attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            hiddens.append(x)  # Save hidden states for skip connections
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, ::2])
        masks = masks[:-1]
        mask_mid = masks[-1]

        for resnet, transformer_blocks in self.mid_blocks:
            x = resnet(x, mask_mid, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
            attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()

        for resnet, transformer_blocks, upsample in self.up_blocks:
            mask_up = masks.pop()
            skip = hiddens.pop()
            x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
            x = resnet(x, mask_up, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
            attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            x = upsample(x * mask_up)
        x = self.final_block(x, mask_up)
        output = self.final_proj(x * mask_up)
        return output * mask


================================================
FILE: cosyvoice/flow/flow.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask


class MaskedDiffWithXvec(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 80,
                 spk_embed_dim: int = 192,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 encoder: torch.nn.Module = None,
                 length_regulator: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.mel_feat_conf = mel_feat_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
        self.encoder = encoder
        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
        self.decoder = decoder
        self.length_regulator = length_regulator
        self.only_mask_loss = only_mask_loss

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        token = batch['speech_token'].to(device)
        token_len = batch['speech_token_len'].to(device)
        feat = batch['speech_feat'].to(device)
        feat_len = batch['speech_feat_len'].to(device)
        embedding = batch['embedding'].to(device)

        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len)
        h = self.encoder_proj(h)
        h, h_lengths = self.length_regulator(h, feat_len)

        # get conditions
        conds = torch.zeros(feat.shape, device=token.device)
        for i, j in enumerate(feat_len):
            if random.random() < 0.5:
                continue
            index = random.randint(0, int(0.3 * j))
            conds[i, :index] = feat[i, :index]
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(feat_len)).to(h)
        feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
        loss, _ = self.decoder.compute_loss(
            feat.transpose(1, 2).contiguous(),
            mask.unsqueeze(1),
            h.transpose(1, 2).contiguous(),
            embedding,
            cond=conds
        )
        return {'loss': loss}

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  prompt_token,
                  prompt_token_len,
                  prompt_feat,
                  prompt_feat_len,
                  embedding,
                  flow_cache):
        assert token.shape[0] == 1
        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len)
        h = self.encoder_proj(h)
        mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)

        # get conditions
        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
        conds[:, :mel_len1] = prompt_feat
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
        feat, flow_cache = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=embedding,
            cond=conds,
            n_timesteps=10,
            prompt_len=mel_len1,
            flow_cache=flow_cache
        )
        feat = feat[:, :, mel_len1:]
        assert feat.shape[2] == mel_len2
        return feat, flow_cache


class CausalMaskedDiffWithXvec(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 80,
                 spk_embed_dim: int = 192,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 token_mel_ratio: int = 2,
                 pre_lookahead_len: int = 3,
                 encoder: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.mel_feat_conf = mel_feat_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
        self.encoder = encoder
        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
        self.decoder = decoder
        self.only_mask_loss = only_mask_loss
        self.token_mel_ratio = token_mel_ratio
        self.pre_lookahead_len = pre_lookahead_len

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  prompt_token,
                  prompt_token_len,
                  prompt_feat,
                  prompt_feat_len,
                  embedding,
                  finalize):
        assert token.shape[0] == 1
        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len)
        if finalize is False:
            h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
        h = self.encoder_proj(h)

        # get conditions
        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
        conds[:, :mel_len1] = prompt_feat
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
        feat, _ = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=embedding,
            cond=conds,
            n_timesteps=10
        )
        feat = feat[:, :, mel_len1:]
        assert feat.shape[2] == mel_len2
        return feat, None


================================================
FILE: cosyvoice/flow/flow_matching.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnxruntime
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM


class ConditionalCFM(BASECFM):
    def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
        super().__init__(
            n_feats=in_channels,
            cfm_params=cfm_params,
            n_spks=n_spks,
            spk_emb_dim=spk_emb_dim,
        )
        self.t_scheduler = cfm_params.t_scheduler
        self.training_cfg_rate = cfm_params.training_cfg_rate
        self.inference_cfg_rate = cfm_params.inference_cfg_rate
        in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
        # Just change the architecture of the estimator here
        self.estimator = estimator

    @torch.inference_mode()
    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
        """Forward diffusion

        Args:
            mu (torch.Tensor): output of encoder
                shape: (batch_size, n_feats, mel_timesteps)
            mask (torch.Tensor): output_mask
                shape: (batch_size, 1, mel_timesteps)
            n_timesteps (int): number of diffusion steps
            temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
            spks (torch.Tensor, optional): speaker ids. Defaults to None.
                shape: (batch_size, spk_emb_dim)
            cond: Not used but kept for future purposes

        Returns:
            sample: generated mel-spectrogram
                shape: (batch_size, n_feats, mel_timesteps)
        """

        z = torch.randn_like(mu) * temperature
        cache_size = flow_cache.shape[2]
        # fix prompt and overlap part mu and z
        if cache_size != 0:
            z[:, :, :cache_size] = flow_cache[:, :, :, 0]
            mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
        z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
        mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
        flow_cache = torch.stack([z_cache, mu_cache], dim=-1)

        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
        if self.t_scheduler == 'cosine':
            t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache

    def solve_euler(self, x, t_span, mu, mask, spks, cond):
        """
        Fixed euler solver for ODEs.
        Args:
            x (torch.Tensor): random noise
            t_span (torch.Tensor): n_timesteps interpolated
                shape: (n_timesteps + 1,)
            mu (torch.Tensor): output of encoder
                shape: (batch_size, n_feats, mel_timesteps)
            mask (torch.Tensor): output_mask
                shape: (batch_size, 1, mel_timesteps)
            spks (torch.Tensor, optional): speaker ids. Defaults to None.
                shape: (batch_size, spk_emb_dim)
            cond: Not used but kept for future purposes
        """
        t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
        t = t.unsqueeze(dim=0)

        # I am storing this because I can later plot it by putting a debugger here and saving it to a file
        # Or in future might add like a return_all_steps flag
        sol = []

        if self.inference_cfg_rate > 0:
            # Do not use concat, it may cause memory format changed and trt infer with wrong results!
            x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
            mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
            mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
            t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
            spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
            cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
        else:
            x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
        for step in range(1, len(t_span)):
            # Classifier-Free Guidance inference introduced in VoiceBox
            if self.inference_cfg_rate > 0:
                x_in[:] = x
                mask_in[:] = mask
                mu_in[0] = mu
                t_in[:] = t.unsqueeze(0)
                spks_in[0] = spks
                cond_in[0] = cond
            else:
                x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
            dphi_dt = self.forward_estimator(
                x_in, mask_in,
                mu_in, t_in,
                spks_in,
                cond_in
            )
            if self.inference_cfg_rate > 0:
                dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
            x = x + dt * dphi_dt
            t = t + dt
            sol.append(x)
            if step < len(t_span) - 1:
                dt = t_span[step + 1] - t

        return sol[-1].float()

    def forward_estimator(self, x, mask, mu, t, spks, cond):
        if isinstance(self.estimator, torch.nn.Module):
            return self.estimator.forward(x, mask, mu, t, spks, cond)
        elif isinstance(self.estimator, onnxruntime.InferenceSession):
            ort_inputs = {
                'x': x.cpu().numpy(),
                'mask': mask.cpu().numpy(),
                'mu': mu.cpu().numpy(),
                't': t.cpu().numpy(),
                'spks': spks.cpu().numpy(),
                'cond': cond.cpu().numpy()
            }
            output = self.estimator.run(None, ort_inputs)[0]
            return torch.tensor(output, dtype=x.dtype, device=x.device)
        else:
            self.estimator.set_input_shape('x', (2, 80, x.size(2)))
            self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
            self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
            self.estimator.set_input_shape('t', (2,))
            self.estimator.set_input_shape('spks', (2, 80))
            self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
            # run trt engine
            self.estimator.execute_v2([x.contiguous().data_ptr(),
                                       mask.contiguous().data_ptr(),
                                       mu.contiguous().data_ptr(),
                                       t.contiguous().data_ptr(),
                                       spks.contiguous().data_ptr(),
                                       cond.contiguous().data_ptr(),
                                       x.data_ptr()])
            return x

    def compute_loss(self, x1, mask, mu, spks=None, cond=None):
        """Computes diffusion loss

        Args:
            x1 (torch.Tensor): Target
                shape: (batch_size, n_feats, mel_timesteps)
            mask (torch.Tensor): target mask
                shape: (batch_size, 1, mel_timesteps)
            mu (torch.Tensor): output of encoder
                shape: (batch_size, n_feats, mel_timesteps)
            spks (torch.Tensor, optional): speaker embedding. Defaults to None.
                shape: (batch_size, spk_emb_dim)

        Returns:
            loss: conditional flow matching loss
            y: conditional flow
                shape: (batch_size, n_feats, mel_timesteps)
        """
        b, _, t = mu.shape

        # random timestep
        t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
        if self.t_scheduler == 'cosine':
            t = 1 - torch.cos(t * 0.5 * torch.pi)
        # sample noise p(x_0)
        z = torch.randn_like(x1)

        y = (1 - (1 - self.sigma_min) * t) * z + t * x1
        u = x1 - (1 - self.sigma_min) * z

        # during training, we randomly drop condition to trade off mode coverage and sample fidelity
        if self.training_cfg_rate > 0:
            cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
            mu = mu * cfg_mask.view(-1, 1, 1)
            spks = spks * cfg_mask.view(-1, 1)
            cond = cond * cfg_mask.view(-1, 1, 1)

        pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
        loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
        return loss, y


class CausalConditionalCFM(ConditionalCFM):
    def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
        super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
        self.rand_noise = torch.randn([1, 80, 50 * 300])

    @torch.inference_mode()
    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
        """Forward diffusion

        Args:
            mu (torch.Tensor): output of encoder
                shape: (batch_size, n_feats, mel_timesteps)
            mask (torch.Tensor): output_mask
                shape: (batch_size, 1, mel_timesteps)
            n_timesteps (int): number of diffusion steps
            temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
            spks (torch.Tensor, optional): speaker ids. Defaults to None.
                shape: (batch_size, spk_emb_dim)
            cond: Not used but kept for future purposes

        Returns:
            sample: generated mel-spectrogram
                shape: (batch_size, n_feats, mel_timesteps)
        """

        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
        if self.fp16 is True:
            z = z.half()
        # fix prompt and overlap part mu and z
        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
        if self.t_scheduler == 'cosine':
            t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None


================================================
FILE: cosyvoice/flow/length_regulator.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
import torch
from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask


class InterpolateRegulator(nn.Module):
    def __init__(
            self,
            channels: int,
            sampling_ratios: Tuple,
            out_channels: int = None,
            groups: int = 1,
    ):
        super().__init__()
        self.sampling_ratios = sampling_ratios
        out_channels = out_channels or channels
        model = nn.ModuleList([])
        if len(sampling_ratios) > 0:
            for _ in sampling_ratios:
                module = nn.Conv1d(channels, channels, 3, 1, 1)
                norm = nn.GroupNorm(groups, channels)
                act = nn.Mish()
                model.extend([module, norm, act])
        model.append(
            nn.Conv1d(channels, out_channels, 1, 1)
        )
        self.model = nn.Sequential(*model)

    def forward(self, x, ylens=None):
        # x in (B, T, D)
        mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
        out = self.model(x).transpose(1, 2).contiguous()
        olens = ylens
        return out * mask, olens

    def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
        # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
        # x in (B, T, D)
        if x2.shape[1] > 40:
            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
                                   mode='linear')
            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
            x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
        else:
            x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
        if x1.shape[1] != 0:
            x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
            x = torch.concat([x1, x2], dim=2)
        else:
            x = x2
        out = self.model(x).transpose(1, 2).contiguous()
        return out, mel_len1 + mel_len2


================================================
FILE: cosyvoice/hifigan/__init__.py
================================================


================================================
FILE: cosyvoice/hifigan/discriminator.py
================================================
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from typing import List, Optional, Tuple
from einops import rearrange
from torchaudio.transforms import Spectrogram


class MultipleDiscriminator(nn.Module):
    def __init__(
            self, mpd: nn.Module, mrd: nn.Module
    ):
        super().__init__()
        self.mpd = mpd
        self.mrd = mrd

    def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
        y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
        y_d_rs += this_y_d_rs
        y_d_gs += this_y_d_gs
        fmap_rs += this_fmap_rs
        fmap_gs += this_fmap_gs
        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
        y_d_rs += this_y_d_rs
        y_d_gs += this_y_d_gs
        fmap_rs += this_fmap_rs
        fmap_gs += this_fmap_gs
        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class MultiResolutionDiscriminator(nn.Module):
    def __init__(
        self,
        fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
        num_embeddings: Optional[int] = None,
    ):
        """
        Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
        Additionally, it allows incorporating conditional information with a learned embeddings table.

        Args:
            fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
            num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
                Defaults to None.
        """

        super().__init__()
        self.discriminators = nn.ModuleList(
            [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
        )

    def forward(
        self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []

        for d in self.discriminators:
            y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
            y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class DiscriminatorR(nn.Module):
    def __init__(
        self,
        window_length: int,
        num_embeddings: Optional[int] = None,
        channels: int = 32,
        hop_factor: float = 0.25,
        bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
    ):
        super().__init__()
        self.window_length = window_length
        self.hop_factor = hop_factor
        self.spec_fn = Spectrogram(
            n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
        )
        n_fft = window_length // 2 + 1
        bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
        self.bands = bands
        convs = lambda: nn.ModuleList(
            [
                weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
                weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
            ]
        )
        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])

        if num_embeddings is not None:
            self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
            torch.nn.init.zeros_(self.emb.weight)

        self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))

    def spectrogram(self, x):
        # Remove DC offset
        x = x - x.mean(dim=-1, keepdims=True)
        # Peak normalize the volume of input audio
        x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
        x = self.spec_fn(x)
        x = torch.view_as_real(x)
        x = rearrange(x, "b f t c -> b c t f")
        # Split into bands
        x_bands = [x[..., b[0]: b[1]] for b in self.bands]
        return x_bands

    def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
        x_bands = self.spectrogram(x)
        fmap = []
        x = []
        for band, stack in zip(x_bands, self.band_convs):
            for i, layer in enumerate(stack):
                band = layer(band)
                band = torch.nn.functional.leaky_relu(band, 0.1)
                if i > 0:
                    fmap.append(band)
            x.append(band)
        x = torch.cat(x, dim=-1)
        if cond_embedding_id is not None:
            emb = self.emb(cond_embedding_id)
            h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
        else:
            h = 0
        x = self.conv_post(x)
        fmap.append(x)
        x += h

        return x, fmap


================================================
FILE: cosyvoice/hifigan/f0_predictor.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


class ConvRNNF0Predictor(nn.Module):
    def __init__(self,
                 num_class: int = 1,
                 in_channels: int = 80,
                 cond_channels: int = 512
                 ):
        super().__init__()

        self.num_class = num_class
        self.condnet = nn.Sequential(
            weight_norm(
                nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
            ),
            nn.ELU(),
            weight_norm(
                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
            ),
            nn.ELU(),
            weight_norm(
                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
            ),
            nn.ELU(),
            weight_norm(
                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
            ),
            nn.ELU(),
            weight_norm(
                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
            ),
            nn.ELU(),
        )
        self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.condnet(x)
        x = x.transpose(1, 2)
        return torch.abs(self.classifier(x).squeeze(-1))


================================================
FILE: cosyvoice/hifigan/generator.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HIFI-GAN"""

from typing import Dict, Optional, List
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform

from cosyvoice.transformer.activation import Snake
from cosyvoice.utils.common import get_padding
from cosyvoice.utils.common import init_weights


"""hifigan based generator implementation.

This code is modified from https://github.com/jik876/hifi-gan
 ,https://github.com/kan-bayashi/ParallelWaveGAN and
 https://github.com/NVIDIA/BigVGAN

"""


class ResBlock(torch.nn.Module):
    """Residual block module in HiFiGAN/BigVGAN."""
    def __init__(
        self,
        channels: int = 512,
        kernel_size: int = 3,
        dilations: List[int] = [1, 3, 5],
    ):
        super(ResBlock, self).__init__()
        self.convs1 = nn.ModuleList()
        self.convs2 = nn.ModuleList()

        for dilation in dilations:
            self.convs1.append(
                weight_norm(
                    Conv1d(
                        channels,
                        channels,
                        kernel_size,
                        1,
                        dilation=dilation,
                        padding=get_padding(kernel_size, dilation)
                    )
                )
            )
            self.convs2.append(
                weight_norm(
                    Conv1d(
                        channels,
                        channels,
                        kernel_size,
                        1,
                        dilation=1,
                        padding=get_padding(kernel_size, 1)
                    )
                )
            )
        self.convs1.apply(init_weights)
        self.convs2.apply(init_weights)
        self.activations1 = nn.ModuleList([
            Snake(channels, alpha_logscale=False)
            for _ in range(len(self.convs1))
        ])
        self.activations2 = nn.ModuleList([
            Snake(channels, alpha_logscale=False)
            for _ in range(len(self.convs2))
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for idx in range(len(self.convs1)):
            xt = self.activations1[idx](x)
            xt = self.convs1[idx](xt)
            xt = self.activations2[idx](xt)
            xt = self.convs2[idx](xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for idx in range(len(self.convs1)):
            remove_weight_norm(self.convs1[idx])
            remove_weight_norm(self.convs2[idx])


class SineGen(torch.nn.Module):
    """ Definition of sine generator
    SineGen(samp_rate, harmonic_num = 0,
            sine_amp = 0.1, noise_std = 0.003,
            voiced_threshold = 0,
            flag_for_pulse=False)
    samp_rate: sampling rate in Hz
    harmonic_num: number of harmonic overtones (default 0)
    sine_amp: amplitude of sine-wavefrom (default 0.1)
    noise_std: std of Gaussian noise (default 0.003)
    voiced_thoreshold: F0 threshold for U/V classification (default 0)
    flag_for_pulse: this SinGen is used inside PulseGen (default False)
    Note: when flag_for_pulse is True, the first time step of a voiced
        segment is always sin(np.pi) or cos(0)
    """

    def __init__(self, samp_rate, harmonic_num=0,
                 sine_amp=0.1, noise_std=0.003,
                 voiced_threshold=0):
        super(SineGen, self).__init__()
        self.sine_amp = sine_amp
        self.noise_std = noise_std
        self.harmonic_num = harmonic_num
        self.sampling_rate = samp_rate
        self.voiced_threshold = voiced_threshold

    def _f02uv(self, f0):
        # generate uv signal
        uv = (f0 > self.voiced_threshold).type(torch.float32)
        return uv

    @torch.no_grad()
    def forward(self, f0):
        """
        :param f0: [B, 1, sample_len], Hz
        :return: [B, 1, sample_len]
        """

        F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
        for i in range(self.harmonic_num + 1):
            F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate

        theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
        u_dist = Uniform(low=-np.pi, high=np.pi)
        phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
        phase_vec[:, 0, :] = 0

        # generate sine waveforms
        sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)

        # generate uv signal
        uv = self._f02uv(f0)

        # noise: for unvoiced should be similar to sine_amp
        #        std = self.sine_amp/3 -> max value ~ self.sine_amp
        # .       for voiced regions is self.noise_std
        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
        noise = noise_amp * torch.randn_like(sine_waves)

        # first: set the unvoiced part to 0 by uv
        # then: additive noise
        sine_waves = sine_waves * uv + noise
        return sine_waves, uv, noise


class SourceModuleHnNSF(torch.nn.Module):
    """ SourceModule for hn-nsf
    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
                 add_noise_std=0.003, voiced_threshod=0)
    sampling_rate: sampling_rate in Hz
    harmonic_num: number of harmonic above F0 (default: 0)
    sine_amp: amplitude of sine source signal (default: 0.1)
    add_noise_std: std of additive Gaussian noise (default: 0.003)
        note that amplitude of noise in unvoiced is decided
        by sine_amp
    voiced_threshold: threhold to set U/V given F0 (default: 0)
    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
    F0_sampled (batchsize, length, 1)
    Sine_source (batchsize, length, 1)
    noise_source (batchsize, length 1)
    uv (batchsize, length, 1)
    """

    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
                 add_noise_std=0.003, voiced_threshod=0):
        super(SourceModuleHnNSF, self).__init__()

        self.sine_amp = sine_amp
        self.noise_std = add_noise_std

        # to produce sine waveforms
        self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
                                 sine_amp, add_noise_std, voiced_threshod)

        # to merge source harmonics into a single excitation
        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
        self.l_tanh = torch.nn.Tanh()

    def forward(self, x):
        """
        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
        F0_sampled (batchsize, length, 1)
        Sine_source (batchsize, length, 1)
        noise_source (batchsize, length 1)
        """
        # source for harmonic branch
        with torch.no_grad():
            sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
            sine_wavs = sine_wavs.transpose(1, 2)
            uv = uv.transpose(1, 2)
        sine_merge = self.l_tanh(self.l_linear(sine_wavs))

        # source for noise branch, in the same shape as uv
        noise = torch.randn_like(uv) * self.sine_amp / 3
        return sine_merge, noise, uv


class HiFTGenerator(nn.Module):
    """
    HiFTNet Generator: Neural Source Filter + ISTFTNet
    https://arxiv.org/abs/2309.09493
    """
    def __init__(
            self,
            in_channels: int = 80,
            base_channels: int = 512,
            nb_harmonics: int = 8,
            sampling_rate: int = 22050,
            nsf_alpha: float = 0.1,
            nsf_sigma: float = 0.003,
            nsf_voiced_threshold: float = 10,
            upsample_rates: List[int] = [8, 8],
            upsample_kernel_sizes: List[int] = [16, 16],
            istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
            resblock_kernel_sizes: List[int] = [3, 7, 11],
            resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
            source_resblock_kernel_sizes: List[int] = [7, 11],
            source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
            lrelu_slope: float = 0.1,
            audio_limit: float = 0.99,
            f0_predictor: torch.nn.Module = None,
    ):
        super(HiFTGenerator, self).__init__()

        self.out_channels = 1
        self.nb_harmonics = nb_harmonics
        self.sampling_rate = sampling_rate
        self.istft_params = istft_params
        self.lrelu_slope = lrelu_slope
        self.audio_limit = audio_limit

        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)
        self.m_source = SourceModuleHnNSF(
            sampling_rate=sampling_rate,
            upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
            harmonic_num=nb_harmonics,
            sine_amp=nsf_alpha,
            add_noise_std=nsf_sigma,
            voiced_threshod=nsf_voiced_threshold)
        self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])

        self.conv_pre = weight_norm(
            Conv1d(in_channels, base_channels, 7, 1, padding=3)
        )

        # Up
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(
                weight_norm(
                    ConvTranspose1d(
                        base_channels // (2**i),
                        base_channels // (2**(i + 1)),
                        k,
                        u,
                        padding=(k - u) // 2,
                    )
                )
            )

        # Down
        self.source_downs = nn.ModuleList()
        self.source_resblocks = nn.ModuleList()
        downsample_rates = [1] + upsample_rates[::-1][:-1]
        downsample_cum_rates = np.cumprod(downsample_rates)
        for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
            if u == 1:
                self.source_downs.append(
                    Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
                )
            else:
                self.source_downs.append(
                    Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
                )

            self.source_resblocks.append(
                ResBlock(base_channels // (2 ** (i + 1)), k, d)
            )

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = base_channels // (2**(i + 1))
            for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                self.resblocks.append(ResBlock(ch, k, d))

        self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)
        self.reflection_pad = nn.ReflectionPad1d((1, 0))
        self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
        self.f0_predictor = f0_predictor

    def remove_weight_norm(self):
        print('Removing weight norm...')
        for l in self.ups:
            remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        remove_weight_norm(self.conv_pre)
        remove_weight_norm(self.conv_post)
        self.m_source.remove_weight_norm()
        for l in self.source_downs:
            remove_weight_norm(l)
        for l in self.source_resblocks:
            l.remove_weight_norm()

    def _stft(self, x):
        spec = torch.stft(
            x,
            self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
            return_complex=True)
        spec = torch.view_as_real(spec)  # [B, F, TT, 2]
        return spec[..., 0], spec[..., 1]

    def _istft(self, magnitude, phase):
        magnitude = torch.clip(magnitude, max=1e2)
        real = magnitude * torch.cos(phase)
        img = magnitude * torch.sin(phase)
        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
                                        self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
        return inverse_transform

    def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
        s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
        s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)

        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, self.lrelu_slope)
            x = self.ups[i](x)

            if i == self.num_upsamples - 1:
                x = self.reflection_pad(x)

            # fusion
            si = self.source_downs[i](s_stft)
            si = self.source_resblocks[i](si)
            x = x + si

            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels

        x = F.leaky_relu(x)
        x = self.conv_post(x)
        magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
        phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :])  # actually, sin is redundancy

        x = self._istft(magnitude, phase)
        x = torch.clamp(x, -self.audio_limit, self.audio_limit)
        return x

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
        # mel->f0
        f0 = self.f0_predictor(speech_feat)
        # f0->source
        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
        s, _, _ = self.m_source(s)
        s = s.transpose(1, 2)
        # mel+source->speech
        generated_speech = self.decode(x=speech_feat, s=s)
        return generated_speech, f0

    @torch.inference_mode()
    def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
        # mel->f0
        f0 = self.f0_predictor(speech_feat)
        # f0->source
        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
        s, _, _ = self.m_source(s)
        s = s.transpose(1, 2)
        # use cache_source to avoid glitch
        if cache_source.shape[2] != 0:
            s[:, :, :cache_source.shape[2]] = cache_source
        generated_speech = self.decode(x=speech_feat, s=s)
        return generated_speech, s


================================================
FILE: cosyvoice/hifigan/hifigan.py
================================================
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
from cosyvoice.utils.losses import tpr_loss, mel_loss


class HiFiGan(nn.Module):
    def __init__(self, generator, discriminator, mel_spec_transform,
                 multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
                 tpr_loss_weight=1.0, tpr_loss_tau=0.04):
        super(HiFiGan, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.mel_spec_transform = mel_spec_transform
        self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
        self.feat_match_loss_weight = feat_match_loss_weight
        self.tpr_loss_weight = tpr_loss_weight
        self.tpr_loss_tau = tpr_loss_tau

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        if batch['turn'] == 'generator':
            return self.forward_generator(batch, device)
        else:
            return self.forward_discriminator(batch, device)

    def forward_generator(self, batch, device):
        real_speech = batch['speech'].to(device)
        pitch_feat = batch['pitch_feat'].to(device)
        # 1. calculate generator outputs
        generated_speech, generated_f0 = self.generator(batch, device)
        # 2. calculate discriminator outputs
        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
        # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
        loss_gen, _ = generator_loss(y_d_gs)
        loss_fm = feature_loss(fmap_rs, fmap_gs)
        loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
        if self.tpr_loss_weight != 0:
            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
        else:
            loss_tpr = torch.zeros(1).to(device)
        loss_f0 = F.l1_loss(generated_f0, pitch_feat)
        loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
            self.multi_mel_spectral_recon_loss_weight * loss_mel + \
            self.tpr_loss_weight * loss_tpr + loss_f0
        return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}

    def forward_discriminator(self, batch, device):
        real_speech = batch['speech'].to(device)
        # 1. calculate generator outputs
        with torch.no_grad():
            generated_speech, generated_f0 = self.generator(batch, device)
        # 2. calculate discriminator outputs
        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
        # 3. calculate discriminator losses, tpr losses [Optional]
        loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
        if self.tpr_loss_weight != 0:
            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
        else:
            loss_tpr = torch.zeros(1).to(device)
        loss = loss_disc + self.tpr_loss_weight * loss_tpr
        return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}


================================================
FILE: cosyvoice/llm/__init__.py
================================================


================================================
FILE: cosyvoice/llm/llm.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
import torch.nn.functional as F
from transformers import Qwen2ForCausalLM
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy


class TransformerLM(torch.nn.Module):
    def __init__(
            self,
            text_encoder_input_size: int,
            llm_input_size: int,
            llm_output_size: int,
            text_token_size: int,
            speech_token_size: int,
            text_encoder: torch.nn.Module,
            llm: torch.nn.Module,
            sampling: Callable,
            length_normalized_loss: bool = True,
            lsm_weight: float = 0.0,
            spk_embed_dim: int = 192,
    ):
        super().__init__()
        self.llm_input_size = llm_input_size
        self.speech_token_size = speech_token_size
        # 1. build text token inputs related modules
        self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
        self.text_encoder = text_encoder
        self.text_encoder_affine_layer = nn.Linear(
            self.text_encoder.output_size(),
            llm_input_size
        )

        # 2. build speech token language model related modules
        self.sos_eos = 0
        self.task_id = 1
        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
        self.llm = llm
        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
        self.criterion_ce = LabelSmoothingLoss(
            size=speech_token_size + 1,
            padding_idx=IGNORE_ID,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )

        # 3. [Optional] build speech token related modules
        self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)

        # 4. sampling method
        self.sampling = sampling

    def encode(
            self,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ):
        encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        encoder_out = self.text_encoder_affine_layer(encoder_out)
        return encoder_out, encoder_out_lens

    def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
        text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
        speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
                    for i in range(len(text_token))]
        lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
        lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
        return lm_input, lm_input_len

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        """
        Args:
            text: (B, L, D)
            text_lengths: (B,)
            audio: (B, T, N) or (B, T)
            audio_lengths: (B,)
        """
        text_token = batch['text_token'].to(device)
        text_token_len = batch['text_token_len'].to(device)
        speech_token = batch['speech_token'].to(device)
        speech_token_len = batch['speech_token_len'].to(device)
        embedding = batch['embedding'].to(device)

        # 1. prepare llm_target
        lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
                                  [self.speech_token_size]) for i in range(text_token.size(0))]
        lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)

        # 1. encode text_token
        text_token = self.text_embedding(text_token)
        text_token, text_token_len = self.encode(text_token, text_token_len)

        # 2. embedding projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)
        embedding = embedding.unsqueeze(1)

        # 3. eos and task_id
        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)

        # 4. encode speech_token
        speech_token = self.speech_embedding(speech_token)

        # 5. unpad and pad
        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
                                                         task_id_emb, speech_token, speech_token_len)

        # 6. run lm forward
        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
        logits = self.llm_decoder(lm_output)
        loss = self.criterion_ce(logits, lm_target)
        acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
        return {'loss': loss, 'acc': acc}

    def sampling_ids(
            self,
            weighted_scores: torch.Tensor,
            decoded_tokens: List,
            sampling: int,
            ignore_eos: bool = True,
    ):
        while True:
            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
            if (not ignore_eos) or (self.speech_token_size not in top_ids):
                break
        return top_ids

    @torch.inference_mode()
    def inference(
            self,
            text: torch.Tensor,
            text_len: torch.Tensor,
            prompt_text: torch.Tensor,
            prompt_text_len: torch.Tensor,
            prompt_speech_token: torch.Tensor,
            prompt_speech_token_len: torch.Tensor,
            embedding: torch.Tensor,
            sampling: int = 25,
            max_token_text_ratio: float = 20,
            min_token_text_ratio: float = 2,
    ) -> Generator[torch.Tensor, None, None]:
        device = text.device
        text = torch.concat([prompt_text, text], dim=1)
        text_len += prompt_text_len
        text = self.text_embedding(text)

        # 1. encode text
        text, text_len = self.encode(text, text_len)

        # 2. encode embedding
        if embedding.shape[0] != 0:
            embedding = F.normalize(embedding, dim=1)
            embedding = self.spk_embed_affine_layer(embedding)
            embedding = embedding.unsqueeze(dim=1)
        else:
            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

        # 3. concat llm_input
        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
        if prompt_speech_token_len != 0:
            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
        else:
            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)

        # 4. cal min/max_length
        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

        # 5. step by step decode
        out_tokens = []
        offset = 0
        att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
        for i in range(max_len):
            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
                                                                  att_cache=att_cache, cnn_cache=cnn_cache,
                                                                  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                 device=lm_input.device)).to(torch.bool))
            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
            # force continue decode first token
            if i == 0:
                logp[:, self.speech_token_size] = -float('inf')
            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
            if top_ids == self.speech_token_size:
                break
            # in stream mode, yield token one by one
            yield top_ids
            out_tokens.append(top_ids)
            offset += lm_input.size(1)
            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)


class Qwen2Encoder(torch.nn.Module):
    def __init__(self, pretrain_path):
        super().__init__()
        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)

    def forward_one_step(self, xs, masks, cache=None):
        input_masks = masks[:, -1, :]
        outs = self.model(
            inputs_embeds=xs,
            attention_mask=input_masks,
            output_hidden_states=True,
            return_dict=True,
            use_cache=True,
            past_key_values=cache,
        )
        xs = outs.hidden_states[-1]
        new_cache = outs.past_key_values
        return xs, new_cache


class Qwen2LM(torch.nn.Module):
    def __init__(
            self,
            llm_input_size: int,
            llm_output_size: int,
            speech_token_size: int,
            llm: torch.nn.Module,
            sampling: Callable,
            length_normalized_loss: bool = True,
            lsm_weight: float = 0.0,
    ):
        super().__init__()
        self.llm_input_size = llm_input_size
        self.llm_output_size = llm_output_size
        self.speech_token_size = speech_token_size

        # 2. build speech token language model related modules
        self.sos_eos = 0
        self.task_id = 1
        self.fill_token = 2

        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
        self.llm = llm
        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
        self.criterion_ce = LabelSmoothingLoss(
            size=speech_token_size + 3,
            padding_idx=IGNORE_ID,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )

        # 3. [Optional] build speech token related modules
        self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)

        # 4. sampling method
        self.sampling = sampling

    def sampling_ids(
            self,
            weighted_scores: torch.Tensor,
            decoded_tokens: List,
            sampling: int,
            ignore_eos: bool = True,
    ):
        num_trials, max_trials = 0, 100
        while True:
            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
            if (not ignore_eos) or (self.speech_token_size not in top_ids):
                break
            num_trials += 1
            if num_trials > max_trials:
                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
        return top_ids

    @torch.inference_mode()
    def inference(
            self,
            text: torch.Tensor,
            text_len: torch.Tensor,
            prompt_text: torch.Tensor,
            prompt_text_len: torch.Tensor,
            prompt_speech_token: torch.Tensor,
            prompt_speech_token_len: torch.Tensor,
            embedding: torch.Tensor,
            sampling: int = 25,
            max_token_text_ratio: float = 20,
            min_token_text_ratio: float = 2,
    ) -> Generator[torch.Tensor, None, None]:
        device = text.device
        text = torch.concat([prompt_text, text], dim=1)
        text_len += prompt_text_len
        text = self.llm.model.model.embed_tokens(text)

        # 2. encode embedding
        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

        # 3. concat llm_input
        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
        if prompt_speech_token_len != 0:
            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
        else:
            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)

        # 4. cal min/max_length
        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

        # 5. step by step decode
        out_tokens = []
        cache = None
        for i in range(max_len):
            y_pred, cache = self.llm.forward_one_step(lm_input,
                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                                      cache=cache)
            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
            if top_ids == self.speech_token_size:
                break
            if top_ids > self.speech_token_size:
                continue
            # in stream mode, yield token one by one
            yield top_ids
            out_tokens.append(top_ids)
            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)


================================================
FILE: cosyvoice/tokenizer/__init__.py
================================================


================================================
FILE: cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
================================================
IQ== 0
Ig== 1
Iw== 2
JA== 3
JQ== 4
Jg== 5
Jw== 6
KA== 7
KQ== 8
Kg== 9
Kw== 10
LA== 11
LQ== 12
Lg== 13
Lw== 14
MA== 15
MQ== 16
Mg== 17
Mw== 18
NA== 19
NQ== 20
Ng== 21
Nw== 22
OA== 23
OQ== 24
Og== 25
Ow== 26
PA== 27
PQ== 28
Pg== 29
Pw== 30
QA== 31
QQ== 32
Qg== 33
Qw== 34
RA== 35
RQ== 36
Rg== 37
Rw== 38
SA== 39
SQ== 40
Sg== 41
Sw== 42
TA== 43
TQ== 44
Tg== 45
Tw== 46
UA== 47
UQ== 48
Ug== 49
Uw== 50
VA== 51
VQ== 52
Vg== 53
Vw== 54
WA== 55
WQ== 56
Wg== 57
Ww== 58
XA== 59
XQ== 60
Xg== 61
Xw== 62
YA== 63
YQ== 64
Yg== 65
Yw== 66
ZA== 67
ZQ== 68
Zg== 69
Zw== 70
aA== 71
aQ== 72
ag== 73
aw== 74
bA== 75
bQ== 76
bg== 77
bw== 78
cA== 79
cQ== 80
cg== 81
cw== 82
dA== 83
dQ== 84
dg== 85
dw== 86
eA== 87
eQ== 88
eg== 89
ew== 90
fA== 91
fQ== 92
fg== 93
oQ== 94
og== 95
ow== 96
pA== 97
pQ== 98
pg== 99
pw== 100
qA== 101
qQ== 102
qg== 103
qw== 104
rA== 105
rg== 106
rw== 107
sA== 108
sQ== 109
sg== 110
sw== 111
tA== 112
tQ== 113
tg== 114
tw== 115
uA== 116
uQ== 117
ug== 118
uw== 119
vA== 120
vQ== 121
vg== 122
vw== 123
wA== 124
wQ== 125
wg== 126
ww== 127
xA== 128
xQ== 129
xg== 130
xw== 131
yA== 132
yQ== 133
yg== 134
yw== 135
zA== 136
zQ== 137
zg== 138
zw== 139
0A== 140
0Q== 141
0g== 142
0w== 143
1A== 144
1Q== 145
1g== 146
1w== 147
2A== 148
2Q== 149
2g== 150
2w== 151
3A== 152
3Q== 153
3g== 154
3w== 155
4A== 156
4Q== 157
4g== 158
4w== 159
5A== 160
5Q== 161
5g== 162
5w== 163
6A== 164
6Q== 165
6g== 166
6w== 167
7A== 168
7Q== 169
7g== 170
7w== 171
8A== 172
8Q== 173
8g== 174
8w== 175
9A== 176
9Q== 177
9g== 178
9w== 179
+A== 180
+Q== 181
+g== 182
+w== 183
/A== 184
/Q== 185
/g== 186
/w== 187
AA== 188
AQ== 189
Ag== 190
Aw== 191
BA== 192
BQ== 193
Bg== 194
Bw== 195
CA== 196
CQ== 197
Cg== 198
Cw== 199
DA== 200
DQ== 201
Dg== 202
Dw== 203
EA== 204
EQ== 205
Eg== 206
Ew== 207
FA== 208
FQ== 209
Fg== 210
Fw== 211
GA== 212
GQ== 213
Gg== 214
Gw== 215
HA== 216
HQ== 217
Hg== 218
Hw== 219
IA== 220
fw== 221
gA== 222
gQ== 223
gg== 224
gw== 225
hA== 226
hQ== 227
hg== 228
hw== 229
iA== 230
iQ== 231
ig== 232
iw== 233
jA== 234
jQ== 235
jg== 236
jw== 237
kA== 238
kQ== 239
kg== 240
kw== 241
lA== 242
lQ== 243
lg== 244
lw== 245
mA== 246
mQ== 247
mg== 248
mw== 249
nA== 250
nQ== 251
ng== 252
nw== 253
oA== 254
rQ== 255
IHQ= 256
IGE= 257
IHRo 258
aW4= 259
ZXI= 260
IHc= 261
IHM= 262
b3U= 263
IHRoZQ== 264
cmU= 265
b24= 266
YXQ= 267
ZW4= 268
IGM= 269
aXQ= 270
aXM= 271
IGI= 272
bmQ= 273
IGQ= 274
IG0= 275
IGg= 276
IG8= 277
aW5n 278
ZXM= 279
IHA= 280
IHRv 281
YW4= 282
IGY= 283
b3I= 284
bGw= 285
IEk= 286
IGw= 287
IHk= 288
YXI= 289
IGc= 290
IHlvdQ== 291
ZWQ= 292
IGFuZA== 293
IGlu 294
IG9m 295
YXM= 296
IG4= 297
b20= 298
aWM= 299
IHRoYXQ= 300
dXM= 301
ZXQ= 302
dmU= 303
YWw= 304
b3c= 305
bGU= 306
IGlz 307
IGU= 308
IGl0 309
b3Q= 310
J3M= 311
IGJl 312
aW9u 313
IFQ= 314
IHdo 315
IEE= 316
ZW50 317
IFM= 318
IHJl 319
YXk= 320
IHdl 321
IG9u 322
ZXJl 323
IGhh 324
dXQ= 325
YWM= 326
aWQ= 327
aWc= 328
b3M= 329
a2U= 330
dmVy 331
aW0= 332
INA= 333
IFRo 334
YW0= 335
YWxs 336
IGZvcg== 337
ZWw= 338
Y2g= 339
cm8= 340
IHRoaXM= 341
IHN0 342
IFc= 343
IHU= 344
YWQ= 345
b3V0 346
aXI= 347
bGQ= 348
Y3Q= 349
IGs= 350
aWY= 351
IGdv 352
Li4= 353
0L4= 354
aXRo 355
bHk= 356
aHQ= 357
cXU= 358
IC0= 359
IGRv 360
IGo= 361
IGhhdmU= 362
IEI= 363
IGFu 364
IHdpdGg= 365
IGFyZQ== 366
IHI= 367
IGRl 368
IHNl 369
IHNv 370
IHY= 371
c3Q= 372
aWxs 373
dXI= 374
IGxp 375
IE0= 376
ZXN0 377
b2Q= 378
YWxseQ== 379
J3Q= 380
dXN0 381
IGFz 382
IEM= 383
Y2U= 384
IG1l 385
0LA= 386
0LU= 387
aWw= 388
IEg= 389
IHdhcw== 390
dGVy 391
dGg= 392
IGNhbg== 393
YW50 394
IGNvbQ== 395
b3Vy 396
aWdodA== 397
IFk= 398
YXRpb24= 399
IEFuZA== 400
b2w= 401
IHNo 402
0YI= 403
b3A= 404
c2U= 405
IG5vdA== 406
IFNv 407
IG5l 408
dW4= 409
IGFi 410
IGxpa2U= 411
IGF0 412
IEQ= 413
aWU= 414
IGhl 415
IGNvbg== 416
IGNo 417
b3Jl 418
IGFs 419
IG9y 420
IHF1 421
IE8= 422
b21l 423
cmE= 424
dWw= 425
IE4= 426
cHA= 427
IHlvdXI= 428
b3VsZA== 429
IFA= 430
IGZy 431
Z2U= 432
ZXJz 433
J3Jl 434
0Lg= 435
IHRoZXk= 436
IHdoYXQ= 437
dXNl 438
IGFsbA== 439
IFRoZQ== 440
IEw= 441
ZXNz 442
ZW0= 443
IGtu 444
IGp1c3Q= 445
YXJ0 446
IHBybw== 447
dmVyeQ== 448
dW0= 449
IGxv 450
IOw= 451
IG15 452
b2s= 453
IGV4 454
YWI= 455
IHRoZXJl 456
IGJ1dA== 457
IGtub3c= 458
IHN1 459
IEc= 460
0YE= 461
IEU= 462
IG1h 463
0L7Q 464
IGVu 465
IGFib3V0 466
IEl0 467
aXN0 468
IHdvcg== 469
cmk= 470
aW5k 471
IG9uZQ== 472
YXRl 473
YW5k 474
aW5r 475
IGxl 476
b3J0 477
J20= 478
IEY= 479
aWNo 480
0YA= 481
aWRl 482
IGdldA== 483
IG91dA== 484
Li4u 485
IHdpbGw= 486
44E= 487
aXZl 488
0L0= 489
IGZyb20= 490
YWlu 491
IFdl 492
IHVw 493
cGU= 494
cmVz 495
Y2E= 496
IFI= 497
IGlm 498
IHBs 499
IGRvbg== 500
YWNr 501
IDE= 502
ICI= 503
IHRy 504
IHVz 505
IFdo 506
aXR5 507
IEo= 508
IFlvdQ== 509
IGhlcmU= 510
aGVy 511
IHNvbWU= 512
b3Vn 513
YWs= 514
YXJk 515
IGdvaW5n 516
IHVu 517
bWVudA== 518
IHRoaW5r 519
IHBl 520
ZW5k 521
ICg= 522
Y2F1c2U= 523
IHRpbQ== 524
YXN0 525
w6k= 526
IG91cg== 527
IHdhbnQ= 528
YW1l 529
aWVz 530
IOs= 531
dWQ= 532
aW5l 533
IHJlYWxseQ== 534
IHRl 535
IHNlZQ== 536
Y2k= 537
IGJ5 538
c28= 539
dXJl 540
b3Nl 541
IFs= 542
YXJl 543
IG1vcmU= 544
YWg= 545
b25l 546
Y2s= 547
b3BsZQ== 548
0LDQ 549
IHRoZW4= 550
IHRoaW5n 551
IHRoZW0= 552
dmVu 553
b3VuZA== 554
b3N0 555
b25n 556
ZWN0 557
IHJpZ2h0 558
YWc= 559
IGludA== 560
IHBlb3BsZQ== 561
IHdoZW4= 562
b3Vz 563
cGw= 564
IHRpbWU= 565
IGlt 566
IHdobw== 567
IDI= 568
YXA= 569
IGJlY2F1c2U= 570
aGluZw== 571
IG5v 572
aWNl 573
IGxvb2s= 574
IGhhcw== 575
IHdvdWxk 576
IGhvdw== 577
YWN0 578
IGZl 579
bnQ= 580
b3VnaA== 581
IHBy 582
IEJ1dA== 583
IHNheQ== 584
0YM= 585
IG5vdw== 586
IG1hbg== 587
IHZlcnk= 588
IHdvcms= 589
aXo= 590
IEs= 591
aXY= 592
aXR0 593
IGFy 594
ZXA= 595
IGNs 596
IHdoaWNo 597
IGNv 598
YW5z 599
J3Zl 600
IHNh 601
ZmY= 602
J2xs 603
IGFueQ== 604
IGFjdA== 605
IHll 606
YmVy 607
YWNo 608
YWdl 609
cGVy 610
IGFsc28= 611
ZmVy 612
IHRoZXNl 613
IGFk 614
0LXQ 615
dGhlcg== 616
YWNl 617
aWNr 618
YWtl 619
cmVhdA== 620
aXJl 621
dWU= 622
IGFn 623
IFU= 624
dWNo 625
aW9ucw== 626
cnk= 627
MDA= 628
bmE= 629
IGRpZA== 630
IHF1ZQ== 631
IGhhZA== 632
IGV2ZXJ5 633
IEhl 634
IGxh 635
IHdheQ== 636
IHNw 637
Ymxl 638
IFRoaXM= 639
YXNz 640
IHRoZWly 641
aXRl 642
IG5lZWQ= 643
IHBhcnQ= 644
IHdlcmU= 645
IGJhY2s= 646
aXA= 647
b3du 648
b21ldA== 649
YmU= 650
YXNl 651
IG1ha2U= 652
aXJzdA== 653
aWE= 654
ZW5jZQ== 655
YW5n 656
YW5r 657
IGdvdA== 658
IHByZQ== 659
IGNvbnQ= 660
IG90aGVy 661
cHQ= 662
IFRoYXQ= 663
b2c= 664
IGdvb2Q= 665
IGludG8= 666
YWxr 667
IGJlZW4= 668
IGFt 669
IG92ZXI= 670
dWFsbHk= 671
IOI= 672
7J0= 673
IHVuZA== 674
aGU= 675
d2F5 676
IGdy 677
0Yw= 678
IGRpZg== 679
IHBlcg== 680
0Y8= 681
IElu 682
IHR3 683
b25k 684
YXJz 685
aW50 686
b3Jt 687
IGxvdA== 688
IHdoZXJl 689
IMM= 690
IFY= 691
IHNvbWV0 692
0Ls= 693
ZW5z 694
IGd1 695
IGFj 696
dWc= 697
0Ys= 698
xLE= 699
IGZpcnN0 700
cmVl 701
IGhpcw== 702
aXR0bGU= 703
IGltcA== 704
IG1v 705
YXY= 706
IGxpdHRsZQ== 707
IFdoYXQ= 708
IG11Y2g= 709
IHo= 710
IOo= 711
YWJsZQ== 712
INC/ 713
IHBv 714
IGNvbXA= 715
bmU= 716
IGRpcw== 717
IGxldA== 718
YW5jZQ== 719
IGhlcg== 720
IHRoaW5ncw== 721
IHN0YXJ0 722
dWx0 723
IGFwcA== 724
IHJlcw== 725
IGZv 726
IGNvdWxk 727
IGludGVy 728
IHRob3Nl 729
IGRlcw== 730
IHdlbGw= 731
IHR3bw== 732
IGtpbmQ= 733
eHQ= 734
cmVzcw== 735
ZWx5 736
w6Q= 737
IGJy 738
IHRocg== 739
INCy 740
IGk= 741
aXNo 742
IGRpZmZlcg== 743
IHJv 744
IFN0 745
IHNvbWV0aGluZw== 746
IHRha2U= 747
IGJv 748
eXM= 749
IHNoZQ== 750
IHRhbGs= 751
bG8= 752
0Yc= 753
IGV2ZW4= 754
0Lo= 755
44A= 756
INC9 757
IGJ1 758
IElm 759
IGRvd24= 760
IENo 761
YWRl 762
YXRpb25z 763
IHVzZQ== 764
b3Jk 765
IG9mZg== 766
IGFjdHVhbGx5 767
IHNwZQ== 768
ZHU= 769
YXRlZA== 770
YXRlcg== 771
b3Nz 772
bmluZw== 773
w7w= 774
IGRvZXM= 775
INGB 776
IG5ldw== 777
IGJldA== 778
dmVs 779
Y2Vzcw== 780
cGxl 781
IGhhcHA= 782
dGluZw== 783
b25uYQ== 784
IGVz 785
IGRheQ== 786
IG9ubHk= 787
aWdu 788
a2F5 789
c2Vs 790
ZW50cw== 791
b3VudA== 792
aWxk 793
aWxl 794
IHNj 795
IGhpbQ== 796
IGFnYWlu 797
dmluZw== 798
IGdvbm5h 799
IGNvbW0= 800
IGhlbA== 801
b3RoZXI= 802
IGtl 803
aWNhbA== 804
IDM= 805
IGVs 806
IHRocm91Z2g= 807
IGNvbWU= 808
YXJr 809
ZGF5 810
aWVy 811
w7M= 812
IHRoYW4= 813
IFRoZXk= 814
IG1heQ== 815
IHNlcg== 816
7ZU= 817
IGNhbGw= 818
IGRpZmZlcmVudA== 819
IHNob3VsZA== 820
IFRoZXJl 821
YXJ5 822
IE5vdw== 823
44I= 824
dGhpbmc= 825
d2U= 826
b3J5 827
ZnRlcg== 828
IHB1dA== 829
b3Jz 830
aWFs 831
64s= 832
IHVuZGVy 833
IGluYw== 834
IFll 835
dWI= 836
Zm9ybQ== 837
IHZpZGU= 838
4Lg= 839
dmVycw== 840
IGZlZWw= 841
w6E= 842
b2R5 843
ZnQ= 844
Zm9yZQ== 845
IGVt 846
Z2V0 847
IHNhaWQ= 848
aXRpb24= 849
IHJlYw== 850
aW91cw== 851
YXRjaA== 852
IHRyeQ== 853
IGhlbHA= 854
IHNob3c= 855
0LQ= 856
IGJpdA== 857
dWxs 858
0LI= 859
0YLQvg== 860
Z3I= 861
IHBsYXk= 862
aWZl 863
YWls 864
IFllYWg= 865
IHF1ZXN0 866
IG1hbnk= 867
IHBlcnM= 868
IGdyZWF0 869
w60= 870
IGVzdA== 871
bmc= 872
IOKZ 873
dHk= 874
bGE= 875
IE9o 876
INc= 877
4K4= 878
IEJl 879
YWR5 880
IG1vc3Q= 881
Y3Rpb24= 882
IE5v 883
IGRvaW5n 884
IGJlaW5n 885
IHRvbw== 886
Y2Vz 887
IGJs 888
LiI= 889
IHJlbQ== 890
aXNz 891
b25z 892
Pj4= 893
cnU= 894
d24= 895
b250 896
aWI= 897
ZWxs 898
IHNt 899
b3Ro 900
dWFs 901
ID4+ 902
IHBo 903
bGVz 904
b2M= 905
ZnVs 906
IHNlYw== 907
aXNl 908
IGFkZA== 909
aWdo 910
ZXJ0 911
IHNhbWU= 912
4oA= 913
IG1lYW4= 914
IGZpbmQ= 915
ZWs= 916
IGVuZA== 917
LS0= 918
0Lw= 919
IHN0aWxs 920
YXo= 921
ICc= 922
IG1pbg== 923
IHllYXJz 924
dXJu 925
IGFyb3VuZA== 926
c2VsZg== 927
IHdy 928
YnM= 929
b3VnaHQ= 930
IOKZqg== 931
IGZs 932
YW5nZQ== 933
IGFmdGVy 934
IHBvaW50 935
bWVy 936
dmVk 937
IGxvbmc= 938
b3k= 939
5Lg= 940
IGNy 941
d2F5cw== 942
IHN5 943
IHRyYQ== 944
IDIw 945
YXZl 946
IGNoZQ== 947
IGVudA== 948
IGJlZm9yZQ== 949
cGg= 950
IGF0dA== 951
aWFu 952
aWx5 953
IHBlcnNvbg== 954
IGJpZw== 955
IHNjaA== 956
IHJlYWw= 957
IG5leHQ= 958
IGxvdmU= 959
IHZpZGVv 960
IExldA== 961
IGZpbg== 962
IG1haw== 963
aWJsZQ== 964
IHRvZGF5 965
ZXJt 966
IEFs 967
b3dlcg== 968
YW5u 969
aXg= 970
IHBhcg== 971
IHN0dWQ= 972
w7Y= 973
IGltcG9ydA== 974
dGU= 975
IGdpdmU= 976
dmVz 977
IGRpZQ== 978
IGRlYw== 979
IHRlbGw= 980
INC6 981
0YHRgg== 982
IHdoeQ== 983
aWNhbGx5 984
aWN0 985
cmVk 986
IGJhcw== 987
IHN1cmU= 988
IGJlbA== 989
YXRpbmc= 990
IHRhaw== 991
IHNldA== 992
IGxpZmU= 993
IGRpZG4= 994
2Kc= 995
b2I= 996
dW5k 997
YXRo 998
IG9w 999
INC+ 1000
YWl0 1001
IHdvcmxk 1002
IHN1cHA= 1003
aW8= 1004
IGNvdXI= 1005
INC4 1006
d2FyZA== 1007
0LXQvQ== 1008
IGFsd2F5cw== 1009
dXA= 1010
IGhhbmQ= 1011
IEhvdw== 1012
Y2lhbA== 1013
IGNvbnM= 1014
INE= 1015
IGluZA== 1016
IDQ= 1017
IEFz 1018
IGZ1bg== 1019
amVjdA== 1020
IGltcG9ydGFudA== 1021
IHN1cg== 1022
ZXc= 1023
YXRlcw== 1024
IDU= 1025
IGRp 1026
IG1hZGU= 1027
IGlucw== 1028
IGFzaw== 1029
IGV0 1030
IG51bQ== 1031
IGNhcg== 1032
IE9rYXk= 1033
IHNpbQ== 1034
aWs= 1035
IGxhc3Q= 1036
IEdv 1037
IG11cw== 1038
IHJlbA== 1039
dWxhcg== 1040
tOw= 1041
IFdlbGw= 1042
cGVjdA== 1043
IFRoYW5r 1044
IHRocmVl 1045
w6M= 1046
44M= 1047
IGludg== 1048
IGdlbg== 1049
bGlj 1050
IGhhcHBlbg== 1051
64o= 1052
aWVu 1053
ZXZlcg== 1054
0L7Qsg== 1055
IHN0cg== 1056
IEFsbA== 1057
IGluc3Q= 1058
IOKA 1059
IGRlZg== 1060
IHNs 1061
IG1pZ2h0 1062
dW5n 1063
IHllYXI= 1064
IG93bg== 1065
IGtlZXA= 1066
Ym9keQ== 1067
ZGVy 1068
INGC 1069
INC0 1070
IGFub3RoZXI= 1071
IG1vZA== 1072
IGV2 1073
IGd1eXM= 1074
IGFibGU= 1075
w6Nv 1076
cXVl 1077
aWRlbnQ= 1078
IFllcw== 1079
IGl0cw== 1080
IHBsYWNl 1081
IHByb2R1 1082
YXJu 1083
INC8 1084
IHJlcA== 1085
IGV4cGVy 1086
IGZhbQ== 1087
aXRpZXM= 1088
aWZpYw== 1089
IGhpZ2g= 1090
aWVk 1091
b29s 1092
aWV3 1093
0LXRgg== 1094
cmVu 1095
IGRvbmU= 1096
IC4uLg== 1097
64qU 1098
c3RlbQ== 1099
IFNl 1100
IGJldHRlcg== 1101
Y29tZQ== 1102
IGRlbA== 1103
IHR5 1104
IHVt 1105
IGhv 1106
IEFu 1107
IG1vbg== 1108
aW5ncw== 1109
IHNr 1110
IG9i 1111
Y29t 1112
YmxlbQ== 1113
b3Bl 1114
c3RhbmQ= 1115
J2Q= 1116
bWVudHM= 1117
IGVsZQ== 1118
IElz 1119
IGRh 1120
IHJlZw== 1121
bGVhc2U= 1122
aWtl 1123
YWxz 1124
aXpl 1125
6rA= 1126
IGNhcmU= 1127
IG5ldmVy 1128
7J20 1129
ZXNl 1130
IG1ldA== 1131
b2xvZw== 1132
IFdoZW4= 1133
dWNr 1134
0LXRgA== 1135
IMOp 1136
IGRhdA== 1137
w6c= 1138
IGV4YW0= 1139
aWxpdHk= 1140
IGRldA== 1141
Y3Jp 1142
IHVzZWQ= 1143
IERv 1144
IHRyYW5z 1145
ZWc= 1146
dGVu 1147
0Y4= 1148
Y3Vz 1149
IHNlY29uZA== 1150
IGJlc3Q= 1151
IGhhcmQ= 1152
IGlkZQ== 1153
IHByb2JsZW0= 1154
6rM= 1155
IFVu 1156
0YU= 1157
IM4= 1158
IHdhdGNo 1159
IFNo 1160
YXR0ZXI= 1161
IHByZXQ= 1162
IGRlcg== 1163
IGNvdXJzZQ== 1164
xZ8= 1165
YXRpdmU= 1166
aWNz 1167
IHF1ZXN0aW9u 1168
dXRl 1169
7Jc= 1170
IEZvcg== 1171
YXRoZXI= 1172
IGNvbA== 1173
aWVuZA== 1174
IO0= 1175
IFo= 1176
IGRvZXNu 1177
YXJjaA== 1178
IGludGVyZXN0 1179
IHBvbA== 1180
IGNvcg== 1181
aWVuY2U= 1182
IHByZXM= 1183
IGVhY2g= 1184
IHN5c3RlbQ== 1185
IGZhY3Q= 1186
aWVs 1187
YWJseQ== 1188
IGVy 1189
IHJ1bg== 1190
IOyd 1191
IHRvcA== 1192
bmVy 1193
IHRob3VnaHQ= 1194
IGVhcw== 1195
aWVudA== 1196
IGNyZQ== 1197
0Yg= 1198
IGNvbW11bg== 1199
eWU= 1200
cmVhZHk= 1201
bGxvdw== 1202
IGV2ZXJ5dGhpbmc= 1203
b21t 1204
IG1lZA== 1205
mpQ= 1206
IGNvdW50 1207
aXRz 1208
IGNvbXBs 1209
aGlw 1210
2YQ= 1211
b29r 1212
IHRvZ2V0 1213
IHRvZ2V0aGVy 1214
YW1w 1215
IGdhbWU= 1216
IGFscmVhZHk= 1217
0LDQuw== 1218
IGNhbGxlZA== 1219
YWxl 1220
xYI= 1221
IE15 1222
IHVuZGVyc3RhbmQ= 1223
IGRy 1224
IG1vbQ== 1225
aXRlZA== 1226
0L7Quw== 1227
IHVzaW5n 1228
enk= 1229
IG51bWJlcg== 1230
44CB 1231
Y2Vk 1232
IGNsZQ== 1233
0L3Qvg== 1234
64uk 1235
aW5jZQ== 1236
IGxvb2tpbmc= 1237
IHByZXR0eQ== 1238
IHByb2I= 1239
IFNoZQ== 1240
IHZl 1241
IGdldHRpbmc= 1242
IHdlZWs= 1243
IGVmZg== 1244
dWZm 1245
YWly 1246
dWVz 1247
ZXJu 1248
IFE= 1249
b3Vw 1250
ZW50aW9u 1251
IHNpZGU= 1252
0L7QvA== 1253
IGZvcm0= 1254
IGJ1cw== 1255
IGFzcw== 1256
IGVk 1257
YXNvbg== 1258
d2Vlbg== 1259
4oCm 1260
IHR1cm4= 1261
IGN1cg== 1262
IGNvbGw= 1263
IGRpcmU= 1264
IEdvZA== 1265
IDEw 1266
IGVxdQ== 1267
INCx 1268
IG9wZW4= 1269
IHN1Y2g= 1270
aXJk 1271
0LDQug== 1272
IGVhcg== 1273
xJk= 1274
Z2Fu 1275
IHBhcnRpYw== 1276
IGZyaWVuZA== 1277
IGV4cA== 1278
IGV4dA== 1279
IGhvbWU= 1280
IHdhdGVy 1281
IE9u 1282
0YLRjA== 1283
b3Jr 1284
INC/0YA= 1285
IG1vdmU= 1286
bmVzcw== 1287
ZW5zZQ== 1288
aG8= 1289
IGNoYXI= 1290
Y28= 1291
aW5z 1292
IGJvdGg= 1293
IDE5 1294
IGdyYQ== 1295
IGJldHdlZW4= 1296
4bs= 1297
IOyV 1298
YXNo 1299
IFJl 1300
YWk= 1301
YWx0aA== 1302
dXJlcw== 1303
ZW1iZXI= 1304
IGF2 1305
IHZlcg== 1306
w6o= 1307
b25leQ== 1308
IHRoYW5r 1309
IG1heWJl 1310
dWM= 1311
aW1l 1312
6rOg 1313
IGF3YXk= 1314
IG5hbWU= 1315
b3VzZQ== 1316
IGFjYw== 1317
IG11c2lj 1318
IGNoYW5nZQ== 1319
IHBhc3M= 1320
Z2Vy 1321
IGJ1aWxk 1322
IHZhbA== 1323
aW5lc3M= 1324
YW55 1325
IGZldw== 1326
tOs= 1327
dGE= 1328
IGxpc3Q= 1329
w6U= 1330
IG9sZA== 1331
IOye 1332
IHNvcnQ= 1333
IG1lbQ== 1334
IGNh 1335
Y2VwdA== 1336
IGdlbmVy 1337
IHllYWg= 1338
IHdoaWxl 1339
IGFueXRoaW5n 1340
cmlj 1341
Z3JhbQ== 1342
IGVpbg== 1343
Y3k= 1344
dXJpbmc= 1345
IERl 1346
IHBvd2Vy 1347
IGNvbWluZw== 1348
IHdvcmQ= 1349
IC0t 1350
IGJlbGll 1351
IGZvdW5k 1352
dG8= 1353
0L8= 1354
IG1lYW5z 1355
IGluZm9ybQ== 1356
INg= 1357
INGH 1358
IHNtYWxs 1359
MDAw 1360
IGNhbWU= 1361
IO2V 1362
d2g= 1363
IHdvcmtpbmc= 1364
IGV4YW1wbGU= 1365
IHBvcw== 1366
IGRlcA== 1367
6rI= 1368
5Lo= 1369
b3Rl 1370
IGRlbQ== 1371
7Kc= 1372
dHM= 1373
IHZhcg== 1374
YXV0 1375
IHRyaQ== 1376
Y2hu 1377
IGhlYWQ= 1378
IHdob2xl 1379
15k= 1380
emU= 1381
IHRyeWluZw== 1382
IHRlbQ== 1383
IGNvdQ== 1384
ZXRz 1385
IDY= 1386
IGZpbA== 1387
dmVsb3A= 1388
IGNhc2U= 1389
4K8= 1390
IHByb2JhYmx5 1391
IG9rYXk= 1392
IHBsYW4= 1393
IHNpdA== 1394
IHNjaG9vbA== 1395
IFRoZW4= 1396
uOs= 1397
bWU= 1398
IHByb2Nlc3M= 1399
IGZhcg== 1400
IHJlYWQ= 1401
IHBvc3M= 1402
IGJyZQ== 1403
IHNvbA== 1404
aWNodA== 1405
IHN1cHBvcnQ= 1406
IFRv 1407
ZXJ0YWlu 1408
IHN0YXJ0ZWQ= 1409
IGNhcA== 1410
IGxlZnQ= 1411
IGRhdGE= 1412
IHRpbWVz 1413
0LXQuw== 1414
IHdhbnRlZA== 1415
0LDQvQ== 1416
IHRhbGtpbmc= 1417
IGlzdA== 1418
IGhhdmluZw== 1419
dW1w 1420
IGNvbnRpbg== 1421
IHN1Yg== 1422
INC3 1423
cHI= 1424
64uI 1425
aW5h 1426
xbw= 1427
IGNyZWF0 1428
b2Rl 1429
15U= 1430
5pg= 1431
ISE= 1432
IHRlcm0= 1433
aXNt 1434
0L7QtA== 1435
IEJlY2F1c2U= 1436
IHdlbnQ= 1437
aWRlcg== 1438
IHByb3Y= 1439
IGNoaWxk 1440
IGRlbg== 1441
IGxpZ2h0 1442
YnI= 1443
s9C+ 1444
b2g= 1445
IGJvb2s= 1446
INk= 1447
dXRpb24= 1448
IEp1c3Q= 1449
ZW5l 1450
IGZvdXI= 1451
IHZpcw== 1452
6rCA 1453
IGhvcGU= 1454
IG1ha2luZw== 1455
IExl 1456
7JU= 1457
IG9wcA== 1458
YXU= 1459
IG1vbmV5 1460
IHByb2dyYW0= 1461
w6g= 1462
IHN0YW5k 1463
SU4= 1464
IHNpZ24= 1465
IGxlYXJu 1466
w6A= 1467
IERvbg== 1468
IHRlYW0= 1469
INC90LA= 1470
bHVk 1471
IHJlc3Q= 1472
aWNlcw== 1473
5pw= 1474
INGA 1475
IGF1dA== 1476
IGxlYWQ= 1477
YXRpb25hbA== 1478
ZGU= 1479
Z3k= 1480
IG5pY2U= 1481
IGRhcw== 1482
IGRpc3Q= 1483
IGh1bQ== 1484
IE9uZQ== 1485
5og= 1486
IGNvbWVz 1487
IGpv 1488
IGNlbnQ= 1489
IGV4cGw= 1490
IG1hcms= 1491
cmVlbg== 1492
bGVk 1493
Z2lu 1494
7JqU 1495
IGxldmVs 1496
IGNvbmY= 1497
dXNo 1498
IGRldmVsb3A= 1499
IHRlc3Q= 1500
ZW5n 1501
dmlvdXM= 1502
YXR1cmU= 1503
0LXQvA== 1504
cmV0 1505
IGpl 1506
IHN0dWZm 1507
IGNsYXNz 1508
b3dz 1509
IOq3 1510
IHNp 1511
IGxlcw== 1512
cm9w 1513
55o= 1514
IHBvcg== 1515
IHdhcg== 1516
7JeQ 1517
IGV2ZXJ5b25l 1518
IGdl 1519
IGNoZWNr 1520
b3R0 1521
IHNpbmc= 1522
IGFydA== 1523
IGZvbGxvdw== 1524
IDIwMQ== 1525
IEZy 1526
YWlz 1527
7JY= 1528
zrE= 1529
5bA= 1530
IMOg 1531
aW1lcw== 1532
IHJldA== 1533
IGNoYW5n 1534
IHB1Yg== 1535
IGluZg== 1536
IHRlY2hu 1537
YWRh 1538
aXZlcw== 1539
IGJlaA== 1540
IGxvb2tz 1541
44CC 1542
0Lc= 1543
IFdoeQ== 1544
IGVub3VnaA== 1545
IGJyYQ== 1546
aXRjaA== 1547
5Ls= 1548
IGFkdg== 1549
0LE= 1550
IHdpdGhvdXQ= 1551
d2Vy 1552
bWVyaWM= 1553
ZGVu 1554
IGNvbXBsZXQ= 1555
IGlkZWE= 1556
dGVycw== 1557
b2Nr 1558
IGRlZmlu 1559
IGV2ZXI= 1560
IGds 1561
IG9uY2U= 1562
IGJyaW5n 1563
IHNheWluZw== 1564
IGFucw== 1565
IGhlYXI= 1566
bmVjdA== 1567
IGxlc3M= 1568
Z28= 1569
cmVhbQ== 1570
YWRv 1571
7J4= 1572
IG1pbmQ= 1573
ZW50ZQ== 1574
IGZ1bGw= 1575
IGJhZA== 1576
IHdvbQ== 1577
IHNvbWVvbmU= 1578
IGR1 1579
IHdvbg== 1580
IGNvbnRybw== 1581
b3J0dW4= 1582
IGhlYWx0aA== 1583
IGNobw== 1584
IEFy 1585
IGNvbmM= 1586
IGluZm9ybWF0aW9u 1587
IHN0b3A= 1588
YXR0 1589
YXRlbHk= 1590
5L0= 1591
IGdyb3Vw 1592
INGD 1593
IHF1aXRl 1594
IHJlc3A= 1595
RVI= 1596
dWdodA== 1597
6rg= 1598
bWFu 1599
aXplZA== 1600
IEJy 1601
IHJlbWVtYmVy 1602
IGZhbWlseQ== 1603
IGJ1c2luZXNz 1604
YXc= 1605
IHNwZWM= 1606
IGF1 1607
IE9y 1608
xIU= 1609
IHNlZW4= 1610
IGxhcg== 1611
IDc= 1612
Z2c= 1613
YmVycw== 1614
IGRyYQ== 1615
IG1vbnRo 1616
IHNheXM= 1617
IGlzcw== 1618
IGxpdmU= 1619
IGxpbmU= 1620
IG1vbWVudA== 1621
IGV4Yw== 1622
ZWxz 1623
IHNvdW5k 1624
IGNvb2w= 1625
IGxvYw== 1626
IGNlcnRhaW4= 1627
IGRyaQ== 1628
0L7Rgg== 1629
YW1lcw== 1630
IG11c3Q= 1631
bnk= 1632
0LjRgg== 1633
IGtpZA== 1634
IGluY2x1ZA== 1635
7J2E 1636
YXRvcg== 1637
xJ8= 1638
aGE= 1639
YXJlZA== 1640
IHNlZW0= 1641
0Lk= 1642
7IQ= 1643
IGVsc2U= 1644
IOyg 1645
aXJs 1646
IDg= 1647
IHZv 1648
IHF1ZXN0aW9ucw== 1649
aW5lcw== 1650
ZWU= 1651
w7xy 1652
IEFtZXJpYw== 1653
IHN0b3J5 1654
IHNlcnY= 1655
dmVybg== 1656
YWdlcw== 1657
bGFuZA== 1658
IOKAkw== 1659
ZXJh 1660
IENhbg== 1661
IHBvcA== 1662
ZXRoZXI= 1663
IG5h 1664
IG9yZGVy 1665
IG1ha2Vz 1666
IHNpbmNl 1667
Y29u 1668
Y3Rvcg== 1669
IHRob3VnaA== 1670
IHByb2R1Y3Q= 1671
0LvQuA== 1672
IGxlZw== 1673
IG1lZXQ= 1674
YWxm 1675
0YHRjw== 1676
dW5jaA== 1677
aXRlcg== 1678
b3Zl 1679
15XX 1680
aWV0 1681
0LDQvA== 1682
aXRhbA== 1683
IHN1cGVy 1684
bGluZw== 1685
IHBheQ== 1686
IHBhcmE= 1687
IGpvYg== 1688
IEhlcmU= 1689
IHN3 1690
a3M= 1691
cHRpb24= 1692
bWE= 1693
IGJlbGlldmU= 1694
rOs= 1695
IHdhaXQ= 1696
0L7QuQ== 1697
IHVudA== 1698
IHF1aWNr 1699
aHI= 1700
I
Download .txt
gitextract_rz35ckig/

├── .github/
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── README.md
├── __init__.py
├── cosyvoice/
│   ├── __init__.py
│   ├── bin/
│   │   ├── __init__.py
│   │   ├── average_model.py
│   │   ├── export_jit.py
│   │   ├── export_onnx.py
│   │   ├── export_trt.sh
│   │   ├── inference.py
│   │   └── train.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── cosyvoice.py
│   │   ├── frontend.py
│   │   └── model.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── processor.py
│   ├── flow/
│   │   ├── __init__.py
│   │   ├── decoder.py
│   │   ├── flow.py
│   │   ├── flow_matching.py
│   │   └── length_regulator.py
│   ├── hifigan/
│   │   ├── __init__.py
│   │   ├── discriminator.py
│   │   ├── f0_predictor.py
│   │   ├── generator.py
│   │   └── hifigan.py
│   ├── llm/
│   │   ├── __init__.py
│   │   └── llm.py
│   ├── tokenizer/
│   │   ├── __init__.py
│   │   ├── assets/
│   │   │   └── multilingual_zh_ja_yue_char_del.tiktoken
│   │   └── tokenizer.py
│   ├── transformer/
│   │   ├── __init__.py
│   │   ├── activation.py
│   │   ├── attention.py
│   │   ├── convolution.py
│   │   ├── decoder.py
│   │   ├── decoder_layer.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   ├── encoder_layer.py
│   │   ├── label_smoothing_loss.py
│   │   ├── positionwise_feed_forward.py
│   │   ├── subsampling.py
│   │   └── upsample_encoder.py
│   └── utils/
│       ├── __init__.py
│       ├── class_utils.py
│       ├── common.py
│       ├── executor.py
│       ├── file_utils.py
│       ├── frontend_utils.py
│       ├── losses.py
│       ├── mask.py
│       ├── scheduler.py
│       └── train_utils.py
├── downloadmodel.py
├── examples/
│   ├── CrossLingual.json
│   ├── Instruct2.json
│   └── ZeroShot.json
├── pyproject.toml
├── requirements.txt
└── third_party/
    ├── Matcha-TTS/
    │   ├── LICENSE
    │   ├── MANIFEST.in
    │   ├── Makefile
    │   ├── README.md
    │   ├── __init__.py
    │   ├── configs/
    │   │   ├── __init__.py
    │   │   ├── callbacks/
    │   │   │   ├── default.yaml
    │   │   │   ├── model_checkpoint.yaml
    │   │   │   ├── model_summary.yaml
    │   │   │   ├── none.yaml
    │   │   │   └── rich_progress_bar.yaml
    │   │   ├── data/
    │   │   │   ├── hi-fi_en-US_female.yaml
    │   │   │   ├── ljspeech.yaml
    │   │   │   └── vctk.yaml
    │   │   ├── debug/
    │   │   │   ├── default.yaml
    │   │   │   ├── fdr.yaml
    │   │   │   ├── limit.yaml
    │   │   │   ├── overfit.yaml
    │   │   │   └── profiler.yaml
    │   │   ├── eval.yaml
    │   │   ├── experiment/
    │   │   │   ├── hifi_dataset_piper_phonemizer.yaml
    │   │   │   ├── ljspeech.yaml
    │   │   │   ├── ljspeech_min_memory.yaml
    │   │   │   └── multispeaker.yaml
    │   │   ├── extras/
    │   │   │   └── default.yaml
    │   │   ├── hparams_search/
    │   │   │   └── mnist_optuna.yaml
    │   │   ├── hydra/
    │   │   │   └── default.yaml
    │   │   ├── local/
    │   │   │   └── .gitkeep
    │   │   ├── logger/
    │   │   │   ├── aim.yaml
    │   │   │   ├── comet.yaml
    │   │   │   ├── csv.yaml
    │   │   │   ├── many_loggers.yaml
    │   │   │   ├── mlflow.yaml
    │   │   │   ├── neptune.yaml
    │   │   │   ├── tensorboard.yaml
    │   │   │   └── wandb.yaml
    │   │   ├── model/
    │   │   │   ├── cfm/
    │   │   │   │   └── default.yaml
    │   │   │   ├── decoder/
    │   │   │   │   └── default.yaml
    │   │   │   ├── encoder/
    │   │   │   │   └── default.yaml
    │   │   │   ├── matcha.yaml
    │   │   │   └── optimizer/
    │   │   │       └── adam.yaml
    │   │   ├── paths/
    │   │   │   └── default.yaml
    │   │   ├── train.yaml
    │   │   └── trainer/
    │   │       ├── cpu.yaml
    │   │       ├── ddp.yaml
    │   │       ├── ddp_sim.yaml
    │   │       ├── default.yaml
    │   │       ├── gpu.yaml
    │   │       └── mps.yaml
    │   ├── matcha/
    │   │   ├── VERSION
    │   │   ├── __init__.py
    │   │   ├── app.py
    │   │   ├── cli.py
    │   │   ├── data/
    │   │   │   ├── __init__.py
    │   │   │   ├── components/
    │   │   │   │   └── __init__.py
    │   │   │   └── text_mel_datamodule.py
    │   │   ├── hifigan/
    │   │   │   ├── LICENSE
    │   │   │   ├── README.md
    │   │   │   ├── __init__.py
    │   │   │   ├── config.py
    │   │   │   ├── denoiser.py
    │   │   │   ├── env.py
    │   │   │   ├── meldataset.py
    │   │   │   ├── models.py
    │   │   │   └── xutils.py
    │   │   ├── models/
    │   │   │   ├── __init__.py
    │   │   │   ├── baselightningmodule.py
    │   │   │   ├── components/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── decoder.py
    │   │   │   │   ├── flow_matching.py
    │   │   │   │   ├── text_encoder.py
    │   │   │   │   └── transformer.py
    │   │   │   └── matcha_tts.py
    │   │   ├── onnx/
    │   │   │   ├── __init__.py
    │   │   │   ├── export.py
    │   │   │   └── infer.py
    │   │   ├── text/
    │   │   │   ├── __init__.py
    │   │   │   ├── cleaners.py
    │   │   │   ├── numbers.py
    │   │   │   └── symbols.py
    │   │   ├── train.py
    │   │   └── utils/
    │   │       ├── __init__.py
    │   │       ├── audio.py
    │   │       ├── generate_data_statistics.py
    │   │       ├── instantiators.py
    │   │       ├── logging_utils.py
    │   │       ├── model.py
    │   │       ├── monotonic_align/
    │   │       │   ├── __init__.py
    │   │       │   ├── core.pyx
    │   │       │   └── setup.py
    │   │       ├── pylogger.py
    │   │       ├── rich_utils.py
    │   │       └── utils.py
    │   ├── notebooks/
    │   │   └── .gitkeep
    │   ├── pyproject.toml
    │   ├── requirements.txt
    │   ├── scripts/
    │   │   └── schedule.sh
    │   ├── setup.py
    │   └── synthesis.ipynb
    └── __init__.py
Download .txt
SYMBOL INDEX (656 symbols across 71 files)

FILE: __init__.py
  function nt_load_wav (line 14) | def nt_load_wav(speech, sample_rate, target_sr):
  class NTCosyVoiceZeroShotSampler (line 22) | class NTCosyVoiceZeroShotSampler:
    method __init__ (line 23) | def __init__(self):
    method cosyvoice (line 27) | def cosyvoice(self):
    method INPUT_TYPES (line 34) | def INPUT_TYPES(s):
    method main_func (line 49) | def main_func(self, audio, speed, text, prompt_text):
  class NTCosyVoiceCrossLingualSampler (line 65) | class NTCosyVoiceCrossLingualSampler:
    method __init__ (line 66) | def __init__(self):
    method cosyvoice (line 70) | def cosyvoice(self):
    method INPUT_TYPES (line 77) | def INPUT_TYPES(s):
    method main_func (line 91) | def main_func(self, audio, speed, text):
  class NTCosyVoiceInstruct2Sampler (line 108) | class NTCosyVoiceInstruct2Sampler:
    method __init__ (line 109) | def __init__(self):
    method cosyvoice (line 113) | def cosyvoice(self):
    method INPUT_TYPES (line 120) | def INPUT_TYPES(s):
    method main_func (line 135) | def main_func(self, audio, speed, text, instruct):

FILE: cosyvoice/bin/average_model.py
  function get_args (line 24) | def get_args():
  function main (line 43) | def main():

FILE: cosyvoice/bin/export_jit.py
  function get_args (line 29) | def get_args():
  function main (line 40) | def main():

FILE: cosyvoice/bin/export_onnx.py
  function get_dummy_input (line 33) | def get_dummy_input(batch_size, seq_len, out_channels, device):
  function get_args (line 43) | def get_args():
  function main (line 54) | def main():

FILE: cosyvoice/bin/inference.py
  function get_args (line 30) | def get_args():
  function main (line 53) | def main():

FILE: cosyvoice/bin/train.py
  function get_args (line 39) | def get_args():
  function main (line 90) | def main():

FILE: cosyvoice/cli/cosyvoice.py
  class CosyVoice (line 26) | class CosyVoice:
    method __init__ (line 28) | def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
    method list_available_spks (line 59) | def list_available_spks(self):
    method inference_sft (line 63) | def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, tex...
    method inference_zero_shot (line 74) | def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k...
    method inference_cross_lingual (line 88) | def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=...
    method inference_instruct (line 99) | def inference_instruct(self, tts_text, spk_id, instruct_text, stream=F...
    method inference_vc (line 114) | def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=Fa...
  class CosyVoice2 (line 124) | class CosyVoice2(CosyVoice):
    method __init__ (line 126) | def __init__(self, model_dir, load_jit=False, load_onnx=False, load_tr...
    method inference_instruct (line 159) | def inference_instruct(self, *args, **kwargs):
    method inference_instruct2 (line 162) | def inference_instruct2(self, tts_text, instruct_text, prompt_speech_1...

FILE: cosyvoice/cli/frontend.py
  class CosyVoiceFrontEnd (line 37) | class CosyVoiceFrontEnd:
    method __init__ (line 39) | def __init__(self,
    method _extract_text_token (line 73) | def _extract_text_token(self, text):
    method _extract_speech_token (line 79) | def _extract_speech_token(self, speech):
    method _extract_spk_embedding (line 91) | def _extract_spk_embedding(self, speech):
    method _extract_speech_feat (line 102) | def _extract_speech_feat(self, speech):
    method text_normalize (line 108) | def text_normalize(self, text, split=True, text_frontend=True):
    method frontend_sft (line 135) | def frontend_sft(self, tts_text, spk_id):
    method frontend_zero_shot (line 141) | def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k,...
    method frontend_cross_lingual (line 161) | def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample...
    method frontend_instruct (line 170) | def frontend_instruct(self, tts_text, spk_id, instruct_text):
    method frontend_instruct2 (line 179) | def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16...
    method frontend_vc (line 185) | def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_r...

FILE: cosyvoice/cli/model.py
  class CosyVoiceModel (line 24) | class CosyVoiceModel:
    method __init__ (line 26) | def __init__(self,
    method load (line 61) | def load(self, llm_model, flow_model, hift_model):
    method load_jit (line 73) | def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder...
    method load_onnx (line 82) | def load_onnx(self, flow_decoder_estimator_model):
    method llm_job (line 91) | def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embe...
    method token2wav (line 105) | def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid,...
    method tts (line 145) | def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
    method vc (line 208) | def vc(self, source_speech_token, flow_prompt_speech_token, prompt_spe...
  class CosyVoice2Model (line 262) | class CosyVoice2Model:
    method __init__ (line 264) | def __init__(self,
    method load (line 290) | def load(self, llm_model, flow_model, hift_model):
    method load_jit (line 301) | def load_jit(self, flow_encoder_model):
    method load_onnx (line 305) | def load_onnx(self, flow_decoder_estimator_model):
    method load_trt (line 314) | def load_trt(self, flow_decoder_estimator_model):
    method llm_job (line 324) | def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embe...
    method token2wav (line 336) | def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid,...
    method tts (line 370) | def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),

FILE: cosyvoice/dataset/dataset.py
  class Processor (line 27) | class Processor(IterableDataset):
    method __init__ (line 29) | def __init__(self, source, f, *args, **kw):
    method set_epoch (line 36) | def set_epoch(self, epoch):
    method __iter__ (line 39) | def __iter__(self):
    method apply (line 47) | def apply(self, f):
  class DistributedSampler (line 52) | class DistributedSampler:
    method __init__ (line 54) | def __init__(self, shuffle=True, partition=True):
    method update (line 60) | def update(self):
    method set_epoch (line 80) | def set_epoch(self, epoch):
    method sample (line 83) | def sample(self, data):
  class DataList (line 108) | class DataList(IterableDataset):
    method __init__ (line 110) | def __init__(self, lists, shuffle=True, partition=True):
    method set_epoch (line 114) | def set_epoch(self, epoch):
    method __iter__ (line 117) | def __iter__(self):
  function Dataset (line 126) | def Dataset(data_list_file,

FILE: cosyvoice/dataset/processor.py
  function parquet_opener (line 29) | def parquet_opener(data, mode='train', tts_data={}):
  function filter (line 59) | def filter(data,
  function resample (line 111) | def resample(data, resample_rate=22050, min_sample_rate=16000, mode='tra...
  function truncate (line 139) | def truncate(data, truncate_length=24576, mode='train'):
  function compute_fbank (line 160) | def compute_fbank(data,
  function compute_f0 (line 182) | def compute_f0(data, pitch_extractor, mode='train'):
  function parse_embedding (line 203) | def parse_embedding(data, normalize, mode='train'):
  function tokenize (line 221) | def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  function shuffle (line 240) | def shuffle(data, shuffle_size=10000, mode='train'):
  function sort (line 264) | def sort(data, sort_size=500, mode='train'):
  function static_batch (line 292) | def static_batch(data, batch_size=16):
  function dynamic_batch (line 312) | def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  function batch (line 341) | def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=...
  function padding (line 355) | def padding(data, use_spk_embedding, mode='train', gan=False):

FILE: cosyvoice/flow/decoder.py
  class Transpose (line 24) | class Transpose(torch.nn.Module):
    method __init__ (line 25) | def __init__(self, dim0: int, dim1: int):
    method forward (line 30) | def forward(self, x: torch.Tensor):
  class CausalBlock1D (line 35) | class CausalBlock1D(Block1D):
    method __init__ (line 36) | def __init__(self, dim: int, dim_out: int):
    method forward (line 46) | def forward(self, x: torch.Tensor, mask: torch.Tensor):
  class CausalResnetBlock1D (line 51) | class CausalResnetBlock1D(ResnetBlock1D):
    method __init__ (line 52) | def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: ...
  class CausalConv1d (line 58) | class CausalConv1d(torch.nn.Conv1d):
    method __init__ (line 59) | def __init__(
    method forward (line 81) | def forward(self, x: torch.Tensor):
  class ConditionalDecoder (line 87) | class ConditionalDecoder(nn.Module):
    method __init__ (line 88) | def __init__(
    method initialize_weights (line 203) | def initialize_weights(self):
    method forward (line 217) | def forward(self, x, mask, mu, t, spks=None, cond=None):

FILE: cosyvoice/flow/flow.py
  class MaskedDiffWithXvec (line 24) | class MaskedDiffWithXvec(torch.nn.Module):
    method __init__ (line 25) | def __init__(self,
    method forward (line 60) | def forward(
    method inference (line 105) | def inference(self,
  class CausalMaskedDiffWithXvec (line 151) | class CausalMaskedDiffWithXvec(torch.nn.Module):
    method __init__ (line 152) | def __init__(self,
    method inference (line 190) | def inference(self,

FILE: cosyvoice/flow/flow_matching.py
  class ConditionalCFM (line 20) | class ConditionalCFM(BASECFM):
    method __init__ (line 21) | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, ...
    method forward (line 36) | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, c...
    method solve_euler (line 70) | def solve_euler(self, x, t_span, mu, mask, spks, cond):
    method forward_estimator (line 130) | def forward_estimator(self, x, mask, mu, t, spks, cond):
    method compute_loss (line 161) | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  class CausalConditionalCFM (line 203) | class CausalConditionalCFM(ConditionalCFM):
    method __init__ (line 204) | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, ...
    method forward (line 209) | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, c...

FILE: cosyvoice/flow/length_regulator.py
  class InterpolateRegulator (line 21) | class InterpolateRegulator(nn.Module):
    method __init__ (line 22) | def __init__(
    method forward (line 44) | def forward(self, x, ylens=None):
    method inference (line 52) | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):

FILE: cosyvoice/hifigan/discriminator.py
  class MultipleDiscriminator (line 9) | class MultipleDiscriminator(nn.Module):
    method __init__ (line 10) | def __init__(
    method forward (line 17) | def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
  class MultiResolutionDiscriminator (line 32) | class MultiResolutionDiscriminator(nn.Module):
    method __init__ (line 33) | def __init__(
    method forward (line 53) | def forward(
  class DiscriminatorR (line 72) | class DiscriminatorR(nn.Module):
    method __init__ (line 73) | def __init__(
    method spectrogram (line 107) | def spectrogram(self, x):
    method forward (line 119) | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = N...

FILE: cosyvoice/hifigan/f0_predictor.py
  class ConvRNNF0Predictor (line 19) | class ConvRNNF0Predictor(nn.Module):
    method __init__ (line 20) | def __init__(self,
    method forward (line 52) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: cosyvoice/hifigan/generator.py
  class ResBlock (line 43) | class ResBlock(torch.nn.Module):
    method __init__ (line 45) | def __init__(
    method forward (line 91) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method remove_weight_norm (line 100) | def remove_weight_norm(self):
  class SineGen (line 106) | class SineGen(torch.nn.Module):
    method __init__ (line 122) | def __init__(self, samp_rate, harmonic_num=0,
    method _f02uv (line 132) | def _f02uv(self, f0):
    method forward (line 138) | def forward(self, f0):
  class SourceModuleHnNSF (line 171) | class SourceModuleHnNSF(torch.nn.Module):
    method __init__ (line 189) | def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine...
    method forward (line 204) | def forward(self, x):
  class HiFTGenerator (line 223) | class HiFTGenerator(nn.Module):
    method __init__ (line 228) | def __init__(
    method remove_weight_norm (line 319) | def remove_weight_norm(self):
    method _stft (line 333) | def _stft(self, x):
    method _istft (line 341) | def _istft(self, magnitude, phase):
    method decode (line 349) | def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, ...
    method forward (line 383) | def forward(
    method inference (line 400) | def inference(self, speech_feat: torch.Tensor, cache_source: torch.Ten...

FILE: cosyvoice/hifigan/hifigan.py
  class HiFiGan (line 9) | class HiFiGan(nn.Module):
    method __init__ (line 10) | def __init__(self, generator, discriminator, mel_spec_transform,
    method forward (line 22) | def forward(
    method forward_generator (line 32) | def forward_generator(self, batch, device):
    method forward_discriminator (line 53) | def forward_discriminator(self, batch, device):

FILE: cosyvoice/llm/llm.py
  class TransformerLM (line 25) | class TransformerLM(torch.nn.Module):
    method __init__ (line 26) | def __init__(
    method encode (line 71) | def encode(
    method pad_unpad_sequence (line 81) | def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_...
    method forward (line 90) | def forward(
    method sampling_ids (line 140) | def sampling_ids(
    method inference (line 154) | def inference(
  class Qwen2Encoder (line 219) | class Qwen2Encoder(torch.nn.Module):
    method __init__ (line 220) | def __init__(self, pretrain_path):
    method forward_one_step (line 224) | def forward_one_step(self, xs, masks, cache=None):
  class Qwen2LM (line 239) | class Qwen2LM(torch.nn.Module):
    method __init__ (line 240) | def __init__(
    method sampling_ids (line 276) | def sampling_ids(
    method inference (line 294) | def inference(

FILE: cosyvoice/tokenizer/tokenizer.py
  function get_encoding (line 170) | def get_encoding(name: str = "gpt2", num_languages: int = 99):
  function get_tokenizer (line 210) | def get_tokenizer(
  class QwenTokenizer (line 241) | class QwenTokenizer():
    method __init__ (line 242) | def __init__(self, token_path, skip_special_tokens=True):
    method encode (line 263) | def encode(self, text, **kwargs):
    method decode (line 268) | def decode(self, tokens):
  function get_qwen_tokenizer (line 275) | def get_qwen_tokenizer(

FILE: cosyvoice/transformer/activation.py
  class Swish (line 24) | class Swish(torch.nn.Module):
    method forward (line 27) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Snake (line 34) | class Snake(nn.Module):
    method __init__ (line 50) | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha...
    method forward (line 73) | def forward(self, x):

FILE: cosyvoice/transformer/attention.py
  class MultiHeadedAttention (line 26) | class MultiHeadedAttention(nn.Module):
    method __init__ (line 36) | def __init__(self,
    method forward_qkv (line 53) | def forward_qkv(
    method forward_attention (line 82) | def forward_attention(
    method forward (line 129) | def forward(
  class RelPositionMultiHeadedAttention (line 200) | class RelPositionMultiHeadedAttention(MultiHeadedAttention):
    method __init__ (line 209) | def __init__(self,
    method rel_shift (line 225) | def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
    method forward (line 249) | def forward(

FILE: cosyvoice/transformer/convolution.py
  class ConvolutionModule (line 24) | class ConvolutionModule(nn.Module):
    method __init__ (line 27) | def __init__(self,
    method forward (line 90) | def forward(

FILE: cosyvoice/transformer/decoder.py
  class TransformerDecoder (line 33) | class TransformerDecoder(torch.nn.Module):
    method __init__ (line 58) | def __init__(
    method forward (line 116) | def forward(
    method forward_layers (line 169) | def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
    method forward_layers_checkpointed (line 178) | def forward_layers_checkpointed(self, x: torch.Tensor,
    method forward_one_step (line 187) | def forward_one_step(
    method tie_or_clone_weights (line 230) | def tie_or_clone_weights(self, jit_mode: bool = True):
  class BiTransformerDecoder (line 256) | class BiTransformerDecoder(torch.nn.Module):
    method __init__ (line 276) | def __init__(
    method forward (line 332) | def forward(
    method forward_one_step (line 367) | def forward_one_step(
    method tie_or_clone_weights (line 392) | def tie_or_clone_weights(self, jit_mode: bool = True):

FILE: cosyvoice/transformer/decoder_layer.py
  class DecoderLayer (line 22) | class DecoderLayer(nn.Module):
    method __init__ (line 41) | def __init__(
    method forward (line 62) | def forward(

FILE: cosyvoice/transformer/embedding.py
  class PositionalEncoding (line 26) | class PositionalEncoding(torch.nn.Module):
    method __init__ (line 37) | def __init__(self,
    method forward (line 59) | def forward(self,
    method position_encoding (line 79) | def position_encoding(self,
  class RelPositionalEncoding (line 120) | class RelPositionalEncoding(PositionalEncoding):
    method __init__ (line 129) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5...
    method forward (line 133) | def forward(self,
  class WhisperPositionalEncoding (line 150) | class WhisperPositionalEncoding(PositionalEncoding):
    method __init__ (line 154) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1...
  class LearnablePositionalEncoding (line 167) | class LearnablePositionalEncoding(PositionalEncoding):
    method __init__ (line 171) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 4...
  class NoPositionalEncoding (line 178) | class NoPositionalEncoding(torch.nn.Module):
    method __init__ (line 182) | def __init__(self, d_model: int, dropout_rate: float):
    method forward (line 187) | def forward(self,
    method position_encoding (line 196) | def position_encoding(self, offset: Union[int, torch.Tensor],
  class EspnetRelPositionalEncoding (line 201) | class EspnetRelPositionalEncoding(torch.nn.Module):
    method __init__ (line 215) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5...
    method extend_pe (line 224) | def extend_pe(self, x: torch.Tensor):
    method forward (line 256) | def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = ...
    method position_encoding (line 272) | def position_encoding(self,

FILE: cosyvoice/transformer/encoder.py
  class BaseEncoder (line 37) | class BaseEncoder(torch.nn.Module):
    method __init__ (line 39) | def __init__(
    method output_size (line 108) | def output_size(self) -> int:
    method forward (line 111) | def forward(
    method forward_layers (line 165) | def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
    method forward_layers_checkpointed (line 173) | def forward_layers_checkpointed(self, xs: torch.Tensor,
    method forward_chunk (line 184) | def forward_chunk(
    method forward_chunk_by_chunk (line 275) | def forward_chunk_by_chunk(
  class TransformerEncoder (line 338) | class TransformerEncoder(BaseEncoder):
    method __init__ (line 341) | def __init__(
  class ConformerEncoder (line 387) | class ConformerEncoder(BaseEncoder):
    method __init__ (line 390) | def __init__(

FILE: cosyvoice/transformer/encoder_layer.py
  class TransformerEncoderLayer (line 24) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 58) | def forward(
  class ConformerEncoderLayer (line 109) | class ConformerEncoderLayer(nn.Module):
    method __init__ (line 129) | def __init__(
    method forward (line 160) | def forward(

FILE: cosyvoice/transformer/label_smoothing_loss.py
  class LabelSmoothingLoss (line 21) | class LabelSmoothingLoss(nn.Module):
    method __init__ (line 54) | def __init__(self,
    method forward (line 68) | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

FILE: cosyvoice/transformer/positionwise_feed_forward.py
  class PositionwiseFeedForward (line 20) | class PositionwiseFeedForward(torch.nn.Module):
    method __init__ (line 33) | def __init__(
    method forward (line 47) | def forward(self, xs: torch.Tensor) -> torch.Tensor:
  class MoEFFNLayer (line 58) | class MoEFFNLayer(torch.nn.Module):
    method __init__ (line 75) | def __init__(
    method forward (line 91) | def forward(self, xs: torch.Tensor) -> torch.Tensor:

FILE: cosyvoice/transformer/subsampling.py
  class BaseSubsampling (line 23) | class BaseSubsampling(torch.nn.Module):
    method __init__ (line 25) | def __init__(self):
    method position_encoding (line 30) | def position_encoding(self, offset: Union[int, torch.Tensor],
  class EmbedinigNoSubsampling (line 35) | class EmbedinigNoSubsampling(BaseSubsampling):
    method __init__ (line 39) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 45) | def forward(
  class LinearNoSubsampling (line 69) | class LinearNoSubsampling(BaseSubsampling):
    method __init__ (line 79) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 92) | def forward(
  class Conv1dSubsampling2 (line 116) | class Conv1dSubsampling2(BaseSubsampling):
    method __init__ (line 128) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 145) | def forward(
  class Conv2dSubsampling4 (line 173) | class Conv2dSubsampling4(BaseSubsampling):
    method __init__ (line 183) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 202) | def forward(
  class Conv2dSubsampling6 (line 230) | class Conv2dSubsampling6(BaseSubsampling):
    method __init__ (line 239) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 256) | def forward(
  class Conv2dSubsampling8 (line 282) | class Conv2dSubsampling8(BaseSubsampling):
    method __init__ (line 292) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 311) | def forward(
  class LegacyLinearNoSubsampling (line 338) | class LegacyLinearNoSubsampling(BaseSubsampling):
    method __init__ (line 348) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 362) | def forward(

FILE: cosyvoice/transformer/upsample_encoder.py
  class Upsample1D (line 37) | class Upsample1D(nn.Module):
    method __init__ (line 51) | def __init__(self, channels: int, out_channels: int, stride: int = 2):
    method forward (line 59) | def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
  class PreLookaheadLayer (line 66) | class PreLookaheadLayer(nn.Module):
    method __init__ (line 67) | def __init__(self, channels: int, pre_lookahead_len: int = 1):
    method forward (line 81) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  class UpsampleConformerEncoder (line 99) | class UpsampleConformerEncoder(torch.nn.Module):
    method __init__ (line 101) | def __init__(
    method output_size (line 234) | def output_size(self) -> int:
    method forward (line 237) | def forward(
    method forward_layers (line 306) | def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
    method forward_up_layers (line 313) | def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,

FILE: cosyvoice/utils/class_utils.py
  function get_model_type (line 77) | def get_model_type(configs):

FILE: cosyvoice/utils/common.py
  function pad_list (line 27) | def pad_list(xs: List[torch.Tensor], pad_value: int):
  function th_accuracy (line 76) | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
  function get_padding (line 98) | def get_padding(kernel_size, dilation=1):
  function init_weights (line 102) | def init_weights(m, mean=0.0, std=0.01):
  function ras_sampling (line 109) | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, t...
  function nucleus_sampling (line 117) | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
  function random_sampling (line 135) | def random_sampling(weighted_scores, decoded_tokens, sampling):
  function fade_in_out (line 140) | def fade_in_out(fade_in_mel, fade_out_mel, window):
  function set_all_random_seed (line 151) | def set_all_random_seed(seed):
  function mask_to_bias (line 158) | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:

FILE: cosyvoice/utils/executor.py
  class Executor (line 26) | class Executor:
    method __init__ (line 28) | def __init__(self, gan: bool = False):
    method train_one_epoc (line 35) | def train_one_epoc(self, model, optimizer, scheduler, train_data_loade...
    method train_one_epoc_gan (line 84) | def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d,...
    method cv (line 143) | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=Tr...

FILE: cosyvoice/utils/file_utils.py
  function read_lists (line 24) | def read_lists(list_file):
  function read_json_lists (line 32) | def read_json_lists(list_file):
  function load_wav (line 41) | def load_wav(wav, target_sr):

FILE: cosyvoice/utils/frontend_utils.py
  function contains_chinese (line 21) | def contains_chinese(text):
  function replace_corner_mark (line 26) | def replace_corner_mark(text):
  function remove_bracket (line 33) | def remove_bracket(text):
  function spell_out_number (line 42) | def spell_out_number(text: str, inflect_parser):
  function split_paragraph (line 65) | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, toke...
  function replace_blank (line 121) | def replace_blank(text: str):
  function is_only_punctuation (line 133) | def is_only_punctuation(text):

FILE: cosyvoice/utils/losses.py
  function tpr_loss (line 5) | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
  function mel_loss (line 14) | def mel_loss(real_speech, generated_speech, mel_transforms):

FILE: cosyvoice/utils/mask.py
  function subsequent_mask (line 53) | def subsequent_mask(
  function subsequent_chunk_mask (line 89) | def subsequent_chunk_mask(
  function add_optional_chunk_mask (line 127) | def add_optional_chunk_mask(xs: torch.Tensor,
  function make_pad_mask (line 201) | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:

FILE: cosyvoice/utils/scheduler.py
  class WarmupLR (line 27) | class WarmupLR(_LRScheduler):
    method __init__ (line 44) | def __init__(
    method __repr__ (line 56) | def __repr__(self):
    method get_lr (line 59) | def get_lr(self):
    method set_step (line 70) | def set_step(self, step: int):
  class WarmupPolicy (line 74) | class WarmupPolicy(_LRScheduler):
    method __init__ (line 84) | def __init__(self,
    method get_lr (line 110) | def get_lr(self):
    method _get_warmup_lr (line 128) | def _get_warmup_lr(self, step):
    method _get_lr (line 132) | def _get_lr(self, step):
  class SquareRootConstantPolicy (line 137) | class SquareRootConstantPolicy(_LRScheduler):
    method __init__ (line 147) | def __init__(self,
    method get_lr (line 175) | def get_lr(self):
    method _get_lr (line 193) | def _get_lr(self, step):
  class WarmupHoldPolicy (line 198) | class WarmupHoldPolicy(WarmupPolicy):
    method __init__ (line 212) | def __init__(
    method get_lr (line 257) | def get_lr(self):
  class WarmupAnnealHoldPolicy (line 282) | class WarmupAnnealHoldPolicy(_LRScheduler):
    method __init__ (line 295) | def __init__(
    method get_lr (line 340) | def get_lr(self):
    method _get_warmup_lr (line 365) | def _get_warmup_lr(self, step):
    method _get_constant_lr (line 369) | def _get_constant_lr(self, step):
    method _get_lr (line 372) | def _get_lr(self, step):
  function _squareroot_annealing (line 377) | def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
  function _square_annealing (line 384) | def _square_annealing(initial_lr, step, max_steps, min_lr):
  function _cosine_annealing (line 391) | def _cosine_annealing(initial_lr, step, max_steps, min_lr):
  function _linear_warmup_with_cosine_annealing (line 397) | def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
  function _poly_decay (line 421) | def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
  function _noam_hold_annealing (line 433) | def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
  class SquareAnnealing (line 444) | class SquareAnnealing(WarmupPolicy):
    method __init__ (line 446) | def __init__(self,
    method _get_lr (line 459) | def _get_lr(self, step):
  class SquareRootAnnealing (line 471) | class SquareRootAnnealing(WarmupPolicy):
    method __init__ (line 473) | def __init__(self,
    method _get_lr (line 486) | def _get_lr(self, step):
  class CosineAnnealing (line 497) | class CosineAnnealing(WarmupAnnealHoldPolicy):
    method __init__ (line 499) | def __init__(self,
    method _get_lr (line 512) | def _get_lr(self, step):
    method _get_warmup_lr (line 532) | def _get_warmup_lr(self, step):
    method _get_constant_lr (line 539) | def _get_constant_lr(self, step):
    method _get_linear_warmup_with_cosine_annealing_lr (line 543) | def _get_linear_warmup_with_cosine_annealing_lr(self, step):
  class NoamAnnealing (line 558) | class NoamAnnealing(_LRScheduler):
    method __init__ (line 560) | def __init__(self,
    method get_lr (line 588) | def get_lr(self):
    method _noam_annealing (line 610) | def _noam_annealing(self, initial_lr, step):
  class NoamHoldAnnealing (line 623) | class NoamHoldAnnealing(WarmupHoldPolicy):
    method __init__ (line 625) | def __init__(self,
    method _get_lr (line 693) | def _get_lr(self, step):
    method set_step (line 715) | def set_step(self, step: int):
  class ConstantLR (line 719) | class ConstantLR(_LRScheduler):
    method __init__ (line 726) | def __init__(
    method get_lr (line 734) | def get_lr(self):
    method set_step (line 737) | def set_step(self, step: int):

FILE: cosyvoice/utils/train_utils.py
  function init_distributed (line 39) | def init_distributed(args):
  function init_dataset_and_dataloader (line 53) | def init_dataset_and_dataloader(args, configs, gan):
  function check_modify_and_save_config (line 72) | def check_modify_and_save_config(args, configs):
  function wrap_cuda_model (line 94) | def wrap_cuda_model(args, model):
  function init_optimizer_and_scheduler (line 111) | def init_optimizer_and_scheduler(args, configs, model, gan):
  function init_summarywriter (line 187) | def init_summarywriter(args):
  function save_model (line 195) | def save_model(model, model_name, info_dict):
  function cosyvoice_join (line 217) | def cosyvoice_join(group_join, info_dict):
  function batch_forward (line 238) | def batch_forward(model, batch, scaler, info_dict):
  function batch_backward (line 259) | def batch_backward(model, scaler, info_dict):
  function update_parameter_and_lr (line 273) | def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_di...
  function log_per_step (line 301) | def log_per_step(writer, info_dict):
  function log_per_save (line 330) | def log_per_save(writer, info_dict):

FILE: third_party/Matcha-TTS/matcha/app.py
  function MATCHA_TTS_LOC (line 33) | def MATCHA_TTS_LOC(x):
  function VOCODER_LOC (line 37) | def VOCODER_LOC(x):
  function load_model (line 66) | def load_model(model_name, vocoder_name):
  function load_model_ui (line 72) | def load_model_ui(model_type, textbox):
  function process_text_gradio (line 102) | def process_text_gradio(text):
  function synthesise_mel (line 108) | def synthesise_mel(text, text_length, n_timesteps, temperature, length_s...
  function multispeaker_example_cacher (line 125) | def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scal...
  function ljspeech_example_cacher (line 137) | def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, s...
  function main (line 149) | def main():

FILE: third_party/Matcha-TTS/matcha/cli.py
  function plot_spectrogram_to_numpy (line 37) | def plot_spectrogram_to_numpy(spectrogram, filename):
  function process_text (line 48) | def process_text(i: int, text: str, device: torch.device):
  function get_texts (line 62) | def get_texts(args):
  function assert_required_models_available (line 71) | def assert_required_models_available(args):
  function load_hifigan (line 84) | def load_hifigan(checkpoint_path, device):
  function load_vocoder (line 93) | def load_vocoder(vocoder_name, checkpoint_path, device):
  function load_matcha (line 108) | def load_matcha(model_name, checkpoint_path, device):
  function to_waveform (line 117) | def to_waveform(mel, vocoder, denoiser=None):
  function save_to_folder (line 125) | def save_to_folder(filename: str, output: dict, folder: str):
  function validate_args (line 134) | def validate_args(args):
  function validate_args_for_multispeaker_model (line 163) | def validate_args_for_multispeaker_model(args):
  function validate_args_for_single_speaker_model (line 188) | def validate_args_for_single_speaker_model(args):
  function cli (line 208) | def cli():
  class BatchedSynthesisDataset (line 292) | class BatchedSynthesisDataset(torch.utils.data.Dataset):
    method __init__ (line 293) | def __init__(self, processed_texts):
    method __len__ (line 296) | def __len__(self):
    method __getitem__ (line 299) | def __getitem__(self, idx):
  function batched_collate_fn (line 303) | def batched_collate_fn(batch):
  function batched_synthesis (line 316) | def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
  function unbatched_synthesis (line 358) | def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, s...
  function print_config (line 397) | def print_config(args):
  function get_device (line 407) | def get_device(args):

FILE: third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py
  function parse_filelist (line 15) | def parse_filelist(filelist_path, split_char="|"):
  class TextMelDataModule (line 21) | class TextMelDataModule(LightningDataModule):
    method __init__ (line 22) | def __init__(  # pylint: disable=unused-argument
    method setup (line 49) | def setup(self, stage: Optional[str] = None):  # pylint: disable=unuse...
    method train_dataloader (line 88) | def train_dataloader(self):
    method val_dataloader (line 98) | def val_dataloader(self):
    method teardown (line 108) | def teardown(self, stage: Optional[str] = None):
    method state_dict (line 112) | def state_dict(self):  # pylint: disable=no-self-use
    method load_state_dict (line 116) | def load_state_dict(self, state_dict: Dict[str, Any]):
  class TextMelDataset (line 121) | class TextMelDataset(torch.utils.data.Dataset):
    method __init__ (line 122) | def __init__(
    method get_datapoint (line 156) | def get_datapoint(self, filepath_and_text):
    method get_mel (line 172) | def get_mel(self, filepath):
    method get_text (line 189) | def get_text(self, text, add_blank=True):
    method __getitem__ (line 196) | def __getitem__(self, index):
    method __len__ (line 200) | def __len__(self):
  class TextMelBatchCollate (line 204) | class TextMelBatchCollate:
    method __init__ (line 205) | def __init__(self, n_spks):
    method __call__ (line 208) | def __call__(self, batch):

FILE: third_party/Matcha-TTS/matcha/hifigan/denoiser.py
  class Denoiser (line 7) | class Denoiser(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_lengt...
    method forward (line 59) | def forward(self, audio, strength=0.0005):

FILE: third_party/Matcha-TTS/matcha/hifigan/env.py
  class AttrDict (line 7) | class AttrDict(dict):
    method __init__ (line 8) | def __init__(self, *args, **kwargs):
  function build_env (line 13) | def build_env(config, config_name, path):

FILE: third_party/Matcha-TTS/matcha/hifigan/meldataset.py
  function load_wav (line 17) | def load_wav(full_path):
  function dynamic_range_compression (line 22) | def dynamic_range_compression(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression (line 26) | def dynamic_range_decompression(x, C=1):
  function dynamic_range_compression_torch (line 30) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression_torch (line 34) | def dynamic_range_decompression_torch(x, C=1):
  function spectral_normalize_torch (line 38) | def spectral_normalize_torch(magnitudes):
  function spectral_de_normalize_torch (line 43) | def spectral_de_normalize_torch(magnitudes):
  function mel_spectrogram (line 52) | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_siz...
  function get_dataset_filelist (line 92) | def get_dataset_filelist(a):
  class MelDataset (line 105) | class MelDataset(torch.utils.data.Dataset):
    method __init__ (line 106) | def __init__(
    method __getitem__ (line 146) | def __getitem__(self, index):
    method __len__ (line 216) | def __len__(self):

FILE: third_party/Matcha-TTS/matcha/hifigan/models.py
  class ResBlock1 (line 14) | class ResBlock1(torch.nn.Module):
    method __init__ (line 15) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
    method forward (line 90) | def forward(self, x):
    method remove_weight_norm (line 99) | def remove_weight_norm(self):
  class ResBlock2 (line 106) | class ResBlock2(torch.nn.Module):
    method __init__ (line 107) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
    method forward (line 136) | def forward(self, x):
    method remove_weight_norm (line 143) | def remove_weight_norm(self):
  class Generator (line 148) | class Generator(torch.nn.Module):
    method __init__ (line 149) | def __init__(self, h):
    method forward (line 181) | def forward(self, x):
    method remove_weight_norm (line 199) | def remove_weight_norm(self):
  class DiscriminatorP (line 209) | class DiscriminatorP(torch.nn.Module):
    method __init__ (line 210) | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=...
    method forward (line 225) | def forward(self, x):
  class MultiPeriodDiscriminator (line 247) | class MultiPeriodDiscriminator(torch.nn.Module):
    method __init__ (line 248) | def __init__(self):
    method forward (line 260) | def forward(self, y, y_hat):
  class DiscriminatorS (line 276) | class DiscriminatorS(torch.nn.Module):
    method __init__ (line 277) | def __init__(self, use_spectral_norm=False):
    method forward (line 293) | def forward(self, x):
  class MultiScaleDiscriminator (line 306) | class MultiScaleDiscriminator(torch.nn.Module):
    method __init__ (line 307) | def __init__(self):
    method forward (line 318) | def forward(self, y, y_hat):
  function feature_loss (line 337) | def feature_loss(fmap_r, fmap_g):
  function discriminator_loss (line 346) | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
  function generator_loss (line 360) | def generator_loss(disc_outputs):

FILE: third_party/Matcha-TTS/matcha/hifigan/xutils.py
  function plot_spectrogram (line 14) | def plot_spectrogram(spectrogram):
  function init_weights (line 25) | def init_weights(m, mean=0.0, std=0.01):
  function apply_weight_norm (line 31) | def apply_weight_norm(m):
  function get_padding (line 37) | def get_padding(kernel_size, dilation=1):
  function load_checkpoint (line 41) | def load_checkpoint(filepath, device):
  function save_checkpoint (line 49) | def save_checkpoint(filepath, obj):
  function scan_checkpoint (line 55) | def scan_checkpoint(cp_dir, prefix):

FILE: third_party/Matcha-TTS/matcha/models/baselightningmodule.py
  class BaseLightningClass (line 19) | class BaseLightningClass(LightningModule, ABC):
    method update_data_statistics (line 20) | def update_data_statistics(self, data_statistics):
    method configure_optimizers (line 30) | def configure_optimizers(self) -> Any:
    method get_losses (line 56) | def get_losses(self, batch):
    method on_load_checkpoint (line 75) | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
    method training_step (line 78) | def training_step(self, batch: Any, batch_idx: int):
    method validation_step (line 127) | def validation_step(self, batch: Any, batch_idx: int):
    method on_validation_end (line 167) | def on_validation_end(self) -> None:
    method on_before_optimizer_step (line 208) | def on_before_optimizer_step(self, optimizer):

FILE: third_party/Matcha-TTS/matcha/models/components/decoder.py
  class SinusoidalPosEmb (line 14) | class SinusoidalPosEmb(torch.nn.Module):
    method __init__ (line 15) | def __init__(self, dim):
    method forward (line 20) | def forward(self, x, scale=1000):
  class Block1D (line 32) | class Block1D(torch.nn.Module):
    method __init__ (line 33) | def __init__(self, dim, dim_out, groups=8):
    method forward (line 41) | def forward(self, x, mask):
  class ResnetBlock1D (line 46) | class ResnetBlock1D(torch.nn.Module):
    method __init__ (line 47) | def __init__(self, dim, dim_out, time_emb_dim, groups=8):
    method forward (line 56) | def forward(self, x, mask, time_emb):
  class Downsample1D (line 64) | class Downsample1D(nn.Module):
    method __init__ (line 65) | def __init__(self, dim):
    method forward (line 69) | def forward(self, x):
  class TimestepEmbedding (line 73) | class TimestepEmbedding(nn.Module):
    method __init__ (line 74) | def __init__(
    method forward (line 105) | def forward(self, sample, condition=None):
  class Upsample1D (line 120) | class Upsample1D(nn.Module):
    method __init__ (line 134) | def __init__(self, channels, use_conv=False, use_conv_transpose=True, ...
    method forward (line 148) | def forward(self, inputs):
  class ConformerWrapper (line 161) | class ConformerWrapper(ConformerBlock):
    method __init__ (line 162) | def __init__(  # pylint: disable=useless-super-delegation
    method forward (line 189) | def forward(
  class Decoder (line 200) | class Decoder(nn.Module):
    method __init__ (line 201) | def __init__(
    method get_block (line 319) | def get_block(block_type, dim, attention_head_dim, num_heads, dropout,...
    method initialize_weights (line 345) | def initialize_weights(self):
    method forward (line 363) | def forward(self, x, mask, mu, t, spks=None, cond=None):

FILE: third_party/Matcha-TTS/matcha/models/components/flow_matching.py
  class BASECFM (line 12) | class BASECFM(torch.nn.Module, ABC):
    method __init__ (line 13) | def __init__(
    method forward (line 33) | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, c...
    method solve_euler (line 55) | def solve_euler(self, x, t_span, mu, mask, spks, cond):
    method compute_loss (line 87) | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  class CFM (line 121) | class CFM(BASECFM):
    method __init__ (line 122) | def __init__(self, in_channels, out_channel, cfm_params, decoder_param...

FILE: third_party/Matcha-TTS/matcha/models/components/text_encoder.py
  class LayerNorm (line 15) | class LayerNorm(nn.Module):
    method __init__ (line 16) | def __init__(self, channels, eps=1e-4):
    method forward (line 24) | def forward(self, x):
  class ConvReluNorm (line 36) | class ConvReluNorm(nn.Module):
    method __init__ (line 37) | def __init__(self, in_channels, hidden_channels, out_channels, kernel_...
    method forward (line 60) | def forward(self, x, x_mask):
  class DurationPredictor (line 70) | class DurationPredictor(nn.Module):
    method __init__ (line 71) | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
    method forward (line 84) | def forward(self, x, x_mask):
  class RotaryPositionalEmbeddings (line 97) | class RotaryPositionalEmbeddings(nn.Module):
    method __init__ (line 107) | def __init__(self, d: int, base: int = 10_000):
    method _build_cache (line 119) | def _build_cache(self, x: torch.Tensor):
    method _neg_half (line 147) | def _neg_half(self, x: torch.Tensor):
    method forward (line 154) | def forward(self, x: torch.Tensor):
  class MultiHeadAttention (line 175) | class MultiHeadAttention(nn.Module):
    method __init__ (line 176) | def __init__(
    method forward (line 216) | def forward(self, x, c, attn_mask=None):
    method attention (line 226) | def attention(self, query, key, value, mask=None):
    method _attention_bias_proximal (line 249) | def _attention_bias_proximal(length):
  class FFN (line 255) | class FFN(nn.Module):
    method __init__ (line 256) | def __init__(self, in_channels, out_channels, filter_channels, kernel_...
    method forward (line 268) | def forward(self, x, x_mask):
  class Encoder (line 276) | class Encoder(nn.Module):
    method __init__ (line 277) | def __init__(
    method forward (line 314) | def forward(self, x, x_mask):
  class TextEncoder (line 328) | class TextEncoder(nn.Module):
    method __init__ (line 329) | def __init__(
    method forward (line 378) | def forward(self, x, x_lengths, spks=None):

FILE: third_party/Matcha-TTS/matcha/models/components/transformer.py
  class SnakeBeta (line 17) | class SnakeBeta(nn.Module):
    method __init__ (line 35) | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainab...
    method forward (line 64) | def forward(self, x):
  class FeedForward (line 83) | class FeedForward(nn.Module):
    method __init__ (line 96) | def __init__(
    method forward (line 131) | def forward(self, hidden_states):
  class BasicTransformerBlock (line 138) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 159) | def __init__(
    method set_chunk_feed_forward (line 238) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
    method forward (line 243) | def forward(

FILE: third_party/Matcha-TTS/matcha/models/matcha_tts.py
  class MatchaTTS (line 23) | class MatchaTTS(BaseLightningClass):  # 🍵
    method __init__ (line 24) | def __init__(
    method synthesise (line 74) | def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=...
    method forward (line 150) | def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None...

FILE: third_party/Matcha-TTS/matcha/onnx/export.py
  class MatchaWithVocoder (line 22) | class MatchaWithVocoder(LightningModule):
    method __init__ (line 23) | def __init__(self, matcha, vocoder):
    method forward (line 28) | def forward(self, x, x_lengths, scales, spks=None):
  function get_exportable_module (line 35) | def get_exportable_module(matcha, vocoder, n_timesteps):
  function get_inputs (line 63) | def get_inputs(is_multi_speaker):
  function main (line 91) | def main():

FILE: third_party/Matcha-TTS/matcha/onnx/infer.py
  function validate_args (line 15) | def validate_args(args):
  function write_wavs (line 24) | def write_wavs(model, inputs, output_dir, external_vocoder=None):
  function write_mels (line 66) | def write_mels(model, inputs, output_dir):
  function main (line 85) | def main():

FILE: third_party/Matcha-TTS/matcha/text/__init__.py
  function text_to_sequence (line 10) | def text_to_sequence(text, cleaner_names):
  function cleaned_text_to_sequence (line 27) | def cleaned_text_to_sequence(cleaned_text):
  function sequence_to_text (line 38) | def sequence_to_text(sequence):
  function _clean_text (line 47) | def _clean_text(text, cleaner_names):

FILE: third_party/Matcha-TTS/matcha/text/cleaners.py
  function expand_abbreviations (line 66) | def expand_abbreviations(text):
  function lowercase (line 72) | def lowercase(text):
  function collapse_whitespace (line 76) | def collapse_whitespace(text):
  function convert_to_ascii (line 80) | def convert_to_ascii(text):
  function basic_cleaners (line 84) | def basic_cleaners(text):
  function transliteration_cleaners (line 91) | def transliteration_cleaners(text):
  function english_cleaners2 (line 99) | def english_cleaners2(text):
  function english_cleaners_piper (line 109) | def english_cleaners_piper(text):

FILE: third_party/Matcha-TTS/matcha/text/numbers.py
  function _remove_commas (line 16) | def _remove_commas(m):
  function _expand_decimal_point (line 20) | def _expand_decimal_point(m):
  function _expand_dollars (line 24) | def _expand_dollars(m):
  function _expand_ordinal (line 45) | def _expand_ordinal(m):
  function _expand_number (line 49) | def _expand_number(m):
  function normalize_numbers (line 64) | def normalize_numbers(text):

FILE: third_party/Matcha-TTS/matcha/train.py
  function train (line 35) | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  function main (line 101) | def main(cfg: DictConfig) -> Optional[float]:

FILE: third_party/Matcha-TTS/matcha/utils/audio.py
  function load_wav (line 10) | def load_wav(full_path):
  function dynamic_range_compression (line 15) | def dynamic_range_compression(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression (line 19) | def dynamic_range_decompression(x, C=1):
  function dynamic_range_compression_torch (line 23) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression_torch (line 27) | def dynamic_range_decompression_torch(x, C=1):
  function spectral_normalize_torch (line 31) | def spectral_normalize_torch(magnitudes):
  function spectral_de_normalize_torch (line 36) | def spectral_de_normalize_torch(magnitudes):
  function mel_spectrogram (line 45) | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_siz...

FILE: third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py
  function compute_data_statistics (line 25) | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, ou...
  function main (line 50) | def main():

FILE: third_party/Matcha-TTS/matcha/utils/instantiators.py
  function instantiate_callbacks (line 13) | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
  function instantiate_loggers (line 36) | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:

FILE: third_party/Matcha-TTS/matcha/utils/logging_utils.py
  function log_hyperparameters (line 12) | def log_hyperparameters(object_dict: Dict[str, Any]) -> None:

FILE: third_party/Matcha-TTS/matcha/utils/model.py
  function sequence_mask (line 7) | def sequence_mask(length, max_length=None):
  function fix_len_compatibility (line 14) | def fix_len_compatibility(length, num_downsamplings_in_unet=2):
  function convert_pad_shape (line 23) | def convert_pad_shape(pad_shape):
  function generate_path (line 29) | def generate_path(duration, mask):
  function duration_loss (line 44) | def duration_loss(logw, logw_, lengths):
  function normalize (line 49) | def normalize(data, mu, std):
  function denormalize (line 71) | def denormalize(data, mu, std):

FILE: third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
  function maximum_path (line 7) | def maximum_path(value, mask):

FILE: third_party/Matcha-TTS/matcha/utils/pylogger.py
  function get_pylogger (line 6) | def get_pylogger(name: str = __name__) -> logging.Logger:

FILE: third_party/Matcha-TTS/matcha/utils/rich_utils.py
  function print_config_tree (line 18) | def print_config_tree(
  function enforce_tags (line 80) | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:

FILE: third_party/Matcha-TTS/matcha/utils/utils.py
  function extras (line 20) | def extras(cfg: DictConfig) -> None:
  function task_wrapper (line 51) | def task_wrapper(task_func: Callable) -> Callable:
  function get_metric_value (line 106) | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> f...
  function intersperse (line 130) | def intersperse(lst, item):
  function save_figure_to_numpy (line 137) | def save_figure_to_numpy(fig):
  function plot_tensor (line 143) | def plot_tensor(tensor):
  function save_plot (line 155) | def save_plot(tensor, savepath):
  function to_numpy (line 166) | def to_numpy(tensor):
  function get_user_data_dir (line 177) | def get_user_data_dir(appname="matcha_tts"):
  function assert_model_downloaded (line 208) | def assert_model_downloaded(checkpoint_path, url, use_wget=True):
Condensed preview — 163 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,192K chars).
[
  {
    "path": ".github/workflows/publish.yml",
    "chars": 676,
    "preview": "name: Publish to Comfy registry\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n      - master\n    paths:\n  "
  },
  {
    "path": ".gitignore",
    "chars": 89,
    "preview": "/idea\n/__pycache__\n**/__pycache__\n**/**/__pycache__\n/pretrained_models\n.vscode\n.vs\n.idea\n"
  },
  {
    "path": "README.md",
    "chars": 689,
    "preview": "# CosyVoice2 for ComfyUI\nComfyUI_NTCosyVoice is a plugin of ComfyUI for Cosysvoice2\n## install plugin\n```angular2html\ngi"
  },
  {
    "path": "__init__.py",
    "chars": 5517,
    "preview": "import sys\nimport os\nnor_dir = os.path.dirname(__file__)\nMatcha_path = os.path.join(nor_dir, 'third_party/Matcha-TTS')\ns"
  },
  {
    "path": "cosyvoice/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/bin/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/bin/average_model.py",
    "chars": 3143,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Di Wu)\n# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apa"
  },
  {
    "path": "cosyvoice/bin/export_jit.py",
    "chars": 2628,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/bin/export_onnx.py",
    "chars": 4560,
    "preview": "# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)\n# Copyright (c) 2024 Alibaba Inc (authors:"
  },
  {
    "path": "cosyvoice/bin/export_trt.sh",
    "chars": 931,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n# download tensorrt from https://developer.nvidia.com/ten"
  },
  {
    "path": "cosyvoice/bin/inference.py",
    "chars": 5471,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/bin/train.py",
    "chars": 6806,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/cli/cosyvoice.py",
    "chars": 10424,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/frontend.py",
    "chars": 11234,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/model.py",
    "chars": 26497,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/dataset/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/dataset/dataset.py",
    "chars": 5403,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licen"
  },
  {
    "path": "cosyvoice/dataset/processor.py",
    "chars": 15416,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/flow/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/flow/decoder.py",
    "chars": 12311,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/flow/flow.py",
    "chars": 10252,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/flow/flow_matching.py",
    "chars": 10647,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/flow/length_regulator.py",
    "chars": 3060,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/hifigan/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/hifigan/discriminator.py",
    "chars": 5341,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn.utils import weight_norm\nfrom typing import List, Optional, Tuple\nfrom "
  },
  {
    "path": "cosyvoice/hifigan/f0_predictor.py",
    "chars": 1976,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "cosyvoice/hifigan/generator.py",
    "chars": 15498,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "cosyvoice/hifigan/hifigan.py",
    "chars": 3231,
    "preview": "from typing import Dict, Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom matcha.hifigan"
  },
  {
    "path": "cosyvoice/llm/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/llm/llm.py",
    "chars": 14677,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/tokenizer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken",
    "chars": 907395,
    "preview": "IQ== 0\nIg== 1\nIw== 2\nJA== 3\nJQ== 4\nJg== 5\nJw== 6\nKA== 7\nKQ== 8\nKg== 9\nKw== 10\nLA== 11\nLQ== 12\nLg== 13\nLw== 14\nMA== 15\nMQ"
  },
  {
    "path": "cosyvoice/tokenizer/tokenizer.py",
    "chars": 7456,
    "preview": "import base64\nimport os\nfrom functools import lru_cache\nfrom typing import Optional\nimport torch\nfrom transformers impor"
  },
  {
    "path": "cosyvoice/transformer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/transformer/activation.py",
    "chars": 3087,
    "preview": "# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)\n#               2020 Northwestern Polytechnical Universi"
  },
  {
    "path": "cosyvoice/transformer/attention.py",
    "chars": 14389,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#               2022 Xingchen Song (s"
  },
  {
    "path": "cosyvoice/transformer/convolution.py",
    "chars": 5230,
    "preview": "# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/decoder.py",
    "chars": 16580,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/decoder_layer.py",
    "chars": 4807,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/embedding.py",
    "chars": 11399,
    "preview": "# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/encoder.py",
    "chars": 21434,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/transformer/encoder_layer.py",
    "chars": 9596,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/transformer/label_smoothing_loss.py",
    "chars": 3459,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/positionwise_feed_forward.py",
    "chars": 4219,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/subsampling.py",
    "chars": 12666,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# Licensed under th"
  },
  {
    "path": "cosyvoice/transformer/upsample_encoder.py",
    "chars": 13766,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/utils/class_utils.py",
    "chars": 3270,
    "preview": "# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>\n#            2024 Alibaba Inc (authors: Xiang Lyu)"
  },
  {
    "path": "cosyvoice/utils/common.py",
    "chars": 5834,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under "
  },
  {
    "path": "cosyvoice/utils/executor.py",
    "chars": 8559,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under "
  },
  {
    "path": "cosyvoice/utils/file_utils.py",
    "chars": 1661,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licen"
  },
  {
    "path": "cosyvoice/utils/frontend_utils.py",
    "chars": 4231,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/utils/losses.py",
    "chars": 607,
    "preview": "import torch\nimport torch.nn.functional as F\n\n\ndef tpr_loss(disc_real_outputs, disc_generated_outputs, tau):\n    loss = "
  },
  {
    "path": "cosyvoice/utils/mask.py",
    "chars": 8351,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (aut"
  },
  {
    "path": "cosyvoice/utils/scheduler.py",
    "chars": 24920,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2022 Ximalaya Inc (Yuguang Yang)\n#               2024 Ali"
  },
  {
    "path": "cosyvoice/utils/train_utils.py",
    "chars": 15154,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2023 Horizon Inc. (authors: Xingchen Song)\n#   "
  },
  {
    "path": "downloadmodel.py",
    "chars": 613,
    "preview": "from modelscope import snapshot_download\nsnapshot_download('chenmingyu/CosyVoice2-0.5B', local_dir='pretrained_models/Co"
  },
  {
    "path": "examples/CrossLingual.json",
    "chars": 3554,
    "preview": "{\n  \"last_node_id\": 11,\n  \"last_link_id\": 11,\n  \"nodes\": [\n    {\n      \"id\": 1,\n      \"type\": \"LoadAudio\",\n      \"pos\": "
  },
  {
    "path": "examples/Instruct2.json",
    "chars": 2669,
    "preview": "{\n  \"last_node_id\": 5,\n  \"last_link_id\": 4,\n  \"nodes\": [\n    {\n      \"id\": 1,\n      \"type\": \"LoadAudio\",\n      \"pos\": [\n"
  },
  {
    "path": "examples/ZeroShot.json",
    "chars": 2652,
    "preview": "{\n  \"last_node_id\": 4,\n  \"last_link_id\": 2,\n  \"nodes\": [\n    {\n      \"id\": 3,\n      \"type\": \"PreviewAudio\",\n      \"pos\":"
  },
  {
    "path": "pyproject.toml",
    "chars": 1012,
    "preview": "[project]\nname = \"ntcosyvoice\"\ndescription = \"ComfyUI_NTCosyVoice is a plugin of ComfyUI for Cosysvoice2\"\nversion = \"1.0"
  },
  {
    "path": "requirements.txt",
    "chars": 260,
    "preview": "conformer\ndeepspeed\ndiffusers\ngdown\nmodelscope\nhydra-core\nHyperPyYAML\ninflect\nlibrosa\npyarrow\nlightning\nmatplotlib\nomega"
  },
  {
    "path": "third_party/Matcha-TTS/LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2023 Shivam Mehta\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "third_party/Matcha-TTS/MANIFEST.in",
    "chars": 352,
    "preview": "include README.md\ninclude LICENSE.txt\ninclude requirements.*.txt\ninclude *.cff\ninclude requirements.txt\ninclude matcha/V"
  },
  {
    "path": "third_party/Matcha-TTS/Makefile",
    "chars": 1155,
    "preview": "\nhelp:  ## Show help\n\t@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = \":.*?## \"}; {printf \"\\033["
  },
  {
    "path": "third_party/Matcha-TTS/README.md",
    "chars": 8647,
    "preview": "<div align=\"center\">\n\n# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching\n\n### [Shivam Mehta](https:/"
  },
  {
    "path": "third_party/Matcha-TTS/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/configs/__init__.py",
    "chars": 81,
    "preview": "# this file is needed here to include configs when building project as a package\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/callbacks/default.yaml",
    "chars": 97,
    "preview": "defaults:\n  - model_checkpoint.yaml\n  - model_summary.yaml\n  - rich_progress_bar.yaml\n  - _self_\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml",
    "chars": 1199,
    "preview": "# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html\n\nmodel_checkpoint:\n  _ta"
  },
  {
    "path": "third_party/Matcha-TTS/configs/callbacks/model_summary.yaml",
    "chars": 252,
    "preview": "# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html\n\nmodel_summary:\n  _targ"
  },
  {
    "path": "third_party/Matcha-TTS/configs/callbacks/none.yaml",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml",
    "chars": 172,
    "preview": "# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html\n\nrich_progress_bar:\n  _t"
  },
  {
    "path": "third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml",
    "chars": 472,
    "preview": "defaults:\n  - ljspeech\n  - _self_\n\n# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/\n_target_: matc"
  },
  {
    "path": "third_party/Matcha-TTS/configs/data/ljspeech.yaml",
    "chars": 520,
    "preview": "_target_: matcha.data.text_mel_datamodule.TextMelDataModule\nname: ljspeech\ntrain_filelist_path: data/filelists/ljs_audio"
  },
  {
    "path": "third_party/Matcha-TTS/configs/data/vctk.yaml",
    "chars": 385,
    "preview": "defaults:\n  - ljspeech\n  - _self_\n\n_target_: matcha.data.text_mel_datamodule.TextMelDataModule\nname: vctk\ntrain_filelist"
  },
  {
    "path": "third_party/Matcha-TTS/configs/debug/default.yaml",
    "chars": 903,
    "preview": "# @package _global_\n\n# default debugging setup, runs 1 full epoch\n# other debugging configs can inherit from this one\n\n#"
  },
  {
    "path": "third_party/Matcha-TTS/configs/debug/fdr.yaml",
    "chars": 120,
    "preview": "# @package _global_\n\n# runs 1 train, 1 validation and 1 test step\n\ndefaults:\n  - default\n\ntrainer:\n  fast_dev_run: true\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/debug/limit.yaml",
    "chars": 218,
    "preview": "# @package _global_\n\n# uses only 1% of the training data and 5% of validation/test data\n\ndefaults:\n  - default\n\ntrainer:"
  },
  {
    "path": "third_party/Matcha-TTS/configs/debug/overfit.yaml",
    "chars": 204,
    "preview": "# @package _global_\n\n# overfits to 3 batches\n\ndefaults:\n  - default\n\ntrainer:\n  max_epochs: 20\n  overfit_batches: 3\n\n# m"
  },
  {
    "path": "third_party/Matcha-TTS/configs/debug/profiler.yaml",
    "chars": 225,
    "preview": "# @package _global_\n\n# runs with execution time profiling\n\ndefaults:\n  - default\n\ntrainer:\n  max_epochs: 1\n  # profiler:"
  },
  {
    "path": "third_party/Matcha-TTS/configs/eval.yaml",
    "chars": 335,
    "preview": "# @package _global_\n\ndefaults:\n  - _self_\n  - data: mnist # choose datamodule with `test_dataloader()` for evaluation\n  "
  },
  {
    "path": "third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml",
    "chars": 423,
    "preview": "# @package _global_\n\n# to execute this experiment run:\n# python train.py experiment=multispeaker\n\ndefaults:\n  - override"
  },
  {
    "path": "third_party/Matcha-TTS/configs/experiment/ljspeech.yaml",
    "chars": 332,
    "preview": "# @package _global_\n\n# to execute this experiment run:\n# python train.py experiment=multispeaker\n\ndefaults:\n  - override"
  },
  {
    "path": "third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml",
    "chars": 361,
    "preview": "# @package _global_\n\n# to execute this experiment run:\n# python train.py experiment=multispeaker\n\ndefaults:\n  - override"
  },
  {
    "path": "third_party/Matcha-TTS/configs/experiment/multispeaker.yaml",
    "chars": 336,
    "preview": "# @package _global_\n\n# to execute this experiment run:\n# python train.py experiment=multispeaker\n\ndefaults:\n  - override"
  },
  {
    "path": "third_party/Matcha-TTS/configs/extras/default.yaml",
    "chars": 232,
    "preview": "# disable python warnings if they annoy you\nignore_warnings: False\n\n# ask user for tags if none are provided in the conf"
  },
  {
    "path": "third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml",
    "chars": 1818,
    "preview": "# @package _global_\n\n# example hyperparameter optimization of some experiment with Optuna:\n# python train.py -m hparams_"
  },
  {
    "path": "third_party/Matcha-TTS/configs/hydra/default.yaml",
    "chars": 608,
    "preview": "# https://hydra.cc/docs/configure_hydra/intro/\n\n# enable color logging\ndefaults:\n  - override hydra_logging: colorlog\n  "
  },
  {
    "path": "third_party/Matcha-TTS/configs/local/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/aim.yaml",
    "chars": 1267,
    "preview": "# https://aimstack.io/\n\n# example usage in lightning module:\n# https://github.com/aimhubio/aim/blob/main/examples/pytorc"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/comet.yaml",
    "chars": 372,
    "preview": "# https://www.comet.ml\n\ncomet:\n  _target_: lightning.pytorch.loggers.comet.CometLogger\n  api_key: ${oc.env:COMET_API_TOK"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/csv.yaml",
    "chars": 157,
    "preview": "# csv logger built in lightning\n\ncsv:\n  _target_: lightning.pytorch.loggers.csv_logs.CSVLogger\n  save_dir: \"${paths.outp"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/many_loggers.yaml",
    "chars": 118,
    "preview": "# train with many loggers at once\n\ndefaults:\n  # - comet\n  - csv\n  # - mlflow\n  # - neptune\n  - tensorboard\n  - wandb\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/mlflow.yaml",
    "chars": 339,
    "preview": "# https://mlflow.org\n\nmlflow:\n  _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger\n  # experiment_name: \"\"\n  # run_"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/neptune.yaml",
    "chars": 277,
    "preview": "# https://neptune.ai\n\nneptune:\n  _target_: lightning.pytorch.loggers.neptune.NeptuneLogger\n  api_key: ${oc.env:NEPTUNE_A"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/tensorboard.yaml",
    "chars": 258,
    "preview": "# https://www.tensorflow.org/tensorboard/\n\ntensorboard:\n  _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLog"
  },
  {
    "path": "third_party/Matcha-TTS/configs/logger/wandb.yaml",
    "chars": 522,
    "preview": "# https://wandb.ai\n\nwandb:\n  _target_: lightning.pytorch.loggers.wandb.WandbLogger\n  # name: \"\" # name of the run (norma"
  },
  {
    "path": "third_party/Matcha-TTS/configs/model/cfm/default.yaml",
    "chars": 40,
    "preview": "name: CFM\nsolver: euler\nsigma_min: 1e-4\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/model/decoder/default.yaml",
    "chars": 119,
    "preview": "channels: [256, 256]\ndropout: 0.05\nattention_head_dim: 64\nn_blocks: 1\nnum_mid_blocks: 2\nnum_heads: 2\nact_fn: snakebeta\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/model/encoder/default.yaml",
    "chars": 417,
    "preview": "encoder_type: RoPE Encoder\nencoder_params:\n  n_feats: ${model.n_feats}\n  n_channels: 192\n  filter_channels: 768\n  filter"
  },
  {
    "path": "third_party/Matcha-TTS/configs/model/matcha.yaml",
    "chars": 328,
    "preview": "defaults:\n  - _self_\n  - encoder: default.yaml\n  - decoder: default.yaml\n  - cfm: default.yaml\n  - optimizer: adam.yaml\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/model/optimizer/adam.yaml",
    "chars": 70,
    "preview": "_target_: torch.optim.Adam\n_partial_: true\nlr: 1e-4\nweight_decay: 0.0\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/paths/default.yaml",
    "chars": 632,
    "preview": "# path to root directory\n# this requires PROJECT_ROOT environment variable to exist\n# you can replace it with \".\" if you"
  },
  {
    "path": "third_party/Matcha-TTS/configs/train.yaml",
    "chars": 1557,
    "preview": "# @package _global_\n\n# specify here default configuration\n# order of defaults determines the order in which configs over"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/cpu.yaml",
    "chars": 51,
    "preview": "defaults:\n  - default\n\naccelerator: cpu\ndevices: 1\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/ddp.yaml",
    "chars": 104,
    "preview": "defaults:\n  - default\n\nstrategy: ddp\n\naccelerator: gpu\ndevices: [0,1]\nnum_nodes: 1\nsync_batchnorm: True\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml",
    "chars": 115,
    "preview": "defaults:\n  - default\n\n# simulate DDP on CPU, useful for debugging\naccelerator: cpu\ndevices: 2\nstrategy: ddp_spawn\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/default.yaml",
    "chars": 439,
    "preview": "_target_: lightning.pytorch.trainer.Trainer\n\ndefault_root_dir: ${paths.output_dir}\n\nmax_epochs: -1\n\naccelerator: gpu\ndev"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/gpu.yaml",
    "chars": 51,
    "preview": "defaults:\n  - default\n\naccelerator: gpu\ndevices: 1\n"
  },
  {
    "path": "third_party/Matcha-TTS/configs/trainer/mps.yaml",
    "chars": 51,
    "preview": "defaults:\n  - default\n\naccelerator: mps\ndevices: 1\n"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/VERSION",
    "chars": 8,
    "preview": "0.0.5.1\n"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/app.py",
    "chars": 13981,
    "preview": "import tempfile\nfrom argparse import Namespace\nfrom pathlib import Path\n\nimport gradio as gr\nimport soundfile as sf\nimpo"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/cli.py",
    "chars": 15467,
    "preview": "import argparse\nimport datetime as dt\nimport os\nimport warnings\nfrom pathlib import Path\n\nimport matplotlib.pyplot as pl"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/data/components/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py",
    "chars": 7555,
    "preview": "import random\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torchaudio as ta\nfrom lightning import Lightni"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2020 Jungil Kong\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/README.md",
    "chars": 5570,
    "preview": "# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis\n\n### Jungil Kong, Jaehyeon "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/config.py",
    "chars": 779,
    "preview": "v1 = {\n    \"resblock\": \"1\",\n    \"num_gpus\": 0,\n    \"batch_size\": 16,\n    \"learning_rate\": 0.0004,\n    \"adam_b1\": 0.8,\n  "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/denoiser.py",
    "chars": 2644,
    "preview": "# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/env.py",
    "chars": 429,
    "preview": "\"\"\" from https://github.com/jik876/hifi-gan \"\"\"\n\nimport os\nimport shutil\n\n\nclass AttrDict(dict):\n    def __init__(self, "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/meldataset.py",
    "chars": 6786,
    "preview": "\"\"\" from https://github.com/jik876/hifi-gan \"\"\"\n\nimport math\nimport os\nimport random\n\nimport numpy as np\nimport torch\nim"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/models.py",
    "chars": 11668,
    "preview": "\"\"\" from https://github.com/jik876/hifi-gan \"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/hifigan/xutils.py",
    "chars": 1396,
    "preview": "\"\"\" from https://github.com/jik876/hifi-gan \"\"\"\n\nimport glob\nimport os\n\nimport matplotlib\nimport torch\nfrom torch.nn.uti"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/baselightningmodule.py",
    "chars": 7003,
    "preview": "\"\"\"\nThis is a base lightning module that can be used to train a model.\nThe benefit of this abstraction is that all the l"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/components/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/components/decoder.py",
    "chars": 14459,
    "preview": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom conform"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/components/flow_matching.py",
    "chars": 4657,
    "preview": "from abc import ABC\n\nimport torch\nimport torch.nn.functional as F\n\nfrom matcha.models.components.decoder import Decoder\n"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/components/text_encoder.py",
    "chars": 14845,
    "preview": "\"\"\" from https://github.com/jaywalnut310/glow-tts \"\"\"\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom einops impor"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/components/transformer.py",
    "chars": 13235,
    "preview": "from typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.attention import (\n    "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/models/matcha_tts.py",
    "chars": 10056,
    "preview": "import datetime as dt\nimport math\nimport random\n\nimport torch\n\nimport matcha.utils.monotonic_align as monotonic_align\nfr"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/onnx/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/matcha/onnx/export.py",
    "chars": 5377,
    "preview": "import argparse\nimport random\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom lightning import LightningM"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/onnx/infer.py",
    "chars": 6287,
    "preview": "import argparse\nimport os\nimport warnings\nfrom pathlib import Path\nfrom time import perf_counter\n\nimport numpy as np\nimp"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/text/__init__.py",
    "chars": 1696,
    "preview": "\"\"\" from https://github.com/keithito/tacotron \"\"\"\nfrom matcha.text import cleaners\nfrom matcha.text.symbols import symbo"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/text/cleaners.py",
    "chars": 3560,
    "preview": "\"\"\" from https://github.com/keithito/tacotron\n\nCleaners are transformations that run over the input text at both trainin"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/text/numbers.py",
    "chars": 2248,
    "preview": "\"\"\" from https://github.com/keithito/tacotron \"\"\"\n\nimport re\n\nimport inflect\n\n_inflect = inflect.engine()\n_comma_number_"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/text/symbols.py",
    "chars": 509,
    "preview": "\"\"\" from https://github.com/keithito/tacotron\n\nDefines the set of symbols used in text input to the model.\n\"\"\"\n_pad = \"_"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/train.py",
    "chars": 4613,
    "preview": "from typing import Any, Dict, List, Optional, Tuple\n\nimport hydra\nimport lightning as L\nimport rootutils\nfrom lightning "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/__init__.py",
    "chars": 326,
    "preview": "from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers\nfrom matcha.utils.logging_utils import"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/audio.py",
    "chars": 2282,
    "preview": "import numpy as np\nimport torch\nimport torch.utils.data\nfrom librosa.filters import mel as librosa_mel_fn\nfrom scipy.io."
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py",
    "chars": 3269,
    "preview": "r\"\"\"\nThe file creates a pickle file where the values needed for loading of dataset is stored and the model can load it\nw"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/instantiators.py",
    "chars": 1828,
    "preview": "from typing import List\n\nimport hydra\nfrom lightning import Callback\nfrom lightning.pytorch.loggers import Logger\nfrom o"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/logging_utils.py",
    "chars": 1711,
    "preview": "from typing import Any, Dict\n\nfrom lightning.pytorch.utilities import rank_zero_only\nfrom omegaconf import OmegaConf\n\nfr"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/model.py",
    "chars": 2935,
    "preview": "\"\"\" from https://github.com/jaywalnut310/glow-tts \"\"\"\n\nimport numpy as np\nimport torch\n\n\ndef sequence_mask(length, max_l"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py",
    "chars": 646,
    "preview": "import numpy as np\nimport torch\n\nfrom matcha.utils.monotonic_align.core import maximum_path_c\n\n\ndef maximum_path(value, "
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx",
    "chars": 1236,
    "preview": "import numpy as np\n\ncimport cython\ncimport numpy as np\n\nfrom cython.parallel import prange\n\n\n@cython.boundscheck(False)\n"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py",
    "chars": 207,
    "preview": "# from distutils.core import setup\n# from Cython.Build import cythonize\n# import numpy\n\n# setup(name='monotonic_align',\n"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/pylogger.py",
    "chars": 720,
    "preview": "import logging\n\nfrom lightning.pytorch.utilities import rank_zero_only\n\n\ndef get_pylogger(name: str = __name__) -> loggi"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/rich_utils.py",
    "chars": 3279,
    "preview": "from pathlib import Path\nfrom typing import Sequence\n\nimport rich\nimport rich.syntax\nimport rich.tree\nfrom hydra.core.hy"
  },
  {
    "path": "third_party/Matcha-TTS/matcha/utils/utils.py",
    "chars": 7159,
    "preview": "import os\nimport sys\nimport warnings\nfrom importlib.util import find_spec\nfrom pathlib import Path\nfrom typing import An"
  },
  {
    "path": "third_party/Matcha-TTS/notebooks/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/Matcha-TTS/pyproject.toml",
    "chars": 982,
    "preview": "[build-system]\nrequires = [\"setuptools\", \"wheel\", \"cython==0.29.35\", \"numpy==1.24.3\", \"packaging\"]\n\n[tool.black]\nline-le"
  },
  {
    "path": "third_party/Matcha-TTS/requirements.txt",
    "chars": 201,
    "preview": "# --------- pytorch --------- #\ntorch>=2.0.0\ntorchvision>=0.15.0\nlightning>=2.0.0\ntorchmetrics>=0.11.4\n\n# --------- hydr"
  },
  {
    "path": "third_party/Matcha-TTS/scripts/schedule.sh",
    "chars": 207,
    "preview": "#!/bin/bash\n# Schedule execution of many runs\n# Run from root folder with: bash scripts/schedule.sh\n\npython src/train.py"
  },
  {
    "path": "third_party/Matcha-TTS/setup.py",
    "chars": 1295,
    "preview": "#!/usr/bin/env python\nimport os\n\nimport numpy\nfrom Cython.Build import cythonize\nfrom setuptools import Extension, find_"
  },
  {
    "path": "third_party/Matcha-TTS/synthesis.ipynb",
    "chars": 590014,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f37f4e3b-f764-4502-a6a2-6417bd9bfab9\",\n   \"metadata\": {},\n   \"so"
  },
  {
    "path": "third_party/__init__.py",
    "chars": 0,
    "preview": ""
  }
]

About this extraction

This page contains the full source code of the muxueChen/ComfyUI_NTCosyVoice GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 163 files (2.0 MB), approximately 532.8k tokens, and a symbol index with 656 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!