Full Code of FunAudioLLM/CosyVoice for AI

main ace7c47f41bb cached
153 files
1.9 MB
845.7k tokens
786 symbols
1 requests
Download .txt
Showing preview only (2,032K chars total). Download the full file or copy to clipboard to get everything.
Repository: FunAudioLLM/CosyVoice
Branch: main
Commit: ace7c47f41bb
Files: 153
Total size: 1.9 MB

Directory structure:
gitextract_xzebwrda/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   └── feature_request.md
│   └── workflows/
│       ├── lint.yml
│       └── stale-issues.yml
├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── FAQ.md
├── LICENSE
├── README.md
├── cosyvoice/
│   ├── __init__.py
│   ├── bin/
│   │   ├── average_model.py
│   │   ├── export_jit.py
│   │   ├── export_onnx.py
│   │   └── train.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── cosyvoice.py
│   │   ├── frontend.py
│   │   └── model.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── processor.py
│   ├── flow/
│   │   ├── DiT/
│   │   │   ├── dit.py
│   │   │   └── modules.py
│   │   ├── decoder.py
│   │   ├── flow.py
│   │   ├── flow_matching.py
│   │   └── length_regulator.py
│   ├── hifigan/
│   │   ├── discriminator.py
│   │   ├── f0_predictor.py
│   │   ├── generator.py
│   │   └── hifigan.py
│   ├── llm/
│   │   └── llm.py
│   ├── tokenizer/
│   │   ├── assets/
│   │   │   └── multilingual_zh_ja_yue_char_del.tiktoken
│   │   └── tokenizer.py
│   ├── transformer/
│   │   ├── __init__.py
│   │   ├── activation.py
│   │   ├── attention.py
│   │   ├── convolution.py
│   │   ├── decoder.py
│   │   ├── decoder_layer.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   ├── encoder_layer.py
│   │   ├── label_smoothing_loss.py
│   │   ├── positionwise_feed_forward.py
│   │   ├── subsampling.py
│   │   └── upsample_encoder.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── class_utils.py
│   │   ├── common.py
│   │   ├── executor.py
│   │   ├── file_utils.py
│   │   ├── frontend_utils.py
│   │   ├── losses.py
│   │   ├── mask.py
│   │   ├── onnx.py
│   │   ├── scheduler.py
│   │   └── train_utils.py
│   └── vllm/
│       └── cosyvoice2.py
├── docker/
│   └── Dockerfile
├── example.py
├── examples/
│   ├── grpo/
│   │   └── cosyvoice2/
│   │       ├── Dockerfile
│   │       ├── README.md
│   │       ├── huggingface_to_pretrained.py
│   │       ├── infer_dataset.py
│   │       ├── prepare_data.py
│   │       ├── pretrained_to_huggingface.py
│   │       ├── requirements.txt
│   │       ├── reward_tts.py
│   │       ├── run.sh
│   │       ├── scripts/
│   │       │   ├── compute_wer.sh
│   │       │   └── offline-decode-files.py
│   │       └── token2wav_asr_server.py
│   ├── libritts/
│   │   ├── cosyvoice/
│   │   │   ├── conf/
│   │   │   │   ├── cosyvoice.yaml
│   │   │   │   └── ds_stage2.json
│   │   │   ├── local/
│   │   │   │   ├── download_and_untar.sh
│   │   │   │   ├── prepare_data.py
│   │   │   │   └── prepare_reject_sample.py
│   │   │   ├── path.sh
│   │   │   ├── run.sh
│   │   │   └── tts_text.json
│   │   ├── cosyvoice2/
│   │   │   ├── conf/
│   │   │   │   ├── cosyvoice2.yaml
│   │   │   │   └── ds_stage2.json
│   │   │   ├── run.sh
│   │   │   └── run_dpo.sh
│   │   └── cosyvoice3/
│   │       ├── conf/
│   │       │   ├── cosyvoice3.yaml
│   │       │   └── ds_stage2.json
│   │       └── run.sh
│   └── magicdata-read/
│       └── cosyvoice/
│           ├── local/
│           │   ├── download_and_untar.sh
│           │   └── prepare_data.py
│           ├── run.sh
│           └── tts_text.json
├── requirements.txt
├── runtime/
│   ├── python/
│   │   ├── Dockerfile
│   │   ├── fastapi/
│   │   │   ├── client.py
│   │   │   └── server.py
│   │   └── grpc/
│   │       ├── client.py
│   │       ├── cosyvoice.proto
│   │       └── server.py
│   └── triton_trtllm/
│       ├── Dockerfile.server
│       ├── README.Cosyvoice2.DiT.md
│       ├── README.Cosyvoice2.Unet.md
│       ├── README.Cosyvoice3.md
│       ├── README.md
│       ├── client_grpc.py
│       ├── client_http.py
│       ├── docker-compose.cosyvoice2.dit.yml
│       ├── docker-compose.cosyvoice2.unet.yml
│       ├── docker-compose.cosyvoice3.yml
│       ├── infer_cosyvoice3.py
│       ├── model_repo/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice2/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice2_dit/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── speaker_embedding/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── tensorrt_llm/
│       │   │   ├── 1/
│       │   │   │   └── .gitkeep
│       │   │   └── config.pbtxt
│       │   ├── token2wav/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   └── token2wav_dit/
│       │       ├── 1/
│       │       │   ├── model.py
│       │       │   └── token2wav_dit.py
│       │       └── config.pbtxt
│       ├── model_repo_cosyvoice3/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice3/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── speaker_embedding/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── token2wav/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   └── vocoder/
│       │       ├── 1/
│       │       │   └── model.py
│       │       └── config.pbtxt
│       ├── offline_inference.py
│       ├── requirements.txt
│       ├── run.sh
│       ├── run_cosyvoice3.sh
│       ├── run_stepaudio2_dit_token2wav.sh
│       ├── scripts/
│       │   ├── convert_checkpoint.py
│       │   ├── convert_cosyvoice3_to_hf.py
│       │   ├── fill_template.py
│       │   └── test_llm.py
│       ├── streaming_inference.py
│       ├── token2wav.py
│       └── token2wav_cosyvoice3.py
├── tools/
│   ├── extract_embedding.py
│   ├── extract_speech_token.py
│   └── make_parquet_list.py
├── vllm_example.py
└── webui.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Desktop (please complete the following information):**
 - OS: [e.g. iOS]
 - Browser [e.g. chrome, safari]
 - Version [e.g. 22]

**Smartphone (please complete the following information):**
 - Device: [e.g. iPhone6]
 - OS: [e.g. iOS8.1]
 - Browser [e.g. stock browser, safari]
 - Version [e.g. 22]

**Additional context**
Add any other context about the problem here.


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.


================================================
FILE: .github/workflows/lint.yml
================================================
name: Lint

on:
  pull_request:
  push:

jobs:
  quick-checks:
    runs-on: ubuntu-latest
    steps:
      - name: Fetch CosyVoice
        uses: actions/checkout@v1
      - name: Checkout PR tip
        run: |
          set -eux
          if [[ "${{ github.event_name }}" == "pull_request" ]]; then
            # We are on a PR, so actions/checkout leaves us on a merge commit.
            # Check out the actual tip of the branch.
            git checkout ${{ github.event.pull_request.head.sha }}
          fi
          echo ::set-output name=commit_sha::$(git rev-parse HEAD)
        id: get_pr_tip
      - name: Ensure no tabs
        run: |
          (! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
      - name: Ensure no trailing whitespace
        run: |
          (! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))

  flake8-py3:
    runs-on: ubuntu-latest
    steps:
      - name: Setup Python
        uses: actions/setup-python@v1
        with:
          python-version: 3.9
          architecture: x64
      - name: Fetch CosyVoice
        uses: actions/checkout@v1
      - name: Checkout PR tip
        run: |
          set -eux
          if [[ "${{ github.event_name }}" == "pull_request" ]]; then
            # We are on a PR, so actions/checkout leaves us on a merge commit.
            # Check out the actual tip of the branch.
            git checkout ${{ github.event.pull_request.head.sha }}
          fi
          echo ::set-output name=commit_sha::$(git rev-parse HEAD)
        id: get_pr_tip
      - name: Run flake8
        run: |
          set -eux
          pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
          flake8 --version
          flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
          if [ $? != 0 ]; then exit 1; fi

================================================
FILE: .github/workflows/stale-issues.yml
================================================
name: Close inactive issues
on:
  schedule:
    - cron: "30 1 * * *"

jobs:
  close-issues:
    runs-on: ubuntu-latest
    permissions:
      issues: write
      pull-requests: write
    steps:
      - uses: actions/stale@v5
        with:
          days-before-issue-stale: 30
          days-before-issue-close: 14
          stale-issue-label: "stale"
          stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
          close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
          days-before-pr-stale: -1
          days-before-pr-close: -1
          repo-token: ${{ secrets.GITHUB_TOKEN }}


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Visual Studio Code files
.vscode
.vs

# PyCharm files
.idea

# Eclipse Project settings
*.*project
.settings

# Sublime Text settings
*.sublime-workspace
*.sublime-project

# Editor temporaries
*.swn
*.swo
*.swp
*.swm
*~

# IPython notebook checkpoints
.ipynb_checkpoints

# macOS dir files
.DS_Store

