Showing preview only (4,242K chars total). Download the full file or copy to clipboard to get everything.
Repository: 2noise/ChatTTS
Branch: main
Commit: c26573a61ebd
Files: 96
Total size: 4.0 MB
Directory structure:
gitextract_f8g4d5at/
├── .gitattributes
├── .github/
│ └── workflows/
│ ├── checksum.yml
│ ├── close-issue.yml
│ ├── pull-format.yml
│ ├── push-format.yml
│ ├── unitest.yml
│ └── upload-pypi.yml
├── .gitignore
├── ChatTTS/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ └── config.py
│ ├── core.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── cuda/
│ │ │ ├── __init__.py
│ │ │ ├── patch.py
│ │ │ └── te_llama.py
│ │ ├── dvae.py
│ │ ├── embed.py
│ │ ├── gpt.py
│ │ ├── processors.py
│ │ ├── speaker.py
│ │ ├── tokenizer.py
│ │ └── velocity/
│ │ ├── __init__.py
│ │ ├── block_manager.py
│ │ ├── configs.py
│ │ ├── llama.py
│ │ ├── llm.py
│ │ ├── llm_engine.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── output.py
│ │ ├── sampler.py
│ │ ├── sampling_params.py
│ │ ├── scheduler.py
│ │ ├── sequence.py
│ │ └── worker.py
│ ├── norm.py
│ ├── res/
│ │ ├── __init__.py
│ │ ├── homophones_map.json
│ │ └── sha256_map.json
│ └── utils/
│ ├── __init__.py
│ ├── dl.py
│ ├── gpu.py
│ ├── io.py
│ └── log.py
├── LICENSE
├── README.md
├── docs/
│ ├── cn/
│ │ └── README.md
│ ├── es/
│ │ └── README.md
│ ├── fr/
│ │ └── README.md
│ ├── jp/
│ │ └── README.md
│ ├── kr/
│ │ └── README.md
│ └── ru/
│ └── README.md
├── examples/
│ ├── __init__.py
│ ├── api/
│ │ ├── README.md
│ │ ├── client.py
│ │ ├── main.py
│ │ ├── openai_api.py
│ │ ├── postScript.py
│ │ └── requirements.txt
│ ├── cmd/
│ │ ├── run.py
│ │ └── stream.py
│ ├── ipynb/
│ │ ├── colab.ipynb
│ │ └── example.ipynb
│ ├── onnx/
│ │ ├── README.md
│ │ ├── exporter.py
│ │ ├── gpt.py
│ │ └── modeling_llama.py
│ └── web/
│ ├── __init__.py
│ ├── ex.py
│ ├── funcs.py
│ └── webui.py
├── openai_api.ipynb
├── requirements.txt
├── setup.py
├── tests/
│ ├── #511.py
│ ├── #588.py
│ ├── #655.py
│ └── testall.sh
└── tools/
├── __init__.py
├── audio/
│ ├── __init__.py
│ ├── av.py
│ ├── ffmpeg.py
│ ├── np.py
│ └── pcm.py
├── checksum/
│ ├── main.go
│ └── tmpl.go
├── llm/
│ ├── __init__.py
│ └── llm.py
├── logger/
│ ├── __init__.py
│ └── log.py
├── normalizer/
│ ├── __init__.py
│ ├── en.py
│ └── zh.py
└── seeder/
├── __init__.py
└── ctx.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
# ignore jupyter notebooks in the language bar on github
**/*.ipynb linguist-vendored
*.ipynb
================================================
FILE: .github/workflows/checksum.yml
================================================
name: Calculate and Sync SHA256
on:
workflow_dispatch:
jobs:
checksum:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup Go Environment
uses: actions/setup-go@v5
- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.11/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
- name: Calculate all Checksums
run: go run tools/checksum/*.go
- name: Commit back
if: ${{ !github.head_ref }}
id: commitback
continue-on-error: true
run: |
git config --local user.name 'github-actions[bot]'
git config --local user.email 'github-actions[bot]@users.noreply.github.com'
git add --all
git commit -m "chore(env): sync checksum on ${{github.ref_name}}"
- name: Create Pull Request
if: steps.commitback.outcome == 'success'
continue-on-error: true
uses: peter-evans/create-pull-request@v5
with:
delete-branch: true
body: "Automatically sync checksum in .env"
title: "chore(env): sync checksum on ${{github.ref_name}}"
commit-message: "chore(env): sync checksum on ${{github.ref_name}}"
branch: checksum-${{github.ref_name}}
================================================
FILE: .github/workflows/close-issue.yml
================================================
name: Close Inactive Issues
on:
schedule:
- cron: "0 4 * * *"
jobs:
close-issues:
runs-on: ubuntu-24.04
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
exempt-issue-labels: "help wanted,following up,todo list,enhancement,algorithm,delayed,performance"
days-before-issue-stale: 30
days-before-issue-close: 15
stale-issue-label: "stale"
close-issue-message: "This issue was closed because it has been inactive for 15 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
operations-per-run: 10000
repo-token: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .github/workflows/pull-format.yml
================================================
name: Check Pull Request Format
on:
pull_request_target:
types: [opened, reopened, synchronize]
jobs:
# This workflow closes invalid PR
change-or-close-pr:
# The type of runner that the job will run on
runs-on: ubuntu-24.04
permissions: write-all
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
- name: Change Base Branch
if: github.event.pull_request.base.ref != 'dev'
uses: actions/github-script@v4
id: change-base
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const { owner, repo, number } = context.issue;
const newBase = 'dev';
try {
const result = await github.pulls.update({
owner,
repo,
pull_number: number,
base: newBase
});
console.log(result);
return 'success';
} catch (error) {
console.log(error);
return 'failed';
}
- name: Close PR if it is not pointed to dev Branch
if: "github.event.pull_request.base.ref != 'dev' && steps.change-base.outputs.result == 'failed'"
uses: superbrothers/close-pull-request@v3
with:
# Optional. Post a issue comment just before closing a pull request.
comment: "Invalid PR to `non-dev` branch `${{ github.event.pull_request.base.ref }}`."
pull-format:
runs-on: ubuntu-latest
permissions:
contents: write
continue-on-error: true
steps:
- name: Checkout Repo
continue-on-error: true
uses: actions/checkout@v4
- name: Checkout PR # see https://github.com/orgs/community/discussions/24945
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh pr checkout ${{ github.event.pull_request.number }}
- name: Set up Python
uses: actions/setup-python@v5
- name: Create venv
run: python3 -m venv .venv
- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install Black
run: pip install "black[jupyter]"
- name: Run Black
# run: black $(git ls-files '*.py')
run: black .
- name: Commit back
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
continue-on-error: true
run: |
git config --local user.name 'github-actions[bot]'
git config --local user.email 'github-actions[bot]@users.noreply.github.com'
git add --all
git commit -m "chore(format): run black on ${{github.ref_name}}"
git push
================================================
FILE: .github/workflows/push-format.yml
================================================
name: Standardize Code Format
on:
push:
branches:
- main
- dev
jobs:
push-format:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'chore(format): ') && !contains(github.event.head_commit.message, 'chore(env): ')"
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
with:
ref: ${{github.ref_name}}
- name: Set up Python
uses: actions/setup-python@v5
- name: Create venv
run: python3 -m venv .venv
- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install Black
run: pip install "black[jupyter]"
- name: Run Black
# run: black $(git ls-files '*.py')
run: black .
- name: Commit Back
continue-on-error: true
id: commitback
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add --all
git commit -m "chore(format): run black on ${{github.ref_name}}"
- name: Create Pull Request
if: steps.commitback.outcome == 'success'
continue-on-error: true
uses: peter-evans/create-pull-request@v5
with:
delete-branch: true
body: "Automatically apply code formatter change"
title: "chore(format): run black on ${{github.ref_name}}"
commit-message: "chore(format): run black on ${{github.ref_name}}"
branch: formatter-${{github.ref_name}}
================================================
FILE: .github/workflows/unitest.yml
================================================
name: Unit Test
on: [ push, pull_request ]
jobs:
build:
runs-on: ${{ matrix.os }}
if: "!contains(github.event.head_commit.message, 'chore(format): ') && !contains(github.event.head_commit.message, 'chore(env): ')"
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
os: [ubuntu-latest]
fail-fast: true
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependents
run: |
sudo apt-get install -y portaudio19-dev python3-pyaudio
- name: Create venv
run: python3 -m venv .venv
- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Test Install
run: pip install .
- name: Install Dependencies
run: pip install -r requirements.txt
- name: Run Test
run: tests/testall.sh
================================================
FILE: .github/workflows/upload-pypi.yml
================================================
name: Upload to PyPI
on:
push:
tags:
- 'v*'
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
ref: ${{github.ref_name}}
- name: Set up Python
uses: actions/setup-python@v5
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install --upgrade wheel
pip install twine
- name: Build Package
env:
CHTTS_VER: ${{ github.ref_name }}
run: |
echo "Release Tag: ${{ github.ref_name }}"
sed -i 's/v0.0.0/${{ github.ref_name }}/g' setup.py
python setup.py sdist
- name: Upload Package
run: |
twine upload dist/* -u "__token__" -p ${{ secrets.PYPI_TOKEN }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.ckpt
# C extensions
*.so
*.pt
# Distribution / packaging
.Python
outputs/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
asset/*
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# MacOS System
.DS_Store
# assets and configs of ChatTTS
/asset
/config
# inferred result
*.wav
*.mp3
================================================
FILE: ChatTTS/__init__.py
================================================
from .core import Chat
================================================
FILE: ChatTTS/config/__init__.py
================================================
from .config import Config
================================================
FILE: ChatTTS/config/config.py
================================================
from dataclasses import dataclass
@dataclass(repr=False, eq=False)
class Path:
vocos_ckpt_path: str = "asset/Vocos.safetensors"
dvae_ckpt_path: str = "asset/DVAE.safetensors"
gpt_ckpt_path: str = "asset/gpt"
decoder_ckpt_path: str = "asset/Decoder.safetensors"
tokenizer_path: str = "asset/tokenizer"
embed_path: str = "asset/Embed.safetensors"
@dataclass(repr=False, eq=False)
class Decoder:
idim: int = 384
odim: int = 384
hidden: int = 512
n_layer: int = 12
bn_dim: int = 128
@dataclass(repr=False, eq=False)
class VQ:
dim: int = 1024
levels: tuple = (5, 5, 5, 5)
G: int = 2
R: int = 2
@dataclass(repr=False, eq=False)
class DVAE:
encoder: Decoder = Decoder(
idim=512,
odim=1024,
hidden=256,
n_layer=12,
bn_dim=128,
)
decoder: Decoder = Decoder(
idim=512,
odim=512,
hidden=256,
n_layer=12,
bn_dim=128,
)
vq: VQ = VQ()
@dataclass(repr=False, eq=False)
class GPT:
hidden_size: int = 768
intermediate_size: int = 3072
num_attention_heads: int = 12
num_hidden_layers: int = 20
use_cache: bool = False
max_position_embeddings: int = 4096
spk_emb_dim: int = 192
spk_KL: bool = False
num_audio_tokens: int = 626
num_text_tokens: int = 21178
num_vq: int = 4
@dataclass(repr=False, eq=False)
class Embed:
hidden_size: int = 768
num_audio_tokens: int = 626
num_text_tokens: int = 21178
num_vq: int = 4
@dataclass(repr=False, eq=False)
class FeatureExtractorInitArgs:
sample_rate: int = 24000
n_fft: int = 1024
hop_length: int = 256
n_mels: int = 100
padding: str = "center"
@dataclass(repr=False, eq=False)
class FeatureExtractor:
class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures"
init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs()
@dataclass(repr=False, eq=False)
class BackboneInitArgs:
input_channels: int = 100
dim: int = 512
intermediate_dim: int = 1536
num_layers: int = 8
@dataclass(repr=False, eq=False)
class Backbone:
class_path: str = "vocos.models.VocosBackbone"
init_args: BackboneInitArgs = BackboneInitArgs()
@dataclass(repr=False, eq=False)
class FourierHeadInitArgs:
dim: int = 512
n_fft: int = 1024
hop_length: int = 256
padding: str = "center"
@dataclass(repr=False, eq=False)
class FourierHead:
class_path: str = "vocos.heads.ISTFTHead"
init_args: FourierHeadInitArgs = FourierHeadInitArgs()
@dataclass(repr=False, eq=False)
class Vocos:
feature_extractor: FeatureExtractor = FeatureExtractor()
backbone: Backbone = Backbone()
head: FourierHead = FourierHead()
@dataclass(repr=False, eq=False)
class Config:
path: Path = Path()
decoder: Decoder = Decoder()
dvae: DVAE = DVAE()
gpt: GPT = GPT()
embed: Embed = Embed()
vocos: Vocos = Vocos()
spk_stat: str = (
"愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉癅乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆"
)
================================================
FILE: ChatTTS/core.py
================================================
import os
import re
import logging
import tempfile
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
from pathlib import Path
import numpy as np
import torch
from vocos import Vocos
from vocos.pretrained import instantiate_class
from huggingface_hub import snapshot_download
from .config import Config
from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker
from .utils import (
load_safetensors,
check_all_assets,
download_all_assets,
select_device,
get_latest_modified_file,
del_all,
)
from .utils import logger as utils_logger
from .utils import FileLike
from .norm import Normalizer
class Chat:
def __init__(self, logger=logging.getLogger(__name__)):
self.logger = logger
utils_logger.set_logger(logger)
self.config = Config()
self.normalizer = Normalizer(
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
logger,
)
with open(
os.path.join(os.path.dirname(__file__), "res", "sha256_map.json")
) as f:
self.sha256_map: Dict[str, str] = load(f)
self.context = GPT.Context()
def has_loaded(self, use_decoder=False):
not_finish = False
check_list = ["vocos", "gpt", "tokenizer", "embed"]
if use_decoder:
check_list.append("decoder")
else:
check_list.append("dvae")
for module in check_list:
if not hasattr(self, module):
self.logger.warning(f"{module} not initialized.")
not_finish = True
return not not_finish
def download_models(
self,
source: Literal["huggingface", "local", "custom"] = "local",
force_redownload=False,
custom_path: Optional[FileLike] = None,
) -> Optional[str]:
if source == "local":
download_path = custom_path if custom_path is not None else os.getcwd()
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp, homedir=download_path)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error(
"download to local path %s failed.", download_path
)
return None
elif source == "huggingface":
try:
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
)
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=custom_path,
force_download=force_redownload,
)
except:
download_path = None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path
if download_path is None:
self.logger.error("Model download failed")
return None
return download_path
def load(
self,
source: Literal["huggingface", "local", "custom"] = "local",
force_redownload=False,
compile: bool = False,
custom_path: Optional[FileLike] = None,
device: Optional[torch.device] = None,
coef: Optional[str] = None,
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
enable_cache=True,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
if download_path is None:
return False
return self._load(
device=device,
compile=compile,
coef=coef,
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
experimental=experimental,
enable_cache=enable_cache,
**{
k: os.path.join(download_path, v)
for k, v in asdict(self.config.path).items()
},
)
def unload(self):
logger = self.logger
self.normalizer.destroy()
del self.normalizer
del self.sha256_map
del_list = ["vocos", "gpt", "decoder", "dvae", "tokenizer", "embed"]
for module in del_list:
if hasattr(self, module):
delattr(self, module)
self.__init__(logger)
def sample_random_speaker(self) -> str:
return self.speaker.sample_random()
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
return self.speaker.encode_prompt(self.dvae.sample_audio(wav))
@dataclass(repr=False, eq=False)
class RefineTextParams:
prompt: str = ""
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.7
repetition_penalty: float = 1.0
max_new_token: int = 384
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
manual_seed: Optional[int] = None
@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
spk_smp: Optional[str] = None
txt_smp: Optional[str] = None
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
stream_batch: int = 24
stream_speed: int = 12000
pass_first_n_batches: int = 2
def infer(
self,
text,
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
self.context.set(False)
if split_text and isinstance(text, str):
if "\n" in text:
text = text.split("\n")
else:
text = re.split(r"(?<=。)|(?<=\.\s)", text)
nt = []
if isinstance(text, list):
for t in text:
if t:
nt.append(t)
text = nt
else:
text = [text]
self.logger.info("split text into %d parts", len(text))
self.logger.debug("%s", str(text))
if len(text) == 0:
return []
res_gen = self._infer(
text,
stream,
lang,
skip_refine_text,
refine_text_only,
use_decoder,
do_text_normalization,
do_homophone_replacement,
split_text,
max_split_batch,
params_refine_text,
params_infer_code,
)
if stream:
return res_gen
elif not refine_text_only:
stripped_wavs = []
thr = np.float32(1e-5)
for wavs in res_gen:
for wav in wavs:
stripped_wavs.append(wav[np.abs(wav) > thr])
if split_text:
return [np.concatenate(stripped_wavs)]
return stripped_wavs
else:
return next(res_gen)
def interrupt(self):
self.context.set(True)
@torch.no_grad()
def _load(
self,
vocos_ckpt_path: str = None,
dvae_ckpt_path: str = None,
gpt_ckpt_path: str = None,
embed_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: Optional[torch.device] = None,
compile: bool = False,
coef: Optional[str] = None,
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
enable_cache=True,
):
if device is None:
device = select_device(experimental=experimental)
self.logger.info("use device %s", str(device))
self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
self.compile = compile
feature_extractor = instantiate_class(
args=(), init=asdict(self.config.vocos.feature_extractor)
)
backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone))
head = instantiate_class(args=(), init=asdict(self.config.vocos.head))
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
.to(
# Vocos on mps will crash, use cpu fallback.
# Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now,
# so we put this calculation of data on CPU instead of NPU.
"cpu"
if "mps" in str(device) or "npu" in str(device)
else device
)
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(load_safetensors(vocos_ckpt_path))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")
# computation of MelSpectrogram on npu is not support now, use cpu fallback.
dvae_device = torch.device("cpu") if "npu" in str(self.device) else device
dvae = DVAE(
decoder_config=asdict(self.config.dvae.decoder),
encoder_config=asdict(self.config.dvae.encoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=dvae_device,
)
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_pretrained(dvae_ckpt_path, dvae_device)
self.dvae = dvae.eval()
self.logger.log(logging.INFO, "dvae loaded.")
embed = Embed(
self.config.embed.hidden_size,
self.config.embed.num_audio_tokens,
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.load_pretrained(embed_path, device=device)
self.embed = embed.to(device)
self.logger.log(logging.INFO, "embed loaded.")
gpt = GPT(
gpt_config=asdict(self.config.gpt),
embed=self.embed,
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
device=device,
device_gpt=self.device_gpt,
logger=self.logger,
enable_cache=enable_cache,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
self.logger.log(logging.INFO, "gpt loaded.")
self.speaker = Speaker(
self.config.gpt.hidden_size, self.config.spk_stat, device
)
self.logger.log(logging.INFO, "speaker loaded.")
decoder = DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_pretrained(decoder_ckpt_path, device)
self.decoder = decoder.eval()
self.logger.log(logging.INFO, "decoder loaded.")
if tokenizer_path:
self.tokenizer = Tokenizer(tokenizer_path)
self.logger.log(logging.INFO, "tokenizer loaded.")
self.coef = coef
return self.has_loaded()
def _infer(
self,
text: Union[List[str], str],
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
assert self.has_loaded(use_decoder=use_decoder)
if not isinstance(text, list):
text = [text]
text = [
self.normalizer(
t,
do_text_normalization,
do_homophone_replacement,
lang,
)
for t in text
]
self.logger.debug("normed texts %s", str(text))
if not skip_refine_text:
refined = self._refine_text(
text,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [i[i.less(self.tokenizer.break_0_ids)] for i in text_tokens]
text = self.tokenizer.decode(text_tokens)
self.logger.debug("refined texts %s", str(text))
refined.destroy()
if refine_text_only:
if split_text and isinstance(text, list):
text = "\n".join(text)
yield text
return
if split_text and len(text) > 1 and params_infer_code.spk_smp is None:
refer_text = text[0]
result = next(
self._infer_code(
refer_text,
False,
self.device,
use_decoder,
params_infer_code,
)
)
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
assert len(wavs), 1
params_infer_code.spk_smp = self.sample_audio_speaker(wavs[0])
params_infer_code.txt_smp = refer_text
if stream:
length = 0
pass_batch_count = 0
if split_text:
n = len(text) // max_split_batch
if len(text) % max_split_batch:
n += 1
else:
n = 1
max_split_batch = len(text)
for i in range(n):
text_remain = text[i * max_split_batch :]
if len(text_remain) > max_split_batch:
text_remain = text_remain[:max_split_batch]
if split_text:
self.logger.info(
"infer split %d~%d",
i * max_split_batch,
i * max_split_batch + len(text_remain),
)
for result in self._infer_code(
text_remain,
stream,
self.device,
use_decoder,
params_infer_code,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
keep_cols = np.sum(np.abs(new_wavs) > 1e-5, axis=0) > 0
yield new_wavs[:][:, keep_cols]
@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device) or "npu" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
else:
return self.vocos.decode(spec).cpu().numpy()
@torch.inference_mode()
def _decode_to_wavs(
self,
result_list: List[torch.Tensor],
use_decoder: bool,
):
decoder = self.decoder if use_decoder else self.dvae
max_x_len = -1
if len(result_list) == 0:
return np.array([], dtype=np.float32)
for result in result_list:
if result.size(0) > max_x_len:
max_x_len = result.size(0)
batch_result = torch.zeros(
(len(result_list), result_list[0].size(1), max_x_len),
dtype=result_list[0].dtype,
device=result_list[0].device,
)
for i in range(len(result_list)):
src = result_list[i]
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
del src
del_all(result_list)
mel_specs = decoder(batch_result)
del batch_result
wavs = self._vocos_decode(mel_specs)
del mel_specs
return wavs
@torch.no_grad()
def _infer_code(
self,
text: Tuple[List[str], str],
stream: bool,
device: torch.device,
return_hidden: bool,
params: InferCodeParams,
):
gpt = self.gpt
if not isinstance(text, list):
text = [text]
assert len(text), "text should not be empty"
if not isinstance(params.temperature, list):
temperature = [params.temperature] * self.config.gpt.num_vq
else:
temperature = params.temperature
input_ids, attention_mask, text_mask = self.tokenizer.encode(
self.speaker.decorate_code_prompts(
text,
params.prompt,
params.txt_smp,
params.spk_emb,
),
self.config.gpt.num_vq,
prompt=(
self.speaker.decode_prompt(params.spk_smp)
if params.spk_smp is not None
else None
),
device=self.device_gpt,
)
start_idx = input_ids.shape[-2]
num_code = self.config.gpt.num_audio_tokens - 1
logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)
if gpt.is_vllm:
from .model.velocity import SamplingParams
sample_params = SamplingParams(
temperature=temperature,
max_new_token=params.max_new_token,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_processors, logits_warpers),
eos_token=num_code,
infer_text=False,
start_idx=start_idx,
)
input_ids = [i.tolist() for i in input_ids]
result = gpt.llm.generate(
None,
sample_params,
input_ids,
)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)
del text_mask, input_ids
return [
GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]
emb = self.embed(input_ids, text_mask)
del text_mask
if params.spk_emb is not None:
self.speaker.apply(
emb,
params.spk_emb,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
)
result = gpt.generate(
emb,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)
del emb, input_ids
return result
@torch.no_grad()
def _refine_text(
self,
text: str,
device: torch.device,
params: RefineTextParams,
):
gpt = self.gpt
if not isinstance(text, list):
text = [text]
input_ids, attention_mask, text_mask = self.tokenizer.encode(
self.speaker.decorate_text_prompts(text, params.prompt),
self.config.gpt.num_vq,
device=self.device_gpt,
)
logits_warpers, logits_processors = gen_logits(
num_code=self.tokenizer.len,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)
if gpt.is_vllm:
from .model.velocity import SamplingParams
sample_params = SamplingParams(
repetition_penalty=params.repetition_penalty,
temperature=params.temperature,
top_p=params.top_P,
top_k=params.top_K,
max_new_token=params.max_new_token,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_processors, logits_warpers),
eos_token=self.tokenizer.eos_token,
infer_text=True,
start_idx=input_ids.shape[-2],
)
input_ids_list = [i.tolist() for i in input_ids]
del input_ids
result = gpt.llm.generate(
None, sample_params, input_ids_list, params.show_tqdm
)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states)
del text_mask, input_ids_list, result
return GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
)
emb = self.embed(input_ids, text_mask)
del text_mask
result = next(
gpt.generate(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=self.tokenizer.eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=True,
stream=False,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
manual_seed=params.manual_seed,
context=self.context,
)
)
del emb, input_ids
return result
================================================
FILE: ChatTTS/model/__init__.py
================================================
from .dvae import DVAE
from .embed import Embed
from .gpt import GPT
from .processors import gen_logits
from .speaker import Speaker
from .tokenizer import Tokenizer
================================================
FILE: ChatTTS/model/cuda/__init__.py
================================================
from .te_llama import TELlamaModel
================================================
FILE: ChatTTS/model/cuda/patch.py
================================================
import torch
class LlamaRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(hidden_states.device) * hidden_states.to(input_dtype)
================================================
FILE: ChatTTS/model/cuda/te_llama.py
================================================
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py
#
# Edited by fumiama.
import re
from contextlib import contextmanager
from typing import Dict
import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import torch
import transformers
from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaConfig,
)
from transformers.modeling_utils import _load_state_dict_into_model
from .patch import LlamaRMSNorm
@contextmanager
def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = (
transformers.models.llama.modeling_llama.LlamaDecoderLayer
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm
transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
original_llama_decoder_cls
)
transformers.models.llama.modeling_llama.LlamaRMSNorm = (
original_llama_rms_norm_cls
)
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
Args:
config: LlamaConfig
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=False,
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(
config.hidden_size // config.num_attention_heads
)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=self.te_rope_emb,
),
)
class TELlamaModel:
"""
LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
class is monkey-patched with `TELlamaDecoderLayer` class before
initializing the causal LM with `LlamaModel`.
Args:
config: LlamaConfig
"""
def __new__(cls, config: LlamaConfig):
with replace_decoder(
te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm
):
model = LlamaModel(config)
return model
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
config: LlamaConfig,
):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config)
# replace_params copies parameters relevant only to TransformerEngine
_replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
return vanilla_model
def _replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.query_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.key_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.value_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
)
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
)
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
: config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
)
return all_layer_prefixes
================================================
FILE: ChatTTS/model/dvae.py
================================================
import math
from typing import List, Optional, Literal, Union
import numpy as np
import pybase16384 as b14
import torch
import torch.nn as nn
import torchaudio
from vector_quantize_pytorch import GroupedResidualFSQ
from ..utils import load_safetensors
class ConvNeXtBlock(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
kernel: int,
dilation: int,
layer_scale_init_value: float = 1e-6,
):
# ConvNeXt Block copied from Vocos.
super().__init__()
self.dwconv = nn.Conv1d(
dim,
dim,
kernel_size=kernel,
padding=dilation * (kernel // 2),
dilation=dilation,
groups=dim,
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.weight = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
residual = x
y = self.dwconv(x)
y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(y)
del y
y = self.pwconv1(x)
del x
x = self.act(y)
del y
y = self.pwconv2(x)
del x
if self.weight is not None:
y *= self.weight
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
x = y + residual
del y
return x
class GFSQ(nn.Module):
def __init__(
self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True
):
super(GFSQ, self).__init__()
self.quantizer = GroupedResidualFSQ(
dim=dim,
levels=list(levels),
num_quantizers=R,
groups=G,
)
self.n_ind = math.prod(levels)
self.eps = eps
self.transpose = transpose
self.G = G
self.R = R
def _embed(self, x: torch.Tensor):
if self.transpose:
x = x.transpose(1, 2)
"""
x = rearrange(
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
)
"""
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose_(1, 2) if self.transpose else feat
def __call__(self, x: torch.Tensor) -> torch.Tensor:
return super().__call__(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.transpose:
x.transpose_(1, 2)
# feat, ind = self.quantizer(x)
_, ind = self.quantizer(x)
"""
ind = rearrange(
ind, "g b t r ->b t (g r)",
)
"""
ind = ind.permute(1, 2, 0, 3).contiguous()
ind = ind.view(ind.size(0), ind.size(1), -1)
"""
embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
embed_onehot = embed_onehot_tmp.to(x.dtype)
del embed_onehot_tmp
e_mean = torch.mean(embed_onehot, dim=[0, 1])
# e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
return
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
feat.transpose_(1, 2) if self.transpose else feat,
perplexity,
"""
return ind.transpose_(1, 2) if self.transpose else ind
class DVAEDecoder(nn.Module):
def __init__(
self,
idim: int,
odim: int,
n_layer=12,
bn_dim=64,
hidden=256,
kernel=7,
dilation=2,
up=False,
):
super().__init__()
self.up = up
self.conv_in = nn.Sequential(
nn.Conv1d(idim, bn_dim, 3, 1, 1),
nn.GELU(),
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
)
self.decoder_block = nn.ModuleList(
[
ConvNeXtBlock(
hidden,
hidden * 4,
kernel,
dilation,
)
for _ in range(n_layer)
]
)
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
# B, C, T
y = self.conv_in(x)
del x
for f in self.decoder_block:
y = f(y, conditioning)
x = self.conv_out(y)
del y
return x
class MelSpectrogramFeatures(torch.nn.Module):
def __init__(
self,
sample_rate=24000,
n_fft=1024,
hop_length=256,
n_mels=100,
padding: Literal["center", "same"] = "center",
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.device = device
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=padding == "center",
power=1,
)
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
return super().__call__(audio)
def forward(self, audio: torch.Tensor) -> torch.Tensor:
audio = audio.to(self.device)
mel: torch.Tensor = self.mel_spec(audio)
features = torch.log(torch.clip(mel, min=1e-5))
return features
class DVAE(nn.Module):
def __init__(
self,
decoder_config: dict,
encoder_config: Optional[dict] = None,
vq_config: Optional[dict] = None,
dim=512,
coef: Optional[str] = None,
device: torch.device = torch.device("cpu"),
):
super().__init__()
if coef is None:
coef = torch.rand(100)
else:
coef = torch.from_numpy(
np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy()
)
self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2))
if encoder_config is not None:
self.downsample_conv = nn.Sequential(
nn.Conv1d(100, dim, 3, 1, 1),
nn.GELU(),
nn.Conv1d(dim, dim, 4, 2, 1),
nn.GELU(),
)
self.preprocessor_mel = MelSpectrogramFeatures(device=device)
self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config)
self.decoder = DVAEDecoder(**decoder_config)
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
if vq_config is not None:
self.vq_layer = GFSQ(**vq_config)
else:
self.vq_layer = None
def __repr__(self) -> str:
return b14.encode_to_string(
self.coef.cpu().numpy().astype(np.float32).tobytes()
)
def __call__(
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
) -> torch.Tensor:
return super().__call__(inp, mode)
@torch.inference_mode()
def load_pretrained(self, filename: str, device: torch.device):
state_dict_tensors = load_safetensors(filename)
self.load_state_dict(state_dict_tensors)
self.to(device)
@torch.inference_mode()
def forward(
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
) -> torch.Tensor:
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
mel = self.preprocessor_mel(inp)
x: torch.Tensor = self.downsample_conv(
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
).unsqueeze_(0)
del mel
x = self.encoder(x)
ind = self.vq_layer(x)
del x
return ind
if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp
vq_feats = (
vq_feats.view(
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
)
.permute(0, 2, 3, 1)
.flatten(2)
)
dec_out = self.out_conv(
self.decoder(
x=vq_feats,
),
)
del vq_feats
return torch.mul(dec_out, self.coef, out=dec_out)
@torch.inference_mode()
def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
return self(wav, "encode").squeeze_(0)
================================================
FILE: ChatTTS/model/embed.py
================================================
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
from ..utils import load_safetensors
class Embed(nn.Module):
def __init__(
self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4
):
super().__init__()
self.num_vq = num_vq
self.num_audio_tokens = num_audio_tokens
self.model_dim = hidden_size
self.emb_code = nn.ModuleList(
[nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)],
)
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
self.head_text = weight_norm(
nn.Linear(self.model_dim, num_text_tokens, bias=False),
name="weight",
)
self.head_code = nn.ModuleList(
[
weight_norm(
nn.Linear(self.model_dim, num_audio_tokens, bias=False),
name="weight",
)
for _ in range(self.num_vq)
],
)
@torch.inference_mode()
def load_pretrained(self, filename: str, device: torch.device):
state_dict_tensors = load_safetensors(filename)
self.load_state_dict(state_dict_tensors)
self.to(device)
def __call__(
self, input_ids: torch.Tensor, text_mask: torch.Tensor
) -> torch.Tensor:
"""
get_emb
"""
return super().__call__(input_ids, text_mask)
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
"""
get_emb
"""
device = next(self.parameters()).device
input_ids_dev = input_ids.to(device)
text_mask_dev = text_mask.to(device)
emb_text: torch.Tensor = self.emb_text(
input_ids_dev[text_mask_dev].narrow(1, 0, 1).squeeze_(1)
)
text_mask_inv = text_mask_dev.logical_not()
masked_input_ids: torch.Tensor = input_ids_dev[text_mask_inv]
emb_code = [
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
]
emb_code = torch.stack(emb_code, 2).sum(2)
emb = torch.zeros(
(input_ids_dev.shape[:-1]) + (emb_text.shape[-1],),
device=emb_text.device,
dtype=emb_text.dtype,
)
emb[text_mask_dev] = emb_text
emb[text_mask_inv] = emb_code.to(emb.dtype)
del emb_text, emb_code, text_mask_inv
return emb
================================================
FILE: ChatTTS/model/gpt.py
================================================
import platform
from dataclasses import dataclass
import logging
from typing import Union, List, Optional, Tuple, Callable
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
from tqdm import tqdm
from transformers import LlamaModel, LlamaConfig
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_flash_attn_2_available
from ..utils import del_all
from .embed import Embed
class GPT(nn.Module):
def __init__(
self,
gpt_config: dict,
embed: Embed,
use_flash_attn=False,
use_vllm=False,
device=torch.device("cpu"),
device_gpt=torch.device("cpu"),
logger=logging.getLogger(__name__),
enable_cache=True,
):
super().__init__()
self.logger = logger
self.device = device
self.device_gpt = device_gpt
self.enable_cache = enable_cache
self.generator = torch.Generator(device=device)
self.num_vq = int(gpt_config["num_vq"])
self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
self.num_text_tokens = int(gpt_config["num_text_tokens"])
self.use_flash_attn = use_flash_attn
self.is_te_llama = False
self.is_vllm = use_vllm
if self.is_vllm:
return
self.llama_config = self._build_llama_config(gpt_config)
self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]
def load_pretrained(
self, gpt_folder: str, embed_file_path: str, experimental=False
):
if self.is_vllm and platform.system().lower() == "linux":
from .velocity import LLM
self.llm = LLM(
model=gpt_folder,
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
)
self.logger.info("vLLM model loaded")
return
self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to(
self.device_gpt
)
del self.gpt.embed_tokens
if (
experimental
and "cuda" in str(self.device_gpt)
and platform.system().lower() == "linux"
): # is TELlamaModel
try:
from .cuda import TELlamaModel
self.logger.warning(
"Linux with CUDA, try NVIDIA accelerated TELlamaModel because experimental is enabled"
)
state_dict = self.gpt.state_dict()
vanilla = TELlamaModel.from_state_dict(state_dict, self.llama_config)
# Force mem release. Taken from huggingface code
del state_dict, self.gpt
gc.collect()
self.gpt = vanilla
self.is_te_llama = True
except Exception as e:
self.logger.warning(
f"use default LlamaModel for importing TELlamaModel error: {e}"
)
class Context:
def __init__(self):
self._interrupt = False
def set(self, v: bool):
self._interrupt = v
def get(self) -> bool:
return self._interrupt
def _build_llama_config(
self,
config: dict,
) -> Tuple[LlamaModel, LlamaConfig]:
if self.use_flash_attn and is_flash_attn_2_available():
llama_config = LlamaConfig(
**config,
attn_implementation="flash_attention_2",
)
self.logger.warning(
"enabling flash_attention_2 may make gpt be even slower"
)
else:
llama_config = LlamaConfig(**config)
return llama_config
def prepare(self, compile=False):
if self.use_flash_attn and is_flash_attn_2_available():
self.gpt = self.gpt.to(dtype=torch.float16)
if compile and not self.is_te_llama and not self.is_vllm:
try:
self.compile(backend="inductor", dynamic=True)
self.gpt.compile(backend="inductor", dynamic=True)
except RuntimeError as e:
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")
@dataclass(repr=False, eq=False)
class _GenerationInputs:
position_ids: torch.Tensor
cache_position: torch.Tensor
input_ids: Optional[torch.Tensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attention_mask: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
def to(self, device: torch.device, dtype: torch.dtype):
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device, dtype=dtype)
if self.position_ids is not None:
self.position_ids = self.position_ids.to(device, dtype=dtype)
if self.inputs_embeds is not None:
self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
if self.cache_position is not None:
self.cache_position = self.cache_position.to(device, dtype=dtype)
@torch.no_grad()
def _prepare_generation_inputs(
self,
input_ids: torch.Tensor,
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> _GenerationInputs:
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
if hasattr(self.gpt.layers[0], "self_attn"):
past_key_values = getattr(
self.gpt.layers[0].self_attn, "past_key_value", None
)
has_static_cache = past_key_values is not None
past_length = 0
max_cache_length = None
cache_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
if past_key_values.layers and len(past_key_values.layers):
past_length = (
int(cache_position[0])
if cache_position is not None
else past_key_values.get_seq_length()
)
try:
max_cache_length = past_key_values.get_max_cache_shape()
except:
max_cache_length = (
past_key_values.get_max_length()
) # deprecated in transformers 4.48
cache_length = (
past_length
if max_cache_length is None
else min(max_cache_length, past_length)
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
start = attention_mask.shape[1] - past_length
input_ids = input_ids.narrow(1, -start, start)
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids.narrow(
1, past_length, input_ids.size(1) - past_length
)
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and max_cache_length > 0
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask.narrow(
1, -max_cache_length, max_cache_length
)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask.eq(0), 1)
if past_key_values:
position_ids = position_ids.narrow(
1, -input_ids.shape[1], input_ids.shape[1]
)
input_length = (
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + input_length, device=input_ids.device
)
else:
cache_position = cache_position.narrow(0, -input_length, input_length)
if has_static_cache:
past_key_values = None
model_inputs = self._GenerationInputs(
position_ids=position_ids,
cache_position=cache_position,
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs.inputs_embeds = inputs_embeds
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs.input_ids = input_ids.contiguous()
model_inputs.past_key_values = past_key_values
model_inputs.attention_mask = attention_mask
return model_inputs
@dataclass(repr=False, eq=False)
class GenerationOutputs:
ids: List[torch.Tensor]
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
hiddens: List[torch.Tensor]
def destroy(self):
del_all(self.ids)
del_all(self.attentions)
del_all(self.hiddens)
@torch.no_grad()
def _prepare_generation_outputs(
self,
inputs_ids: torch.Tensor,
start_idx: int,
end_idx: torch.Tensor,
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
hiddens: List[torch.Tensor],
infer_text: bool,
) -> GenerationOutputs:
inputs_ids = [
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
]
if infer_text:
inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
if len(hiddens) > 0:
hiddens = torch.stack(hiddens, 1)
hiddens = [
hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
]
return self.GenerationOutputs(
ids=inputs_ids,
attentions=attentions,
hiddens=hiddens,
)
@torch.no_grad()
def generate(
self,
emb: torch.Tensor,
inputs_ids: torch.Tensor,
temperature: torch.Tensor,
eos_token: Union[int, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
max_new_token=2048,
min_new_token=0,
logits_processors: Tuple[
Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]
] = (),
infer_text=False,
return_attn=False,
return_hidden=False,
stream=False,
show_tqdm=True,
ensure_non_empty=True,
stream_batch=24,
manual_seed: Optional[int] = None,
context=Context(),
):
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
hiddens = []
stream_iter = 0
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
)
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
old_temperature = temperature
temperature = (
temperature.unsqueeze(0)
.expand(inputs_ids.shape[0], -1)
.contiguous()
.view(-1, 1)
)
attention_mask_cache = torch.ones(
(
inputs_ids.shape[0],
inputs_ids.shape[1] + max_new_token,
),
dtype=torch.bool,
device=inputs_ids.device,
)
if attention_mask is not None:
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
)
progress = inputs_ids.size(1)
# pre-allocate inputs_ids
inputs_ids_buf = torch.zeros(
inputs_ids.size(0),
progress + max_new_token,
inputs_ids.size(2),
dtype=inputs_ids.dtype,
device=inputs_ids.device,
)
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
del inputs_ids
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
pbar: Optional[tqdm] = None
if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="text" if infer_text else "code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)
past_key_values = None
for i in range(max_new_token):
model_input = self._prepare_generation_inputs(
inputs_ids,
past_key_values,
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
)
if i > 0:
del emb
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
if infer_text:
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
else:
code_emb = [
self.emb_code[i](inputs_ids_emb[:, :, i])
for i in range(self.num_vq)
]
emb = torch.stack(code_emb, 3).sum(3)
del inputs_ids_emb, model_input.input_ids
model_input.inputs_embeds = emb
model_input.to(self.device_gpt, self.gpt.dtype)
outputs: BaseModelOutputWithPast = self.gpt(
attention_mask=model_input.attention_mask,
position_ids=model_input.position_ids,
past_key_values=model_input.past_key_values,
inputs_embeds=model_input.inputs_embeds,
use_cache=not self.is_te_llama and self.enable_cache,
output_attentions=return_attn,
cache_position=model_input.cache_position,
)
del_all(model_input)
attentions.append(outputs.attentions)
hidden_states = outputs.last_hidden_state.to(
self.device, dtype=torch.float
) # 🐻
past_key_values = outputs.past_key_values
del_all(outputs)
if return_hidden:
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
with P.cached():
if infer_text:
logits: torch.Tensor = self.head_text(hidden_states)
else:
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
# logits = logits[:, -1].float()
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
if not infer_text:
# logits = rearrange(logits, "b c n -> (b n) c")
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids.narrow(
1,
start_idx,
inputs_ids.size(1) - start_idx,
).permute(0, 2, 1)
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
-1,
).to(self.device)
del inputs_ids_sliced
else:
logits_token = (
inputs_ids.narrow(
1,
start_idx,
inputs_ids.size(1) - start_idx,
)
.narrow(2, 0, 1)
.to(self.device)
)
logits /= temperature
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
del logits_token
if i < min_new_token:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
if manual_seed is None:
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
else:
idx_next = torch.multinomial(
scores,
num_samples=1,
generator=self.generator.manual_seed(manual_seed),
).to(finish.device)
del scores
if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
else:
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_buf.narrow(1, progress, 1).copy_(
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
)
if i == 0 and finish.any():
self.logger.warning(
"unexpected end at index %s",
str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
)
if ensure_non_empty and manual_seed is None:
if show_tqdm:
pbar.close()
self.logger.warning("regenerate in order to ensure non-empty")
del_all(attentions)
del_all(hiddens)
del (
start_idx,
end_idx,
finish,
temperature,
attention_mask_cache,
past_key_values,
idx_next,
inputs_ids_buf,
)
new_gen = self.generate(
emb,
inputs_ids,
old_temperature,
eos_token,
attention_mask,
max_new_token,
min_new_token,
logits_processors,
infer_text,
return_attn,
return_hidden,
stream,
show_tqdm,
ensure_non_empty,
stream_batch,
manual_seed,
context,
)
for result in new_gen:
yield result
del inputs_ids
return
del idx_next
progress += 1
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
not_finished = finish.logical_not().to(end_idx.device)
end_idx.add_(not_finished.int())
stream_iter += not_finished.any().int()
if stream:
if stream_iter > 0 and stream_iter % stream_batch == 0:
self.logger.debug("yield stream result, end: %d", end_idx)
yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
)
del not_finished
if finish.all() or context.get():
break
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
if not finish.all():
if context.get():
self.logger.warning("generation is interrupted")
else:
self.logger.warning(
f"incomplete result. hit max_new_token: {max_new_token}"
)
del finish, inputs_ids_buf
yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
)
================================================
FILE: ChatTTS/model/processors.py
================================================
import torch
import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
class CustomRepetitionPenaltyLogitsProcessorRepeat:
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(
f"`penalty` has to be a strictly positive float, but is {penalty}"
)
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if input_ids.size(1) > self.past_window:
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
if freq.size(0) > self.max_input_ids:
freq.narrow(
0, self.max_input_ids, freq.size(0) - self.max_input_ids
).zero_()
alpha = torch.pow(self.penalty, freq)
scores = scores.contiguous()
inp = scores.multiply(alpha)
oth = scores.divide(alpha)
con = scores < 0
out = torch.where(con, inp, oth)
del inp, oth, scores, con, alpha
return out
def gen_logits(
num_code: int,
top_P=0.7,
top_K=20,
repetition_penalty=1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(
CustomRepetitionPenaltyLogitsProcessorRepeat(
repetition_penalty, num_code, 16
)
)
return logits_warpers, logits_processors
================================================
FILE: ChatTTS/model/speaker.py
================================================
import lzma
from typing import List, Optional, Union
import pybase16384 as b14
import numpy as np
import torch
import torch.nn.functional as F
class Speaker:
def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None:
spk_stat = torch.from_numpy(
np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()
).to(device=device)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.dim = dim
def sample_random(self) -> str:
return self._encode(self._sample_random())
@torch.inference_mode()
def apply(
self,
emb: torch.Tensor,
spk_emb: Union[str, torch.Tensor],
input_ids: torch.Tensor,
spk_emb_ids: int,
device: torch.device,
inplace: bool = True,
) -> torch.Tensor:
if isinstance(spk_emb, str):
spk_emb_tensor = torch.from_numpy(self._decode(spk_emb))
else:
spk_emb_tensor = spk_emb
n = (
F.normalize(
spk_emb_tensor,
p=2.0,
dim=0,
eps=1e-12,
)
.to(device)
.unsqueeze_(0)
.expand(emb.size(0), -1)
.unsqueeze_(1)
.expand(emb.shape)
)
cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape)
out = torch.where(cond, n, emb, out=emb if inplace else None)
if inplace:
del cond, n
return out
@staticmethod
@torch.no_grad()
def decorate_code_prompts(
text: List[str],
prompt: str,
txt_smp: Optional[str],
spk_emb: Optional[str],
) -> List[str]:
for i, t in enumerate(text):
text[i] = (
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.strip()
)
"""
see https://github.com/2noise/ChatTTS/issues/459
"""
if prompt:
text = [prompt + i for i in text]
txt_smp = "" if txt_smp is None else txt_smp
if spk_emb is not None:
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]
return text
@staticmethod
@torch.no_grad()
def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
@staticmethod
@torch.no_grad()
def encode_prompt(prompt: torch.Tensor) -> str:
arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16)
shp = arr.shape
assert len(shp) == 2, "prompt must be a 2D tensor"
s = b14.encode_to_string(
np.array(shp, dtype="<u2").tobytes()
+ lzma.compress(
arr.astype("<u2").tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s
@staticmethod
@torch.no_grad()
def decode_prompt(prompt: str) -> torch.Tensor:
dec = b14.decode_from_string(prompt)
shp = np.frombuffer(dec[:4], dtype="<u2")
p = np.frombuffer(
lzma.decompress(
dec[4:],
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype="<u2",
).copy()
del dec
return torch.from_numpy(p.astype(np.int32)).view(*shp)
@torch.no_grad()
def _sample_random(self) -> torch.Tensor:
spk = (
torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype)
.mul_(self.std)
.add_(self.mean)
)
return spk
@staticmethod
@torch.no_grad()
def _encode(spk_emb: torch.Tensor) -> str:
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s
@staticmethod
def _decode(spk_emb: str) -> np.ndarray:
return np.frombuffer(
lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype=np.float16,
).copy()
================================================
FILE: ChatTTS/model/tokenizer.py
================================================
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""
from typing import List, Tuple, Optional, Union
import torch
from transformers import BertTokenizerFast
from ..utils import del_all, FileLike
class Tokenizer:
def __init__(
self,
tokenizer_path: FileLike,
):
"""
tokenizer: BertTokenizerFast = torch.load(
tokenizer_path, map_location=device, mmap=True
)
# tokenizer.save_pretrained("asset/tokenizer", legacy_format=False)
"""
tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path)
self._tokenizer = tokenizer
self.len = len(tokenizer)
self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")
@torch.inference_mode()
def encode(
self,
text: List[str],
num_vq: int,
prompt: Optional[torch.Tensor] = None,
device="cpu",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_ids_lst = []
attention_mask_lst = []
max_input_ids_len = -1
max_attention_mask_len = -1
prompt_size = 0
if prompt is not None:
assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
prompt_size = prompt.size(1)
# avoid random speaker embedding of tokenizer in the other dims
for t in text:
x = self._tokenizer.encode_plus(
t, return_tensors="pt", add_special_tokens=False, padding=True
)
input_ids_lst.append(x["input_ids"].squeeze_(0))
attention_mask_lst.append(x["attention_mask"].squeeze_(0))
del_all(x)
ids_sz = input_ids_lst[-1].size(0)
if ids_sz > max_input_ids_len:
max_input_ids_len = ids_sz
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz
if prompt is not None:
max_input_ids_len += prompt_size
max_attention_mask_len += prompt_size
input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(
1,
max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
input_ids_lst[i].size(0),
).copy_(
input_ids_lst[i]
) # left padding
del_all(input_ids_lst)
attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attn = attention_mask.narrow(0, i, 1)
attn.narrow(
1,
max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
attention_mask_lst[i].size(0),
).copy_(
attention_mask_lst[i]
) # left padding
if prompt_size > 0:
attn.narrow(
1,
max_attention_mask_len - prompt_size,
prompt_size,
).fill_(1)
del_all(attention_mask_lst)
text_mask = attention_mask.bool()
new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
del input_ids
if prompt_size > 0:
text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
new_input_ids.narrow(
1,
max_input_ids_len - prompt_size,
prompt_size,
).copy_(prompt_t)
del prompt_t
return new_input_ids, attention_mask, text_mask
@torch.inference_mode
def decode(
self,
sequences: Union[List[int], List[List[int]]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
):
return self._tokenizer.batch_decode(
sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
)
================================================
FILE: ChatTTS/model/velocity/__init__.py
================================================
from .llm import LLM
from .sampling_params import SamplingParams
================================================
FILE: ChatTTS/model/velocity/block_manager.py
================================================
"""A block manager that manages token blocks."""
import enum
from typing import Dict, List, Optional, Set, Tuple
from vllm.block import PhysicalTokenBlock
from .sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockAllocator:
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(
device=device, block_number=i, block_size=block_size
)
self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager:
"""Manages the mapping between logical and physical token blocks."""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window, block_size)
self.block_sliding_window = sliding_window // block_size
self.watermark = watermark
assert watermark >= 0.0
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks:
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
for logical_idx in range(len(seq.logical_token_blocks)):
if (
self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window
):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
# Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
if len(block_table) < len(logical_blocks):
if (
self.block_sliding_window
and len(block_table) >= self.block_sliding_window
):
# reuse a block
block_table.append(
block_table[len(block_table) % self.block_sliding_window]
)
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
# We want to append the token to the last physical block.
last_block = block_table[-1]
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
return None
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
new_block = self.gpu_allocator.allocate()
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
for block in src_block_table:
block.ref_count += 1
def _get_physical_blocks(
self, seq_group: SequenceGroup
) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
blocks.update(self.block_tables[seq.seq_id])
return list(blocks)
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for cpu_block in block_table:
if cpu_block in mapping:
gpu_block = mapping[cpu_block]
gpu_block.ref_count += 1
else:
gpu_block = self.gpu_allocator.allocate()
mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU.
self.cpu_allocator.free(cpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
cpu_block.block_number: gpu_block.block_number
for cpu_block, gpu_block in mapping.items()
}
return block_number_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table:
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1
else:
cpu_block = self.cpu_allocator.allocate()
mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU.
self.gpu_allocator.free(gpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
gpu_block.block_number: cpu_block.block_number
for gpu_block, cpu_block in mapping.items()
}
return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None:
for block in set(block_table):
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:
self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
def reset(self) -> None:
for block_table in self.block_tables.values():
self._free_block_table(block_table)
self.block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()
================================================
FILE: ChatTTS/model/velocity/configs.py
================================================
from typing import Optional, Union, Tuple
import os
import torch
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip
import argparse
import dataclasses
from dataclasses import dataclass
logger = init_logger(__name__)
_GB = 1 << 30
class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str],
load_format: str,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
num_audio_tokens: int = 1024,
num_text_tokens: int = 80,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.load_format = load_format
self.seed = seed
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.num_audio_tokens = num_audio_tokens
self.num_text_tokens = num_text_tokens
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import (
snapshot_download,
) # pylint: disable=C
model_path = snapshot_download(
model_id=model, cache_dir=download_dir, revision=revision
)
self.model = model_path
self.download_dir = model_path
self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"]
rocm_not_supported_load_format = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
)
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}"
)
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'."
)
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({hf_quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm."
)
logger.warning(
f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
)
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(
self.max_context_len_to_capture, self.max_model_len
)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size})."
)
total_num_hidden_layers = self.hf_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size})."
)
def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
class CacheConfig:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.sliding_window = sliding_window
self._verify_args()
# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}."
)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = (
f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
"allocated for the swap space."
)
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError("Pipeline parallelism is not supported yet.")
class SchedulerConfig:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args()
def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs})."
)
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k
for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(
f"dtype '{dtype}' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}"
)
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len
default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}."
)
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling["original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size."
)
return int(max_model_len)
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = "auto"
dtype: str = "auto"
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
num_audio_tokens: int = 1024
num_text_tokens: int = 80
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# NOTE: If you update any of the arguments below, please also
# make sure to update docs/source/models/engine_args.rst
# Model arguments
parser.add_argument(
"--model",
type=str,
default="facebook/opt-125m",
help="name or path of the huggingface model to use",
)
parser.add_argument(
"--tokenizer",
type=str,
default=EngineArgs.tokenizer,
help="name or path of the huggingface tokenizer to use",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="the specific model version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
parser.add_argument(
"--tokenizer-revision",
type=str,
default=None,
help="the specific tokenizer version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=EngineArgs.tokenizer_mode,
choices=["auto", "slow"],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
"always use the slow tokenizer.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="trust remote code from huggingface",
)
parser.add_argument(
"--download-dir",
type=str,
default=EngineArgs.download_dir,
help="directory to download and load the weights, "
"default to the default cache dir of "
"huggingface",
)
parser.add_argument(
"--load-format",
type=str,
default=EngineArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
"is not available. "
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling.",
)
parser.add_argument(
"--dtype",
type=str,
default=EngineArgs.dtype,
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="data type for model weights and activations. "
'The "auto" option will use FP16 precision '
"for FP32 and FP16 models, and BF16 precision "
"for BF16 models.",
)
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="model context length. If unspecified, "
"will be automatically derived from the model.",
)
# Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be "
"automatically set when using more than 1 GPU",
)
parser.add_argument(
"--pipeline-parallel-size",
"-pp",
type=int,
default=EngineArgs.pipeline_parallel_size,
help="number of pipeline stages",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=EngineArgs.tensor_parallel_size,
help="number of tensor parallel replicas",
)
parser.add_argument(
"--max-parallel-loading-workers",
type=int,
help="load model sequentially in multiple batches, "
"to avoid RAM OOM when using tensor "
"parallel and large models",
)
# KV cache arguments
parser.add_argument(
"--block-size",
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help="token block size",
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument(
"--seed", type=int, default=EngineArgs.seed, help="random seed"
)
parser.add_argument(
"--swap-space",
type=int,
default=EngineArgs.swap_space,
help="CPU swap space size (GiB) per GPU",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=EngineArgs.gpu_memory_utilization,
help="the fraction of GPU memory to be used for "
"the model executor, which can range from 0 to 1."
"If unspecified, will use the default value of 0.9.",
)
parser.add_argument(
"--max-num-batched-tokens",
type=int,
default=EngineArgs.max_num_batched_tokens,
help="maximum number of batched tokens per " "iteration",
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=EngineArgs.max_num_seqs,
help="maximum number of sequences per iteration",
)
parser.add_argument(
"--max-paddings",
type=int,
default=EngineArgs.max_paddings,
help="maximum number of paddings in a batch",
)
parser.add_argument(
"--disable-log-stats",
action="store_true",
help="disable logging statistics",
)
# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
parser.add_argument(
"--max-context-len-to-capture",
type=int,
default=EngineArgs.max_context_len_to_capture,
help="maximum context length covered by CUDA "
"graphs. When a sequence has context length "
"larger than this, we fall back to eager mode.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
model_config = ModelConfig(
self.model,
self.tokenizer,
self.tokenizer_mode,
self.trust_remote_code,
self.download_dir,
self.load_format,
self.dtype,
self.seed,
self.revision,
self.tokenizer_revision,
self.max_model_len,
self.quantization,
self.enforce_eager,
self.max_context_len_to_capture,
self.num_audio_tokens,
self.num_text_tokens,
)
cache_config = CacheConfig(
self.block_size,
self.gpu_memory_utilization,
self.swap_space,
model_config.get_sliding_window(),
)
parallel_config = ParallelConfig(
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings,
)
return model_config, cache_config, parallel_config, scheduler_config
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument(
"--engine-use-ray",
action="store_true",
help="use Ray to start the LLM engine in a "
"separate process as the server process.",
)
parser.add_argument(
"--disable-log-requests",
action="store_true",
help="disable logging requests",
)
parser.add_argument(
"--max-log-len",
type=int,
default=None,
help="max number of prompt characters or prompt "
"ID numbers being printed in log. "
"Default: unlimited.",
)
return parser
================================================
FILE: ChatTTS/model/velocity/llama.py
================================================
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
ParallelLMHead,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
linear_method=linear_method,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, linear_method=linear_method
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(
self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_emb: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = input_emb
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, sampling_metadata
)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
================================================
FILE: ChatTTS/model/velocity/llm.py
================================================
from typing import List, Optional, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.utils import Counter
from .configs import EngineArgs
from .llm_engine import LLMEngine
from .output import RequestOutput
from .sampling_params import SamplingParams
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq" and "squeezellm". If None, we first check
the `quantization_config` attribute in the model config file. If
that is None, we assume the model weights are not quantized and use
`dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
post_model_path: str = None,
num_audio_tokens: int = 0,
num_text_tokens: int = 0,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
num_audio_tokens=num_audio_tokens,
num_text_tokens=num_text_tokens,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path)
self.request_counter = Counter()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be " "provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (
prompts is not None
and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)
):
raise ValueError(
"The lengths of prompts and prompt_token_ids " "must be the same."
)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
rtns = self._run_engine(use_tqdm)
for i, rtn in enumerate(rtns):
token_ids = rtn.outputs[0].token_ids
for j, token_id in enumerate(token_ids):
if len(token_id) == 1:
token_ids[j] = token_id[0]
else:
token_ids[j] = list(token_id)
return rtns
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id, prompt, sampling_params, prompt_token_ids
)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
================================================
FILE: ChatTTS/model/velocity/llm_engine.py
================================================
import copy
from collections import defaultdict
import os
import time
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
from .scheduler import Scheduler, SchedulerOutputs
from .configs import EngineArgs
from vllm.engine.metrics import record_metrics
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from vllm.logger import init_logger
from .output import RequestOutput
from .sampling_params import SamplingParams
from .sequence import (
SamplerOutput,
Sequence,
SequenceGroup,
SequenceGroupOutput,
SequenceOutput,
SequenceStatus,
)
from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
import numpy as np
if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
"""
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
placement_group: Optional["PlacementGroup"],
post_model_path: str,
log_stats: bool,
) -> None:
logger.info(
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"seed={model_config.seed}), "
f"post_model_path={post_model_path!r}"
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.post_model_path = post_model_path
self.seq_counter = Counter()
# Create the parallel GPU workers.
if self.parallel_config.worker_use_ray:
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
self._init_workers_ray(placement_group)
else:
self._init_workers()
# Profile the memory usage and initialize the cache.
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker
assert (
self.parallel_config.world_size == 1
), "Ray is required if parallel_config.world_size > 1."
self.workers: List[Worker] = []
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
post_model_path=self.post_model_path,
)
self._run_workers("init_model")
self._run_workers("load_model")
def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
self.driver_dummy_worker: RayWorkerVllm = None
self.workers: List[RayWorkerVllm] = []
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote()
)
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers]
)
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
# Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids), start=1
):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config,
parallel_config,
scheduler_config,
local_rank,
rank,
distributed_init_method,
)
)
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config,
parallel_config,
scheduler_config,
driver_local_rank,
driver_rank,
distributed_init_method,
is_driver_worker=True,
)
self._run_workers("init_model")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(
f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}"
)
if num_gpu_blocks <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len = self.cache_config.block_size * num_gpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
@classmethod
def from_engine_args(
cls, engine_args: EngineArgs, post_model_path=None
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(
*engine_configs,
placement_group,
log_stats=not engine_args.disable_log_stats,
post_model_path=post_model_path,
)
return engine
def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
"""
if arrival_time is None:
arrival_time = time.monotonic()
assert prompt_token_ids is not None, "prompt_token_ids must be provided"
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Args:
request_id: The ID(s) of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() + sampling_params.max_tokens,
self.scheduler_config.max_model_len,
)
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length,
)
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
)
return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(
self, seq_group: SequenceGroup, outputs: SequenceGroupOutput
) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(
child_sample.output_token,
child_sample.logprobs,
child_sample.hidden_states,
child_sample.finished,
)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(
last_child_sample.output_token,
last_child_sample.logprobs,
last_child_sample.hidden_states,
last_child_sample.finished,
)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
# self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
new_finished_seqs = [
(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()
]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(
key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
),
reverse=True,
)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [
(seq, parent) for seq, parent in child_seqs if not seq.is_finished()
]
# Sort the running sequences by their scores.
running_child_seqs.sort(
key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
),
reverse=True,
)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params,
best_running_seq,
current_worst_seq,
)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs
) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(
scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens
)
return request_outputs
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
},
)
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []
return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.monotonic()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
if not should_log:
return
# Discard the old stats.
self.num_prompt_tokens = [
(t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC
]
self.num_generation_tokens = [
(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
record_metrics(
avg_prompt_throughput=avg_prompt_throughput,
avg_generation_throughput=avg_generation_throughput,
scheduler_running=len(self.scheduler.running),
scheduler_swapped=len(self.scheduler.swapped),
scheduler_waiting=len(self.scheduler.waiting),
gpu_cache_usage=gpu_cache_usage,
cpu_cache_usage=cpu_cache_usage,
)
logger.info(
"Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%"
)
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, read_offset) = (
detokenize_incrementally(
self.tokenizer,
all_input_ids=seq.get_token_i
gitextract_f8g4d5at/
├── .gitattributes
├── .github/
│ └── workflows/
│ ├── checksum.yml
│ ├── close-issue.yml
│ ├── pull-format.yml
│ ├── push-format.yml
│ ├── unitest.yml
│ └── upload-pypi.yml
├── .gitignore
├── ChatTTS/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ └── config.py
│ ├── core.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── cuda/
│ │ │ ├── __init__.py
│ │ │ ├── patch.py
│ │ │ └── te_llama.py
│ │ ├── dvae.py
│ │ ├── embed.py
│ │ ├── gpt.py
│ │ ├── processors.py
│ │ ├── speaker.py
│ │ ├── tokenizer.py
│ │ └── velocity/
│ │ ├── __init__.py
│ │ ├── block_manager.py
│ │ ├── configs.py
│ │ ├── llama.py
│ │ ├── llm.py
│ │ ├── llm_engine.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── output.py
│ │ ├── sampler.py
│ │ ├── sampling_params.py
│ │ ├── scheduler.py
│ │ ├── sequence.py
│ │ └── worker.py
│ ├── norm.py
│ ├── res/
│ │ ├── __init__.py
│ │ ├── homophones_map.json
│ │ └── sha256_map.json
│ └── utils/
│ ├── __init__.py
│ ├── dl.py
│ ├── gpu.py
│ ├── io.py
│ └── log.py
├── LICENSE
├── README.md
├── docs/
│ ├── cn/
│ │ └── README.md
│ ├── es/
│ │ └── README.md
│ ├── fr/
│ │ └── README.md
│ ├── jp/
│ │ └── README.md
│ ├── kr/
│ │ └── README.md
│ └── ru/
│ └── README.md
├── examples/
│ ├── __init__.py
│ ├── api/
│ │ ├── README.md
│ │ ├── client.py
│ │ ├── main.py
│ │ ├── openai_api.py
│ │ ├── postScript.py
│ │ └── requirements.txt
│ ├── cmd/
│ │ ├── run.py
│ │ └── stream.py
│ ├── ipynb/
│ │ ├── colab.ipynb
│ │ └── example.ipynb
│ ├── onnx/
│ │ ├── README.md
│ │ ├── exporter.py
│ │ ├── gpt.py
│ │ └── modeling_llama.py
│ └── web/
│ ├── __init__.py
│ ├── ex.py
│ ├── funcs.py
│ └── webui.py
├── openai_api.ipynb
├── requirements.txt
├── setup.py
├── tests/
│ ├── #511.py
│ ├── #588.py
│ ├── #655.py
│ └── testall.sh
└── tools/
├── __init__.py
├── audio/
│ ├── __init__.py
│ ├── av.py
│ ├── ffmpeg.py
│ ├── np.py
│ └── pcm.py
├── checksum/
│ ├── main.go
│ └── tmpl.go
├── llm/
│ ├── __init__.py
│ └── llm.py
├── logger/
│ ├── __init__.py
│ └── log.py
├── normalizer/
│ ├── __init__.py
│ ├── en.py
│ └── zh.py
└── seeder/
├── __init__.py
└── ctx.py
SYMBOL INDEX (491 symbols across 50 files)
FILE: ChatTTS/config/config.py
class Path (line 5) | class Path:
class Decoder (line 15) | class Decoder:
class VQ (line 24) | class VQ:
class DVAE (line 32) | class DVAE:
class GPT (line 51) | class GPT:
class Embed (line 67) | class Embed:
class FeatureExtractorInitArgs (line 75) | class FeatureExtractorInitArgs:
class FeatureExtractor (line 84) | class FeatureExtractor:
class BackboneInitArgs (line 90) | class BackboneInitArgs:
class Backbone (line 98) | class Backbone:
class FourierHeadInitArgs (line 104) | class FourierHeadInitArgs:
class FourierHead (line 112) | class FourierHead:
class Vocos (line 118) | class Vocos:
class Config (line 125) | class Config:
FILE: ChatTTS/core.py
class Chat (line 32) | class Chat:
method __init__ (line 33) | def __init__(self, logger=logging.getLogger(__name__)):
method has_loaded (line 50) | def has_loaded(self, use_decoder=False):
method download_models (line 66) | def download_models(
method load (line 137) | def load(
method unload (line 167) | def unload(self):
method sample_random_speaker (line 178) | def sample_random_speaker(self) -> str:
method sample_audio_speaker (line 181) | def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -...
class RefineTextParams (line 185) | class RefineTextParams:
class InferCodeParams (line 198) | class InferCodeParams(RefineTextParams):
method infer (line 210) | def infer(
method interrupt (line 274) | def interrupt(self):
method _load (line 278) | def _load(
method _infer (line 390) | def _infer(
method _vocos_decode (line 511) | def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
method _decode_to_wavs (line 518) | def _decode_to_wavs(
method _infer_code (line 547) | def _infer_code(
method _refine_text (line 670) | def _refine_text(
FILE: ChatTTS/model/cuda/patch.py
class LlamaRMSNorm (line 4) | class LlamaRMSNorm(torch.nn.Module):
method __init__ (line 5) | def __init__(self, hidden_size, eps=1e-6):
method forward (line 13) | def forward(self, hidden_states: torch.Tensor):
FILE: ChatTTS/model/cuda/te_llama.py
function replace_decoder (line 29) | def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
class TELlamaDecoderLayer (line 50) | class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
method __init__ (line 61) | def __init__(self, config, *args, **kwargs):
method forward (line 81) | def forward(self, hidden_states, *args, attention_mask, **kwargs):
class TELlamaModel (line 96) | class TELlamaModel:
method __new__ (line 106) | def __new__(cls, config: LlamaConfig):
method from_state_dict (line 114) | def from_state_dict(
function _replace_params (line 134) | def _replace_params(hf_state_dict, te_state_dict, config):
FILE: ChatTTS/model/dvae.py
class ConvNeXtBlock (line 14) | class ConvNeXtBlock(nn.Module):
method __init__ (line 15) | def __init__(
method forward (line 46) | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
class GFSQ (line 69) | class GFSQ(nn.Module):
method __init__ (line 71) | def __init__(
method _embed (line 87) | def _embed(self, x: torch.Tensor):
method __call__ (line 99) | def __call__(self, x: torch.Tensor) -> torch.Tensor:
method forward (line 102) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class DVAEDecoder (line 131) | class DVAEDecoder(nn.Module):
method __init__ (line 132) | def __init__(
method forward (line 163) | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
class MelSpectrogramFeatures (line 175) | class MelSpectrogramFeatures(torch.nn.Module):
method __init__ (line 176) | def __init__(
method __call__ (line 199) | def __call__(self, audio: torch.Tensor) -> torch.Tensor:
method forward (line 202) | def forward(self, audio: torch.Tensor) -> torch.Tensor:
class DVAE (line 209) | class DVAE(nn.Module):
method __init__ (line 210) | def __init__(
method __repr__ (line 245) | def __repr__(self) -> str:
method __call__ (line 250) | def __call__(
method load_pretrained (line 256) | def load_pretrained(self, filename: str, device: torch.device):
method forward (line 262) | def forward(
method sample_audio (line 300) | def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch....
FILE: ChatTTS/model/embed.py
class Embed (line 8) | class Embed(nn.Module):
method __init__ (line 9) | def __init__(
method load_pretrained (line 38) | def load_pretrained(self, filename: str, device: torch.device):
method __call__ (line 43) | def __call__(
method forward (line 52) | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) ->...
FILE: ChatTTS/model/gpt.py
class GPT (line 21) | class GPT(nn.Module):
method __init__ (line 22) | def __init__(
method load_pretrained (line 62) | def load_pretrained(
class Context (line 106) | class Context:
method __init__ (line 107) | def __init__(self):
method set (line 110) | def set(self, v: bool):
method get (line 113) | def get(self) -> bool:
method _build_llama_config (line 116) | def _build_llama_config(
method prepare (line 134) | def prepare(self, compile=False):
class _GenerationInputs (line 145) | class _GenerationInputs:
method to (line 153) | def to(self, device: torch.device, dtype: torch.dtype):
method _prepare_generation_inputs (line 164) | def _prepare_generation_inputs(
class GenerationOutputs (line 280) | class GenerationOutputs:
method destroy (line 285) | def destroy(self):
method _prepare_generation_outputs (line 291) | def _prepare_generation_outputs(
method generate (line 319) | def generate(
FILE: ChatTTS/model/processors.py
class CustomRepetitionPenaltyLogitsProcessorRepeat (line 6) | class CustomRepetitionPenaltyLogitsProcessorRepeat:
method __init__ (line 8) | def __init__(self, penalty: float, max_input_ids: int, past_window: int):
method __call__ (line 18) | def __call__(
function gen_logits (line 38) | def gen_logits(
FILE: ChatTTS/model/speaker.py
class Speaker (line 10) | class Speaker:
method __init__ (line 11) | def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu"))...
method sample_random (line 18) | def sample_random(self) -> str:
method apply (line 22) | def apply(
method decorate_code_prompts (line 56) | def decorate_code_prompts(
method decorate_text_prompts (line 86) | def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
method encode_prompt (line 91) | def encode_prompt(prompt: torch.Tensor) -> str:
method decode_prompt (line 108) | def decode_prompt(prompt: str) -> torch.Tensor:
method _sample_random (line 123) | def _sample_random(self) -> torch.Tensor:
method _encode (line 133) | def _encode(spk_emb: torch.Tensor) -> str:
method _decode (line 146) | def _decode(spk_emb: str) -> np.ndarray:
FILE: ChatTTS/model/tokenizer.py
class Tokenizer (line 16) | class Tokenizer:
method __init__ (line 17) | def __init__(
method encode (line 36) | def encode(
method decode (line 129) | def decode(
FILE: ChatTTS/model/velocity/block_manager.py
class BlockAllocator (line 14) | class BlockAllocator:
method __init__ (line 22) | def __init__(
method allocate (line 40) | def allocate(self) -> PhysicalTokenBlock:
method free (line 47) | def free(self, block: PhysicalTokenBlock) -> None:
method get_num_free_blocks (line 54) | def get_num_free_blocks(self) -> int:
class AllocStatus (line 58) | class AllocStatus(enum.Enum):
class BlockSpaceManager (line 73) | class BlockSpaceManager:
method __init__ (line 76) | def __init__(
method can_allocate (line 102) | def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
method allocate (line 119) | def allocate(self, seq_group: SequenceGroup) -> None:
method can_append_slot (line 142) | def can_append_slot(self, seq_group: SequenceGroup) -> bool:
method append_slot (line 149) | def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
method fork (line 184) | def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
method _get_physical_blocks (line 192) | def _get_physical_blocks(
method can_swap_in (line 204) | def can_swap_in(self, seq_group: SequenceGroup) -> bool:
method swap_in (line 214) | def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
method can_swap_out (line 239) | def can_swap_out(self, seq_group: SequenceGroup) -> bool:
method swap_out (line 243) | def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
method _free_block_table (line 268) | def _free_block_table(self, block_table: BlockTable) -> None:
method free (line 275) | def free(self, seq: Sequence) -> None:
method reset (line 283) | def reset(self) -> None:
method get_block_table (line 288) | def get_block_table(self, seq: Sequence) -> List[int]:
method get_num_free_gpu_blocks (line 292) | def get_num_free_gpu_blocks(self) -> int:
method get_num_free_cpu_blocks (line 295) | def get_num_free_cpu_blocks(self) -> int:
FILE: ChatTTS/model/velocity/configs.py
class ModelConfig (line 21) | class ModelConfig:
method __init__ (line 65) | def __init__(
method _verify_load_format (line 121) | def _verify_load_format(self) -> None:
method _verify_tokenizer_mode (line 151) | def _verify_tokenizer_mode(self) -> None:
method _verify_quantization (line 160) | def _verify_quantization(self) -> None:
method _verify_cuda_graph (line 197) | def _verify_cuda_graph(self) -> None:
method verify_with_parallel_config (line 204) | def verify_with_parallel_config(
method get_sliding_window (line 226) | def get_sliding_window(self) -> Optional[int]:
method get_vocab_size (line 229) | def get_vocab_size(self) -> int:
method get_hidden_size (line 232) | def get_hidden_size(self) -> int:
method get_head_size (line 235) | def get_head_size(self) -> int:
method get_total_num_kv_heads (line 239) | def get_total_num_kv_heads(self) -> int:
method get_num_kv_heads (line 275) | def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
method get_num_layers (line 284) | def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
class CacheConfig (line 289) | class CacheConfig:
method __init__ (line 299) | def __init__(
method _verify_args (line 316) | def _verify_args(self) -> None:
method verify_with_parallel_config (line 323) | def verify_with_parallel_config(
class ParallelConfig (line 344) | class ParallelConfig:
method __init__ (line 355) | def __init__(
method _verify_args (line 372) | def _verify_args(self) -> None:
class SchedulerConfig (line 377) | class SchedulerConfig:
method __init__ (line 390) | def __init__(
method _verify_args (line 408) | def _verify_args(self) -> None:
function _get_and_verify_dtype (line 437) | def _get_and_verify_dtype(
function _get_and_verify_max_len (line 491) | def _get_and_verify_max_len(
class EngineArgs (line 551) | class EngineArgs:
method __post_init__ (line 582) | def __post_init__(self):
method add_cli_args (line 587) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...
method from_cli_args (line 786) | def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
method create_engine_configs (line 793) | def create_engine_configs(
class AsyncEngineArgs (line 836) | class AsyncEngineArgs(EngineArgs):
method add_cli_args (line 844) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...
FILE: ChatTTS/model/velocity/llama.py
class LlamaMLP (line 59) | class LlamaMLP(nn.Module):
method __init__ (line 61) | def __init__(
method forward (line 85) | def forward(self, x):
class LlamaAttention (line 92) | class LlamaAttention(nn.Module):
method __init__ (line 94) | def __init__(
method forward (line 153) | def forward(
class LlamaDecoderLayer (line 169) | class LlamaDecoderLayer(nn.Module):
method __init__ (line 171) | def __init__(
method forward (line 201) | def forward(
class LlamaModel (line 228) | class LlamaModel(nn.Module):
method __init__ (line 230) | def __init__(
method forward (line 251) | def forward(
method load_weights (line 272) | def load_weights(
class LlamaForCausalLM (line 317) | class LlamaForCausalLM(nn.Module):
method __init__ (line 319) | def __init__(
method forward (line 331) | def forward(
method sample (line 341) | def sample(
method load_weights (line 351) | def load_weights(
FILE: ChatTTS/model/velocity/llm.py
class LLM (line 13) | class LLM:
method __init__ (line 68) | def __init__(
method get_tokenizer (line 113) | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokeni...
method set_tokenizer (line 116) | def set_tokenizer(
method generate (line 122) | def generate(
method _add_request (line 182) | def _add_request(
method _run_engine (line 193) | def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
FILE: ChatTTS/model/velocity/llm_engine.py
class LLMEngine (line 38) | class LLMEngine:
method __init__ (line 65) | def __init__(
method _init_workers (line 127) | def _init_workers(self):
method _init_workers_ray (line 151) | def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_r...
method _verify_args (line 258) | def _verify_args(self) -> None:
method _init_cache (line 262) | def _init_cache(self) -> None:
method from_engine_args (line 308) | def from_engine_args(
method add_request (line 326) | def add_request(
method abort_request (line 365) | def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
method get_model_config (line 373) | def get_model_config(self) -> ModelConfig:
method get_num_unfinished_requests (line 377) | def get_num_unfinished_requests(self) -> int:
method has_unfinished_requests (line 381) | def has_unfinished_requests(self) -> bool:
method _check_beam_search_early_stopping (line 385) | def _check_beam_search_early_stopping(
method _process_sequence_group_outputs (line 429) | def _process_sequence_group_outputs(
method _process_model_outputs (line 613) | def _process_model_outputs(
method step (line 637) | def step(self) -> List[RequestOutput]:
method _log_system_stats (line 667) | def _log_system_stats(
method _decode_sequence (line 742) | def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
method _check_stop (line 763) | def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) ...
method _run_workers (line 799) | def _run_workers(
FILE: ChatTTS/model/velocity/model_loader.py
function _set_default_torch_dtype (line 16) | def _set_default_torch_dtype(dtype: torch.dtype):
function get_model (line 24) | def get_model(model_config: ModelConfig) -> nn.Module:
FILE: ChatTTS/model/velocity/model_runner.py
class ModelRunner (line 38) | class ModelRunner:
method __init__ (line 40) | def __init__(
method load_model (line 80) | def load_model(self) -> None:
method set_block_size (line 95) | def set_block_size(self, block_size: int) -> None:
method _prepare_prompt (line 105) | def _prepare_prompt(
method _prepare_decode (line 179) | def _prepare_decode(
method _prepare_sample (line 279) | def _prepare_sample(
method prepare_input_tensors (line 353) | def prepare_input_tensors(
method execute_model (line 460) | def execute_model(
method profile_run (line 599) | def profile_run(self) -> None:
method capture_model (line 633) | def capture_model(self, kv_caches: List[KVCache]) -> None:
class CUDAGraphRunner (line 692) | class CUDAGraphRunner:
method __init__ (line 694) | def __init__(self, model: nn.Module):
method capture (line 700) | def capture(
method forward (line 743) | def forward(
method __call__ (line 772) | def __call__(self, *args, **kwargs):
function _pad_to_max (line 776) | def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
function _make_tensor_with_pad (line 783) | def _make_tensor_with_pad(
function _get_graph_batch_size (line 806) | def _get_graph_batch_size(batch_size: int) -> int:
function _async_h2d (line 815) | def _async_h2d(data: list, dtype, pin_memory):
FILE: ChatTTS/model/velocity/output.py
class CompletionOutput (line 12) | class CompletionOutput:
method __init__ (line 26) | def __init__(
method finished (line 44) | def finished(self) -> bool:
method __repr__ (line 47) | def __repr__(self) -> str:
class RequestOutput (line 59) | class RequestOutput:
method __init__ (line 71) | def __init__(
method from_seq_group (line 88) | def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
method __repr__ (line 136) | def __repr__(self) -> str:
FILE: ChatTTS/model/velocity/sampler.py
class Sampler (line 8) | class Sampler:
method __init__ (line 9) | def __init__(self, post_model: Embed, num_audio_tokens: int, num_vq: i...
method sample (line 15) | def sample(
FILE: ChatTTS/model/velocity/sampling_params.py
class SamplingType (line 12) | class SamplingType(IntEnum):
class SamplingParams (line 24) | class SamplingParams:
method __init__ (line 94) | def __init__(
method _verify_args (line 180) | def _verify_args(self) -> None:
method _verify_beam_search (line 222) | def _verify_beam_search(self) -> None:
method _verify_non_beam_search (line 240) | def _verify_non_beam_search(self) -> None:
method _verify_greedy_sampling (line 255) | def _verify_greedy_sampling(self) -> None:
method sampling_type (line 262) | def sampling_type(self) -> SamplingType:
method __repr__ (line 269) | def __repr__(self) -> str:
FILE: ChatTTS/model/velocity/scheduler.py
class PreemptionMode (line 20) | class PreemptionMode(enum.Enum):
class SchedulerOutputs (line 34) | class SchedulerOutputs:
method __init__ (line 36) | def __init__(
method is_empty (line 56) | def is_empty(self) -> bool:
class Scheduler (line 66) | class Scheduler:
method __init__ (line 68) | def __init__(
method add_seq_group (line 99) | def add_seq_group(self, seq_group: SequenceGroup) -> None:
method abort_seq_group (line 103) | def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
method has_unfinished_seqs (line 124) | def has_unfinished_seqs(self) -> bool:
method get_num_unfinished_seq_groups (line 127) | def get_num_unfinished_seq_groups(self) -> int:
method _schedule (line 130) | def _schedule(self) -> SchedulerOutputs:
method schedule (line 295) | def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutp...
method fork_seq (line 321) | def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
method free_seq (line 324) | def free_seq(self, seq: Sequence) -> None:
method free_finished_seq_groups (line 327) | def free_finished_seq_groups(self) -> None:
method _allocate (line 332) | def _allocate(self, seq_group: SequenceGroup) -> None:
method _append_slot (line 337) | def _append_slot(
method _preempt (line 351) | def _preempt(
method _preempt_by_recompute (line 380) | def _preempt_by_recompute(
method _preempt_by_swap (line 393) | def _preempt_by_swap(
method _swap_in (line 401) | def _swap_in(
method _swap_out (line 411) | def _swap_out(
FILE: ChatTTS/model/velocity/sequence.py
class SequenceStatus (line 14) | class SequenceStatus(enum.Enum):
method is_finished (line 26) | def is_finished(status: "SequenceStatus") -> bool:
method get_finished_reason (line 35) | def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
class SequenceData (line 52) | class SequenceData:
method __init__ (line 65) | def __init__(
method append_token_id (line 75) | def append_token_id(self, token_id: int, logprob: float) -> None:
method append_hidden_states (line 84) | def append_hidden_states(self, hidden_states: torch.Tensor) -> None:
method get_len (line 90) | def get_len(self) -> int:
method get_prompt_len (line 93) | def get_prompt_len(self) -> int:
method get_output_len (line 96) | def get_output_len(self) -> int:
method get_token_ids (line 99) | def get_token_ids(self) -> List[int]:
method get_last_token_id (line 102) | def get_last_token_id(self) -> int:
method __repr__ (line 107) | def __repr__(self) -> str:
class Sequence (line 118) | class Sequence:
method __init__ (line 129) | def __init__(
method _append_logical_block (line 155) | def _append_logical_block(self) -> None:
method _append_tokens_to_blocks (line 162) | def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
method append_token_id (line 177) | def append_token_id(
method get_len (line 191) | def get_len(self) -> int:
method get_prompt_len (line 194) | def get_prompt_len(self) -> int:
method get_output_len (line 197) | def get_output_len(self) -> int:
method get_token_ids (line 200) | def get_token_ids(self) -> List[int]:
method get_last_token_id (line 203) | def get_last_token_id(self) -> int:
method get_output_token_ids (line 206) | def get_output_token_ids(self) -> List[int]:
method get_cumulative_logprob (line 209) | def get_cumulative_logprob(self) -> float:
method get_beam_search_score (line 212) | def get_beam_search_score(
method is_finished (line 232) | def is_finished(self) -> bool:
method fork (line 235) | def fork(self, new_seq_id: int) -> "Sequence":
method __repr__ (line 240) | def __repr__(self) -> str:
class SequenceGroup (line 248) | class SequenceGroup:
method __init__ (line 258) | def __init__(
method prompt (line 272) | def prompt(self) -> str:
method prompt_token_ids (line 278) | def prompt_token_ids(self) -> List[int]:
method get_max_num_running_seqs (line 283) | def get_max_num_running_seqs(self) -> int:
method get_seqs (line 300) | def get_seqs(
method get_unfinished_seqs (line 309) | def get_unfinished_seqs(self) -> List[Sequence]:
method get_finished_seqs (line 312) | def get_finished_seqs(self) -> List[Sequence]:
method num_seqs (line 315) | def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
method num_unfinished_seqs (line 318) | def num_unfinished_seqs(self) -> int:
method num_finished_seqs (line 321) | def num_finished_seqs(self) -> int:
method find (line 324) | def find(self, seq_id: int) -> Sequence:
method add (line 329) | def add(self, seq: Sequence) -> None:
method remove (line 334) | def remove(self, seq_id: int) -> None:
method is_finished (line 339) | def is_finished(self) -> bool:
method __repr__ (line 342) | def __repr__(self) -> str:
class SequenceGroupMetadata (line 350) | class SequenceGroupMetadata:
method __init__ (line 363) | def __init__(
class SequenceOutput (line 378) | class SequenceOutput:
method __init__ (line 389) | def __init__(
method __repr__ (line 403) | def __repr__(self) -> str:
method __eq__ (line 412) | def __eq__(self, other: object) -> bool:
class SequenceGroupOutput (line 422) | class SequenceGroupOutput:
method __init__ (line 425) | def __init__(
method __repr__ (line 433) | def __repr__(self) -> str:
method __eq__ (line 439) | def __eq__(self, other: object) -> bool:
FILE: ChatTTS/model/velocity/worker.py
class Worker (line 19) | class Worker:
method __init__ (line 27) | def __init__(
method init_model (line 64) | def init_model(self) -> None:
method load_model (line 88) | def load_model(self):
method profile_num_available_blocks (line 92) | def profile_num_available_blocks(
method init_cache_engine (line 125) | def init_cache_engine(self, cache_config: CacheConfig) -> None:
method warm_up_model (line 134) | def warm_up_model(self) -> None:
method cache_swap (line 141) | def cache_swap(
method execute_model (line 168) | def execute_model(
function _init_distributed_environment (line 207) | def _init_distributed_environment(
function _check_if_gpu_supports_dtype (line 241) | def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
FILE: ChatTTS/norm.py
function _find_index (line 14) | def _find_index(table: np.ndarray, val: np.uint16):
function _fast_replace (line 22) | def _fast_replace(
function _split_tags (line 38) | def _split_tags(text: str) -> Tuple[List[str], List[str]]:
function _combine_tags (line 61) | def _combine_tags(texts: List[str], tags: List[str]) -> str:
class Normalizer (line 71) | class Normalizer:
method __init__ (line 72) | def __init__(self, map_file_path: str, logger=logging.getLogger(__name...
method __call__ (line 163) | def __call__(
method register (line 203) | def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
method unregister (line 218) | def unregister(self, name: str):
method destroy (line 222) | def destroy(self):
method _load_homophones_map (line 226) | def _load_homophones_map(self, map_file_path: str) -> np.ndarray:
method _count_invalid_characters (line 235) | def _count_invalid_characters(self, s: str):
method _apply_half2full_map (line 240) | def _apply_half2full_map(self, text: str) -> str:
method _apply_character_map (line 243) | def _apply_character_map(self, text: str) -> str:
method _detect_language (line 246) | def _detect_language(self, sentence: str) -> Literal["zh", "en"]:
FILE: ChatTTS/utils/dl.py
function sha256 (line 12) | def sha256(fileno: int) -> str:
function check_model (line 19) | def check_model(
function check_folder (line 46) | def check_folder(
function check_all_assets (line 66) | def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=...
function download_and_extract_tar_gz (line 114) | def download_and_extract_tar_gz(
function download_and_extract_zip (line 130) | def download_and_extract_zip(
function download_all_assets (line 146) | def download_all_assets(tmpdir: str, homedir: str, version="0.2.11"):
FILE: ChatTTS/utils/gpu.py
function select_device (line 13) | def select_device(min_memory=2047, experimental=False):
function _is_torch_npu_available (line 59) | def _is_torch_npu_available():
FILE: ChatTTS/utils/io.py
function load_safetensors (line 20) | def load_safetensors(filename: str):
function get_latest_modified_file (line 28) | def get_latest_modified_file(directory):
function del_all (line 41) | def del_all(d: Union[dict, list]):
FILE: ChatTTS/utils/log.py
class Logger (line 5) | class Logger:
method __init__ (line 6) | def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)):
method set_logger (line 9) | def set_logger(self, logger: logging.Logger):
method get_logger (line 12) | def get_logger(self) -> logging.Logger:
FILE: examples/api/main.py
function startup_event (line 37) | async def startup_event():
function validation_exception_handler (line 53) | async def validation_exception_handler(request, exc: RequestValidationEr...
class ChatTTSParams (line 58) | class ChatTTSParams(BaseModel):
function generate_voice (line 72) | async def generate_voice(params: ChatTTSParams):
FILE: examples/api/openai_api.py
function startup_event (line 63) | async def startup_event():
class OpenAITTSRequest (line 106) | class OpenAITTSRequest(BaseModel):
method validate_request (line 129) | def validate_request(cls, request_data: Dict):
function custom_exception_handler (line 140) | async def custom_exception_handler(request, exc):
function generate_voice (line 150) | async def generate_voice(request_data: Dict):
function health_check (line 283) | async def health_check():
FILE: examples/api/postScript.py
function parse_arguments (line 15) | def parse_arguments():
function main (line 179) | def main():
FILE: examples/cmd/run.py
function save_mp3_file (line 24) | def save_mp3_file(wav, index):
function load_normalizer (line 32) | def load_normalizer(chat: ChatTTS.Chat):
function main (line 54) | def main(
FILE: examples/cmd/stream.py
class ChatStreamer (line 9) | class ChatStreamer:
method __init__ (line 10) | def __init__(self, base_block_size=8000):
method _update_stream (line 15) | def _update_stream(history_stream_wav, new_stream_wav, thre):
method _accum (line 33) | def _accum(accum_wavs, stream_wav):
method batch_stream_formatted (line 42) | def batch_stream_formatted(stream_wav, output_format="PCM16_byte"):
method formatted (line 51) | def formatted(data, output_format="PCM16_byte"):
method checkvoice (line 60) | def checkvoice(data):
method _subgen (line 68) | def _subgen(data, thre=12000):
method generate (line 74) | def generate(self, streamchat, output_format=None):
method play (line 148) | def play(self, streamchat, wait=5):
FILE: examples/onnx/exporter.py
function export_gpt (line 40) | def export_gpt():
function export_decoder (line 349) | def export_decoder():
function export_vocos (line 373) | def export_vocos():
FILE: examples/onnx/gpt.py
class GPT (line 11) | class GPT(nn.Module):
method __init__ (line 12) | def __init__(
method from_pretrained (line 75) | def from_pretrained(self, file_path: str):
method _build_llama (line 80) | def _build_llama(
FILE: examples/onnx/modeling_llama.py
function _make_causal_mask (line 55) | def _make_causal_mask(
function _expand_mask (line 86) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
class LlamaRMSNorm (line 102) | class LlamaRMSNorm(nn.Module):
method __init__ (line 103) | def __init__(self, hidden_size, eps=1e-6):
method forward (line 111) | def forward(self, hidden_states):
class LlamaRotaryEmbedding (line 119) | class LlamaRotaryEmbedding(torch.nn.Module):
method __init__ (line 120) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
method _set_cos_sin_cache (line 138) | def _set_cos_sin_cache(self, seq_len, device, dtype):
method forward (line 154) | def forward(self, x, seq_len=None):
class LlamaLinearScalingRotaryEmbedding (line 165) | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
method __init__ (line 168) | def __init__(
method _set_cos_sin_cache (line 179) | def _set_cos_sin_cache(self, seq_len, device, dtype):
class LlamaDynamicNTKScalingRotaryEmbedding (line 197) | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
method __init__ (line 200) | def __init__(
method _set_cos_sin_cache (line 211) | def _set_cos_sin_cache(self, seq_len, device, dtype):
function rotate_half (line 239) | def rotate_half(x):
function apply_rotary_pos_emb (line 246) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
class LlamaMLP (line 259) | class LlamaMLP(nn.Module):
method __init__ (line 260) | def __init__(self, config):
method forward (line 270) | def forward(self, x):
function repeat_kv (line 298) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention (line 312) | class LlamaAttention(nn.Module):
method __init__ (line 315) | def __init__(self, config: LlamaConfig):
method _init_rope (line 345) | def _init_rope(self):
method _shape (line 368) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
method forward (line 375) | def forward(
class LlamaDecoderLayer (line 508) | class LlamaDecoderLayer(nn.Module):
method __init__ (line 509) | def __init__(self, config: LlamaConfig):
method forward (line 519) | def forward(
class LlamaPreTrainedModel (line 597) | class LlamaPreTrainedModel(PreTrainedModel):
method _init_weights (line 604) | def _init_weights(self, module):
method _set_gradient_checkpointing (line 615) | def _set_gradient_checkpointing(self, module, value=False):
class LlamaModel (line 688) | class LlamaModel(LlamaPreTrainedModel):
method __init__ (line 696) | def __init__(self, config: LlamaConfig):
method get_input_embeddings (line 713) | def get_input_embeddings(self):
method set_input_embeddings (line 716) | def set_input_embeddings(self, value):
method _prepare_decoder_attention_mask (line 720) | def _prepare_decoder_attention_mask(
method forward (line 748) | def forward(
class LlamaForCausalLM (line 901) | class LlamaForCausalLM(LlamaPreTrainedModel):
method __init__ (line 904) | def __init__(self, config):
method get_input_embeddings (line 914) | def get_input_embeddings(self):
method set_input_embeddings (line 917) | def set_input_embeddings(self, value):
method get_output_embeddings (line 920) | def get_output_embeddings(self):
method set_output_embeddings (line 923) | def set_output_embeddings(self, new_embeddings):
method set_decoder (line 926) | def set_decoder(self, decoder):
method get_decoder (line 929) | def get_decoder(self):
method forward (line 936) | def forward(
method prepare_inputs_for_generation (line 1041) | def prepare_inputs_for_generation(
method _reorder_cache (line 1077) | def _reorder_cache(past_key_values, beam_idx):
class LlamaForSequenceClassification (line 1104) | class LlamaForSequenceClassification(LlamaPreTrainedModel):
method __init__ (line 1105) | def __init__(self, config):
method get_input_embeddings (line 1114) | def get_input_embeddings(self):
method set_input_embeddings (line 1117) | def set_input_embeddings(self, value):
method forward (line 1121) | def forward(
FILE: examples/web/funcs.py
function generate_seed (line 50) | def generate_seed():
function on_voice_change (line 55) | def on_voice_change(vocie_selection):
function on_audio_seed_change (line 59) | def on_audio_seed_change(audio_seed_input):
function load_chat (line 65) | def load_chat(cust_path: Optional[str], coef: Optional[str], enable_cach...
function reload_chat (line 97) | def reload_chat(coef: Optional[str]) -> str:
function on_upload_sample_audio (line 120) | def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str:
function _set_generate_buttons (line 129) | def _set_generate_buttons(generate_button, interrupt_button, is_reset=Fa...
function refine_text (line 135) | def refine_text(
function generate_audio (line 166) | def generate_audio(
function interrupt_generate (line 214) | def interrupt_generate():
function set_buttons_before_generate (line 221) | def set_buttons_before_generate(generate_button, interrupt_button):
function set_buttons_after_generate (line 233) | def set_buttons_after_generate(generate_button, interrupt_button, audio_...
FILE: examples/web/webui.py
function main (line 17) | def main():
FILE: tests/#588.py
function trim_tags (line 39) | def trim_tags(txt: str) -> str:
FILE: tools/audio/av.py
function wav2 (line 21) | def wav2(i: BytesIO, o: BufferedWriter, format: str):
function load_audio (line 43) | def load_audio(
FILE: tools/audio/ffmpeg.py
function has_ffmpeg_installed (line 4) | def has_ffmpeg_installed() -> bool:
FILE: tools/audio/np.py
function float_to_int16 (line 8) | def float_to_int16(audio: np.ndarray) -> np.ndarray:
FILE: tools/audio/pcm.py
function _pcm_to_wav_buffer (line 8) | def _pcm_to_wav_buffer(wav: np.ndarray, sample_rate: int = 24000) -> Byt...
function pcm_arr_to_mp3_view (line 35) | def pcm_arr_to_mp3_view(wav: np.ndarray, sample_rate: int = 24000) -> me...
function pcm_arr_to_ogg_view (line 54) | def pcm_arr_to_ogg_view(wav: np.ndarray, sample_rate: int = 24000) -> me...
function pcm_arr_to_wav_view (line 73) | def pcm_arr_to_wav_view(
FILE: tools/checksum/main.go
function main (line 11) | func main() {
FILE: tools/checksum/tmpl.go
constant jsontmpl (line 17) | jsontmpl = `{
FILE: tools/llm/llm.py
class ChatOpenAI (line 55) | class ChatOpenAI:
method __init__ (line 56) | def __init__(self, api_key, base_url, model):
method call (line 63) | def call(self, user_question, temperature=0.3, prompt_version="kimi", ...
FILE: tools/logger/log.py
class Formatter (line 37) | class Formatter(logging.Formatter):
method __init__ (line 38) | def __init__(self, color=platform.system().lower() != "windows"):
method format (line 43) | def format(self, record: logging.LogRecord):
function get_logger (line 58) | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_ro...
FILE: tools/normalizer/en.py
function normalizer_en_nemo_text (line 5) | def normalizer_en_nemo_text() -> Callable[[str], str]:
FILE: tools/normalizer/zh.py
function normalizer_zh_tn (line 4) | def normalizer_zh_tn() -> Callable[[str], str]:
FILE: tools/seeder/ctx.py
class TorchSeedContext (line 4) | class TorchSeedContext:
method __init__ (line 5) | def __init__(self, seed):
method __enter__ (line 9) | def __enter__(self):
method __exit__ (line 13) | def __exit__(self, type, value, traceback):
Condensed preview — 96 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,344K chars).
[
{
"path": ".gitattributes",
"chars": 94,
"preview": "# ignore jupyter notebooks in the language bar on github\n**/*.ipynb linguist-vendored\n*.ipynb\n"
},
{
"path": ".github/workflows/checksum.yml",
"chars": 1446,
"preview": "name: Calculate and Sync SHA256\non:\n workflow_dispatch:\n\njobs:\n checksum:\n runs-on: ubuntu-24.04\n steps:\n -"
},
{
"path": ".github/workflows/close-issue.yml",
"chars": 741,
"preview": "name: Close Inactive Issues\non:\n schedule:\n - cron: \"0 4 * * *\"\n\njobs:\n close-issues:\n runs-on: ubuntu-24.04\n "
},
{
"path": ".github/workflows/pull-format.yml",
"chars": 2730,
"preview": "name: Check Pull Request Format\n\non:\n pull_request_target:\n types: [opened, reopened, synchronize]\n\njobs:\n # This w"
},
{
"path": ".github/workflows/push-format.yml",
"chars": 1647,
"preview": "name: Standardize Code Format\n\non:\n push:\n branches:\n - main\n - dev\n\njobs:\n push-format:\n runs-on: ubu"
},
{
"path": ".github/workflows/unitest.yml",
"chars": 1046,
"preview": "name: Unit Test\non: [ push, pull_request ]\njobs:\n build:\n runs-on: ${{ matrix.os }}\n\n if: \"!contains(github.event"
},
{
"path": ".github/workflows/upload-pypi.yml",
"chars": 856,
"preview": "name: Upload to PyPI\n\non:\n push:\n tags:\n - 'v*'\n\njobs:\n build:\n runs-on: ubuntu-22.04\n\n steps:\n\n - "
},
{
"path": ".gitignore",
"chars": 3210,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n*.ckpt\n# C extensions\n*.so\n*.pt\n\n# Distributio"
},
{
"path": "ChatTTS/__init__.py",
"chars": 23,
"preview": "from .core import Chat\n"
},
{
"path": "ChatTTS/config/__init__.py",
"chars": 27,
"preview": "from .config import Config\n"
},
{
"path": "ChatTTS/config/config.py",
"chars": 4742,
"preview": "from dataclasses import dataclass\n\n\n@dataclass(repr=False, eq=False)\nclass Path:\n vocos_ckpt_path: str = \"asset/Vocos"
},
{
"path": "ChatTTS/core.py",
"chars": 25058,
"preview": "import os\nimport re\nimport logging\nimport tempfile\nfrom dataclasses import dataclass, asdict\nfrom typing import Literal,"
},
{
"path": "ChatTTS/model/__init__.py",
"chars": 166,
"preview": "from .dvae import DVAE\nfrom .embed import Embed\nfrom .gpt import GPT\nfrom .processors import gen_logits\nfrom .speaker im"
},
{
"path": "ChatTTS/model/cuda/__init__.py",
"chars": 35,
"preview": "from .te_llama import TELlamaModel\n"
},
{
"path": "ChatTTS/model/cuda/patch.py",
"chars": 686,
"preview": "import torch\n\n\nclass LlamaRMSNorm(torch.nn.Module):\n def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n L"
},
{
"path": "ChatTTS/model/cuda/te_llama.py",
"chars": 7445,
"preview": "# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# See LICENSE for license information"
},
{
"path": "ChatTTS/model/dvae.py",
"chars": 9067,
"preview": "import math\nfrom typing import List, Optional, Literal, Union\n\nimport numpy as np\nimport pybase16384 as b14\nimport torch"
},
{
"path": "ChatTTS/model/embed.py",
"chars": 2505,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn.utils.parametrizations import weight_norm\n\nfrom ..utils import load_saf"
},
{
"path": "ChatTTS/model/gpt.py",
"chars": 22803,
"preview": "import platform\nfrom dataclasses import dataclass\nimport logging\nfrom typing import Union, List, Optional, Tuple, Callab"
},
{
"path": "ChatTTS/model/processors.py",
"chars": 1939,
"preview": "import torch\nimport torch.nn.functional as F\nfrom transformers.generation import TopKLogitsWarper, TopPLogitsWarper\n\n\ncl"
},
{
"path": "ChatTTS/model/speaker.py",
"chars": 4726,
"preview": "import lzma\nfrom typing import List, Optional, Union\n\nimport pybase16384 as b14\nimport numpy as np\nimport torch\nimport t"
},
{
"path": "ChatTTS/model/tokenizer.py",
"chars": 4610,
"preview": "import os\n\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\"\"\"\nhttps://stackoverflow.com/questions/62691279/how-to-disabl"
},
{
"path": "ChatTTS/model/velocity/__init__.py",
"chars": 65,
"preview": "from .llm import LLM\nfrom .sampling_params import SamplingParams\n"
},
{
"path": "ChatTTS/model/velocity/block_manager.py",
"chars": 11886,
"preview": "\"\"\"A block manager that manages token blocks.\"\"\"\n\nimport enum\nfrom typing import Dict, List, Optional, Set, Tuple\n\nfrom "
},
{
"path": "ChatTTS/model/velocity/configs.py",
"chars": 33105,
"preview": "from typing import Optional, Union, Tuple\nimport os\n\nimport torch\nfrom transformers import PretrainedConfig\n\nfrom vllm.l"
},
{
"path": "ChatTTS/model/velocity/llama.py",
"chars": 14387,
"preview": "# coding=utf-8\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/m"
},
{
"path": "ChatTTS/model/velocity/llm.py",
"chars": 9477,
"preview": "from typing import List, Optional, Union\n\nfrom tqdm import tqdm\nfrom transformers import PreTrainedTokenizer, PreTrained"
},
{
"path": "ChatTTS/model/velocity/llm_engine.py",
"chars": 34955,
"preview": "import copy\nfrom collections import defaultdict\nimport os\nimport time\nfrom typing import TYPE_CHECKING, Any, Dict, Itera"
},
{
"path": "ChatTTS/model/velocity/model_loader.py",
"chars": 2613,
"preview": "\"\"\"Utilities for selecting and loading models.\"\"\"\n\nimport contextlib\n\nimport torch\nimport torch.nn as nn\n\nfrom vllm.conf"
},
{
"path": "ChatTTS/model/velocity/model_runner.py",
"chars": 32922,
"preview": "import time\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn"
},
{
"path": "ChatTTS/model/velocity/output.py",
"chars": 4938,
"preview": "from typing import List, Optional\nimport torch\n\nfrom .sequence import (\n PromptLogprobs,\n SampleLogprobs,\n Sequ"
},
{
"path": "ChatTTS/model/velocity/sampler.py",
"chars": 4062,
"preview": "import torch\nfrom torch.functional import F\nfrom typing import List, Callable\n\nfrom ..embed import Embed\n\n\nclass Sampler"
},
{
"path": "ChatTTS/model/velocity/sampling_params.py",
"chars": 13211,
"preview": "\"\"\"Sampling parameters for text generation.\"\"\"\n\nfrom enum import IntEnum\nfrom functools import cached_property\nfrom typi"
},
{
"path": "ChatTTS/model/velocity/scheduler.py",
"chars": 17746,
"preview": "import enum\nimport time\nfrom typing import Dict, Iterable, List, Optional, Tuple, Union\n\nfrom vllm.config import CacheCo"
},
{
"path": "ChatTTS/model/velocity/sequence.py",
"chars": 15322,
"preview": "\"\"\"Sequence and its related classes.\"\"\"\n\nimport copy\nimport enum\nfrom typing import Dict, List, Optional, Union\nimport t"
},
{
"path": "ChatTTS/model/velocity/worker.py",
"chars": 9483,
"preview": "\"\"\"A GPU worker class.\"\"\"\n\nimport os\nfrom typing import Dict, List, Optional, Tuple\n\nimport torch\nimport torch.distribut"
},
{
"path": "ChatTTS/norm.py",
"chars": 8754,
"preview": "import json\nimport logging\nimport re\nfrom typing import Dict, Tuple, List, Literal, Callable, Optional\nimport sys\n\nfrom "
},
{
"path": "ChatTTS/res/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ChatTTS/res/homophones_map.json",
"chars": 229952,
"preview": "{\n \"粡\": \"同\",\n \"為\": \"位\",\n \"瀹\": \"月\",\n \"滆\": \"格\",\n \"摲\": \"颤\",\n \"渹\": \"轰\",\n \"於\": \"鱼\",\n \"満\": \"满\",\n \"鍑"
},
{
"path": "ChatTTS/res/sha256_map.json",
"chars": 1002,
"preview": "{\n\t\"sha256_asset_Decoder_safetensors\": \"77aa55e0a977949c4733df3c6f876fa85860d3298cba63295a7bc6901729d4e0\",\n\t\"sha256_asse"
},
{
"path": "ChatTTS/utils/__init__.py",
"chars": 187,
"preview": "from .dl import check_all_assets, download_all_assets\nfrom .gpu import select_device\nfrom .io import load_safetensors, g"
},
{
"path": "ChatTTS/utils/dl.py",
"chars": 5182,
"preview": "import os\nfrom pathlib import Path\nimport hashlib\nimport requests\nfrom io import BytesIO\nfrom typing import Dict, Tuple,"
},
{
"path": "ChatTTS/utils/gpu.py",
"chars": 2286,
"preview": "import importlib.util\n\nimport torch\n\ntry:\n import torch_npu\nexcept ImportError:\n pass\n\nfrom .log import logger\n\n\nd"
},
{
"path": "ChatTTS/utils/io.py",
"chars": 1719,
"preview": "import os\nimport logging\nfrom typing import Union, IO\nfrom dataclasses import is_dataclass\n\nfrom safetensors import safe"
},
{
"path": "ChatTTS/utils/log.py",
"chars": 335,
"preview": "import logging\nfrom pathlib import Path\n\n\nclass Logger:\n def __init__(self, logger=logging.getLogger(Path(__file__).p"
},
{
"path": "LICENSE",
"chars": 34523,
"preview": " GNU AFFERO GENERAL PUBLIC LICENSE\n Version 3, 19 November 2007\n\n Copyright (C)"
},
{
"path": "README.md",
"chars": 11441,
"preview": "<div align=\"center\">\n\n<a href=\"https://trendshift.io/repositories/10489\" target=\"_blank\"><img src=\"https://trendshift.io"
},
{
"path": "docs/cn/README.md",
"chars": 7563,
"preview": "<div align=\"center\">\n\n<a href=\"https://trendshift.io/repositories/10489\" target=\"_blank\"><img src=\"https://trendshift.io"
},
{
"path": "docs/es/README.md",
"chars": 10023,
"preview": "<div align=\"center\">\n\n<a href=\"https://trendshift.io/repositories/10489\" target=\"_blank\"><img src=\"https://trendshift.io"
},
{
"path": "docs/fr/README.md",
"chars": 11701,
"preview": "<div align=\"center\">\n\n<a href=\"https://trendshift.io/repositories/10489\" target=\"_blank\"><img src=\"https://trendshift.io"
},
{
"path": "docs/jp/README.md",
"chars": 4953,
"preview": "# ChatTTS\n> [!NOTE]\n> 以下の内容は最新情報ではない可能性がありますのでご了承ください。全ての内容は英語版に基準することになります。\n\n[\n"
},
{
"path": "examples/cmd/stream.py",
"chars": 7454,
"preview": "import random\n\nimport numpy as np\n\nfrom tools.audio import float_to_int16\n\n\n# 流式推理数据获取器,支持流式获取音频编码字节流\nclass ChatStreamer"
},
{
"path": "examples/ipynb/colab.ipynb",
"chars": 11361,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"xYJFXKP9xhQM\"\n },\n \"source\": [\n \"## Clo"
},
{
"path": "examples/ipynb/example.ipynb",
"chars": 9972,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Import packages\"\n ]\n },\n {\n "
},
{
"path": "examples/onnx/README.md",
"chars": 293,
"preview": "# Export onnx or JIT models for deployment\n\n## Run `pip install onnx -U`.\n\n## Export GPT\n\n3. Run `python examples/onnx/e"
},
{
"path": "examples/onnx/exporter.py",
"chars": 13438,
"preview": "import os, sys\n\nif sys.platform == \"darwin\":\n os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nnow_dir = os.getcwd()\n"
},
{
"path": "examples/onnx/gpt.py",
"chars": 2464,
"preview": "import logging\nfrom typing import Tuple\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils.parametrizations import "
},
{
"path": "examples/onnx/modeling_llama.py",
"chars": 49104,
"preview": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on"
},
{
"path": "examples/web/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/web/ex.py",
"chars": 839,
"preview": "ex = [\n [\n \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n 0.3,\n 0.7,\n 20"
},
{
"path": "examples/web/funcs.py",
"chars": 6191,
"preview": "import random\nfrom typing import Optional\nfrom time import sleep\n\nimport gradio as gr\n\nimport sys\n\nsys.path.append(\"..\")"
},
{
"path": "examples/web/webui.py",
"chars": 9560,
"preview": "import os, sys\n\nif sys.platform == \"darwin\":\n os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nnow_dir = os.getcwd()\n"
},
{
"path": "openai_api.ipynb",
"chars": 3419501,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 36,\n \"metadata\": {},\n \"outputs\": [\n {\n \"data\""
},
{
"path": "requirements.txt",
"chars": 261,
"preview": "numpy<3.0.0\nnumba\ntorch>=2.1.0\ntorchaudio\ntqdm\nvector_quantize_pytorch\ntransformers>=4.41.1\nvocos\nIPython\ngradio\npybase1"
},
{
"path": "setup.py",
"chars": 1130,
"preview": "import os\nfrom setuptools import setup, find_packages\n\nversion = \"v0.0.0\"\n\nsetup(\n name=\"chattts\",\n version=os.env"
},
{
"path": "tests/#511.py",
"chars": 1634,
"preview": "import os, sys\n\nif sys.platform == \"darwin\":\n os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nnow_dir = os.getcwd()\n"
},
{
"path": "tests/#588.py",
"chars": 975,
"preview": "import os, sys\n\nif sys.platform == \"darwin\":\n os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nnow_dir = os.getcwd()\n"
},
{
"path": "tests/#655.py",
"chars": 2422,
"preview": "import os, sys\n\nif sys.platform == \"darwin\":\n os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nnow_dir = os.getcwd()\n"
},
{
"path": "tests/testall.sh",
"chars": 261,
"preview": "#!/bin/sh\n\nexitcode=0\n\nfor file in tests/*.py\ndo\n echo \"Testing $file...\"\n python \"$file\"\n if [ $? -ne 0 ]\n "
},
{
"path": "tools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tools/audio/__init__.py",
"chars": 178,
"preview": "from .av import load_audio\nfrom .pcm import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view\nfrom .ffmpeg i"
},
{
"path": "tools/audio/av.py",
"chars": 3831,
"preview": "from io import BufferedWriter, BytesIO\nfrom pathlib import Path\nfrom typing import Dict, Tuple, Optional, Union, List\n\ni"
},
{
"path": "tools/audio/ffmpeg.py",
"chars": 116,
"preview": "from pydub.utils import which\n\n\ndef has_ffmpeg_installed() -> bool:\n return which(\"ffmpeg\") and which(\"ffprobe\")\n"
},
{
"path": "tools/audio/np.py",
"chars": 269,
"preview": "import math\n\nimport numpy as np\nfrom numba import jit\n\n\n@jit(nopython=True)\ndef float_to_int16(audio: np.ndarray) -> np."
},
{
"path": "tools/audio/pcm.py",
"chars": 2959,
"preview": "import wave\nfrom io import BytesIO\nimport numpy as np\nfrom .np import float_to_int16\nfrom .av import wav2\n\n\ndef _pcm_to_"
},
{
"path": "tools/checksum/main.go",
"chars": 608,
"preview": "package main\n\nimport (\n\t\"crypto/sha256\"\n\t\"encoding/hex\"\n\t\"fmt\"\n\t\"io\"\n\t\"os\"\n)\n\nfunc main() {\n\tvar buf [32]byte\n\th := sha2"
},
{
"path": "tools/checksum/tmpl.go",
"chars": 800,
"preview": "package main\n\nvar files = [...]string{\n\t\"asset/Decoder.safetensors\",\n\t\"asset/DVAE.safetensors\",\n\t\"asset/Embed.safetensor"
},
{
"path": "tools/llm/__init__.py",
"chars": 28,
"preview": "from .llm import ChatOpenAI\n"
},
{
"path": "tools/llm/llm.py",
"chars": 2437,
"preview": "from openai import OpenAI\n\nprompt_dict = {\n \"kimi\": [\n {\n \"role\": \"system\",\n \"content\": "
},
{
"path": "tools/logger/__init__.py",
"chars": 28,
"preview": "from .log import get_logger\n"
},
{
"path": "tools/logger/log.py",
"chars": 2535,
"preview": "import platform, sys\nimport logging\nfrom datetime import datetime, timezone\n\nlogging.getLogger(\"numba\").setLevel(logging"
},
{
"path": "tools/normalizer/__init__.py",
"chars": 73,
"preview": "from .en import normalizer_en_nemo_text\nfrom .zh import normalizer_zh_tn\n"
},
{
"path": "tools/normalizer/en.py",
"chars": 336,
"preview": "from typing import Callable\nfrom functools import partial\n\n\ndef normalizer_en_nemo_text() -> Callable[[str], str]:\n f"
},
{
"path": "tools/normalizer/zh.py",
"chars": 188,
"preview": "from typing import Callable\n\n\ndef normalizer_zh_tn() -> Callable[[str], str]:\n from tn.chinese.normalizer import Norm"
},
{
"path": "tools/seeder/__init__.py",
"chars": 34,
"preview": "from .ctx import TorchSeedContext\n"
},
{
"path": "tools/seeder/ctx.py",
"chars": 329,
"preview": "import torch\n\n\nclass TorchSeedContext:\n def __init__(self, seed):\n self.seed = seed\n self.state = None\n"
}
]
About this extraction
This page contains the full source code of the 2noise/ChatTTS GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 96 files (4.0 MB), approximately 1.1M tokens, and a symbol index with 491 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.