exp
data
raw_wav
tensorboard
**/*build*

# Clangd files
.cache
compile_commands.json

# train/inference files
*.wav
*.m4a
*.aac
*.pt
pretrained_models/*
*_pb2_grpc.py
*_pb2.py
*.tar

================================================
FILE: .gitmodules
================================================
[submodule "third_party/Matcha-TTS"]
	path = third_party/Matcha-TTS
	url = https://github.com/shivammehta25/Matcha-TTS.git


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
 advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
 address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
 professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at mikelei@mobvoi.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: FAQ.md
================================================
## ModuleNotFoundError: No module named 'matcha'

Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.

run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.

## cannot find resource.zip or cannot unzip resource.zip

Please make sure you have git-lfs installed. Execute

```sh
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
```


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)

## 👉🏻 CosyVoice 👈🏻

**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)

**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)

**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)

## Highlight🔥

**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
### Key Features
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.


## Roadmap

- [x] 2025/12

    - [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
    - [x] release Fun-CosyVoice3-0.5B modelscope gradio space

- [x] 2025/08

    - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support

- [x] 2025/07

    - [x] release Fun-CosyVoice 3.0 eval set

- [x] 2025/05

    - [x] add CosyVoice2-0.5B vllm support

- [x] 2024/12

    - [x] 25hz CosyVoice2-0.5B released

- [x] 2024/09

    - [x] 25hz CosyVoice-300M base model
    - [x] 25hz CosyVoice-300M voice conversion function

- [x] 2024/08

    - [x] Repetition Aware Sampling(RAS) inference for llm stability
    - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization

- [x] 2024/07

    - [x] Flow matching training support
    - [x] WeTextProcessing support when ttsfrd is not available
    - [x] Fastapi server and client

## Evaluation

| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |


## Install

### Clone and install

- Clone the repo
    ``` sh
    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
    # If you failed to clone the submodule due to network failures, please run the following command until success
    cd CosyVoice
    git submodule update --init --recursive
    ```

- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:

    ``` sh
    conda create -n cosyvoice -y python=3.10
    conda activate cosyvoice
    pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com

    # If you encounter sox compatibility issues
    # ubuntu
    sudo apt-get install sox libsox-dev
    # centos
    sudo yum install sox sox-devel
    ```

### Model download

We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.

``` python
# modelscope SDK model download
from modelscope import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')

# for oversea users, huggingface SDK model download
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
```

Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.

Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.

``` sh
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd_dependency-0.1-py3-none-any.whl
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
```

### Basic Usage

We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
Follow the code in `example.py` for detailed usage of each model.
```sh
python example.py
```

#### vLLM Usage
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.

Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.

``` sh
conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm
# for vllm==0.9.0
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm>=0.11.0
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py
```

#### Start web demo

You can use our web demo page to get familiar with CosyVoice quickly.

Please see the demo website for details.

``` python
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
```

#### Advanced Usage

For advanced users, we have provided training and inference scripts in `examples/libritts`.

#### Build for deployment

Optionally, if you want service deployment,
You can run the following steps.

``` sh
cd runtime/python
docker build -t cosyvoice:v1.0 .
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
# for grpc usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
# for fastapi usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```

#### Using Nvidia TensorRT-LLM for deployment

Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
To quick start:

``` sh
cd runtime/triton_trtllm
docker compose up -d
```
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)

## Discussion & Communication

You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).

You can also scan the QR code to join our official Dingding chat group.

<img src="./asset/dingding.png" width="250px">

## Acknowledge

1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).

## Citations

``` bibtex
@article{du2024cosyvoice,
  title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
  author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
  journal={arXiv preprint arXiv:2407.05407},
  year={2024}
}

@article{du2024cosyvoice,
  title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
  author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
  journal={arXiv preprint arXiv:2412.10117},
  year={2024}
}

@article{du2025cosyvoice,
  title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
  author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
  journal={arXiv preprint arXiv:2505.17589},
  year={2025}
}

@inproceedings{lyu2025build,
  title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
  author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
  booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  pages={1--2},
  year={2025},
  organization={IEEE}
}
```

## Disclaimer
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.


================================================
FILE: cosyvoice/__init__.py
================================================


================================================
FILE: cosyvoice/bin/average_model.py
================================================
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import glob

import yaml
import torch


def get_args():
    parser = argparse.ArgumentParser(description='average model')
    parser.add_argument('--dst_model', required=True, help='averaged model')
    parser.add_argument('--src_path',
                        required=True,
                        help='src model path for average')
    parser.add_argument('--val_best',
                        action="store_true",
                        help='averaged model')
    parser.add_argument('--num',
                        default=5,
                        type=int,
                        help='nums for averaged model')

    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    val_scores = []
    if args.val_best:
        yamls = glob.glob('{}/*.yaml'.format(args.src_path))
        yamls = [
            f for f in yamls
            if not (os.path.basename(f).startswith('train')
                    or os.path.basename(f).startswith('init'))
        ]
        for y in yamls:
            with open(y, 'r') as f:
                dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
                loss = float(dic_yaml['loss_dict']['loss'])
                epoch = int(dic_yaml['epoch'])
                step = int(dic_yaml['step'])
                tag = dic_yaml['tag']
                val_scores += [[epoch, step, loss, tag]]
        sorted_val_scores = sorted(val_scores,
                                   key=lambda x: x[2],
                                   reverse=False)
        print("best val (epoch, step, loss, tag) = " +
              str(sorted_val_scores[:args.num]))
        path_list = [
            args.src_path + '/epoch_{}_whole.pt'.format(score[0])
            for score in sorted_val_scores[:args.num]
        ]
    print(path_list)
    avg = {}
    num = args.num
    assert num == len(path_list)
    for path in path_list:
        print('Processing {}'.format(path))
        states = torch.load(path, map_location=torch.device('cpu'))
        for k in states.keys():
            if k not in ['step', 'epoch']:
                if k not in avg.keys():
                    avg[k] = states[k].clone()
                else:
                    avg[k] += states[k]
    # average
    for k in avg.keys():
        if avg[k] is not None:
            # pytorch 1.6 use true_divide instead of /=
            avg[k] = torch.true_divide(avg[k], num)
    print('Saving to {}'.format(args.dst_model))
    torch.save(avg, args.dst_model)


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/bin/export_jit.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging


def get_args():
    parser = argparse.ArgumentParser(description='export your model for deployment')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_models/CosyVoice-300M',
                        help='local path')
    args = parser.parse_args()
    print(args)
    return args


def get_optimized_script(model, preserved_attrs=[]):
    script = torch.jit.script(model)
    if preserved_attrs != []:
        script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
    else:
        script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)
    return script


def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')

    torch._C._jit_set_fusion_strategy([('STATIC', 1)])
    torch._C._jit_set_profiling_mode(False)
    torch._C._jit_set_profiling_executor(False)

    model = AutoModel(model_dir=args.model_dir)

    if model.__class__.__name__ == 'CosyVoice':
        # 1. export llm text_encoder
        llm_text_encoder = model.model.llm.text_encoder
        script = get_optimized_script(llm_text_encoder)
        script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
        script = get_optimized_script(llm_text_encoder.half())
        script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
        logging.info('successfully export llm_text_encoder')

        # 2. export llm llm
        llm_llm = model.model.llm.llm
        script = get_optimized_script(llm_llm, ['forward_chunk'])
        script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
        script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
        script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
        logging.info('successfully export llm_llm')

        # 3. export flow encoder
        flow_encoder = model.model.flow.encoder
        script = get_optimized_script(flow_encoder)
        script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
        script = get_optimized_script(flow_encoder.half())
        script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
        logging.info('successfully export flow_encoder')
    elif model.__class__.__name__ == 'CosyVoice2':
        # 1. export flow encoder
        flow_encoder = model.model.flow.encoder
        script = get_optimized_script(flow_encoder)
        script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
        script = get_optimized_script(flow_encoder.half())
        script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
        logging.info('successfully export flow_encoder')
    else:
        raise ValueError('unsupported model type')


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/bin/export_onnx.py
================================================
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import onnxruntime
import random
import torch
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging


def get_dummy_input(batch_size, seq_len, out_channels, device):
    x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
    mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    t = torch.rand((batch_size), dtype=torch.float32, device=device)
    spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
    cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
    return x, mask, mu, t, spks, cond


def get_args():
    parser = argparse.ArgumentParser(description='export your model for deployment')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_models/CosyVoice-300M',
                        help='local path')
    args = parser.parse_args()
    print(args)
    return args


@torch.no_grad()
def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')

    model = AutoModel(model_dir=args.model_dir)

    # 1. export flow decoder estimator
    estimator = model.model.flow.decoder.estimator
    estimator.eval()

    device = model.model.device
    batch_size, seq_len = 2, 256
    out_channels = model.model.flow.decoder.estimator.out_channels
    x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
    torch.onnx.export(
        estimator,
        (x, mask, mu, t, spks, cond),
        '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
        output_names=['estimator_out'],
        dynamic_axes={
            'x': {2: 'seq_len'},
            'mask': {2: 'seq_len'},
            'mu': {2: 'seq_len'},
            'cond': {2: 'seq_len'},
            'estimator_out': {2: 'seq_len'},
        }
    )

    # 2. test computation consistency
    option = onnxruntime.SessionOptions()
    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    option.intra_op_num_threads = 1
    providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
                                                  sess_options=option, providers=providers)

    for _ in tqdm(range(10)):
        x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
        output_pytorch = estimator(x, mask, mu, t, spks, cond)
        ort_inputs = {
            'x': x.cpu().numpy(),
            'mask': mask.cpu().numpy(),
            'mu': mu.cpu().numpy(),
            't': t.cpu().numpy(),
            'spks': spks.cpu().numpy(),
            'cond': cond.cpu().numpy()
        }
        output_onnx = estimator_onnx.run(None, ort_inputs)[0]
        torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
    logging.info('successfully export estimator')


if __name__ == "__main__":
    main()


================================================
FILE: cosyvoice/bin/train.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed

from hyperpyyaml import load_hyperpyyaml

from torch.distributed.elastic.multiprocessing.errors import record

from cosyvoice.utils.losses import DPOLoss
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
    init_distributed,
    init_dataset_and_dataloader,
    init_optimizer_and_scheduler,
    init_summarywriter, save_model,
    wrap_cuda_model, check_modify_and_save_config)


def get_args():
    parser = argparse.ArgumentParser(description='training your network')
    parser.add_argument('--train_engine',
                        default='torch_ddp',
                        choices=['torch_ddp', 'deepspeed'],
                        help='Engine for paralleled training')
    parser.add_argument('--model', required=True, help='model which will be trained')
    parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--train_data', required=True, help='train data file')
    parser.add_argument('--cv_data', required=True, help='cv data file')
    parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
    parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
    parser.add_argument('--checkpoint', help='checkpoint model')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--tensorboard_dir',
                        default='tensorboard',
                        help='tensorboard log dir')
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo'],
                        help='distributed backend')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--prefetch',
                        default=100,
                        type=int,
                        help='prefetch number')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--dpo',
                        action='store_true',
                        default=False,
                        help='Use Direct Preference Optimization')
    parser.add_argument('--deepspeed.save_states',
                        dest='save_states',
                        default='model_only',
                        choices=['model_only', 'model+optimizer'],
                        help='save model/optimizer states')
    parser.add_argument('--timeout',
                        default=60,
                        type=int,
                        help='timeout (in seconds) of cosyvoice_join.')
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    return args


@record
def main():
    args = get_args()
    os.environ['onnx_path'] = args.onnx_path
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    # gan train has some special initialization logic
    gan = True if args.model == 'hifigan' else False

    override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
    if gan is True:
        override_dict.pop('hift')
    if args.qwen_pretrain_path is not None:
        override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
    with open(args.config, 'r') as f:
        configs = load_hyperpyyaml(f, overrides=override_dict)
    if gan is True:
        configs['train_conf'] = configs['train_conf_gan']
    configs['train_conf'].update(vars(args))

    # Init env for ddp
    init_distributed(args)

    # Get dataset & dataloader
    train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
        init_dataset_and_dataloader(args, configs, gan, args.dpo)

    # Do some sanity checks and save config to arsg.model_dir
    configs = check_modify_and_save_config(args, configs)

    # Tensorboard summary
    writer = init_summarywriter(args)

    # load checkpoint
    if args.dpo is True:
        configs[args.model].forward = configs[args.model].forward_dpo
    model = configs[args.model]
    start_step, start_epoch = 0, -1
    if args.checkpoint is not None:
        if os.path.exists(args.checkpoint):
            state_dict = torch.load(args.checkpoint, map_location='cpu')
            model.load_state_dict(state_dict, strict=False)
            if 'step' in state_dict:
                start_step = state_dict['step']
            if 'epoch' in state_dict:
                start_epoch = state_dict['epoch']
        else:
            logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))

    # Dispatch model from cpu to gpu
    model = wrap_cuda_model(args, model)

    # Get optimizer & scheduler
    model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
    scheduler.set_step(start_step)
    if scheduler_d is not None:
        scheduler_d.set_step(start_step)

    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    info_dict['step'] = start_step
    info_dict['epoch'] = start_epoch
    save_model(model, 'init', info_dict)

    # DPO related
    if args.dpo is True:
        ref_model = deepcopy(configs[args.model])
        state_dict = torch.load(args.ref_model, map_location='cpu')
        ref_model.load_state_dict(state_dict, strict=False)
        dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
        # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
        ref_model = wrap_cuda_model(args, ref_model)
    else:
        ref_model, dpo_loss = None, None

    # Get executor
    executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
    executor.step = start_step

    # Init scaler, used for pytorch amp mixed precision training
    scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
    print('start step {} start epoch {}'.format(start_step, start_epoch))

    # Start training loop
    for epoch in range(start_epoch + 1, info_dict['max_epoch']):
        executor.epoch = epoch
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        if gan is True:
            executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
                                        writer, info_dict, scaler, group_join)
        else:
            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
        dist.destroy_process_group(group_join)


if __name__ == '__main__':
    main()


================================================
FILE: cosyvoice/cli/__init__.py
================================================


================================================
FILE: cosyvoice/cli/cosyvoice.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import Generator
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type


class CosyVoice:

    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
        self.model_dir = model_dir
        self.fp16 = fp16
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
        if not os.path.exists(hyper_yaml_path):
            raise ValueError('{} not found!'.format(hyper_yaml_path))
        with open(hyper_yaml_path, 'r') as f:
            configs = load_hyperpyyaml(f)
        assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v1.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
            load_jit, load_trt, fp16 = False, False, False
            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
        self.model.load('{}/llm.pt'.format(model_dir),
                        '{}/flow.pt'.format(model_dir),
                        '{}/hift.pt'.format(model_dir))
        if load_jit:
            self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
        if load_trt:
            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                trt_concurrent,
                                self.fp16)
        del configs

    def list_available_spks(self):
        spks = list(self.frontend.spk2info.keys())
        return spks

    def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
        assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
        model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
        del model_input['text']
        del model_input['text_len']
        self.frontend.spk2info[zero_shot_spk_id] = model_input
        return True

    def save_spkinfo(self):
        torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))

    def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_sft(i, spk_id)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
        prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
                logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
        assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
        instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()

    def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
        model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
        start_time = time.time()
        for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
            speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
            logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
            yield model_output
            start_time = time.time()


class CosyVoice2(CosyVoice):

    def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
        self.model_dir = model_dir
        self.fp16 = fp16
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
        if not os.path.exists(hyper_yaml_path):
            raise ValueError('{} not found!'.format(hyper_yaml_path))
        with open(hyper_yaml_path, 'r') as f:
            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
        assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v2.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
            load_jit, load_trt, load_vllm, fp16 = False, False, False, False
            logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
        self.model.load('{}/llm.pt'.format(model_dir),
                        '{}/flow.pt'.format(model_dir),
                        '{}/hift.pt'.format(model_dir))
        if load_vllm:
            self.model.load_vllm('{}/vllm'.format(model_dir))
        if load_jit:
            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
        if load_trt:
            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                trt_concurrent,
                                self.fp16)
        del configs

    def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
            start_time = time.time()
            logging.info('synthesis text {}'.format(i))
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                yield model_output
                start_time = time.time()


class CosyVoice3(CosyVoice2):

    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
        self.model_dir = model_dir
        self.fp16 = fp16
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
        if not os.path.exists(hyper_yaml_path):
            raise ValueError('{} not found!'.format(hyper_yaml_path))
        with open(hyper_yaml_path, 'r') as f:
            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
        assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v3.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
            load_trt, fp16 = False, False
            logging.warning('no cuda device, set load_trt/fp16 to False')
        self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
        self.model.load('{}/llm.pt'.format(model_dir),
                        '{}/flow.pt'.format(model_dir),
                        '{}/hift.pt'.format(model_dir))
        if load_vllm:
            self.model.load_vllm('{}/vllm'.format(model_dir))
        if load_trt:
            if self.fp16 is True:
                logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                trt_concurrent,
                                self.fp16)
        del configs


def AutoModel(**kwargs):
    if not os.path.exists(kwargs['model_dir']):
        kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
    if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
        return CosyVoice(**kwargs)
    elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
        return CosyVoice2(**kwargs)
    elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
        return CosyVoice3(**kwargs)
    else:
        raise TypeError('No valid model type found!')


================================================
FILE: cosyvoice/cli/frontend.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Generator
import json
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import os
import re
import inflect
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation


class CosyVoiceFrontEnd:

    def __init__(self,
                 get_tokenizer: Callable,
                 feat_extractor: Callable,
                 campplus_model: str,
                 speech_tokenizer_model: str,
                 spk2info: str = '',
                 allowed_special: str = 'all'):
        self.tokenizer = get_tokenizer()
        self.feat_extractor = feat_extractor
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        option = onnxruntime.SessionOptions()
        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        option.intra_op_num_threads = 1
        self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
        self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
                                                                     providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
                                                                                "CPUExecutionProvider"])
        if os.path.exists(spk2info):
            self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
        else:
            self.spk2info = {}
        self.allowed_special = allowed_special
        self.inflect_parser = inflect.engine()
        # NOTE compatible when no text frontend tool is avaliable
        try:
            import ttsfrd
            self.frd = ttsfrd.TtsFrontendEngine()
            ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
            assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
                'failed to initialize ttsfrd resource'
            self.frd.set_lang_type('pinyinvg')
            self.text_frontend = 'ttsfrd'
            logging.info('use ttsfrd frontend')
        except:
            try:
                from wetext import Normalizer as ZhNormalizer
                from wetext import Normalizer as EnNormalizer
                self.zh_tn_model = ZhNormalizer(remove_erhua=False)
                self.en_tn_model = EnNormalizer()
                self.text_frontend = 'wetext'
                logging.info('use wetext frontend')
            except:
                self.text_frontend = ''
                logging.info('no frontend is avaliable')


    def _extract_text_token(self, text):
        if isinstance(text, Generator):
            logging.info('get tts_text generator, will return _extract_text_token_generator!')
            # NOTE add a dummy text_token_len for compatibility
            return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
        else:
            text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
            text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
            text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
            return text_token, text_token_len

    def _extract_text_token_generator(self, text_generator):
        for text in text_generator:
            text_token, _ = self._extract_text_token(text)
            for i in range(text_token.shape[1]):
                yield text_token[:, i: i + 1]

    def _extract_speech_token(self, prompt_wav):
        speech = load_wav(prompt_wav, 16000)
        assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
        feat = whisper.log_mel_spectrogram(speech, n_mels=128)
        speech_token = self.speech_tokenizer_session.run(None,
                                                         {self.speech_tokenizer_session.get_inputs()[0].name:
                                                          feat.detach().cpu().numpy(),
                                                          self.speech_tokenizer_session.get_inputs()[1].name:
                                                          np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
        speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
        speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
        return speech_token, speech_token_len

    def _extract_spk_embedding(self, prompt_wav):
        speech = load_wav(prompt_wav, 16000)
        feat = kaldi.fbank(speech,
                           num_mel_bins=80,
                           dither=0,
                           sample_frequency=16000)
        feat = feat - feat.mean(dim=0, keepdim=True)
        embedding = self.campplus_session.run(None,
                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
        embedding = torch.tensor([embedding]).to(self.device)
        return embedding

    def _extract_speech_feat(self, prompt_wav):
        speech = load_wav(prompt_wav, 24000)
        speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
        speech_feat = speech_feat.unsqueeze(dim=0)
        speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
        return speech_feat, speech_feat_len

    def text_normalize(self, text, split=True, text_frontend=True):
        if isinstance(text, Generator):
            logging.info('get tts_text generator, will skip text_normalize!')
            return [text]
        # NOTE skip text_frontend when ssml symbol in text
        if '<|' in text and '|>' in text:
            text_frontend = False
        if text_frontend is False or text == '':
            return [text] if split is True else text
        text = text.strip()
        if self.text_frontend == 'ttsfrd':
            texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
            text = ''.join(texts)
        else:
            if contains_chinese(text):
                if self.text_frontend == 'wetext':
                    text = self.zh_tn_model.normalize(text)
                text = text.replace("\n", "")
                text = replace_blank(text)
                text = replace_corner_mark(text)
                text = text.replace(".", "。")
                text = text.replace(" - ", ",")
                text = remove_bracket(text)
                text = re.sub(r'[,,、]+$', '。', text)
                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
                                             token_min_n=60, merge_len=20, comma_split=False))
            else:
                if self.text_frontend == 'wetext':
                    text = self.en_tn_model.normalize(text)
                text = spell_out_number(text, self.inflect_parser)
                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
                                             token_min_n=60, merge_len=20, comma_split=False))
        texts = [i for i in texts if not is_only_punctuation(i)]
        return texts if split is True else text

    def frontend_sft(self, tts_text, spk_id):
        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
        embedding = self.spk2info[spk_id]['embedding']
        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
        return model_input

    def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
        if zero_shot_spk_id == '':
            prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
            speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
            speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
            if resample_rate == 24000:
                # cosyvoice2, force speech_feat % speech_token = 2
                token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
                speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
                speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
            embedding = self._extract_spk_embedding(prompt_wav)
            model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
                           'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                           'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                           'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                           'llm_embedding': embedding, 'flow_embedding': embedding}
        else:
            model_input = {**self.spk2info[zero_shot_spk_id]}
        model_input['text'] = tts_text_token
        model_input['text_len'] = tts_text_token_len
        return model_input

    def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
        model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
        # in cross lingual mode, we remove prompt in llm
        del model_input['prompt_text']
        del model_input['prompt_text_len']
        del model_input['llm_prompt_speech_token']
        del model_input['llm_prompt_speech_token_len']
        return model_input

    def frontend_instruct(self, tts_text, spk_id, instruct_text):
        model_input = self.frontend_sft(tts_text, spk_id)
        # in instruct mode, we remove spk_embedding in llm due to information leakage
        del model_input['llm_embedding']
        instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
        model_input['prompt_text'] = instruct_text_token
        model_input['prompt_text_len'] = instruct_text_token_len
        return model_input

    def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
        model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
        del model_input['llm_prompt_speech_token']
        del model_input['llm_prompt_speech_token_len']
        return model_input

    def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
        embedding = self._extract_spk_embedding(prompt_wav)
        source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
        model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
                       'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
                       'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
                       'flow_embedding': embedding}
        return model_input


================================================
FILE: cosyvoice/cli/model.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Generator
import torch
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper


class CosyVoiceModel:

    def __init__(self,
                 llm: torch.nn.Module,
                 flow: torch.nn.Module,
                 hift: torch.nn.Module,
                 fp16: bool = False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.llm = llm
        self.flow = flow
        self.hift = hift
        self.fp16 = fp16
        self.token_min_hop_len = 2 * self.flow.input_frame_rate
        self.token_max_hop_len = 4 * self.flow.input_frame_rate
        self.token_overlap_len = 20
        # mel fade in out
        self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
        self.mel_window = np.hamming(2 * self.mel_overlap_len)
        # hift cache
        self.mel_cache_len = 20
        self.source_cache_len = int(self.mel_cache_len * 256)
        # speech fade in out
        self.speech_window = np.hamming(2 * self.source_cache_len)
        # rtf and decoding related
        self.stream_scale_factor = 1
        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
        self.lock = threading.Lock()
        # dict used to store session related variable
        self.tts_speech_token_dict = {}
        self.llm_end_dict = {}
        self.mel_overlap_dict = {}
        self.flow_cache_dict = {}
        self.hift_cache_dict = {}
        self.silent_tokens = []

    def load(self, llm_model, flow_model, hift_model):
        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
        self.llm.to(self.device).eval()
        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
        self.flow.to(self.device).eval()
        # in case hift_model is a hifigan model
        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.to(self.device).eval()

    def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
        llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
        self.llm.text_encoder = llm_text_encoder
        llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
        self.llm.llm = llm_llm
        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
        self.flow.encoder = flow_encoder

    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
            convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
        del self.flow.decoder.estimator
        import tensorrt as trt
        with open(flow_decoder_estimator_model, 'rb') as f:
            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)

    def get_trt_kwargs(self):
        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
        input_names = ["x", "mask", "mu", "cond"]
        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}

    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
        cur_silent_token_num, max_silent_token_num = 0, 5
        with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
            if isinstance(text, Generator):
                assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
                token_generator = self.llm.inference_bistream(text=text,
                                                              prompt_text=prompt_text.to(self.device),
                                                              prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                                              prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                              prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                              embedding=llm_embedding.to(self.device))
            else:
                token_generator = self.llm.inference(text=text.to(self.device),
                                                     text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                                     prompt_text=prompt_text.to(self.device),
                                                     prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                                     prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                     prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                     embedding=llm_embedding.to(self.device),
                                                     uuid=uuid)
            for i in token_generator:
                if i in self.silent_tokens:
                    cur_silent_token_num += 1
                    if cur_silent_token_num > max_silent_token_num:
                        continue
                else:
                    cur_silent_token_num = 0
                self.tts_speech_token_dict[uuid].append(i)
        self.llm_end_dict[uuid] = True

    def vc_job(self, source_speech_token, uuid):
        self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
        self.llm_end_dict[uuid] = True

    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
        with torch.cuda.amp.autocast(self.fp16):
            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
                                                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                                      prompt_token=prompt_token.to(self.device),
                                                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                                                      prompt_feat=prompt_feat.to(self.device),
                                                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                                                      embedding=embedding.to(self.device),
                                                                      flow_cache=self.flow_cache_dict[uuid])

        # mel overlap fade in out
        if self.mel_overlap_dict[uuid].shape[2] != 0:
            tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
        # append hift cache
        if self.hift_cache_dict[uuid] is not None:
            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
        else:
            hift_cache_source = torch.zeros(1, 1, 0)
        # keep overlap mel and hift cache
        if finalize is False:
            self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
            tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                          'source': tts_source[:, :, -self.source_cache_len:],
                                          'speech': tts_speech[:, -self.source_cache_len:]}
            tts_speech = tts_speech[:, :-self.source_cache_len]
        else:
            if speed != 1.0:
                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
        return tts_speech

    def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
            self.hift_cache_dict[this_uuid] = None
            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
        if source_speech_token.shape[1] == 0:
            p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
        else:
            p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
        p.start()
        if stream is True:
            token_hop_len = self.token_min_hop_len
            while True:
                time.sleep(0.1)
                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
                        .unsqueeze(dim=0)
                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token,
                                                     prompt_feat=prompt_speech_feat,
                                                     embedding=flow_embedding,
                                                     uuid=this_uuid,
                                                     finalize=False)
                    yield {'tts_speech': this_tts_speech.cpu()}
                    with self.lock:
                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
                    # increase token_hop_len for better speech quality
                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                    break
            p.join()
            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True)
            yield {'tts_speech': this_tts_speech.cpu()}
        else:
            # deal with all tokens
            p.join()
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             uuid=this_uuid,
                                             finalize=True,
                                             speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
        with self.lock:
            self.tts_speech_token_dict.pop(this_uuid)
            self.llm_end_dict.pop(this_uuid)
            self.mel_overlap_dict.pop(this_uuid)
            self.hift_cache_dict.pop(this_uuid)
            self.flow_cache_dict.pop(this_uuid)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.current_stream().synchronize()


class CosyVoice2Model(CosyVoiceModel):

    def __init__(self,
                 llm: torch.nn.Module,
                 flow: torch.nn.Module,
                 hift: torch.nn.Module,
                 fp16: bool = False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.llm = llm
        self.flow = flow
        self.hift = hift
        self.fp16 = fp16
        # NOTE must matching training static_chunk_size
        self.token_hop_len = 25
        # NOTE increase token_hop_len incrementally to avoid duplicate inference
        self.token_max_hop_len = 4 * self.token_hop_len
        self.stream_scale_factor = 2
        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
        # hift cache
        self.mel_cache_len = 8
        self.source_cache_len = int(self.mel_cache_len * 480)
        # speech fade in out
        self.speech_window = np.hamming(2 * self.source_cache_len)
        # rtf and decoding related
        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
        self.lock = threading.Lock()
        # dict used to store session related variable
        self.tts_speech_token_dict = {}
        self.llm_end_dict = {}
        self.hift_cache_dict = {}
        self.silent_tokens = []

    def load_jit(self, flow_encoder_model):
        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
        self.flow.encoder = flow_encoder

    def load_vllm(self, model_dir):
        export_cosyvoice2_vllm(self.llm, model_dir, self.device)
        from vllm import EngineArgs, LLMEngine
        engine_args = EngineArgs(model=model_dir,
                                 skip_tokenizer_init=True,
                                 enable_prompt_embeds=True,
                                 gpu_memory_utilization=0.2)
        self.llm.vllm = LLMEngine.from_engine_args(engine_args)
        self.llm.lock = threading.Lock()
        del self.llm.llm.model.model.layers

    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
        with torch.cuda.amp.autocast(self.fp16):
            tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                             prompt_token=prompt_token.to(self.device),
                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                             prompt_feat=prompt_feat.to(self.device),
                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                             embedding=embedding.to(self.device),
                                             streaming=stream,
                                             finalize=finalize)
        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
        # append hift cache
        if self.hift_cache_dict[uuid] is not None:
            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
        else:
            hift_cache_source = torch.zeros(1, 1, 0)
        # keep overlap mel and hift cache
        if finalize is False:
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                          'source': tts_source[:, :, -self.source_cache_len:],
                                          'speech': tts_speech[:, -self.source_cache_len:]}
            tts_speech = tts_speech[:, :-self.source_cache_len]
        else:
            if speed != 1.0:
                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
        return tts_speech

    def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
            self.hift_cache_dict[this_uuid] = None
        if source_speech_token.shape[1] == 0:
            p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
        else:
            p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
        p.start()
        if stream is True:
            token_offset = 0
            prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
            while True:
                time.sleep(0.1)
                this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token,
                                                     prompt_feat=prompt_speech_feat,
                                                     embedding=flow_embedding,
                                                     token_offset=token_offset,
                                                     uuid=this_uuid,
                                                     stream=stream,
                                                     finalize=False)
                    token_offset += this_token_hop_len
                    self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
                    yield {'tts_speech': this_tts_speech.cpu()}
                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
                    break
            p.join()
            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             token_offset=token_offset,
                                             uuid=this_uuid,
                                             finalize=True)
            yield {'tts_speech': this_tts_speech.cpu()}
        else:
            # deal with all tokens
            p.join()
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                             prompt_token=flow_prompt_speech_token,
                                             prompt_feat=prompt_speech_feat,
                                             embedding=flow_embedding,
                                             token_offset=0,
                                             uuid=this_uuid,
                                             finalize=True,
                                             speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
        with self.lock:
            self.tts_speech_token_dict.pop(this_uuid)
            self.llm_end_dict.pop(this_uuid)
            self.hift_cache_dict.pop(this_uuid)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.current_stream().synchronize()


class CosyVoice3Model(CosyVoice2Model):

    def __init__(self,
                 llm: torch.nn.Module,
                 flow: torch.nn.Module,
                 hift: torch.nn.Module,
                 fp16: bool = False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.llm = llm
        self.flow = flow
        self.hift = hift
        self.fp16 = fp16
        # NOTE must matching training static_chunk_size
        self.token_hop_len = 25
        # NOTE increase token_hop_len incrementally to avoid duplicate inference
        self.token_max_hop_len = 4 * self.token_hop_len
        self.stream_scale_factor = 2
        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
        # rtf and decoding related
        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
        self.lock = threading.Lock()
        # dict used to store session related variable
        self.tts_speech_token_dict = {}
        self.llm_end_dict = {}
        self.hift_cache_dict = {}
        # FSQ silent and breath token
        self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]

    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
        with torch.cuda.amp.autocast(self.fp16):
            tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                             prompt_token=prompt_token.to(self.device),
                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                             prompt_feat=prompt_feat.to(self.device),
                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                             embedding=embedding.to(self.device),
                                             streaming=stream,
                                             finalize=finalize)
            tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
            # append mel cache
            if self.hift_cache_dict[uuid] is not None:
                hift_cache_mel = self.hift_cache_dict[uuid]['mel']
                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
                self.hift_cache_dict[uuid]['mel'] = tts_mel
            else:
                self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
            if speed != 1.0:
                assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
            tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
            self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
        return tts_speech


================================================
FILE: cosyvoice/dataset/__init__.py
================================================


================================================
FILE: cosyvoice/dataset/dataset.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#               2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import math
from functools import partial

import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists


class Processor(IterableDataset):

    def __init__(self, source, f, *args, **kw):
        assert callable(f)
        self.source = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self.source is not None
        assert callable(self.f)
        return self.f(iter(self.source), *self.args, **self.kw)

    def apply(self, f):
        assert callable(f)
        return Processor(self, f, *self.args, **self.kw)


class DistributedSampler:

    def __init__(self, shuffle=True, partition=True):
        self.epoch = -1
        self.update()
        self.shuffle = shuffle
        self.partition = partition

    def update(self):
        assert dist.is_available()
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            self.rank = 0
            self.world_size = 1
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            self.worker_id = 0
            self.num_workers = 1
        else:
            self.worker_id = worker_info.id
            self.num_workers = worker_info.num_workers
        return dict(rank=self.rank,
                    world_size=self.world_size,
                    worker_id=self.worker_id,
                    num_workers=self.num_workers)

    def set_epoch(self, epoch):
        self.epoch = epoch

    def sample(self, data):
        """ Sample data according to rank/world_size/num_workers

            Args:
                data(List): input data list

            Returns:
                List: data list after sample
        """
        data = list(range(len(data)))
        # force datalist even
        if self.partition:
            if self.shuffle:
                random.Random(self.epoch).shuffle(data)
            if len(data) < self.world_size:
                data = data * math.ceil(self.world_size / len(data))
                data = data[:self.world_size]
            data = data[self.rank::self.world_size]
        if len(data) < self.num_workers:
            data = data * math.ceil(self.num_workers / len(data))
            data = data[:self.num_workers]
        data = data[self.worker_id::self.num_workers]
        return data


class DataList(IterableDataset):

    def __init__(self, lists, shuffle=True, partition=True):
        self.lists = lists
        self.sampler = DistributedSampler(shuffle, partition)

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)

    def __iter__(self):
        sampler_info = self.sampler.update()
        indexes = self.sampler.sample(self.lists)
        for index in indexes:
            data = dict(src=self.lists[index])
            data.update(sampler_info)
            yield data


def Dataset(data_list_file,
            data_pipeline,
            mode='train',
            gan=False,
            dpo=False,
            shuffle=True,
            partition=True):
    """ Construct dataset from arguments

        We have two shuffle stage in the Dataset. The first is global
        shuffle at shards tar/raw file level. The second is global shuffle
        at training samples level.

        Args:
            data_type(str): raw/shard
            tokenizer (BaseTokenizer): tokenizer to tokenize
            partition(bool): whether to do data partition in terms of rank
    """
    lists = read_lists(data_list_file)
    dataset = DataList(lists,
                       shuffle=shuffle,
                       partition=partition)
    # map partial arg to padding func
    for i in range(1, len(data_pipeline)):
        if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
            data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
        if data_pipeline[i].func.__name__ == 'padding':
            data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
    for func in data_pipeline:
        dataset = Processor(dataset, func, mode=mode)
    return dataset


================================================
FILE: cosyvoice/dataset/processor.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random

import pyarrow.parquet as pq
from io import BytesIO
import numpy as np
import whisper
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
from cosyvoice.utils.onnx import embedding_extractor, online_feature

AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}


def parquet_opener(data, mode='train'):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        url = sample['src']
        try:
            for df in pq.ParquetFile(url).iter_batches(batch_size=64):
                df = df.to_pandas()
                for i in range(len(df)):
                    sample.update(dict(df.loc[i]))
                    # NOTE do not return sample directly, must initialize a new dict
                    yield {**sample}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(url, ex))


def filter(data,
           max_length=10240,
           min_length=10,
           token_max_length=200,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=1,
           mode='train'):
    """ Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
        del sample['audio_data']
        # sample['wav'] is torch.Tensor, we have 100 frames every second
        num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
        if num_frames < min_length:
            continue
        if num_frames > max_length:
            continue
        if len(sample['text_token']) < token_min_length:
            continue
        if len(sample['text_token']) > token_max_length:
            continue
        if online_feature is False and len(sample['speech_token']) == 0:
            continue
        if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
            continue
        if num_frames != 0:
            if len(sample['text_token']) / num_frames < min_output_input_ratio:
                continue
            if len(sample['text_token']) / num_frames > max_output_input_ratio:
                continue
        yield sample


def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
    """ Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['speech']
        if sample_rate != resample_rate:
            if sample_rate < min_sample_rate:
                continue
            sample['sample_rate'] = resample_rate
            sample['speech'] = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=resample_rate)(waveform)
        max_val = sample['speech'].abs().max()
        if max_val > 1:
            sample['speech'] /= max_val
        yield sample


def truncate(data, truncate_length=24576, mode='train'):
    """ Truncate data.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            truncate_length: truncate length

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        waveform = sample['speech']
        if waveform.shape[1] > truncate_length:
            start = random.randint(0, waveform.shape[1] - truncate_length)
            waveform = waveform[:, start: start + truncate_length]
        else:
            waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
        sample['speech'] = waveform
        yield sample


def compute_fbank(data,
                  feat_extractor,
                  num_frames=-1,
                  mode='train'):
    """ Extract fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        assert 'utt' in sample
        assert 'text_token' in sample
        # NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
        if num_frames != -1:
            index = int(np.ceil(sample['speech'].shape[1] / num_frames))
            sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
        sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
        yield sample


def compute_whisper_fbank(data, num_frames=-1, mode='train'):
    """ Extract whisper fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        if num_frames != -1:
            assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
        sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
        sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
        yield sample


def compute_f0(data, sample_rate, hop_size, mode='train'):
    """ Extract f0

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    frame_period = hop_size * 1000 / sample_rate
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        assert 'utt' in sample
        assert 'text_token' in sample
        waveform = sample['speech']
        _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
        if sum(_f0 != 0) < 5:  # this happens when the algorithm fails
            _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)  # if harvest fails, try dio
        f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
        f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
        sample['pitch_feat'] = f0
        yield sample


def parse_embedding(data, normalize, mode='train'):
    """ Parse utt_embedding/spk_embedding

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
            sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
            embedding = embedding_extractor.inference(sample['speech_16k'])
            sample['spk_embedding'] = sample['utt_embedding'] = embedding
        else:
            sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
            sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
        if normalize:
            sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
            sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
        yield sample


def tokenize(data, get_tokenizer, allowed_special, mode='train'):
    """ Decode text to chars or BPE
        Inplace operation

        Args:
            data: Iterable[{key, wav, txt, sample_rate}]

        Returns:
            Iterable[{key, wav, txt, tokens, label, sample_rate}]
    """
    tokenizer = get_tokenizer()
    for sample in data:
        assert 'text' in sample
        sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
        if 'instruct' in sample:
            sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
        yield sample


def shuffle(data, shuffle_size=10000, mode='train'):
    """ Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
    buf = []
    yield_size = int(shuffle_size / 2)
    for sample in data:
        buf.append(sample)
        if len(buf) >= shuffle_size:
            random.shuffle(buf)
            for x in buf[:yield_size]:
                yield x
            buf = buf[yield_size:]
    # The sample left over
    random.shuffle(buf)
    for x in buf:
        yield x


def sort(data, sort_size=500, mode='train'):
    """ Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= sort_size:
            buf.sort(key=lambda x: x['speech_feat'].size(0))
            for x in buf:
                yield x
            buf = []
    # The sample left over
    buf.sort(key=lambda x: x['speech_feat'].size(0))
    for x in buf:
        yield x


def static_batch(data, batch_size=16):
    """ Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= batch_size:
            yield buf
            buf = []
    if len(buf) > 0:
        yield buf


def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    longest_frames = 0
    for sample in data:
        assert 'speech_feat' in sample
        assert isinstance(sample['speech_feat'], torch.Tensor)
        new_sample_frames = sample['speech_feat'].size(0)
        longest_frames = max(longest_frames, new_sample_frames)
        frames_after_padding = longest_frames * (len(buf) + 1)
        if frames_after_padding > max_frames_in_batch:
            yield buf
            buf = [sample]
            longest_frames = new_sample_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
    """ Wrapper for static/dynamic batch
    """
    if batch_type == 'static':
        return static_batch(data, batch_size)
    elif batch_type == 'dynamic':
        return dynamic_batch(data, max_frames_in_batch)
    else:
        logging.fatal('Unsupported batch type {}'.format(batch_type))


def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
        batch = {}
        batch['utts'] = [sample[i]['utt'] for i in order]
        batch['text'] = [sample[i]['text'] for i in order]
        text_token = [torch.tensor(sample[i]['text_token']) for i in order]
        batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
        batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
        speech_feat = [sample[i]['speech_feat'] for i in order]
        batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
        batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
        batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
        batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
        if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
            instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
            batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
            batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
        if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
            whisper_feat = [sample[i]['whisper_feat'] for i in order]
            batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
            batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
        if torch.tensor(['speech_token' in sample[i] for i in order]).all():
            speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
            batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
            batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
        if gan is True:
            # in gan train, we need speech/pitch_feat
            speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
            batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
            batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
            pitch_feat = [sample[i]['pitch_feat'] for i in order]
            batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
            batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
        if dpo is True:
            reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
            batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
            batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
        if use_spk_embedding is True:
            batch["embedding"] = batch["spk_embedding"]
        else:
            batch["embedding"] = batch["utt_embedding"]
        yield batch


================================================
FILE: cosyvoice/flow/DiT/dit.py
================================================

"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""

from __future__ import annotations

import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from cosyvoice.utils.mask import add_optional_chunk_mask
from cosyvoice.flow.DiT.modules import (
    TimestepEmbedding,
    ConvNeXtV2Block,
    CausalConvPositionEmbedding,
    DiTBlock,
    AdaLayerNormZero_Final,
    precompute_freqs_cis,
    get_pos_embed_indices,
)


# Text embedding


class TextEmbedding(nn.Module):
    def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
        super().__init__()
        self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token

        if conv_layers > 0:
            self.extra_modeling = True
            self.precompute_max_pos = 4096  # ~44s of 24khz audio
            self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
            self.text_blocks = nn.Sequential(
                *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
            )
        else:
            self.extra_modeling = False

    def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
        batch, text_len = text.shape[0], text.shape[1]
        text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
        text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
        text = F.pad(text, (0, seq_len - text_len), value=0)

        if drop_text:  # cfg for text
            text = torch.zeros_like(text)

        text = self.text_embed(text)  # b n -> b n d

        # possible extra modeling
        if self.extra_modeling:
            # sinus pos emb
            batch_start = torch.zeros((batch,), dtype=torch.long)
            pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
            text_pos_embed = self.freqs_cis[pos_idx]
            text = text + text_pos_embed

            # convnextv2 blocks
            text = self.text_blocks(text)

        return text


# noised input audio and context mixing embedding


class InputEmbedding(nn.Module):
    def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
        super().__init__()
        spk_dim = 0 if spk_dim is None else spk_dim
        self.spk_dim = spk_dim
        self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
        self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)

    def forward(
            self,
            x: float["b n d"],
            cond: float["b n d"],
            text_embed: float["b n d"],
            spks: float["b d"],
    ):
        to_cat = [x, cond, text_embed]
        if self.spk_dim > 0:
            spks = repeat(spks, "b c -> b t c", t=x.shape[1])
            to_cat.append(spks)

        x = self.proj(torch.cat(to_cat, dim=-1))
        x = self.conv_pos_embed(x) + x
        return x


# Transformer backbone using DiT blocks


class DiT(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth=8,
        heads=8,
        dim_head=64,
        dropout=0.1,
        ff_mult=4,
        mel_dim=80,
        mu_dim=None,
        long_skip_connection=False,
        spk_dim=None,
        out_channels=None,
        static_chunk_size=50,
        num_decoding_left_chunks=2
    ):
        super().__init__()

        self.time_embed = TimestepEmbedding(dim)
        if mu_dim is None:
            mu_dim = mel_dim
        self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)

        self.rotary_embed = RotaryEmbedding(dim_head)

        self.dim = dim
        self.depth = depth

        self.transformer_blocks = nn.ModuleList(
            [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
        )
        self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None

        self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
        self.proj_out = nn.Linear(dim, mel_dim)
        self.out_channels = out_channels
        self.static_chunk_size = static_chunk_size
        self.num_decoding_left_chunks = num_decoding_left_chunks

    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
        x = x.transpose(1, 2)
        mu = mu.transpose(1, 2)
        cond = cond.transpose(1, 2)
        spks = spks.unsqueeze(dim=1)
        batch, seq_len = x.shape[0], x.shape[1]
        if t.ndim == 0:
            t = t.repeat(batch)

        # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
        t = self.time_embed(t)
        x = self.input_embed(x, cond, mu, spks.squeeze(1))

        rope = self.rotary_embed.forward_from_seq_len(seq_len)

        if self.long_skip_connection is not None:
            residual = x

        if streaming is True:
            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
        else:
            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)

        for block in self.transformer_blocks:
            x = block(x, t, mask=attn_mask.bool(), rope=rope)

        if self.long_skip_connection is not None:
            x = self.long_skip_connection(torch.cat((x, residual), dim=-1))

        x = self.norm_out(x, t)
        output = self.proj_out(x).transpose(1, 2)
        return output


================================================
FILE: cosyvoice/flow/DiT/modules.py
================================================

"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""

from __future__ import annotations
from typing import Optional
import math

import torch
from torch import nn
import torch.nn.functional as F
import torchaudio

from x_transformers.x_transformers import apply_rotary_pos_emb


# raw wav to mel spec
class MelSpec(nn.Module):
    def __init__(
        self,
        filter_length=1024,
        hop_length=256,
        win_length=1024,
        n_mel_channels=100,
        target_sample_rate=24_000,
        normalize=False,
        power=1,
        norm=None,
        center=True,
    ):
        super().__init__()
        self.n_mel_channels = n_mel_channels

        self.mel_stft = torchaudio.transforms.MelSpectrogram(
            sample_rate=target_sample_rate,
            n_fft=filter_length,
            win_length=win_length,
            hop_length=hop_length,
            n_mels=n_mel_channels,
            power=power,
            center=center,
            normalized=normalize,
            norm=norm,
        )

        self.register_buffer("dummy", torch.tensor(0), persistent=False)

    def forward(self, inp):
        if len(inp.shape) == 3:
            inp = inp.squeeze(1)  # 'b 1 nw -> b nw'

        assert len(inp.shape) == 2

        if self.dummy.device != inp.device:
            self.to(inp.device)

        mel = self.mel_stft(inp)
        mel = mel.clamp(min=1e-5).log()
        return mel


# sinusoidal position embedding


class SinusPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x, scale=1000):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


# convolutional position embedding


class ConvPositionEmbedding(nn.Module):
    def __init__(self, dim, kernel_size=31, groups=16):
        super().__init__()
        assert kernel_size % 2 != 0
        self.conv1d = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
            nn.Mish(),
            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
            nn.Mish(),
        )

    def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722
        if mask is not None:
            mask = mask[..., None]
            x = x.masked_fill(~mask, 0.0)

        x = x.permute(0, 2, 1)
        x = self.conv1d(x)
        out = x.permute(0, 2, 1)

        if mask is not None:
            out = out.masked_fill(~mask, 0.0)

        return out


class CausalConvPositionEmbedding(nn.Module):
    def __init__(self, dim, kernel_size=31, groups=16):
        super().__init__()
        assert kernel_size % 2 != 0
        self.kernel_size = kernel_size
        self.conv1 = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
            nn.Mish(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
            nn.Mish(),
        )

    def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722
        if mask is not None:
            mask = mask[..., None]
            x = x.masked_fill(~mask, 0.0)

        x = x.permute(0, 2, 1)
        x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
        x = self.conv1(x)
        x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
        x = self.conv2(x)
        out = x.permute(0, 2, 1)

        if mask is not None:
            out = out.masked_fill(~mask, 0.0)

        return out


# rotary positional embedding related


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
    # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
    theta *= theta_rescale_factor ** (dim / (dim - 2))
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return torch.cat([freqs_cos, freqs_sin], dim=-1)


def get_pos_embed_indices(start, length, max_pos, scale=1.0):
    # length = length if isinstance(length, int) else length.max()
    scale = scale * torch.ones_like(start, dtype=torch.float32)  # in case scale is a scalar
    pos = (
        start.unsqueeze(1)
        + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
    )
    # avoid extra long error.
    pos = torch.where(pos < max_pos, pos, max_pos - 1)
    return pos


# Global Response Normalization layer (Instance Normalization ?)


class GRN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=1, keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108


class ConvNeXtV2Block(nn.Module):
    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        dilation: int = 1,
    ):
        super().__init__()
        padding = (dilation * (7 - 1)) // 2
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
        )  # depthwise conv
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN(intermediate_dim)
        self.pwconv2 = nn.Linear(intermediate_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = x.transpose(1, 2)  # b n d -> b d n
        x = self.dwconv(x)
        x = x.transpose(1, 2)  # b d n -> b n d
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        return residual + x


# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation


class AdaLayerNormZero(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.silu = nn.SiLU()
        self.linear = nn.Linear(dim, dim * 6)

        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb=None):
        emb = self.linear(self.silu(emb))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)

        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation


class AdaLayerNormZero_Final(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.silu = nn.SiLU()
        self.linear = nn.Linear(dim, dim * 2)

        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb):
        emb = self.linear(self.silu(emb))
        scale, shift = torch.chunk(emb, 2, dim=1)

        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
        return x


# FeedForward


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        activation = nn.GELU(approximate=approximate)
        project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
        self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))

    def forward(self, x):
        return self.ff(x)


# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py


class Attention(nn.Module):
    def __init__(
        self,
        processor: JointAttnProcessor | AttnProcessor,
        dim: int,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        context_dim: Optional[int] = None,  # if not None -> joint attention
        context_pre_only=None,
    ):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.processor = processor

        self.dim = dim
        self.heads = heads
        self.inner_dim = dim_head * heads
        self.dropout = dropout

        self.context_dim = context_dim
        self.context_pre_only = context_pre_only

        self.to_q = nn.Linear(dim, self.inner_dim)
        self.to_k = nn.Linear(dim, self.inner_dim)
        self.to_v = nn.Linear(dim, self.inner_dim)

        if self.context_dim is not None:
            self.to_k_c = nn.Linear(context_dim, self.inner_dim)
            self.to_v_c = nn.Linear(context_dim, self.inner_dim)
            if self.context_pre_only is not None:
                self.to_q_c = nn.Linear(context_dim, self.inner_dim)

        self.to_out = nn.ModuleList([])
        self.to_out.append(nn.Linear(self.inner_dim, dim))
        self.to_out.append(nn.Dropout(dropout))

        if self.context_pre_only is not None and not self.context_pre_only:
            self.to_out_c = nn.Linear(self.inner_dim, dim)

    def forward(
        self,
        x: float["b n d"],  # noised input x  # noqa: F722
        c: float["b n d"] = None,  # context c  # noqa: F722
        mask: bool["b n"] | None = None,  # noqa: F722
        rope=None,  # rotary position embedding for x
        c_rope=None,  # rotary position embedding for c
    ) -> torch.Tensor:
        if c is not None:
            return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
        else:
            return self.processor(self, x, mask=mask, rope=rope)


# Attention processor


class AttnProcessor:
    def __init__(self):
        pass

    def __call__(
        self,
        attn: Attention,
        x: float["b n d"],  # noised input x  # noqa: F722
        mask: bool["b n"] | None = None,  # noqa: F722
        rope=None,  # rotary position embedding
    ) -> torch.FloatTensor:
        batch_size = x.shape[0]

        # `sample` projections.
        query = attn.to_q(x)
        key = attn.to_k(x)
        value = attn.to_v(x)

        # apply rotary position embedding
        if rope is not None:
            freqs, xpos_scale = rope
            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)

            query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
            key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)

        # attention
        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # mask. e.g. inference got a batch with different target durations, mask out the padding
        if mask is not None:
            attn_mask = mask
            if attn_mask.dim() == 2:
                attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
                attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
        else:
            attn_mask = None

        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
        x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        x = x.to(query.dtype)

        # linear proj
        x = attn.to_out[0](x)
        # dropout
        x = attn.to_out[1](x)

        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(-1)
            else:
                mask = mask[:, 0, -1].unsqueeze(-1)
            x = x.masked_fill(~mask, 0.0)

        return x


# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py


class JointAttnProcessor:
    def __init__(self):
        pass

    def __call__(
        self,
        attn: Attention,
        x: float["b n d"],  # noised input x  # noqa: F722
        c: float["b nt d"] = None,  # context c, here text # noqa: F722
        mask: bool["b n"] | None = None,  # noqa: F722
        rope=None,  # rotary position embedding for x
        c_rope=None,  # rotary position embedding for c
    ) -> torch.FloatTensor:
        residual = x

        batch_size = c.shape[0]

        # `sample` projections.
        query = attn.to_q(x)
        key = attn.to_k(x)
        value = attn.to_v(x)

        # `context` projections.
        c_query = attn.to_q_c(c)
        c_key = attn.to_k_c(c)
        c_value = attn.to_v_c(c)

        # apply rope for context and noised input independently
        if rope is not None:
            freqs, xpos_scale = rope
            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
            query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
            key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
        if c_rope is not None:
            freqs, xpos_scale = c_rope
            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
            c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
            c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)

        # attention
        query = torch.cat([query, c_query], dim=1)
        key = torch.cat([key, c_key], dim=1)
        value = torch.cat([value, c_value], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # mask. e.g. inference got a batch with different target durations, mask out the padding
        if mask is not None:
            attn_mask = F.pad(mask, (0, c.shape[1]), value=True)  # no mask for c (text)
            attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
            attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
        else:
            attn_mask = None

        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
        x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        x = x.to(query.dtype)

        # Split the attention outputs.
        x, c = (
            x[:, : residual.shape[1]],
            x[:, residual.shape[1]:],
        )

        # linear proj
        x = attn.to_out[0](x)
        # dropout
        x = attn.to_out[1](x)
        if not attn.context_pre_only:
            c = attn.to_out_c(c)

        if mask is not None:
            mask = mask.unsqueeze(-1)
            x = x.masked_fill(~mask, 0.0)
            # c = c.masked_fill(~mask, 0.)  # no mask for c (text)

        return x, c


# DiT Block


class DiTBlock(nn.Module):
    def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
        super().__init__()

        self.attn_norm = AdaLayerNormZero(dim)
        self.attn = Attention(
            processor=AttnProcessor(),
            dim=dim,
            heads=heads,
            dim_head=dim_head,
            dropout=dropout,
        )

        self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")

    def forward(self, x, t, mask=None, rope=None):  # x: noised input, t: time embedding
        # pre-norm & modulation for attention input
        norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)

        # attention
        attn_output = self.attn(x=norm, mask=mask, rope=rope)

        # process attention output for input x
        x = x + gate_msa.unsqueeze(1) * attn_output

        ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        ff_output = self.ff(ff_norm)
        x = x + gate_mlp.unsqueeze(1) * ff_output

        return x


# MMDiT Block https://arxiv.org/abs/2403.03206


class MMDiTBlock(nn.Module):
    r"""
    modified from diffusers/src/diffusers/models/attention.py

    notes.
    _c: context related. text, cond, etc. (left part in sd3 fig2.b)
    _x: noised input related. (right part)
    context_pre_only: last layer only do prenorm + modulation cuz no more ffn
    """

    def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
        super().__init__()

        self.context_pre_only = context_pre_only

        self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
        self.attn_norm_x = AdaLayerNormZero(dim)
        self.attn = Attention(
            processor=JointAttnProcessor(),
            dim=dim,
            heads=heads,
            dim_head=dim_head,
            dropout=dropout,
            context_dim=dim,
            context_pre_only=context_pre_only,
        )

        if not context_pre_only:
            self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
            self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
        else:
            self.ff_norm_c = None
            self.ff_c = None
        self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")

    def forward(self, x, c, t, mask=None, rope=None, c_rope=None):  # x: noised input, c: context, t: time embedding
        # pre-norm & modulation for attention input
        if self.context_pre_only:
            norm_c = self.attn_norm_c(c, t)
        else:
            norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
        norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)

        # attention
        x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)

        # process attention output for context c
        if self.context_pre_only:
            c = None
        else:  # if not last layer
            c = c + c_gate_msa.unsqueeze(1) * c_attn_output

            norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
            c_ff_output = self.ff_c(norm_c)
            c = c + c_gate_mlp.unsqueeze(1) * c_ff_output

        # process attention output for input x
        x = x + x_gate_msa.unsqueeze(1) * x_attn_output

        norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
        x_ff_output = self.ff_x(norm_x)
        x = x + x_gate_mlp.unsqueeze(1) * x_ff_output

        return c, x


# time step conditioning embedding


class TimestepEmbedding(nn.Module):
    def __init__(self, dim, freq_embed_dim=256):
        super().__init__()
        self.time_embed = SinusPositionEmbedding(freq_embed_dim)
        self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))

    def forward(self, timestep: float["b"]):  # noqa: F821
        time_hidden = self.time_embed(timestep)
        time_hidden = time_hidden.to(timestep.dtype)
        time = self.time_mlp(time_hidden)  # b d
        return time


================================================
FILE: cosyvoice/flow/decoder.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat
from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock


class Transpose(torch.nn.Module):
    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.transpose(x, self.dim0, self.dim1)
        return x


class CausalConv1d(torch.nn.Conv1d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',
        device=None,
        dtype=None
    ) -> None:
        super(CausalConv1d, self).__init__(in_channels, out_channels,
                                           kernel_size, stride,
                                           padding=0, dilation=dilation,
                                           groups=groups, bias=bias,
                                           padding_mode=padding_mode,
                                           device=device, dtype=dtype)
        assert stride == 1
        self.causal_padding = kernel_size - 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.pad(x, (self.causal_padding, 0), value=0.0)
        x = super(CausalConv1d, self).forward(x)
        return x


class CausalBlock1D(Block1D):
    def __init__(self, dim: int, dim_out: int):
        super(CausalBlock1D, self).__init__(dim, dim_out)
        self.block = torch.nn.Sequential(
            CausalConv1d(dim, dim_out, 3),
            Transpose(1, 2),
            nn.LayerNorm(dim_out),
            Transpose(1, 2),
            nn.Mish(),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        output = self.block(x * mask)
        return output * mask


class CausalResnetBlock1D(ResnetBlock1D):
    def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
        super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
        self.block1 = CausalBlock1D(dim, dim_out)
        self.block2 = CausalBlock1D(dim_out, dim_out)


class ConditionalDecoder(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        channels=(256, 256),
        dropout=0.05,
        attention_head_dim=64,
        n_blocks=1,
        num_mid_blocks=2,
        num_heads=4,
        act_fn="snake",
    ):
        """
        This decoder requires an input with the same shape of the target. So, if your text content
        is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
        """
        super().__init__()
        channels = tuple(channels)
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.time_embeddings = SinusoidalPosEmb(in_channels)
        time_embed_dim = channels[0] * 4
        self.time_mlp = TimestepEmbedding(
            in_channels=in_channels,
            time_embed_dim=time_embed_dim,
            act_fn="silu",
        )
        self.down_blocks = nn.ModuleList([])
        self.mid_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

        output_channel = in_channels
        for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
            input_channel = output_channel
            output_channel = channels[i]
            is_last = i == len(channels) - 1
            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            downsample = (
                Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
            )
            self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))

        for _ in range(num_mid_blocks):
            input_channel = channels[-1]
            out_channels = channels[-1]
            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)

            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )

            self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))

        channels = channels[::-1] + (channels[0],)
        for i in range(len(channels) - 1):
            input_channel = channels[i] * 2
            output_channel = channels[i + 1]
            is_last = i == len(channels) - 2
            resnet = ResnetBlock1D(
                dim=input_channel,
                dim_out=output_channel,
                time_emb_dim=time_embed_dim,
            )
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            upsample = (
                Upsample1D(output_channel, use_conv_transpose=True)
                if not is_last
                else nn.Conv1d(output_channel, output_channel, 3, padding=1)
            )
            self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
        self.final_block = Block1D(channels[-1], channels[-1])
        self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GroupNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
        """Forward pass of the UNet1DConditional model.

        Args:
            x (torch.Tensor): shape (batch_size, in_channels, time)
            mask (_type_): shape (batch_size, 1, time)
            t (_type_): shape (batch_size)
            spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
            cond (_type_, optional): placeholder for future use. Defaults to None.

        Raises:
            ValueError: _description_
            ValueError: _description_

        Returns:
            _type_: _description_
        """

        t = self.time_embeddings(t).to(t.dtype)
        t = self.time_mlp(t)

        x = pack([x, mu], "b * t")[0]

        if spks is not None:
            spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
            x = pack([x, spks], "b * t")[0]
        if cond is not None:
            x = pack([x, cond], "b * t")[0]

        hiddens = []
        masks = [mask]
        for resnet, transformer_blocks, downsample in self.down_blocks:
            mask_down = masks[-1]
            x = resnet(x, mask_down, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            hiddens.append(x)  # Save hidden states for skip connections
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, ::2])
        masks = masks[:-1]
        mask_mid = masks[-1]

        for resnet, transformer_blocks in self.mid_blocks:
            x = resnet(x, mask_mid, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()

        for resnet, transformer_blocks, upsample in self.up_blocks:
            mask_up = masks.pop()
            skip = hiddens.pop()
            x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
            x = resnet(x, mask_up, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            x = upsample(x * mask_up)
        x = self.final_block(x, mask_up)
        output = self.final_proj(x * mask_up)
        return output * mask


class CausalConditionalDecoder(ConditionalDecoder):
    def __init__(
        self,
        in_channels,
        out_channels,
        channels=(256, 256),
        dropout=0.05,
        attention_head_dim=64,
        n_blocks=1,
        num_mid_blocks=2,
        num_heads=4,
        act_fn="snake",
        static_chunk_size=50,
        num_decoding_left_chunks=2,
    ):
        """
        This decoder requires an input with the same shape of the target. So, if your text content
        is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
        """
        torch.nn.Module.__init__(self)
        channels = tuple(channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.time_embeddings = SinusoidalPosEmb(in_channels)
        time_embed_dim = channels[0] * 4
        self.time_mlp = TimestepEmbedding(
            in_channels=in_channels,
            time_embed_dim=time_embed_dim,
            act_fn="silu",
        )
        self.static_chunk_size = static_chunk_size
        self.num_decoding_left_chunks = num_decoding_left_chunks
        self.down_blocks = nn.ModuleList([])
        self.mid_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

        output_channel = in_channels
        for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
            input_channel = output_channel
            output_channel = channels[i]
            is_last = i == len(channels) - 1
            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            downsample = (
                Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
            )
            self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))

        for _ in range(num_mid_blocks):
            input_channel = channels[-1]
            out_channels = channels[-1]
            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)

            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )

            self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))

        channels = channels[::-1] + (channels[0],)
        for i in range(len(channels) - 1):
            input_channel = channels[i] * 2
            output_channel = channels[i + 1]
            is_last = i == len(channels) - 2
            resnet = CausalResnetBlock1D(
                dim=input_channel,
                dim_out=output_channel,
                time_emb_dim=time_embed_dim,
            )
            transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=output_channel,
                        num_attention_heads=num_heads,
                        attention_head_dim=attention_head_dim,
                        dropout=dropout,
                        activation_fn=act_fn,
                    )
                    for _ in range(n_blocks)
                ]
            )
            upsample = (
                Upsample1D(output_channel, use_conv_transpose=True)
                if not is_last
                else CausalConv1d(output_channel, output_channel, 3)
            )
            self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
        self.final_block = CausalBlock1D(channels[-1], channels[-1])
        self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
        self.initialize_weights()

    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
        """Forward pass of the UNet1DConditional model.

        Args:
            x (torch.Tensor): shape (batch_size, in_channels, time)
            mask (_type_): shape (batch_size, 1, time)
            t (_type_): shape (batch_size)
            spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
            cond (_type_, optional): placeholder for future use. Defaults to None.

        Raises:
            ValueError: _description_
            ValueError: _description_

        Returns:
            _type_: _description_
        """
        t = self.time_embeddings(t).to(t.dtype)
        t = self.time_mlp(t)

        x = pack([x, mu], "b * t")[0]

        if spks is not None:
            spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
            x = pack([x, spks], "b * t")[0]
        if cond is not None:
            x = pack([x, cond], "b * t")[0]

        hiddens = []
        masks = [mask]
        for resnet, transformer_blocks, downsample in self.down_blocks:
            mask_down = masks[-1]
            x = resnet(x, mask_down, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            if streaming is True:
                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
            else:
                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            hiddens.append(x)  # Save hidden states for skip connections
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, ::2])
        masks = masks[:-1]
        mask_mid = masks[-1]

        for resnet, transformer_blocks in self.mid_blocks:
            x = resnet(x, mask_mid, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            if streaming is True:
                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
            else:
                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()

        for resnet, transformer_blocks, upsample in self.up_blocks:
            mask_up = masks.pop()
            skip = hiddens.pop()
            x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
            x = resnet(x, mask_up, t)
            x = rearrange(x, "b c t -> b t c").contiguous()
            if streaming is True:
                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
            else:
                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
            attn_mask = mask_to_bias(attn_mask, x.dtype)
            for transformer_block in transformer_blocks:
                x = transformer_block(
                    hidden_states=x,
                    attention_mask=attn_mask,
                    timestep=t,
                )
            x = rearrange(x, "b t c -> b c t").contiguous()
            x = upsample(x * mask_up)
        x = self.final_block(x, mask_up)
        output = self.final_proj(x * mask_up)
        return output * mask


================================================
FILE: cosyvoice/flow/flow.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path


class MaskedDiffWithXvec(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 80,
                 spk_embed_dim: int = 192,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 encoder: torch.nn.Module = None,
                 length_regulator: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
        self.encoder = encoder
        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
        self.decoder = decoder
        self.length_regulator = length_regulator
        self.only_mask_loss = only_mask_loss

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        token = batch['speech_token'].to(device)
        token_len = batch['speech_token_len'].to(device)
        feat = batch['speech_feat'].to(device)
        feat_len = batch['speech_feat_len'].to(device)
        embedding = batch['embedding'].to(device)

        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len)
        h = self.encoder_proj(h)
        h, h_lengths = self.length_regulator(h, feat_len)

        # get conditions
        conds = torch.zeros(feat.shape, device=token.device)
        for i, j in enumerate(feat_len):
            if random.random() < 0.5:
                continue
            index = random.randint(0, int(0.3 * j))
            conds[i, :index] = feat[i, :index]
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(feat_len)).to(h)
        # NOTE this is unnecessary, feat/h already same shape
        loss, _ = self.decoder.compute_loss(
            feat.transpose(1, 2).contiguous(),
            mask.unsqueeze(1),
            h.transpose(1, 2).contiguous(),
            embedding,
            cond=conds
        )
        return {'loss': loss}

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  prompt_token,
                  prompt_token_len,
                  prompt_feat,
                  prompt_feat_len,
                  embedding,
                  flow_cache):
        assert token.shape[0] == 1
        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat speech token and prompt speech token
        token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len)
        h = self.encoder_proj(h)
        mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)

        # get conditions
        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
        conds[:, :mel_len1] = prompt_feat
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
        feat, flow_cache = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=embedding,
            cond=conds,
            n_timesteps=10,
            prompt_len=mel_len1,
            cache=flow_cache
        )
        feat = feat[:, :, mel_len1:]
        assert feat.shape[2] == mel_len2
        return feat.float(), flow_cache


class CausalMaskedDiffWithXvec(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 80,
                 spk_embed_dim: int = 192,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 token_mel_ratio: int = 2,
                 pre_lookahead_len: int = 3,
                 encoder: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
        self.encoder = encoder
        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
        self.decoder = decoder
        self.only_mask_loss = only_mask_loss
        self.token_mel_ratio = token_mel_ratio
        self.pre_lookahead_len = pre_lookahead_len
        if online_feature is True:
            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        if 'speech_token' not in batch:
            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
        else:
            token = batch['speech_token'].to(device)
            token_len = batch['speech_token_len'].to(device)
        feat = batch['speech_feat'].to(device)
        feat_len = batch['speech_feat_len'].to(device)
        embedding = batch['embedding'].to(device)

        # NOTE unified training, static_chunk_size > 0 or = 0
        streaming = True if random.random() < 0.5 else False

        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h, h_lengths = self.encoder(token, token_len, streaming=streaming)
        h = self.encoder_proj(h)

        # get conditions
        conds = torch.zeros(feat.shape, device=token.device)
        for i, j in enumerate(feat_len):
            if random.random() < 0.5:
                continue
            index = random.randint(0, int(0.3 * j))
            conds[i, :index] = feat[i, :index]
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
        loss, _ = self.decoder.compute_loss(
            feat.transpose(1, 2).contiguous(),
            mask.unsqueeze(1),
            h.transpose(1, 2).contiguous(),
            embedding,
            cond=conds,
            streaming=streaming,
        )
        return {'loss': loss}

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  prompt_token,
                  prompt_token_len,
                  prompt_feat,
                  prompt_feat_len,
                  embedding,
                  streaming,
                  finalize):
        assert token.shape[0] == 1
        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        if finalize is True:
            h, h_lengths = self.encoder(token, token_len, streaming=streaming)
        else:
            token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
            h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
        h = self.encoder_proj(h)

        # get conditions
        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
        conds[:, :mel_len1] = prompt_feat
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
        feat, _ = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=embedding,
            cond=conds,
            n_timesteps=10,
            streaming=streaming
        )
        feat = feat[:, :, mel_len1:]
        assert feat.shape[2] == mel_len2
        return feat.float(), None


class CausalMaskedDiffWithDiT(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 80,
                 spk_embed_dim: int = 192,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 token_mel_ratio: int = 2,
                 pre_lookahead_len: int = 3,
                 pre_lookahead_layer: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)
        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
        self.pre_lookahead_len = pre_lookahead_len
        self.pre_lookahead_layer = pre_lookahead_layer
        self.decoder = decoder
        self.only_mask_loss = only_mask_loss
        self.token_mel_ratio = token_mel_ratio
        if online_feature is True:
            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))

    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        if 'speech_token' not in batch:
            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
        else:
            token = batch['speech_token'].to(device)
            token_len = batch['speech_token_len'].to(device)
        feat = batch['speech_feat'].to(device)
        feat_len = batch['speech_feat_len'].to(device)
        embedding = batch['embedding'].to(device)

        # NOTE unified training, static_chunk_size > 0 or = 0
        streaming = True if random.random() < 0.5 else False

        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        h = self.pre_lookahead_layer(token)
        h = h.repeat_interleave(self.token_mel_ratio, dim=1)
        mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)

        # get conditions
        conds = torch.zeros(feat.shape, device=token.device)
        for i, j in enumerate(feat_len):
            if random.random() < 0.5:
                continue
            index = random.randint(0, int(0.3 * j))
            conds[i, :index] = feat[i, :index]
        conds = conds.transpose(1, 2)

        loss, _ = self.decoder.compute_loss(
            feat.transpose(1, 2).contiguous(),
            mask.unsqueeze(1),
            h.transpose(1, 2).contiguous(),
            embedding,
            cond=conds,
            streaming=streaming,
        )
        return {'loss': loss}

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  prompt_token,
                  prompt_token_len,
                  prompt_feat,
                  prompt_feat_len,
                  embedding,
                  streaming,
                  finalize):
        assert token.shape[0] == 1
        # xvec projection
        embedding = F.normalize(embedding, dim=1)
        embedding = self.spk_embed_affine_layer(embedding)

        # concat text and prompt_text
        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
        token = self.input_embedding(torch.clamp(token, min=0)) * mask

        # text encode
        if finalize is True:
            h = self.pre_lookahead_layer(token)
        else:
            h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
        h = h.repeat_interleave(self.token_mel_ratio, dim=1)
        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]

        # get conditions
        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
        conds[:, :mel_len1] = prompt_feat
        conds = conds.transpose(1, 2)

        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
        feat, _ = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=embedding,
            cond=conds,
            n_timesteps=10,
            streaming=streaming
        )
        feat = feat[:, :, mel_len1:]
        assert feat.shape[2] == mel_len2
        return feat.float(), None


if __name__ == '__main__':
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    from hyperpyyaml import load_hyperpyyaml
    with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
        configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
    model = configs['flow']
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()
    max_len = 10 * model.decoder.estimator.static_chunk_size
    chunk_size = model.decoder.estimator.static_chunk_size
    context_size = model.pre_lookahead_layer.pre_lookahead_len
    token = torch.randint(0, 6561, size=(1, max_len)).to(device)
    token_len = torch.tensor([max_len]).to(device)
    prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
    prompt_token_len = torch.tensor([chunk_size]).to(device)
    prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
    prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
    prompt_embedding = torch.rand(1, 192).to(device)
    pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
    for i in range(0, max_len, chunk_size):
        finalize = True if i + chunk_size + context_size >= max_len else False
        pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
                                        prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
        pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
        print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())


================================================
FILE: cosyvoice/flow/flow_matching.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
from cosyvoice.utils.common import set_all_random_seed


class ConditionalCFM(BASECFM):
    def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = 
Download .txt
gitextract_xzebwrda/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   └── feature_request.md
│   └── workflows/
│       ├── lint.yml
│       └── stale-issues.yml
├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── FAQ.md
├── LICENSE
├── README.md
├── cosyvoice/
│   ├── __init__.py
│   ├── bin/
│   │   ├── average_model.py
│   │   ├── export_jit.py
│   │   ├── export_onnx.py
│   │   └── train.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── cosyvoice.py
│   │   ├── frontend.py
│   │   └── model.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── processor.py
│   ├── flow/
│   │   ├── DiT/
│   │   │   ├── dit.py
│   │   │   └── modules.py
│   │   ├── decoder.py
│   │   ├── flow.py
│   │   ├── flow_matching.py
│   │   └── length_regulator.py
│   ├── hifigan/
│   │   ├── discriminator.py
│   │   ├── f0_predictor.py
│   │   ├── generator.py
│   │   └── hifigan.py
│   ├── llm/
│   │   └── llm.py
│   ├── tokenizer/
│   │   ├── assets/
│   │   │   └── multilingual_zh_ja_yue_char_del.tiktoken
│   │   └── tokenizer.py
│   ├── transformer/
│   │   ├── __init__.py
│   │   ├── activation.py
│   │   ├── attention.py
│   │   ├── convolution.py
│   │   ├── decoder.py
│   │   ├── decoder_layer.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   ├── encoder_layer.py
│   │   ├── label_smoothing_loss.py
│   │   ├── positionwise_feed_forward.py
│   │   ├── subsampling.py
│   │   └── upsample_encoder.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── class_utils.py
│   │   ├── common.py
│   │   ├── executor.py
│   │   ├── file_utils.py
│   │   ├── frontend_utils.py
│   │   ├── losses.py
│   │   ├── mask.py
│   │   ├── onnx.py
│   │   ├── scheduler.py
│   │   └── train_utils.py
│   └── vllm/
│       └── cosyvoice2.py
├── docker/
│   └── Dockerfile
├── example.py
├── examples/
│   ├── grpo/
│   │   └── cosyvoice2/
│   │       ├── Dockerfile
│   │       ├── README.md
│   │       ├── huggingface_to_pretrained.py
│   │       ├── infer_dataset.py
│   │       ├── prepare_data.py
│   │       ├── pretrained_to_huggingface.py
│   │       ├── requirements.txt
│   │       ├── reward_tts.py
│   │       ├── run.sh
│   │       ├── scripts/
│   │       │   ├── compute_wer.sh
│   │       │   └── offline-decode-files.py
│   │       └── token2wav_asr_server.py
│   ├── libritts/
│   │   ├── cosyvoice/
│   │   │   ├── conf/
│   │   │   │   ├── cosyvoice.yaml
│   │   │   │   └── ds_stage2.json
│   │   │   ├── local/
│   │   │   │   ├── download_and_untar.sh
│   │   │   │   ├── prepare_data.py
│   │   │   │   └── prepare_reject_sample.py
│   │   │   ├── path.sh
│   │   │   ├── run.sh
│   │   │   └── tts_text.json
│   │   ├── cosyvoice2/
│   │   │   ├── conf/
│   │   │   │   ├── cosyvoice2.yaml
│   │   │   │   └── ds_stage2.json
│   │   │   ├── run.sh
│   │   │   └── run_dpo.sh
│   │   └── cosyvoice3/
│   │       ├── conf/
│   │       │   ├── cosyvoice3.yaml
│   │       │   └── ds_stage2.json
│   │       └── run.sh
│   └── magicdata-read/
│       └── cosyvoice/
│           ├── local/
│           │   ├── download_and_untar.sh
│           │   └── prepare_data.py
│           ├── run.sh
│           └── tts_text.json
├── requirements.txt
├── runtime/
│   ├── python/
│   │   ├── Dockerfile
│   │   ├── fastapi/
│   │   │   ├── client.py
│   │   │   └── server.py
│   │   └── grpc/
│   │       ├── client.py
│   │       ├── cosyvoice.proto
│   │       └── server.py
│   └── triton_trtllm/
│       ├── Dockerfile.server
│       ├── README.Cosyvoice2.DiT.md
│       ├── README.Cosyvoice2.Unet.md
│       ├── README.Cosyvoice3.md
│       ├── README.md
│       ├── client_grpc.py
│       ├── client_http.py
│       ├── docker-compose.cosyvoice2.dit.yml
│       ├── docker-compose.cosyvoice2.unet.yml
│       ├── docker-compose.cosyvoice3.yml
│       ├── infer_cosyvoice3.py
│       ├── model_repo/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice2/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice2_dit/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── speaker_embedding/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── tensorrt_llm/
│       │   │   ├── 1/
│       │   │   │   └── .gitkeep
│       │   │   └── config.pbtxt
│       │   ├── token2wav/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   └── token2wav_dit/
│       │       ├── 1/
│       │       │   ├── model.py
│       │       │   └── token2wav_dit.py
│       │       └── config.pbtxt
│       ├── model_repo_cosyvoice3/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── cosyvoice3/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── speaker_embedding/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── token2wav/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   └── vocoder/
│       │       ├── 1/
│       │       │   └── model.py
│       │       └── config.pbtxt
│       ├── offline_inference.py
│       ├── requirements.txt
│       ├── run.sh
│       ├── run_cosyvoice3.sh
│       ├── run_stepaudio2_dit_token2wav.sh
│       ├── scripts/
│       │   ├── convert_checkpoint.py
│       │   ├── convert_cosyvoice3_to_hf.py
│       │   ├── fill_template.py
│       │   └── test_llm.py
│       ├── streaming_inference.py
│       ├── token2wav.py
│       └── token2wav_cosyvoice3.py
├── tools/
│   ├── extract_embedding.py
│   ├── extract_speech_token.py
│   └── make_parquet_list.py
├── vllm_example.py
└── webui.py
Download .txt
SYMBOL INDEX (786 symbols across 87 files)

FILE: cosyvoice/bin/average_model.py
  function get_args (line 24) | def get_args():
  function main (line 43) | def main():

FILE: cosyvoice/bin/export_jit.py
  function get_args (line 30) | def get_args():
  function get_optimized_script (line 41) | def get_optimized_script(model, preserved_attrs=[]):
  function main (line 51) | def main():

FILE: cosyvoice/bin/export_onnx.py
  function get_dummy_input (line 34) | def get_dummy_input(batch_size, seq_len, out_channels, device):
  function get_args (line 44) | def get_args():
  function main (line 56) | def main():

FILE: cosyvoice/bin/train.py
  function get_args (line 40) | def get_args():
  function main (line 98) | def main():

FILE: cosyvoice/cli/cosyvoice.py
  class CosyVoice (line 27) | class CosyVoice:
    method __init__ (line 29) | def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=Fal...
    method list_available_spks (line 65) | def list_available_spks(self):
    method add_zero_shot_spk (line 69) | def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
    method save_spkinfo (line 77) | def save_spkinfo(self):
    method inference_sft (line 80) | def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, tex...
    method inference_zero_shot (line 91) | def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_...
    method inference_cross_lingual (line 105) | def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_...
    method inference_instruct (line 116) | def inference_instruct(self, tts_text, spk_id, instruct_text, stream=F...
    method inference_vc (line 129) | def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
  class CosyVoice2 (line 139) | class CosyVoice2(CosyVoice):
    method __init__ (line 141) | def __init__(self, model_dir, load_jit=False, load_trt=False, load_vll...
    method inference_instruct2 (line 177) | def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zer...
  class CosyVoice3 (line 189) | class CosyVoice3(CosyVoice2):
    method __init__ (line 191) | def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=Fa...
  function AutoModel (line 228) | def AutoModel(**kwargs):

FILE: cosyvoice/cli/frontend.py
  class CosyVoiceFrontEnd (line 30) | class CosyVoiceFrontEnd:
    method __init__ (line 32) | def __init__(self,
    method _extract_text_token (line 78) | def _extract_text_token(self, text):
    method _extract_text_token_generator (line 89) | def _extract_text_token_generator(self, text_generator):
    method _extract_speech_token (line 95) | def _extract_speech_token(self, prompt_wav):
    method _extract_spk_embedding (line 108) | def _extract_spk_embedding(self, prompt_wav):
    method _extract_speech_feat (line 120) | def _extract_speech_feat(self, prompt_wav):
    method text_normalize (line 127) | def text_normalize(self, text, split=True, text_frontend=True):
    method frontend_sft (line 162) | def frontend_sft(self, tts_text, spk_id):
    method frontend_zero_shot (line 168) | def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resamp...
    method frontend_cross_lingual (line 191) | def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, ...
    method frontend_instruct (line 200) | def frontend_instruct(self, tts_text, spk_id, instruct_text):
    method frontend_instruct2 (line 209) | def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resa...
    method frontend_vc (line 215) | def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):

FILE: cosyvoice/cli/model.py
  class CosyVoiceModel (line 29) | class CosyVoiceModel:
    method __init__ (line 31) | def __init__(self,
    method load (line 65) | def load(self, llm_model, flow_model, hift_model):
    method load_jit (line 75) | def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder...
    method load_trt (line 83) | def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_mod...
    method get_trt_kwargs (line 94) | def get_trt_kwargs(self):
    method llm_job (line 101) | def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embe...
    method vc_job (line 131) | def vc_job(self, source_speech_token, uuid):
    method token2wav (line 135) | def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid,...
    method tts (line 175) | def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embeddin...
  class CosyVoice2Model (line 245) | class CosyVoice2Model(CosyVoiceModel):
    method __init__ (line 247) | def __init__(self,
    method load_jit (line 277) | def load_jit(self, flow_encoder_model):
    method load_vllm (line 281) | def load_vllm(self, model_dir):
    method token2wav (line 292) | def token2wav(self, token, prompt_token, prompt_feat, embedding, token...
    method tts (line 328) | def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embeddin...
  class CosyVoice3Model (line 397) | class CosyVoice3Model(CosyVoice2Model):
    method __init__ (line 399) | def __init__(self,
    method token2wav (line 425) | def token2wav(self, token, prompt_token, prompt_feat, embedding, token...

FILE: cosyvoice/dataset/dataset.py
  class Processor (line 26) | class Processor(IterableDataset):
    method __init__ (line 28) | def __init__(self, source, f, *args, **kw):
    method set_epoch (line 35) | def set_epoch(self, epoch):
    method __iter__ (line 38) | def __iter__(self):
    method apply (line 46) | def apply(self, f):
  class DistributedSampler (line 51) | class DistributedSampler:
    method __init__ (line 53) | def __init__(self, shuffle=True, partition=True):
    method update (line 59) | def update(self):
    method set_epoch (line 79) | def set_epoch(self, epoch):
    method sample (line 82) | def sample(self, data):
  class DataList (line 107) | class DataList(IterableDataset):
    method __init__ (line 109) | def __init__(self, lists, shuffle=True, partition=True):
    method set_epoch (line 113) | def set_epoch(self, epoch):
    method __iter__ (line 116) | def __iter__(self):
  function Dataset (line 125) | def Dataset(data_list_file,

FILE: cosyvoice/dataset/processor.py
  function parquet_opener (line 31) | def parquet_opener(data, mode='train'):
  function filter (line 55) | def filter(data,
  function resample (line 109) | def resample(data, resample_rate=22050, min_sample_rate=16000, mode='tra...
  function truncate (line 137) | def truncate(data, truncate_length=24576, mode='train'):
  function compute_fbank (line 158) | def compute_fbank(data,
  function compute_whisper_fbank (line 183) | def compute_whisper_fbank(data, num_frames=-1, mode='train'):
  function compute_f0 (line 200) | def compute_f0(data, sample_rate, hop_size, mode='train'):
  function parse_embedding (line 225) | def parse_embedding(data, normalize, mode='train'):
  function tokenize (line 248) | def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  function shuffle (line 267) | def shuffle(data, shuffle_size=10000, mode='train'):
  function sort (line 292) | def sort(data, sort_size=500, mode='train'):
  function static_batch (line 320) | def static_batch(data, batch_size=16):
  function dynamic_batch (line 340) | def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  function batch (line 369) | def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=...
  function padding (line 380) | def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):

FILE: cosyvoice/flow/DiT/dit.py
  class TextEmbedding (line 33) | class TextEmbedding(nn.Module):
    method __init__ (line 34) | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult...
    method forward (line 48) | def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noq...
  class InputEmbedding (line 76) | class InputEmbedding(nn.Module):
    method __init__ (line 77) | def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
    method forward (line 84) | def forward(
  class DiT (line 104) | class DiT(nn.Module):
    method __init__ (line 105) | def __init__(
    method forward (line 145) | def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):

FILE: cosyvoice/flow/DiT/modules.py
  class MelSpec (line 24) | class MelSpec(nn.Module):
    method __init__ (line 25) | def __init__(
    method forward (line 54) | def forward(self, inp):
  class SinusPositionEmbedding (line 71) | class SinusPositionEmbedding(nn.Module):
    method __init__ (line 72) | def __init__(self, dim):
    method forward (line 76) | def forward(self, x, scale=1000):
  class ConvPositionEmbedding (line 89) | class ConvPositionEmbedding(nn.Module):
    method __init__ (line 90) | def __init__(self, dim, kernel_size=31, groups=16):
    method forward (line 100) | def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):...
  class CausalConvPositionEmbedding (line 115) | class CausalConvPositionEmbedding(nn.Module):
    method __init__ (line 116) | def __init__(self, dim, kernel_size=31, groups=16):
    method forward (line 129) | def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):...
  function precompute_freqs_cis (line 150) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, the...
  function get_pos_embed_indices (line 164) | def get_pos_embed_indices(start, length, max_pos, scale=1.0):
  class GRN (line 179) | class GRN(nn.Module):
    method __init__ (line 180) | def __init__(self, dim):
    method forward (line 185) | def forward(self, x):
  class ConvNeXtV2Block (line 195) | class ConvNeXtV2Block(nn.Module):
    method __init__ (line 196) | def __init__(
    method forward (line 213) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class AdaLayerNormZero (line 230) | class AdaLayerNormZero(nn.Module):
    method __init__ (line 231) | def __init__(self, dim):
    method forward (line 239) | def forward(self, x, emb=None):
  class AdaLayerNormZero_Final (line 251) | class AdaLayerNormZero_Final(nn.Module):
    method __init__ (line 252) | def __init__(self, dim):
    method forward (line 260) | def forward(self, x, emb):
  class FeedForward (line 271) | class FeedForward(nn.Module):
    method __init__ (line 272) | def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate...
    method forward (line 281) | def forward(self, x):
  class Attention (line 289) | class Attention(nn.Module):
    method __init__ (line 290) | def __init__(
    method forward (line 332) | def forward(
  class AttnProcessor (line 349) | class AttnProcessor:
    method __init__ (line 350) | def __init__(self):
    method __call__ (line 353) | def __call__(
  class JointAttnProcessor (line 414) | class JointAttnProcessor:
    method __init__ (line 415) | def __init__(self):
    method __call__ (line 418) | def __call__(
  class DiTBlock (line 500) | class DiTBlock(nn.Module):
    method __init__ (line 501) | def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
    method forward (line 516) | def forward(self, x, t, mask=None, rope=None):  # x: noised input, t: ...
  class MMDiTBlock (line 536) | class MMDiTBlock(nn.Module):
    method __init__ (line 546) | def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, conte...
    method forward (line 572) | def forward(self, x, c, t, mask=None, rope=None, c_rope=None):  # x: n...
  class TimestepEmbedding (line 606) | class TimestepEmbedding(nn.Module):
    method __init__ (line 607) | def __init__(self, dim, freq_embed_dim=256):
    method forward (line 612) | def forward(self, timestep: float["b"]):  # noqa: F821

FILE: cosyvoice/flow/decoder.py
  class Transpose (line 25) | class Transpose(torch.nn.Module):
    method __init__ (line 26) | def __init__(self, dim0: int, dim1: int):
    method forward (line 31) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class CausalConv1d (line 36) | class CausalConv1d(torch.nn.Conv1d):
    method __init__ (line 37) | def __init__(
    method forward (line 59) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class CausalBlock1D (line 65) | class CausalBlock1D(Block1D):
    method __init__ (line 66) | def __init__(self, dim: int, dim_out: int):
    method forward (line 76) | def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch....
  class CausalResnetBlock1D (line 81) | class CausalResnetBlock1D(ResnetBlock1D):
    method __init__ (line 82) | def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: ...
  class ConditionalDecoder (line 88) | class ConditionalDecoder(nn.Module):
    method __init__ (line 89) | def __init__(
    method initialize_weights (line 196) | def initialize_weights(self):
    method forward (line 210) | def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  class CausalConditionalDecoder (line 294) | class CausalConditionalDecoder(ConditionalDecoder):
    method __init__ (line 295) | def __init__(
    method forward (line 405) | def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):

FILE: cosyvoice/flow/flow.py
  class MaskedDiffWithXvec (line 25) | class MaskedDiffWithXvec(torch.nn.Module):
    method __init__ (line 26) | def __init__(self,
    method forward (line 58) | def forward(
    method inference (line 103) | def inference(self,
  class CausalMaskedDiffWithXvec (line 149) | class CausalMaskedDiffWithXvec(torch.nn.Module):
    method __init__ (line 150) | def __init__(self,
    method forward (line 186) | def forward(
    method inference (line 236) | def inference(self,
  class CausalMaskedDiffWithDiT (line 284) | class CausalMaskedDiffWithDiT(torch.nn.Module):
    method __init__ (line 285) | def __init__(self,
    method forward (line 320) | def forward(
    method inference (line 370) | def inference(self,

FILE: cosyvoice/flow/flow_matching.py
  class ConditionalCFM (line 21) | class ConditionalCFM(BASECFM):
    method __init__ (line 22) | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, ...
    method forward (line 37) | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, c...
    method solve_euler (line 71) | def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
    method forward_estimator (line 126) | def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
    method compute_loss (line 155) | def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=F...
  class CausalConditionalCFM (line 196) | class CausalConditionalCFM(ConditionalCFM):
    method __init__ (line 197) | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, ...
    method forward (line 203) | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, c...

FILE: cosyvoice/flow/length_regulator.py
  class InterpolateRegulator (line 21) | class InterpolateRegulator(nn.Module):
    method __init__ (line 22) | def __init__(
    method forward (line 44) | def forward(self, x, ylens=None):
    method inference (line 52) | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):

FILE: cosyvoice/hifigan/discriminator.py
  class MultipleDiscriminator (line 15) | class MultipleDiscriminator(nn.Module):
    method __init__ (line 16) | def __init__(
    method forward (line 23) | def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
  class MultiResolutionDiscriminator (line 38) | class MultiResolutionDiscriminator(nn.Module):
    method __init__ (line 39) | def __init__(
    method forward (line 59) | def forward(
  class DiscriminatorR (line 78) | class DiscriminatorR(nn.Module):
    method __init__ (line 79) | def __init__(
    method spectrogram (line 113) | def spectrogram(self, x):
    method forward (line 125) | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = N...
  class MultiResSpecDiscriminator (line 149) | class MultiResSpecDiscriminator(torch.nn.Module):
    method __init__ (line 151) | def __init__(self,
    method forward (line 163) | def forward(self, y, y_hat):
  function stft (line 179) | def stft(x, fft_size, hop_size, win_length, window):
  class SpecDiscriminator (line 196) | class SpecDiscriminator(nn.Module):
    method __init__ (line 199) | def __init__(self, fft_size=1024, shift_size=120, win_length=600, wind...
    method forward (line 216) | def forward(self, y):

FILE: cosyvoice/hifigan/f0_predictor.py
  class ConvRNNF0Predictor (line 23) | class ConvRNNF0Predictor(nn.Module):
    method __init__ (line 24) | def __init__(self,
    method forward (line 56) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class CausalConvRNNF0Predictor (line 62) | class CausalConvRNNF0Predictor(nn.Module):
    method __init__ (line 63) | def __init__(self,
    method forward (line 95) | def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Ten...

FILE: cosyvoice/hifigan/generator.py
  class ResBlock (line 46) | class ResBlock(torch.nn.Module):
    method __init__ (line 48) | def __init__(
    method forward (line 110) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method remove_weight_norm (line 119) | def remove_weight_norm(self):
  class SineGen (line 125) | class SineGen(torch.nn.Module):
    method __init__ (line 141) | def __init__(self, samp_rate, harmonic_num=0,
    method _f02uv (line 151) | def _f02uv(self, f0):
    method forward (line 157) | def forward(self, f0):
  class SineGen2 (line 192) | class SineGen2(torch.nn.Module):
    method __init__ (line 208) | def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
    method _f02uv (line 228) | def _f02uv(self, f0):
    method _f02sine (line 233) | def _f02sine(self, f0_values):
    method forward (line 289) | def forward(self, f0):
  class SourceModuleHnNSF (line 320) | class SourceModuleHnNSF(torch.nn.Module):
    method __init__ (line 338) | def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine...
    method forward (line 358) | def forward(self, x):
  class HiFTGenerator (line 378) | class HiFTGenerator(nn.Module):
    method __init__ (line 383) | def __init__(
    method remove_weight_norm (line 477) | def remove_weight_norm(self):
    method _stft (line 491) | def _stft(self, x):
    method _istft (line 499) | def _istft(self, magnitude, phase):
    method decode (line 507) | def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, ...
    method forward (line 541) | def forward(
    method inference (line 558) | def inference(self, speech_feat: torch.Tensor, cache_source: torch.Ten...
  class CausalHiFTGenerator (line 572) | class CausalHiFTGenerator(HiFTGenerator):
    method __init__ (line 577) | def __init__(
    method decode (line 672) | def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, ...
    method inference (line 714) | def inference(self, speech_feat: torch.Tensor, finalize: bool = True) ...

FILE: cosyvoice/hifigan/hifigan.py
  class HiFiGan (line 9) | class HiFiGan(nn.Module):
    method __init__ (line 10) | def __init__(self, generator, discriminator, mel_spec_transform,
    method forward (line 22) | def forward(
    method forward_generator (line 32) | def forward_generator(self, batch, device):
    method forward_discriminator (line 53) | def forward_discriminator(self, batch, device):

FILE: cosyvoice/llm/llm.py
  class TransformerLM (line 34) | class TransformerLM(torch.nn.Module):
    method __init__ (line 35) | def __init__(
    method encode (line 81) | def encode(
    method pad_unpad_sequence (line 91) | def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_toke...
    method forward (line 100) | def forward(
    method sampling_ids (line 150) | def sampling_ids(
    method inference (line 163) | def inference(
  class Qwen2Encoder (line 226) | class Qwen2Encoder(torch.nn.Module):
    method __init__ (line 227) | def __init__(self, pretrain_path):
    method forward (line 231) | def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
    method forward_one_step (line 242) | def forward_one_step(self, xs, masks, cache=None):
  class Qwen2LM (line 257) | class Qwen2LM(TransformerLM):
    method __init__ (line 258) | def __init__(
    method prepare_lm_input_target (line 302) | def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb,...
    method forward (line 351) | def forward(
    method forward_dpo (line 407) | def forward_dpo(
    method inference (line 459) | def inference(
    method inference_wrapper (line 505) | def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
    method inference_bistream (line 552) | def inference_bistream(
  class CosyVoice3LM (line 664) | class CosyVoice3LM(Qwen2LM):
    method __init__ (line 665) | def __init__(

FILE: cosyvoice/tokenizer/tokenizer.py
  function get_encoding (line 170) | def get_encoding(name: str = "gpt2", num_languages: int = 99):
  function get_tokenizer (line 210) | def get_tokenizer(
  class CosyVoice2Tokenizer (line 241) | class CosyVoice2Tokenizer():
    method __init__ (line 242) | def __init__(self, token_path, skip_special_tokens=True):
    method encode (line 263) | def encode(self, text, **kwargs):
    method decode (line 268) | def decode(self, tokens):
  class CosyVoice3Tokenizer (line 274) | class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
    method __init__ (line 275) | def __init__(self, token_path, skip_special_tokens=True):
  function get_qwen_tokenizer (line 317) | def get_qwen_tokenizer(

FILE: cosyvoice/transformer/activation.py
  class Swish (line 24) | class Swish(torch.nn.Module):
    method forward (line 27) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Snake (line 34) | class Snake(nn.Module):
    method __init__ (line 50) | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha...
    method forward (line 73) | def forward(self, x):

FILE: cosyvoice/transformer/attention.py
  class MultiHeadedAttention (line 26) | class MultiHeadedAttention(nn.Module):
    method __init__ (line 36) | def __init__(self,
    method forward_qkv (line 53) | def forward_qkv(
    method forward_attention (line 82) | def forward_attention(
    method forward (line 129) | def forward(
  class RelPositionMultiHeadedAttention (line 200) | class RelPositionMultiHeadedAttention(MultiHeadedAttention):
    method __init__ (line 209) | def __init__(self,
    method rel_shift (line 225) | def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
    method forward (line 249) | def forward(

FILE: cosyvoice/transformer/convolution.py
  class ConvolutionModule (line 25) | class ConvolutionModule(nn.Module):
    method __init__ (line 28) | def __init__(self,
    method forward (line 91) | def forward(
  class CausalConv1d (line 150) | class CausalConv1d(torch.nn.Conv1d):
    method __init__ (line 151) | def __init__(
    method forward (line 176) | def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0...
  class CausalConv1dDownSample (line 190) | class CausalConv1dDownSample(torch.nn.Conv1d):
    method __init__ (line 191) | def __init__(
    method forward (line 214) | def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0...
  class CausalConv1dUpsample (line 224) | class CausalConv1dUpsample(torch.nn.Conv1d):
    method __init__ (line 225) | def __init__(
    method forward (line 248) | def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0...

FILE: cosyvoice/transformer/decoder.py
  class TransformerDecoder (line 33) | class TransformerDecoder(torch.nn.Module):
    method __init__ (line 58) | def __init__(
    method forward (line 116) | def forward(
    method forward_layers (line 169) | def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
    method forward_layers_checkpointed (line 178) | def forward_layers_checkpointed(self, x: torch.Tensor,
    method forward_one_step (line 187) | def forward_one_step(
    method tie_or_clone_weights (line 230) | def tie_or_clone_weights(self, jit_mode: bool = True):
  class BiTransformerDecoder (line 256) | class BiTransformerDecoder(torch.nn.Module):
    method __init__ (line 276) | def __init__(
    method forward (line 332) | def forward(
    method forward_one_step (line 367) | def forward_one_step(
    method tie_or_clone_weights (line 392) | def tie_or_clone_weights(self, jit_mode: bool = True):

FILE: cosyvoice/transformer/decoder_layer.py
  class DecoderLayer (line 22) | class DecoderLayer(nn.Module):
    method __init__ (line 41) | def __init__(
    method forward (line 62) | def forward(

FILE: cosyvoice/transformer/embedding.py
  class PositionalEncoding (line 26) | class PositionalEncoding(torch.nn.Module):
    method __init__ (line 37) | def __init__(self,
    method forward (line 59) | def forward(self,
    method position_encoding (line 79) | def position_encoding(self,
  class RelPositionalEncoding (line 120) | class RelPositionalEncoding(PositionalEncoding):
    method __init__ (line 129) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5...
    method forward (line 133) | def forward(self,
  class WhisperPositionalEncoding (line 150) | class WhisperPositionalEncoding(PositionalEncoding):
    method __init__ (line 154) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1...
  class LearnablePositionalEncoding (line 167) | class LearnablePositionalEncoding(PositionalEncoding):
    method __init__ (line 171) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 4...
  class NoPositionalEncoding (line 178) | class NoPositionalEncoding(torch.nn.Module):
    method __init__ (line 182) | def __init__(self, d_model: int, dropout_rate: float):
    method forward (line 187) | def forward(self,
    method position_encoding (line 196) | def position_encoding(self, offset: Union[int, torch.Tensor],
  class EspnetRelPositionalEncoding (line 201) | class EspnetRelPositionalEncoding(torch.nn.Module):
    method __init__ (line 215) | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5...
    method extend_pe (line 224) | def extend_pe(self, x: torch.Tensor):
    method forward (line 256) | def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = ...
    method position_encoding (line 272) | def position_encoding(self,

FILE: cosyvoice/transformer/encoder.py
  class BaseEncoder (line 37) | class BaseEncoder(torch.nn.Module):
    method __init__ (line 39) | def __init__(
    method output_size (line 108) | def output_size(self) -> int:
    method forward (line 111) | def forward(
    method forward_layers (line 165) | def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
    method forward_layers_checkpointed (line 173) | def forward_layers_checkpointed(self, xs: torch.Tensor,
    method forward_chunk (line 184) | def forward_chunk(
    method forward_chunk_by_chunk (line 275) | def forward_chunk_by_chunk(
  class TransformerEncoder (line 338) | class TransformerEncoder(BaseEncoder):
    method __init__ (line 341) | def __init__(
  class ConformerEncoder (line 387) | class ConformerEncoder(BaseEncoder):
    method __init__ (line 390) | def __init__(

FILE: cosyvoice/transformer/encoder_layer.py
  class TransformerEncoderLayer (line 24) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 58) | def forward(
  class ConformerEncoderLayer (line 109) | class ConformerEncoderLayer(nn.Module):
    method __init__ (line 129) | def __init__(
    method forward (line 160) | def forward(

FILE: cosyvoice/transformer/label_smoothing_loss.py
  class LabelSmoothingLoss (line 21) | class LabelSmoothingLoss(nn.Module):
    method __init__ (line 54) | def __init__(self,
    method forward (line 68) | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

FILE: cosyvoice/transformer/positionwise_feed_forward.py
  class PositionwiseFeedForward (line 20) | class PositionwiseFeedForward(torch.nn.Module):
    method __init__ (line 33) | def __init__(
    method forward (line 47) | def forward(self, xs: torch.Tensor) -> torch.Tensor:
  class MoEFFNLayer (line 58) | class MoEFFNLayer(torch.nn.Module):
    method __init__ (line 75) | def __init__(
    method forward (line 91) | def forward(self, xs: torch.Tensor) -> torch.Tensor:

FILE: cosyvoice/transformer/subsampling.py
  class BaseSubsampling (line 23) | class BaseSubsampling(torch.nn.Module):
    method __init__ (line 25) | def __init__(self):
    method position_encoding (line 30) | def position_encoding(self, offset: Union[int, torch.Tensor],
  class EmbedinigNoSubsampling (line 35) | class EmbedinigNoSubsampling(BaseSubsampling):
    method __init__ (line 39) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 45) | def forward(
  class LinearNoSubsampling (line 69) | class LinearNoSubsampling(BaseSubsampling):
    method __init__ (line 79) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 92) | def forward(
  class Conv1dSubsampling2 (line 116) | class Conv1dSubsampling2(BaseSubsampling):
    method __init__ (line 128) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 145) | def forward(
  class Conv2dSubsampling4 (line 173) | class Conv2dSubsampling4(BaseSubsampling):
    method __init__ (line 183) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 202) | def forward(
  class Conv2dSubsampling6 (line 230) | class Conv2dSubsampling6(BaseSubsampling):
    method __init__ (line 239) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 256) | def forward(
  class Conv2dSubsampling8 (line 282) | class Conv2dSubsampling8(BaseSubsampling):
    method __init__ (line 292) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 311) | def forward(
  class LegacyLinearNoSubsampling (line 338) | class LegacyLinearNoSubsampling(BaseSubsampling):
    method __init__ (line 348) | def __init__(self, idim: int, odim: int, dropout_rate: float,
    method forward (line 362) | def forward(

FILE: cosyvoice/transformer/upsample_encoder.py
  class Upsample1D (line 37) | class Upsample1D(nn.Module):
    method __init__ (line 51) | def __init__(self, channels: int, out_channels: int, stride: int = 2):
    method forward (line 59) | def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -...
  class PreLookaheadLayer (line 66) | class PreLookaheadLayer(nn.Module):
    method __init__ (line 67) | def __init__(self, in_channels: int, channels: int, pre_lookahead_len:...
    method forward (line 82) | def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch....
  class UpsampleConformerEncoder (line 106) | class UpsampleConformerEncoder(torch.nn.Module):
    method __init__ (line 108) | def __init__(
    method output_size (line 241) | def output_size(self) -> int:
    method forward (line 244) | def forward(
    method forward_layers (line 309) | def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
    method forward_up_layers (line 316) | def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,

FILE: cosyvoice/utils/class_utils.py
  function get_model_type (line 77) | def get_model_type(configs):

FILE: cosyvoice/utils/common.py
  function pad_list (line 56) | def pad_list(xs: List[torch.Tensor], pad_value: int):
  function th_accuracy (line 105) | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
  function get_padding (line 127) | def get_padding(kernel_size, dilation=1):
  function init_weights (line 131) | def init_weights(m, mean=0.0, std=0.01):
  function ras_sampling (line 138) | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, t...
  function nucleus_sampling (line 147) | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
  function random_sampling (line 165) | def random_sampling(weighted_scores, decoded_tokens, sampling):
  function fade_in_out (line 170) | def fade_in_out(fade_in_mel, fade_out_mel, window):
  function set_all_random_seed (line 181) | def set_all_random_seed(seed):
  function mask_to_bias (line 188) | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  class TrtContextWrapper (line 199) | class TrtContextWrapper:
    method __init__ (line 200) | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
    method acquire_estimator (line 210) | def acquire_estimator(self):
    method release_estimator (line 213) | def release_estimator(self, context, stream):

FILE: cosyvoice/utils/executor.py
  class Executor (line 26) | class Executor:
    method __init__ (line 28) | def __init__(self, gan: bool = False, ref_model: torch.nn.Module = Non...
    method train_one_epoc (line 37) | def train_one_epoc(self, model, optimizer, scheduler, train_data_loade...
    method train_one_epoc_gan (line 88) | def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d,...
    method cv (line 147) | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=Tr...

FILE: cosyvoice/utils/file_utils.py
  function read_lists (line 27) | def read_lists(list_file):
  function read_json_lists (line 35) | def read_json_lists(list_file):
  function load_wav (line 44) | def load_wav(wav, target_sr, min_sr=16000):
  function convert_onnx_to_trt (line 53) | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
  function export_cosyvoice2_vllm (line 92) | def export_cosyvoice2_vllm(model, model_path, device):

FILE: cosyvoice/utils/frontend_utils.py
  function contains_chinese (line 21) | def contains_chinese(text):
  function replace_corner_mark (line 26) | def replace_corner_mark(text):
  function remove_bracket (line 33) | def remove_bracket(text):
  function spell_out_number (line 42) | def spell_out_number(text: str, inflect_parser):
  function split_paragraph (line 65) | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, toke...
  function replace_blank (line 121) | def replace_blank(text: str):
  function is_only_punctuation (line 133) | def is_only_punctuation(text):

FILE: cosyvoice/utils/losses.py
  function tpr_loss (line 6) | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
  function mel_loss (line 15) | def mel_loss(real_speech, generated_speech, mel_transforms):
  class DPOLoss (line 24) | class DPOLoss(torch.nn.Module):
    method __init__ (line 29) | def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: boo...
    method forward (line 35) | def forward(

FILE: cosyvoice/utils/mask.py
  function subsequent_mask (line 53) | def subsequent_mask(
  function subsequent_chunk_mask_deprecated (line 89) | def subsequent_chunk_mask_deprecated(
  function subsequent_chunk_mask (line 127) | def subsequent_chunk_mask(
  function add_optional_chunk_mask (line 161) | def add_optional_chunk_mask(xs: torch.Tensor,
  function make_pad_mask (line 239) | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:

FILE: cosyvoice/utils/onnx.py
  class SpeechTokenExtractor (line 7) | class SpeechTokenExtractor():
    method __init__ (line 8) | def __init__(self, model_path):
    method inference (line 17) | def inference(self, feat, feat_lengths, device):
  class EmbeddingExtractor (line 26) | class EmbeddingExtractor():
    method __init__ (line 27) | def __init__(self, model_path):
    method inference (line 36) | def inference(self, speech):

FILE: cosyvoice/utils/scheduler.py
  class WarmupLR (line 27) | class WarmupLR(_LRScheduler):
    method __init__ (line 44) | def __init__(
    method __repr__ (line 56) | def __repr__(self):
    method get_lr (line 59) | def get_lr(self):
    method set_step (line 70) | def set_step(self, step: int):
  class WarmupPolicy (line 74) | class WarmupPolicy(_LRScheduler):
    method __init__ (line 84) | def __init__(self,
    method get_lr (line 110) | def get_lr(self):
    method _get_warmup_lr (line 128) | def _get_warmup_lr(self, step):
    method _get_lr (line 132) | def _get_lr(self, step):
  class SquareRootConstantPolicy (line 137) | class SquareRootConstantPolicy(_LRScheduler):
    method __init__ (line 147) | def __init__(self,
    method get_lr (line 175) | def get_lr(self):
    method _get_lr (line 193) | def _get_lr(self, step):
  class WarmupHoldPolicy (line 198) | class WarmupHoldPolicy(WarmupPolicy):
    method __init__ (line 212) | def __init__(
    method get_lr (line 257) | def get_lr(self):
  class WarmupAnnealHoldPolicy (line 282) | class WarmupAnnealHoldPolicy(_LRScheduler):
    method __init__ (line 295) | def __init__(
    method get_lr (line 340) | def get_lr(self):
    method _get_warmup_lr (line 365) | def _get_warmup_lr(self, step):
    method _get_constant_lr (line 369) | def _get_constant_lr(self, step):
    method _get_lr (line 372) | def _get_lr(self, step):
  function _squareroot_annealing (line 377) | def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
  function _square_annealing (line 384) | def _square_annealing(initial_lr, step, max_steps, min_lr):
  function _cosine_annealing (line 391) | def _cosine_annealing(initial_lr, step, max_steps, min_lr):
  function _linear_warmup_with_cosine_annealing (line 397) | def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
  function _poly_decay (line 421) | def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
  function _noam_hold_annealing (line 433) | def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
  class SquareAnnealing (line 444) | class SquareAnnealing(WarmupPolicy):
    method __init__ (line 446) | def __init__(self,
    method _get_lr (line 459) | def _get_lr(self, step):
  class SquareRootAnnealing (line 471) | class SquareRootAnnealing(WarmupPolicy):
    method __init__ (line 473) | def __init__(self,
    method _get_lr (line 486) | def _get_lr(self, step):
  class CosineAnnealing (line 497) | class CosineAnnealing(WarmupAnnealHoldPolicy):
    method __init__ (line 499) | def __init__(self,
    method _get_lr (line 512) | def _get_lr(self, step):
    method _get_warmup_lr (line 532) | def _get_warmup_lr(self, step):
    method _get_constant_lr (line 539) | def _get_constant_lr(self, step):
    method _get_linear_warmup_with_cosine_annealing_lr (line 543) | def _get_linear_warmup_with_cosine_annealing_lr(self, step):
  class NoamAnnealing (line 558) | class NoamAnnealing(_LRScheduler):
    method __init__ (line 560) | def __init__(self,
    method get_lr (line 588) | def get_lr(self):
    method _noam_annealing (line 610) | def _noam_annealing(self, initial_lr, step):
  class NoamHoldAnnealing (line 623) | class NoamHoldAnnealing(WarmupHoldPolicy):
    method __init__ (line 625) | def __init__(self,
    method _get_lr (line 693) | def _get_lr(self, step):
    method set_step (line 715) | def set_step(self, step: int):
  class ConstantLR (line 719) | class ConstantLR(_LRScheduler):
    method __init__ (line 726) | def __init__(
    method get_lr (line 734) | def get_lr(self):
    method set_step (line 737) | def set_step(self, step: int):

FILE: cosyvoice/utils/train_utils.py
  function init_distributed (line 39) | def init_distributed(args):
  function init_dataset_and_dataloader (line 53) | def init_dataset_and_dataloader(args, configs, gan, dpo):
  function check_modify_and_save_config (line 72) | def check_modify_and_save_config(args, configs):
  function wrap_cuda_model (line 94) | def wrap_cuda_model(args, model):
  function init_optimizer_and_scheduler (line 111) | def init_optimizer_and_scheduler(args, configs, model, gan):
  function init_summarywriter (line 187) | def init_summarywriter(args):
  function save_model (line 195) | def save_model(model, model_name, info_dict):
  function cosyvoice_join (line 217) | def cosyvoice_join(group_join, info_dict):
  function batch_forward (line 238) | def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_l...
  function batch_backward (line 277) | def batch_backward(model, scaler, info_dict):
  function update_parameter_and_lr (line 291) | def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_di...
  function log_per_step (line 323) | def log_per_step(writer, info_dict):
  function log_per_save (line 352) | def log_per_save(writer, info_dict):

FILE: cosyvoice/vllm/cosyvoice2.py
  class CosyVoice2ForCausalLM (line 38) | class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    method __init__ (line 51) | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    method get_input_embeddings (line 82) | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    method forward (line 85) | def forward(
    method compute_logits (line 96) | def compute_logits(
    method load_weights (line 109) | def load_weights(self, weights: Iterable[tuple[str,

FILE: example.py
  function cosyvoice_example (line 7) | def cosyvoice_example():
  function cosyvoice2_example (line 36) | def cosyvoice2_example():
  function cosyvoice3_example (line 71) | def cosyvoice3_example():
  function main (line 105) | def main():

FILE: examples/grpo/cosyvoice2/huggingface_to_pretrained.py
  function get_args (line 25) | def get_args():

FILE: examples/grpo/cosyvoice2/infer_dataset.py
  function audio_decode_cosyvoice2 (line 59) | def audio_decode_cosyvoice2(
  function extract_speech_ids (line 96) | def extract_speech_ids(speech_tokens_str):
  function convert_cosy2_tokens_to_speech_id_str (line 109) | def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  function get_args (line 117) | def get_args():
  function data_collator (line 186) | def data_collator(batch, tokenizer, s3_tokenizer):
  function init_distributed (line 269) | def init_distributed():
  function main (line 282) | def main():

FILE: examples/grpo/cosyvoice2/prepare_data.py
  function make_map_fn (line 40) | def make_map_fn(split):

FILE: examples/grpo/cosyvoice2/pretrained_to_huggingface.py
  function get_args (line 33) | def get_args():

FILE: examples/grpo/cosyvoice2/reward_tts.py
  function _parse_ids (line 34) | def _parse_ids(token_str: str) -> List[int]:
  function _remote_reward (line 38) | def _remote_reward(tokens: List[int], ground_truth: str, timeout: float ...
  function compute_score (line 86) | def compute_score(
  function get_args (line 121) | def get_args():
  function load_jsonl (line 156) | def load_jsonl(file_path: str):
  function code_to_solution_str (line 164) | def code_to_solution_str(code_list: List[int]) -> str:

FILE: examples/grpo/cosyvoice2/scripts/offline-decode-files.py
  function remove_punctuation (line 104) | def remove_punctuation(text: str) -> str:
  function store_transcripts (line 112) | def store_transcripts(
  function write_error_stats (line 137) | def write_error_stats(
  function get_args (line 308) | def get_args():
  function assert_file_exists (line 556) | def assert_file_exists(filename: str):
  function read_wave (line 564) | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  function normalize_text_alimeeting (line 588) | def normalize_text_alimeeting(text: str) -> str:
  function main (line 623) | def main():

FILE: examples/grpo/cosyvoice2/token2wav_asr_server.py
  class _ASR_Server (line 53) | class _ASR_Server:
    method __init__ (line 56) | def __init__(self, device_id: int):
    method __call__ (line 60) | def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np...
  function audio_decode_cosyvoice2 (line 79) | def audio_decode_cosyvoice2(
  function get_random_prompt_from_dataset (line 116) | def get_random_prompt_from_dataset(dataset):
  class _Token2Wav_ASR (line 142) | class _Token2Wav_ASR:
    method __init__ (line 145) | def __init__(self, device_id: int):
    method __call__ (line 166) | def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT...
  function _infer_function_factory (line 241) | def _infer_function_factory(device_ids: List[int], model_name: str):
  function main (line 252) | def main():

FILE: examples/libritts/cosyvoice/local/prepare_data.py
  function main (line 11) | def main():

FILE: examples/libritts/cosyvoice/local/prepare_reject_sample.py
  function main (line 14) | def main():

FILE: examples/magicdata-read/cosyvoice/local/prepare_data.py
  function main (line 10) | def main():

FILE: runtime/python/fastapi/client.py
  function main (line 22) | def main():

FILE: runtime/python/fastapi/server.py
  function generate_data (line 40) | def generate_data(model_output):
  function inference_sft (line 48) | async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
  function inference_zero_shot (line 55) | async def inference_zero_shot(tts_text: str = Form(), prompt_text: str =...
  function inference_cross_lingual (line 63) | async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: Up...
  function inference_instruct (line 71) | async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(...
  function inference_instruct2 (line 78) | async def inference_instruct2(tts_text: str = Form(), instruct_text: str...

FILE: runtime/python/grpc/client.py
  function main (line 30) | def main():

FILE: runtime/python/grpc/server.py
  class CosyVoiceServiceImpl (line 34) | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
    method __init__ (line 35) | def __init__(self, args):
    method Inference (line 39) | def Inference(self, request, context):
  function main (line 68) | def main():

FILE: runtime/triton_trtllm/client_grpc.py
  class UserData (line 63) | class UserData:
    method __init__ (line 64) | def __init__(self):
    method record_start_time (line 70) | def record_start_time(self):
    method get_first_chunk_latency (line 73) | def get_first_chunk_latency(self):
    method get_second_chunk_latency (line 78) | def get_second_chunk_latency(self):
  function callback (line 84) | def callback(user_data, result, error):
  function stream_callback (line 97) | def stream_callback(user_data_map, result, error):
  function write_triton_stats (line 112) | def write_triton_stats(stats, summary_file):
  function subtract_stats (line 159) | def subtract_stats(stats_after, stats_before):
  function get_args (line 217) | def get_args():
  function load_audio (line 344) | def load_audio(wav_path, target_sample_rate=16000):
  function prepare_request_input_output (line 359) | def prepare_request_input_output(
  function run_sync_streaming_inference (line 412) | def run_sync_streaming_inference(
  function send_streaming (line 501) | async def send_streaming(
  function send (line 595) | async def send(
  function load_manifests (line 644) | def load_manifests(manifest_path):
  function split_data (line 664) | def split_data(data, k):
  function main (line 687) | async def main():
  function run_main (line 915) | async def run_main():

FILE: runtime/triton_trtllm/client_http.py
  function get_args (line 32) | def get_args():
  function prepare_request (line 86) | def prepare_request(

FILE: runtime/triton_trtllm/infer_cosyvoice3.py
  function send_request_async (line 38) | async def send_request_async(client, url, payload):
  function send_batch_requests_async (line 45) | async def send_batch_requests_async(api_base, model_name, chats, tempera...
  function extract_speech_ids (line 64) | def extract_speech_ids(speech_tokens_str):
  function convert_cosy3_tokens_to_speech_id_str (line 77) | def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens):
  function get_args (line 87) | def get_args():
  function data_collator (line 163) | def data_collator(batch, tokenizer, s3_tokenizer):
  function main (line 219) | def main(args):

FILE: runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
  class TritonPythonModel (line 39) | class TritonPythonModel:
    method initialize (line 46) | def initialize(self, args):
    method execute (line 60) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
  class TritonPythonModel (line 48) | class TritonPythonModel:
    method initialize (line 55) | def initialize(self, args):
    method forward_llm (line 89) | def forward_llm(self, input_ids):
    method forward_audio_tokenizer (line 175) | def forward_audio_tokenizer(self, wav, wav_len):
    method forward_speaker_embedding (line 201) | def forward_speaker_embedding(self, wav):
    method forward_token2wav (line 226) | def forward_token2wav(
    method parse_input (line 282) | def parse_input(self, text, prompt_text, prompt_speech_tokens):
    method _extract_speech_feat (line 290) | def _extract_speech_feat(self, speech):
    method _llm_gen_thread (line 307) | def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, ...
    method execute (line 315) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py
  function parse_speech_token_string (line 52) | def parse_speech_token_string(response_text: str) -> List[int]:
  class TritonPythonModel (line 70) | class TritonPythonModel:
    method initialize (line 77) | def initialize(self, args):
    method _convert_speech_tokens_to_str (line 108) | def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Ten...
    method forward_llm_async (line 121) | async def forward_llm_async(self, target_text: str, reference_text: st...
    method forward_audio_tokenizer (line 181) | def forward_audio_tokenizer(self, wav, wav_len):
    method forward_speaker_embedding (line 207) | def forward_speaker_embedding(self, wav):
    method forward_token2wav (line 232) | async def forward_token2wav(
    method _extract_speech_feat (line 278) | def _extract_speech_feat(self, speech):
    method _process_request (line 295) | async def _process_request(self, request):
    method execute (line 375) | async def execute(self, requests):
    method finalize (line 391) | def finalize(self):

FILE: runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
  class TritonPythonModel (line 40) | class TritonPythonModel:
    method initialize (line 47) | def initialize(self, args):
    method load_spk_trt (line 74) | def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp...
    method get_spk_trt_kwargs (line 84) | def get_spk_trt_kwargs(self):
    method _extract_spk_embedding (line 91) | def _extract_spk_embedding(self, speech):
    method execute (line 127) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/token2wav/1/model.py
  class CosyVoice2 (line 52) | class CosyVoice2:
    method __init__ (line 54) | def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=Fal...
  class CosyVoice2Model (line 75) | class CosyVoice2Model:
    method __init__ (line 77) | def __init__(self,
    method load_jit (line 96) | def load_jit(self, flow_encoder_model):
    method load (line 100) | def load(self, flow_model, hift_model):
    method load_trt (line 108) | def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_mod...
    method get_trt_kwargs (line 119) | def get_trt_kwargs(self):
    method token2wav (line 126) | def token2wav(self, token, prompt_token, prompt_feat, embedding, token...
  class TritonPythonModel (line 163) | class TritonPythonModel:
    method initialize (line 170) | def initialize(self, args):
    method execute (line 197) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py
  function get_spk_id_from_prompt_audio (line 56) | def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
  class TritonPythonModel (line 71) | class TritonPythonModel:
    method initialize (line 78) | def initialize(self, args):
    method execute (line 99) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py
  function fade_in_out (line 38) | def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, w...
  function convert_onnx_to_trt (line 49) | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
  class TrtContextWrapper (line 95) | class TrtContextWrapper:
    method __init__ (line 96) | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
    method acquire_estimator (line 107) | def acquire_estimator(self):
    method release_estimator (line 110) | def release_estimator(self, context, stream):
  class CosyVoice2_Token2Wav (line 114) | class CosyVoice2_Token2Wav(torch.nn.Module):
    method __init__ (line 115) | def __init__(self, model_dir: str, enable_trt: bool = False, device_id...
    method forward_spk_embedding (line 175) | def forward_spk_embedding(self, spk_feat):
    method load_spk_trt (line 204) | def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp...
    method get_spk_trt_kwargs (line 214) | def get_spk_trt_kwargs(self):
    method load_trt (line 221) | def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_mod...
    method get_trt_kwargs_dynamic_batch (line 237) | def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_siz...
    method prompt_audio_tokenization (line 264) | def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Ten...
    method get_spk_emb (line 279) | def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch...
    method get_prompt_mels (line 293) | def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prom...
    method forward_flow (line 311) | def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
    method forward_hift (line 334) | def forward_hift(self, generated_mels: torch.Tensor, generated_mels_le...
    method forward (line 344) | def forward(
    method prepare_prompt_audio (line 359) | def prepare_prompt_audio(
    method get_prompt_audio_cache_for_streaming_tts (line 371) | def get_prompt_audio_cache_for_streaming_tts(
    method forward_streaming (line 390) | def forward_streaming(
  function collate_fn (line 465) | def collate_fn(batch):
  function get_args (line 477) | def get_args():

FILE: runtime/triton_trtllm/model_repo_cosyvoice3/audio_tokenizer/1/model.py
  class TritonPythonModel (line 39) | class TritonPythonModel:
    method initialize (line 46) | def initialize(self, args):
    method execute (line 60) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo_cosyvoice3/cosyvoice3/1/model.py
  function parse_speech_token_string (line 25) | def parse_speech_token_string(response_text):
  class TritonPythonModel (line 39) | class TritonPythonModel:
    method initialize (line 47) | def initialize(self, args):
    method _convert_speech_tokens_to_str (line 72) | def _convert_speech_tokens_to_str(self, speech_tokens):
    method _extract_speech_feat (line 78) | def _extract_speech_feat(self, speech):
    method forward_llm_streaming (line 84) | async def forward_llm_streaming(self, target_text, reference_text, pro...
    method forward_llm_offline (line 139) | async def forward_llm_offline(self, target_text, reference_text, promp...
    method forward_audio_tokenizer (line 167) | def forward_audio_tokenizer(self, wav, wav_len):
    method forward_speaker_embedding (line 181) | def forward_speaker_embedding(self, wav):
    method forward_token2wav (line 195) | async def forward_token2wav(self, target_speech_tokens, prompt_speech_...
    method forward_vocoder (line 232) | async def forward_vocoder(self, mel, finalize):
    method _prepare_prompt (line 253) | def _prepare_prompt(self, request):
    method _process_request_streaming (line 305) | async def _process_request_streaming(self, request):
    method _process_request_offline (line 434) | async def _process_request_offline(self, request):
    method execute (line 469) | async def execute(self, requests):
    method finalize (line 489) | def finalize(self):

FILE: runtime/triton_trtllm/model_repo_cosyvoice3/speaker_embedding/1/model.py
  class TritonPythonModel (line 40) | class TritonPythonModel:
    method initialize (line 47) | def initialize(self, args):
    method load_spk_trt (line 74) | def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp...
    method get_spk_trt_kwargs (line 84) | def get_spk_trt_kwargs(self):
    method _extract_spk_embedding (line 91) | def _extract_spk_embedding(self, speech):
    method execute (line 127) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo_cosyvoice3/token2wav/1/model.py
  class TrtContextWrapper (line 16) | class TrtContextWrapper:
    method __init__ (line 17) | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
    method acquire_estimator (line 27) | def acquire_estimator(self):
    method release_estimator (line 30) | def release_estimator(self, context, stream):
  function convert_onnx_to_trt (line 34) | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocas...
  class TritonPythonModel (line 75) | class TritonPythonModel:
    method initialize (line 81) | def initialize(self, args):
    method load_trt (line 109) | def load_trt(self, model_dir, trt_concurrent=1):
    method get_trt_kwargs (line 126) | def get_trt_kwargs(self):
    method execute (line 134) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo_cosyvoice3/vocoder/1/model.py
  class TritonPythonModel (line 16) | class TritonPythonModel:
    method initialize (line 23) | def initialize(self, args):
    method execute (line 47) | def execute(self, requests):

FILE: runtime/triton_trtllm/offline_inference.py
  function send_request_async (line 56) | async def send_request_async(client, url, payload):
  function send_batch_requests_async (line 63) | async def send_batch_requests_async(api_base, model_name, chats, tempera...
  function extract_speech_ids (line 82) | def extract_speech_ids(speech_tokens_str):
  function convert_cosy2_tokens_to_speech_id_str (line 95) | def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  function get_args (line 103) | def get_args():
  function data_collator (line 209) | def data_collator(batch, tokenizer, s3_tokenizer):
  function init_distributed (line 312) | def init_distributed():
  function main (line 325) | def main(args):

FILE: runtime/triton_trtllm/scripts/convert_checkpoint.py
  function parse_arguments (line 18) | def parse_arguments():
  function args_to_quant_config (line 152) | def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
  function update_quant_config_from_hf (line 186) | def update_quant_config_from_hf(quant_config, hf_config,
  function args_to_build_options (line 216) | def args_to_build_options(args):
  function convert_and_save_hf (line 225) | def convert_and_save_hf(args):
  function execute (line 282) | def execute(workers, func, args):
  function main (line 301) | def main():

FILE: runtime/triton_trtllm/scripts/convert_cosyvoice3_to_hf.py
  function parse_args (line 48) | def parse_args():
  function load_cosyvoice3_model (line 72) | def load_cosyvoice3_model(model_dir: str):
  function get_speech_token_size (line 100) | def get_speech_token_size(llm) -> int:
  function convert_cosyvoice3_to_hf (line 109) | def convert_cosyvoice3_to_hf(
  function main (line 346) | def main():

FILE: runtime/triton_trtllm/scripts/fill_template.py
  function split (line 6) | def split(string, delimiter):
  function main (line 34) | def main(file_path, substitutions, in_place):

FILE: runtime/triton_trtllm/scripts/test_llm.py
  function parse_arguments (line 29) | def parse_arguments(args=None):
  function parse_input (line 47) | def parse_input(tokenizer,
  function main (line 69) | def main(args):

FILE: runtime/triton_trtllm/streaming_inference.py
  function collate_fn (line 13) | def collate_fn(batch):
  function get_args (line 28) | def get_args():

FILE: runtime/triton_trtllm/token2wav.py
  function convert_onnx_to_trt (line 36) | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
  class TrtContextWrapper (line 74) | class TrtContextWrapper:
    method __init__ (line 75) | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
    method acquire_estimator (line 86) | def acquire_estimator(self):
    method release_estimator (line 89) | def release_estimator(self, context, stream):
  class CosyVoice2_Token2Wav (line 93) | class CosyVoice2_Token2Wav(torch.nn.Module):
    method __init__ (line 94) | def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: b...
    method forward_spk_embedding (line 127) | def forward_spk_embedding(self, spk_feat):
    method load_spk_trt (line 156) | def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp...
    method get_spk_trt_kwargs (line 166) | def get_spk_trt_kwargs(self):
    method load_trt (line 173) | def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_mod...
    method get_trt_kwargs_dynamic_batch (line 185) | def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
    method prompt_audio_tokenization (line 193) | def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Ten...
    method get_spk_emb (line 208) | def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch...
    method get_prompt_mels (line 220) | def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prom...
    method forward_flow (line 236) | def forward_flow(self, prompt_speech_tokens_list: list[list[int]], gen...
    method forward_hift (line 257) | def forward_hift(self, generated_mels: torch.Tensor, generated_mels_le...
    method forward (line 267) | def forward(
  function collate_fn (line 287) | def collate_fn(batch):
  function get_args (line 299) | def get_args():

FILE: runtime/triton_trtllm/token2wav_cosyvoice3.py
  function convert_onnx_to_trt (line 31) | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocas...
  class TrtContextWrapper (line 73) | class TrtContextWrapper:
    method __init__ (line 74) | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
    method acquire_estimator (line 85) | def acquire_estimator(self):
    method release_estimator (line 88) | def release_estimator(self, context, stream):
  class CosyVoice3_Token2Wav (line 92) | class CosyVoice3_Token2Wav(torch.nn.Module):
    method __init__ (line 93) | def __init__(self, model_dir, enable_trt=False, device_id=0, autocast_...
    method load_trt (line 140) | def load_trt(self, model_dir, trt_concurrent=1):
    method get_trt_kwargs (line 162) | def get_trt_kwargs(self):
    method load_spk_trt (line 173) | def load_spk_trt(self, model_dir, trt_concurrent=1, fp16=False):
    method get_spk_trt_kwargs (line 185) | def get_spk_trt_kwargs(self):
    method forward_spk_embedding (line 193) | def forward_spk_embedding(self, spk_feat):
    method prompt_audio_tokenization (line 219) | def prompt_audio_tokenization(self, prompt_audios_list):
    method get_spk_emb (line 234) | def get_spk_emb(self, prompt_audios_list):
    method get_prompt_mels (line 245) | def get_prompt_mels(self, prompt_audios_list, prompt_audios_sample_rate):
    method forward_flow (line 263) | def forward_flow(self, prompt_speech_tokens_list, generated_speech_tok...
    method forward_hift (line 296) | def forward_hift(self, generated_mels_list):
    method forward_stream (line 304) | def forward_stream(self, generated_speech_tokens, prompt_speech_tokens,
    method forward (line 379) | def forward(self, generated_speech_tokens_list, prompt_audios_list,

FILE: tools/extract_embedding.py
  function single_job (line 24) | def single_job(utt):
  function main (line 37) | def main(args):

FILE: tools/extract_speech_token.py
  function single_job (line 26) | def single_job(utt):
  function main (line 43) | def main(args):

FILE: tools/make_parquet_list.py
  function job (line 26) | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):

FILE: vllm_example.py
  function cosyvoice2_example (line 12) | def cosyvoice2_example():
  function cosyvoice3_example (line 22) | def cosyvoice3_example():
  function main (line 33) | def main():

FILE: webui.py
  function generate_seed (line 38) | def generate_seed():
  function change_instruction (line 46) | def change_instruction(mode_checkbox_group):
  function generate_audio (line 50) | def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_t...
  function main (line 118) | def main():
Condensed preview — 153 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,110K chars).
[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "chars": 834,
    "preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the b"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "chars": 595,
    "preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your fea"
  },
  {
    "path": ".github/workflows/lint.yml",
    "chars": 2346,
    "preview": "name: Lint\n\non:\n  pull_request:\n  push:\n\njobs:\n  quick-checks:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Fetch"
  },
  {
    "path": ".github/workflows/stale-issues.yml",
    "chars": 705,
    "preview": "name: Close inactive issues\non:\n  schedule:\n    - cron: \"30 1 * * *\"\n\njobs:\n  close-issues:\n    runs-on: ubuntu-latest\n "
  },
  {
    "path": ".gitignore",
    "chars": 560,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# Visual Studio Code files\n.vscode\n.vs\n\n# PyC"
  },
  {
    "path": ".gitmodules",
    "chars": 123,
    "preview": "[submodule \"third_party/Matcha-TTS\"]\n\tpath = third_party/Matcha-TTS\n\turl = https://github.com/shivammehta25/Matcha-TTS.g"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3350,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, w"
  },
  {
    "path": "FAQ.md",
    "chars": 671,
    "preview": "## ModuleNotFoundError: No module named 'matcha'\n\nMatcha-TTS is a third_party module. Please check `third_party` directo"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 12587,
    "preview": "![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Langua"
  },
  {
    "path": "cosyvoice/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/bin/average_model.py",
    "chars": 3202,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Di Wu)\n# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apa"
  },
  {
    "path": "cosyvoice/bin/export_jit.py",
    "chars": 3839,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/bin/export_onnx.py",
    "chars": 4486,
    "preview": "# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)\n# Copyright (c) 2024 Alibaba Inc (authors:"
  },
  {
    "path": "cosyvoice/bin/train.py",
    "chars": 8153,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/cli/cosyvoice.py",
    "chars": 14356,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/frontend.py",
    "chars": 12543,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/cli/model.py",
    "chars": 28138,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)\n"
  },
  {
    "path": "cosyvoice/dataset/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/dataset/dataset.py",
    "chars": 5063,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licen"
  },
  {
    "path": "cosyvoice/dataset/processor.py",
    "chars": 16749,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "cosyvoice/flow/DiT/dit.py",
    "chars": 5688,
    "preview": "\n\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n\nfrom __future__ imp"
  },
  {
    "path": "cosyvoice/flow/DiT/modules.py",
    "chars": 20551,
    "preview": "\n\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n\nfrom __future__ imp"
  },
  {
    "path": "cosyvoice/flow/decoder.py",
    "chars": 19866,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/flow/flow.py",
    "chars": 19736,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/flow/flow_matching.py",
    "chars": 10552,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#               2025 Alibaba Inc (authors: Xiang Lyu, B"
  },
  {
    "path": "cosyvoice/flow/length_regulator.py",
    "chars": 3137,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/hifigan/discriminator.py",
    "chars": 8617,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\ntry:\n    from torch.nn.utils.parametrizations import "
  },
  {
    "path": "cosyvoice/hifigan/f0_predictor.py",
    "chars": 3757,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "cosyvoice/hifigan/generator.py",
    "chars": 30487,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "cosyvoice/hifigan/hifigan.py",
    "chars": 3240,
    "preview": "from typing import Dict, Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom matcha.hifigan"
  },
  {
    "path": "cosyvoice/llm/llm.py",
    "chars": 35971,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#               2025 Alibaba Inc (authors: Xiang Lyu, Y"
  },
  {
    "path": "cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken",
    "chars": 907395,
    "preview": "IQ== 0\nIg== 1\nIw== 2\nJA== 3\nJQ== 4\nJg== 5\nJw== 6\nKA== 7\nKQ== 8\nKg== 9\nKw== 10\nLA== 11\nLQ== 12\nLg== 13\nLw== 14\nMA== 15\nMQ"
  },
  {
    "path": "cosyvoice/tokenizer/tokenizer.py",
    "chars": 11122,
    "preview": "import base64\nimport os\nfrom functools import lru_cache\nfrom typing import Optional\nimport torch\nfrom transformers impor"
  },
  {
    "path": "cosyvoice/transformer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/transformer/activation.py",
    "chars": 3087,
    "preview": "# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)\n#               2020 Northwestern Polytechnical Universi"
  },
  {
    "path": "cosyvoice/transformer/attention.py",
    "chars": 14389,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#               2022 Xingchen Song (s"
  },
  {
    "path": "cosyvoice/transformer/convolution.py",
    "chars": 9772,
    "preview": "# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/decoder.py",
    "chars": 16580,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/decoder_layer.py",
    "chars": 4807,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/embedding.py",
    "chars": 11777,
    "preview": "# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# License"
  },
  {
    "path": "cosyvoice/transformer/encoder.py",
    "chars": 21434,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/transformer/encoder_layer.py",
    "chars": 9596,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/transformer/label_smoothing_loss.py",
    "chars": 3459,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/positionwise_feed_forward.py",
    "chars": 4219,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#\n# Licensed under the Apache License"
  },
  {
    "path": "cosyvoice/transformer/subsampling.py",
    "chars": 12666,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2024 Alibaba Inc (Xiang Lyu)\n#\n# Licensed under th"
  },
  {
    "path": "cosyvoice/transformer/upsample_encoder.py",
    "chars": 14229,
    "preview": "# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)\n#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)\n#"
  },
  {
    "path": "cosyvoice/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cosyvoice/utils/class_utils.py",
    "chars": 3588,
    "preview": "# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>\n#            2024 Alibaba Inc (authors: Xiang Lyu)"
  },
  {
    "path": "cosyvoice/utils/common.py",
    "chars": 8863,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#               202"
  },
  {
    "path": "cosyvoice/utils/executor.py",
    "chars": 8842,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under "
  },
  {
    "path": "cosyvoice/utils/file_utils.py",
    "chars": 4887,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)"
  },
  {
    "path": "cosyvoice/utils/frontend_utils.py",
    "chars": 4231,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "cosyvoice/utils/losses.py",
    "chars": 2121,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom typing import Tuple\n\n\ndef tpr_loss(disc_real_outputs, disc_generated_o"
  },
  {
    "path": "cosyvoice/utils/mask.py",
    "chars": 9728,
    "preview": "# Copyright (c) 2019 Shigeki Karita\n#               2020 Mobvoi Inc (Binbin Zhang)\n#               2024 Alibaba Inc (aut"
  },
  {
    "path": "cosyvoice/utils/onnx.py",
    "chars": 2867,
    "preview": "import onnxruntime\nimport torch, random\nimport os\nimport torchaudio.compliance.kaldi as kaldi\n\n\nclass SpeechTokenExtract"
  },
  {
    "path": "cosyvoice/utils/scheduler.py",
    "chars": 24920,
    "preview": "# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n#               2022 Ximalaya Inc (Yuguang Yang)\n#               2024 Ali"
  },
  {
    "path": "cosyvoice/utils/train_utils.py",
    "chars": 16626,
    "preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2023 Horizon Inc. (authors: Xingchen Song)\n#   "
  },
  {
    "path": "cosyvoice/vllm/cosyvoice2.py",
    "chars": 4541,
    "preview": "# SPDX-License-Identifier: Apache-2.0\n\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.28.0/src/tra"
  },
  {
    "path": "docker/Dockerfile",
    "chars": 2085,
    "preview": "FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04\n\nARG VENV_NAME=\"cosyvoice\"\nENV VENV=$VENV_NAME\nENV LANG=C.UTF-8 LC_ALL=C"
  },
  {
    "path": "example.py",
    "chars": 7579,
    "preview": "import sys\nsys.path.append('third_party/Matcha-TTS')\nfrom cosyvoice.cli.cosyvoice import AutoModel\nimport torchaudio\n\n\nd"
  },
  {
    "path": "examples/grpo/cosyvoice2/Dockerfile",
    "chars": 486,
    "preview": "FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2\nCOPY requirements.txt /myworkspace/requirements.txt\nRUN pip ins"
  },
  {
    "path": "examples/grpo/cosyvoice2/README.md",
    "chars": 4543,
    "preview": "# CosyVoice2 LLM Reinforcement Learning Recipe\n\nThis recipe demonstrates how to fine-tune the **CosyVoice2** large langu"
  },
  {
    "path": "examples/grpo/cosyvoice2/huggingface_to_pretrained.py",
    "chars": 2932,
    "preview": "\n# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apac"
  },
  {
    "path": "examples/grpo/cosyvoice2/infer_dataset.py",
    "chars": 14065,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "examples/grpo/cosyvoice2/prepare_data.py",
    "chars": 2986,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "examples/grpo/cosyvoice2/pretrained_to_huggingface.py",
    "chars": 5297,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "examples/grpo/cosyvoice2/requirements.txt",
    "chars": 457,
    "preview": "conformer==0.3.2\ndiffusers==0.29.0\ngdown==5.1.0\ngradio\nhydra-core==1.3.2\nHyperPyYAML==1.2.2\ninflect==7.3.1\nlibrosa==0.10"
  },
  {
    "path": "examples/grpo/cosyvoice2/reward_tts.py",
    "chars": 7468,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "examples/grpo/cosyvoice2/run.sh",
    "chars": 6234,
    "preview": "#!/usr/bin/env bash\n\nset -eou pipefail\n\nstage=-1\nstop_stage=4\n\nlog() {\n  # This function is from espnet\n  local fname=${"
  },
  {
    "path": "examples/grpo/cosyvoice2/scripts/compute_wer.sh",
    "chars": 942,
    "preview": "wav_dir=$1\nwav_files=$(ls $wav_dir/*.wav)\n# if wav_files is empty, then exit\nif [ -z \"$wav_files\" ]; then\n    exit 1\nfi\n"
  },
  {
    "path": "examples/grpo/cosyvoice2/scripts/offline-decode-files.py",
    "chars": 23511,
    "preview": "# Copyright (c)  2023 by manyeyes\n# Copyright (c)  2023  Xiaomi Corporation\n\n\"\"\"\nThis file demonstrates how to use sherp"
  },
  {
    "path": "examples/grpo/cosyvoice2/token2wav_asr_server.py",
    "chars": 12912,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "examples/libritts/cosyvoice/conf/cosyvoice.yaml",
    "chars": 8349,
    "preview": "# set random seed, so that you may reproduce your result.\n__set_seed1: !apply:random.seed [1986]\n__set_seed2: !apply:num"
  },
  {
    "path": "examples/libritts/cosyvoice/conf/ds_stage2.json",
    "chars": 925,
    "preview": "{\n  \"train_micro_batch_size_per_gpu\": 1,\n  \"gradient_accumulation_steps\": 1,\n  \"steps_per_print\": 100,\n  \"gradient_clipp"
  },
  {
    "path": "examples/libritts/cosyvoice/local/download_and_untar.sh",
    "chars": 2845,
    "preview": "#!/bin/bash\n\n# Copyright   2014  Johns Hopkins University (author: Daniel Povey)\n# Apache 2.0\n\nremove_archive=false\n\nif "
  },
  {
    "path": "examples/libritts/cosyvoice/local/prepare_data.py",
    "chars": 1946,
    "preview": "import argparse\nimport logging\nimport glob\nimport os\nfrom tqdm import tqdm\n\n\nlogger = logging.getLogger()\n\n\ndef main():\n"
  },
  {
    "path": "examples/libritts/cosyvoice/local/prepare_reject_sample.py",
    "chars": 1754,
    "preview": "import argparse\nimport logging\nimport os\nfrom tqdm import tqdm\nimport torch\nimport torchaudio\nfrom cosyvoice.cli.cosyvoi"
  },
  {
    "path": "examples/libritts/cosyvoice/path.sh",
    "chars": 185,
    "preview": "# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHONIOENCODING=UTF-8\nexport "
  },
  {
    "path": "examples/libritts/cosyvoice/run.sh",
    "chars": 4379,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n. ./path.sh || exit 1;\n\nstage=-1\nstop_stage=3\n\ndata_url=w"
  },
  {
    "path": "examples/libritts/cosyvoice/tts_text.json",
    "chars": 89,
    "preview": "{\n  \"1089_134686_000002_000000\": [\n    \"hello, my name is Jack. What is your name?\"\n  ]\n}"
  },
  {
    "path": "examples/libritts/cosyvoice2/conf/cosyvoice2.yaml",
    "chars": 7481,
    "preview": "# set random seed, so that you may reproduce your result.\n__set_seed1: !apply:random.seed [1986]\n__set_seed2: !apply:num"
  },
  {
    "path": "examples/libritts/cosyvoice2/conf/ds_stage2.json",
    "chars": 925,
    "preview": "{\n  \"train_micro_batch_size_per_gpu\": 1,\n  \"gradient_accumulation_steps\": 1,\n  \"steps_per_print\": 100,\n  \"gradient_clipp"
  },
  {
    "path": "examples/libritts/cosyvoice2/run.sh",
    "chars": 4610,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n. ./path.sh || exit 1;\n\nstage=-1\nstop_stage=3\n\ndata_url=w"
  },
  {
    "path": "examples/libritts/cosyvoice2/run_dpo.sh",
    "chars": 5127,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n. ./path.sh || exit 1;\n\nstage=-1\nstop_stage=3\n\ndata_url=w"
  },
  {
    "path": "examples/libritts/cosyvoice3/conf/cosyvoice3.yaml",
    "chars": 7106,
    "preview": "# set random seed, so that you may reproduce your result.\n__set_seed1: !apply:random.seed [1986]\n__set_seed2: !apply:num"
  },
  {
    "path": "examples/libritts/cosyvoice3/conf/ds_stage2.json",
    "chars": 925,
    "preview": "{\n  \"train_micro_batch_size_per_gpu\": 1,\n  \"gradient_accumulation_steps\": 1,\n  \"steps_per_print\": 100,\n  \"gradient_clipp"
  },
  {
    "path": "examples/libritts/cosyvoice3/run.sh",
    "chars": 4719,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n. ./path.sh || exit 1;\n\nstage=-1\nstop_stage=3\n\ndata_url=w"
  },
  {
    "path": "examples/magicdata-read/cosyvoice/local/download_and_untar.sh",
    "chars": 2615,
    "preview": "#!/bin/bash\n\n# Copyright   2014  Johns Hopkins University (author: Daniel Povey)\n# Apache 2.0\n\nremove_archive=false\n\nif "
  },
  {
    "path": "examples/magicdata-read/cosyvoice/local/prepare_data.py",
    "chars": 1698,
    "preview": "import argparse\nimport logging\nimport os\nfrom tqdm import tqdm\n\n\nlogger = logging.getLogger()\n\n\ndef main():\n    utt2wav,"
  },
  {
    "path": "examples/magicdata-read/cosyvoice/run.sh",
    "chars": 4021,
    "preview": "#!/bin/bash\n# Copyright 2024 Alibaba Inc. All Rights Reserved.\n. ./path.sh || exit 1;\n\nstage=-1\nstop_stage=3\n\ndata_url=w"
  },
  {
    "path": "examples/magicdata-read/cosyvoice/tts_text.json",
    "chars": 334,
    "preview": "{\n  \"38_5718_20170915093303\": [\n    \"我想这出最好歌曲把歌词发到网上请别人帮我作曲急急\",\n    \"叫他明天早上差五分儿九点去机场\"\n  ],\n  \"38_5721_20170915091235\": ["
  },
  {
    "path": "requirements.txt",
    "chars": 1125,
    "preview": "--extra-index-url https://download.pytorch.org/whl/cu121\n--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicP"
  },
  {
    "path": "runtime/python/Dockerfile",
    "chars": 732,
    "preview": "FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime\nENV DEBIAN_FRONTEND=noninteractive\n\nWORKDIR /opt/CosyVoice\n\nRUN sed -"
  },
  {
    "path": "runtime/python/fastapi/client.py",
    "chars": 3695,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "runtime/python/fastapi/server.py",
    "chars": 3732,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "runtime/python/grpc/client.py",
    "chars": 4660,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "runtime/python/grpc/cosyvoice.proto",
    "chars": 757,
    "preview": "syntax = \"proto3\";\n\npackage cosyvoice;\noption go_package = \"protos/\";\n\nservice CosyVoice{\n  rpc Inference(Request) retur"
  },
  {
    "path": "runtime/python/grpc/server.py",
    "chars": 4255,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "runtime/triton_trtllm/Dockerfile.server",
    "chars": 411,
    "preview": "FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3\nLABEL maintainer=\"zhangyuekai@foxmail.com\"\n\nRUN apt-get update "
  },
  {
    "path": "runtime/triton_trtllm/README.Cosyvoice2.DiT.md",
    "chars": 7658,
    "preview": "## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM\n\nContributed by Yuek"
  },
  {
    "path": "runtime/triton_trtllm/README.Cosyvoice2.Unet.md",
    "chars": 6365,
    "preview": "## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM\n\nContributed by Yuekai Zhang (NVIDIA).\n\n#"
  },
  {
    "path": "runtime/triton_trtllm/README.Cosyvoice3.md",
    "chars": 3389,
    "preview": "## Accelerating CosyVoice3 with NVIDIA Triton Inference Server and TensorRT-LLM\n\nContributed by Yuekai Zhang (NVIDIA).\n\n"
  },
  {
    "path": "runtime/triton_trtllm/README.md",
    "chars": 1547,
    "preview": "# Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM\n\nContributed by Yuekai Zhang (NVIDIA).\n\nTh"
  },
  {
    "path": "runtime/triton_trtllm/client_grpc.py",
    "chars": 36518,
    "preview": "# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)\n#                2023  Nvidia              (authors:"
  },
  {
    "path": "runtime/triton_trtllm/client_http.py",
    "chars": 5379,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/docker-compose.cosyvoice2.dit.yml",
    "chars": 735,
    "preview": "services:\n  tts:\n    image: soar97/triton-cosyvoice:25.06\n    shm_size: '1gb'\n    ports:\n      - \"8000:8000\"\n      - \"80"
  },
  {
    "path": "runtime/triton_trtllm/docker-compose.cosyvoice2.unet.yml",
    "chars": 630,
    "preview": "services:\n  tts:\n    image: soar97/triton-cosyvoice:25.06\n    shm_size: '1gb'\n    ports:\n      - \"8000:8000\"\n      - \"80"
  },
  {
    "path": "runtime/triton_trtllm/docker-compose.cosyvoice3.yml",
    "chars": 615,
    "preview": "services:\n  tts:\n    image: soar97/triton-cosyvoice:25.06\n    shm_size: '1gb'\n    ports:\n      - \"8000:8000\"\n      - \"80"
  },
  {
    "path": "runtime/triton_trtllm/infer_cosyvoice3.py",
    "chars": 21416,
    "preview": "\"\"\" Example Usage\n    CUDA_VISIBLE_DEVICES=0 \\\n        python3 infer_cosyvoice3_token2wav.py \\\n            --output-dir "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py",
    "chars": 3995,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt",
    "chars": 1181,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py",
    "chars": 21008,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt",
    "chars": 1537,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py",
    "chars": 17071,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt",
    "chars": 1541,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py",
    "chars": 6778,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt",
    "chars": 1103,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt",
    "chars": 18057,
    "preview": "# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/token2wav/1/model.py",
    "chars": 14605,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/token2wav/config.pbtxt",
    "chars": 1640,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py",
    "chars": 5508,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py",
    "chars": 25760,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt",
    "chars": 1426,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/audio_tokenizer/1/model.py",
    "chars": 3827,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/audio_tokenizer/config.pbtxt",
    "chars": 1181,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/cosyvoice3/1/model.py",
    "chars": 22340,
    "preview": "import json\nimport re\nimport time\nimport asyncio\n\nimport numpy as np\nimport torch\nfrom torch.utils.dlpack import to_dlpa"
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/cosyvoice3/config.pbtxt",
    "chars": 1537,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/speaker_embedding/1/model.py",
    "chars": 6638,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/speaker_embedding/config.pbtxt",
    "chars": 1103,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/token2wav/1/model.py",
    "chars": 9309,
    "preview": "import json\nimport os\nimport logging\nimport queue\n\nimport torch\nimport numpy as np\nfrom torch.utils.dlpack import to_dlp"
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/token2wav/config.pbtxt",
    "chars": 1103,
    "preview": "name: \"token2wav\"\nbackend: \"python\"\nmax_batch_size: ${triton_max_batch_size}\n\ndynamic_batching {\n    max_queue_delay_mic"
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/vocoder/1/model.py",
    "chars": 2599,
    "preview": "import json\nimport os\nimport logging\n\nimport torch\nfrom torch.utils.dlpack import to_dlpack\nimport triton_python_backend"
  },
  {
    "path": "runtime/triton_trtllm/model_repo_cosyvoice3/vocoder/config.pbtxt",
    "chars": 567,
    "preview": "name: \"vocoder\"\nbackend: \"python\"\nmax_batch_size: ${triton_max_batch_size}\ndynamic_batching {\n    max_queue_delay_micros"
  },
  {
    "path": "runtime/triton_trtllm/offline_inference.py",
    "chars": 27137,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "runtime/triton_trtllm/requirements.txt",
    "chars": 161,
    "preview": "hyperpyyaml\ns3tokenizer\nonnxruntime-gpu\nomegaconf\nconformer\nhydra-core\nlightning\ngdown\nwget\nlibrosa\npyworld\nopenai-whisp"
  },
  {
    "path": "runtime/triton_trtllm/run.sh",
    "chars": 6638,
    "preview": "#!/bin/bash\n# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)\nexport CUDA_VISIBLE_DEVICES=0\ncosyvoice_path=/workspace/"
  },
  {
    "path": "runtime/triton_trtllm/run_cosyvoice3.sh",
    "chars": 6535,
    "preview": "#!/bin/bash\n# Copyright (c) 2026 NVIDIA (authors: Yuekai Zhang)\nexport CUDA_VISIBLE_DEVICES=0\ncosyvoice_path=/workspace/"
  },
  {
    "path": "runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh",
    "chars": 9919,
    "preview": "#!/bin/bash\n# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)\nexport CUDA_VISIBLE_DEVICES=0\ncosyvoice_path=/workspace/"
  },
  {
    "path": "runtime/triton_trtllm/scripts/convert_checkpoint.py",
    "chars": 13444,
    "preview": "import argparse\nimport os\nimport time\nimport traceback\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\n\n"
  },
  {
    "path": "runtime/triton_trtllm/scripts/convert_cosyvoice3_to_hf.py",
    "chars": 16145,
    "preview": "#!/usr/bin/env python3\n# Copyright 2025 CosyVoice3 TRT-LLM Integration\n#\n# Licensed under the Apache License, Version 2."
  },
  {
    "path": "runtime/triton_trtllm/scripts/fill_template.py",
    "chars": 1827,
    "preview": "# /usr/bin/env python3\nfrom argparse import ArgumentParser\nfrom string import Template\n\n\ndef split(string, delimiter):\n "
  },
  {
    "path": "runtime/triton_trtllm/scripts/test_llm.py",
    "chars": 5070,
    "preview": "\n# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-"
  },
  {
    "path": "runtime/triton_trtllm/streaming_inference.py",
    "chars": 5254,
    "preview": "import torch\nimport os\nimport argparse\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nimport "
  },
  {
    "path": "runtime/triton_trtllm/token2wav.py",
    "chars": 17183,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n# SPDX-License-Identifier: Apach"
  },
  {
    "path": "runtime/triton_trtllm/token2wav_cosyvoice3.py",
    "chars": 19483,
    "preview": "\"\"\" Example Usage\n    CUDA_VISIBLE_DEVICES=0 \\\n        python3 token2wav_cosyvoice3.py --enable-trt || exit 1\n\"\"\"\nimport"
  },
  {
    "path": "tools/extract_embedding.py",
    "chars": 2998,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Vers"
  },
  {
    "path": "tools/extract_speech_token.py",
    "chars": 2797,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Vers"
  },
  {
    "path": "tools/make_parquet_list.py",
    "chars": 5738,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)\n#\n# Licensed under the Apache License, Vers"
  },
  {
    "path": "vllm_example.py",
    "chars": 1464,
    "preview": "import sys\nsys.path.append('third_party/Matcha-TTS')\nfrom vllm import ModelRegistry\nfrom cosyvoice.vllm.cosyvoice2 impor"
  },
  {
    "path": "webui.py",
    "chars": 8496,
    "preview": "# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  }
]

About this extraction

This page contains the full source code of the FunAudioLLM/CosyVoice GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 153 files (1.9 MB), approximately 845.7k tokens, and a symbol index with 786 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!