Repository: resemble-ai/chatterbox
Branch: master
Commit: eaf93540a1dd
Files: 64
Total size: 354.7 KB
Directory structure:
gitextract_lols3utf/
├── .github/
│ └── workflows/
│ └── install_check.yml
├── .gitignore
├── LICENSE
├── README.md
├── example_for_mac.py
├── example_tts.py
├── example_tts_turbo.py
├── example_vc.py
├── gradio_tts_app.py
├── gradio_tts_turbo_app.py
├── gradio_vc_app.py
├── multilingual_app.py
├── pyproject.toml
└── src/
└── chatterbox/
├── __init__.py
├── models/
│ ├── __init__.py
│ ├── s3gen/
│ │ ├── __init__.py
│ │ ├── configs.py
│ │ ├── const.py
│ │ ├── decoder.py
│ │ ├── f0_predictor.py
│ │ ├── flow.py
│ │ ├── flow_matching.py
│ │ ├── hifigan.py
│ │ ├── matcha/
│ │ │ ├── decoder.py
│ │ │ ├── flow_matching.py
│ │ │ ├── text_encoder.py
│ │ │ └── transformer.py
│ │ ├── s3gen.py
│ │ ├── transformer/
│ │ │ ├── __init__.py
│ │ │ ├── activation.py
│ │ │ ├── attention.py
│ │ │ ├── convolution.py
│ │ │ ├── embedding.py
│ │ │ ├── encoder_layer.py
│ │ │ ├── positionwise_feed_forward.py
│ │ │ ├── subsampling.py
│ │ │ └── upsample_encoder.py
│ │ ├── utils/
│ │ │ ├── class_utils.py
│ │ │ ├── intmeanflow.py
│ │ │ ├── mask.py
│ │ │ └── mel.py
│ │ └── xvector.py
│ ├── s3tokenizer/
│ │ ├── __init__.py
│ │ └── s3tokenizer.py
│ ├── t3/
│ │ ├── __init__.py
│ │ ├── inference/
│ │ │ ├── alignment_stream_analyzer.py
│ │ │ └── t3_hf_backend.py
│ │ ├── llama_configs.py
│ │ ├── modules/
│ │ │ ├── cond_enc.py
│ │ │ ├── learned_pos_emb.py
│ │ │ ├── perceiver.py
│ │ │ └── t3_config.py
│ │ └── t3.py
│ ├── tokenizers/
│ │ ├── __init__.py
│ │ └── tokenizer.py
│ ├── utils.py
│ └── voice_encoder/
│ ├── __init__.py
│ ├── config.py
│ ├── melspec.py
│ └── voice_encoder.py
├── mtl_tts.py
├── tts.py
├── tts_turbo.py
└── vc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/install_check.yml
================================================
name: Test Installation
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Test Standard Install
run: |
pip install -e .
================================================
FILE: .gitignore
================================================
.vscode
# Pylance
pyrightconfig.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
syn_out/
checkpoints/
.gradio
# Ignore generated sample .wav files
**/*.wav
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Resemble AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================

# Chatterbox TTS
[](https://resemble-ai.github.io/chatterbox_turbo_demopage/)
[](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo)
[](https://podonos.com/resembleai/chatterbox)
[](https://discord.gg/rJq9cRJBJ6)
*Made with ♥️ by*
**Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI.
We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output.
**Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows.
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (link). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
### ⚡ Model Zoo
Choose the right model for your application.
| Model | Size | Languages | Key Features | Best For | 🤗 | Examples |
|:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- |
| **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) |
| Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
| Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
## Installation
```shell
pip install chatterbox-tts
```
Alternatively, you can install from source:
```shell
# conda create -yn chatterbox python=3.11
# conda activate chatterbox
git clone https://github.com/resemble-ai/chatterbox.git
cd chatterbox
pip install -e .
```
We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
## Usage
##### Chatterbox-Turbo
```python
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?"
# Generate audio (requires a reference clip for voice cloning)
wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
ta.save("test-turbo.wav", wav, model.sr)
```
##### Chatterbox and Chatterbox-Multilingual
```python
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# English example
model = ChatterboxTTS.from_pretrained(device="cuda")
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
ta.save("test-english.wav", wav, model.sr)
# Multilingual examples
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
wav_french = multilingual_model.generate(french_text, language_id="fr")
ta.save("test-french.wav", wav_french, model.sr)
chinese_text = "你好,今天天气真不错,希望你有一个愉快的周末。"
wav_chinese = multilingual_model.generate(chinese_text, language_id="zh")
ta.save("test-chinese.wav", wav_chinese, model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
ta.save("test-2.wav", wav, model.sr)
```
See `example_tts.py` and `example_vc.py` for more examples.
## Supported Languages
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
## Original Chatterbox Tips
- **General Use (TTS and Voice Agents):**
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`.
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
- **Expressive or Dramatic Speech:**
- Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
## Built-in PerTh Watermarking for Responsible AI
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
## Watermark extraction
You can look for the watermark using the following script.
```python
import perth
import librosa
AUDIO_PATH = "YOUR_FILE.wav"
# Load the watermarked audio
watermarked_audio, sr = librosa.load(AUDIO_PATH, sr=None)
# Initialize watermarker (same as used for embedding)
watermarker = perth.PerthImplicitWatermarker()
# Extract watermark
watermark = watermarker.get_watermark(watermarked_audio, sample_rate=sr)
print(f"Extracted watermark: {watermark}")
# Output: 0.0 (no watermark) or 1.0 (watermarked)
```
## Official Discord
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
## Evaluation
Chatterbox Turbo was evaluated using Podonos, a platform for reproducible subjective speech evaluation.
We compared Chatterbox Turbo to competitive TTS systems using Podonos' standardized evaluation suite, focusing on overall preference, naturalness, and expressiveness.
Evaluation reports:
- [Chatterbox Turbo vs ElevenLabs Turbo v2.5](https://podonos.com/resembleai/chatterbox-turbo-vs-elevenlabs-turbo)
- [Chatterbox Turbo vs Cartesia Sonic 3](https://podonos.com/resembleai/chatterbox-turbo-vs-cartesia-sonic3)
- [Chatterbox Turbo vs VibeVoice 7B](https://podonos.com/resembleai/chatterbox-turbo-vs-vibevoice7b)
These evaluations were conducted under identical conditions and are publicly accessible via Podonos.
## Acknowledgements
- [Podonos](https://podonos.com) — for supporting reproducible subjective speech evaluation
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
- [Llama 3](https://github.com/meta-llama/llama3)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
## Citation
If you find this model useful, please consider citing.
```
@misc{chatterboxtts2025,
author = {{Resemble AI}},
title = {{Chatterbox-TTS}},
year = {2025},
howpublished = {\url{https://github.com/resemble-ai/chatterbox}},
note = {GitHub repository}
}
```
## Disclaimer
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.
================================================
FILE: example_for_mac.py
================================================
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
# Detect device (Mac with M1/M2/M3/M4)
device = "mps" if torch.backends.mps.is_available() else "cpu"
map_location = torch.device(device)
torch_load_original = torch.load
def patched_torch_load(*args, **kwargs):
if 'map_location' not in kwargs:
kwargs['map_location'] = map_location
return torch_load_original(*args, **kwargs)
torch.load = patched_torch_load
model = ChatterboxTTS.from_pretrained(device=device)
text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(
text,
audio_prompt_path=AUDIO_PROMPT_PATH,
exaggeration=2.0,
cfg_weight=0.5
)
ta.save("test-2.wav", wav, model.sr)
================================================
FILE: example_tts.py
================================================
import torchaudio as ta
import torch
from pathlib import Path
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# Automatically detect the best available device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
model = ChatterboxTTS.from_pretrained(device=device)
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
ta.save("test-1.wav", wav, model.sr)
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
wav = multilingual_model.generate(text, language_id="fr")
ta.save("test-2.wav", wav, multilingual_model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
if Path(AUDIO_PROMPT_PATH).exists():
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
ta.save("test-3.wav", wav, model.sr)
else:
print(f"Warning: audio prompt file '{AUDIO_PROMPT_PATH}' not found, skipping voice cloning example.")
================================================
FILE: example_tts_turbo.py
================================================
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?"
# Generate audio (requires a reference clip for voice cloning)
# wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
wav = model.generate(text)
ta.save("test-turbo.wav", wav, model.sr)
================================================
FILE: example_vc.py
================================================
import torch
import torchaudio as ta
from chatterbox.vc import ChatterboxVC
# Automatically detect the best available device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
AUDIO_PATH = "YOUR_FILE.wav"
TARGET_VOICE_PATH = "YOUR_FILE.wav"
model = ChatterboxVC.from_pretrained(device)
wav = model.generate(
audio=AUDIO_PATH,
target_voice_path=TARGET_VOICE_PATH,
)
ta.save("testvc.wav", wav, model.sr)
================================================
FILE: gradio_tts_app.py
================================================
import random
import numpy as np
import torch
import gradio as gr
from chatterbox.tts import ChatterboxTTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def load_model():
model = ChatterboxTTS.from_pretrained(DEVICE)
return model
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty):
if model is None:
model = ChatterboxTTS.from_pretrained(DEVICE)
if seed_num != 0:
set_seed(int(seed_num))
wav = model.generate(
text,
audio_prompt_path=audio_prompt_path,
exaggeration=exaggeration,
temperature=temperature,
cfg_weight=cfgw,
min_p=min_p,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return (model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
model_state = gr.State(None) # Loaded once per session/user
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="Text to synthesize (max chars 300)",
max_lines=5
)
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None)
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
demo.load(fn=load_model, inputs=[], outputs=model_state)
run_btn.click(
fn=generate,
inputs=[
model_state,
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
min_p,
top_p,
repetition_penalty,
],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue(
max_size=50,
default_concurrency_limit=1,
).launch(share=True)
================================================
FILE: gradio_tts_turbo_app.py
================================================
import random
import numpy as np
import torch
import gradio as gr
from chatterbox.tts_turbo import ChatterboxTurboTTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EVENT_TAGS = [
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
"[sniff]", "[gasp]", "[chuckle]", "[laugh]"
]
# --- REFINED CSS ---
# 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
# 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
CUSTOM_CSS = """
.tag-container {
display: flex !important;
flex-wrap: wrap !important; /* This fixes the one-per-line issue */
gap: 8px !important;
margin-top: 5px !important;
margin-bottom: 10px !important;
border: none !important;
background: transparent !important;
}
.tag-btn {
min-width: fit-content !important;
width: auto !important;
height: 32px !important;
font-size: 13px !important;
background: #eef2ff !important;
border: 1px solid #c7d2fe !important;
color: #3730a3 !important;
border-radius: 6px !important;
padding: 0 10px !important;
margin: 0 !important;
box-shadow: none !important;
}
.tag-btn:hover {
background: #c7d2fe !important;
transform: translateY(-1px);
}
"""
INSERT_TAG_JS = """
(tag_val, current_text) => {
const textarea = document.querySelector('#main_textbox textarea');
if (!textarea) return current_text + " " + tag_val;
const start = textarea.selectionStart;
const end = textarea.selectionEnd;
let prefix = " ";
let suffix = " ";
if (start === 0) prefix = "";
else if (current_text[start - 1] === ' ') prefix = "";
if (end < current_text.length && current_text[end] === ' ') suffix = "";
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
}
"""
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def load_model():
print(f"Loading Chatterbox-Turbo on {DEVICE}...")
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
return model
def generate(
model,
text,
audio_prompt_path,
temperature,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness
):
if model is None:
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
if seed_num != 0:
set_seed(int(seed_num))
wav = model.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
min_p=min_p,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=repetition_penalty,
norm_loudness=norm_loudness,
)
return (model.sr, wav.squeeze(0).numpy())
with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
gr.Markdown("# ⚡ Chatterbox Turbo")
model_state = gr.State(None)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
label="Text to synthesize (max chars 300)",
max_lines=5,
elem_id="main_textbox"
)
# --- Event Tags ---
# Switched back to Row, but applied specific CSS to force wrapping
with gr.Row(elem_classes=["tag-container"]):
for tag in EVENT_TAGS:
# elem_classes targets the button specifically
btn = gr.Button(tag, elem_classes=["tag-btn"])
btn.click(
fn=None,
inputs=[btn, text],
outputs=text,
js=INSERT_TAG_JS
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
)
run_btn = gr.Button("Generate ⚡", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
with gr.Accordion("Advanced Options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
demo.load(fn=load_model, inputs=[], outputs=model_state)
run_btn.click(
fn=generate,
inputs=[
model_state,
text,
ref_wav,
temp,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness,
],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue(
max_size=50,
default_concurrency_limit=1,
).launch(share=True)
================================================
FILE: gradio_vc_app.py
================================================
import torch
import gradio as gr
from chatterbox.vc import ChatterboxVC
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = ChatterboxVC.from_pretrained(DEVICE)
def generate(audio, target_voice_path):
wav = model.generate(
audio, target_voice_path=target_voice_path,
)
return model.sr, wav.squeeze(0).numpy()
demo = gr.Interface(
generate,
[
gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input audio file"),
gr.Audio(sources=["upload", "microphone"], type="filepath", label="Target voice audio file (if none, the default voice is used)", value=None),
],
"audio",
)
if __name__ == "__main__":
demo.launch()
================================================
FILE: multilingual_app.py
================================================
import random
import numpy as np
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
LANGUAGE_CONFIG = {
"ar": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
},
"da": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
"text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
},
"de": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
"text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
},
"el": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
"text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
},
"en": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
},
"es": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
"text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
},
"fi": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
"text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
},
"fr": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
},
"he": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
"text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
},
"hi": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
"text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
},
"it": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
},
"ja": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
},
"ko": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
"text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
},
"ms": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
"text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
},
"nl": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
"text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
},
"no": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
"text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
},
"pl": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
"text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
},
"pt": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
"text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
},
"ru": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
},
"sv": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
"text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
},
"sw": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
"text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
},
"tr": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
},
"zh": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
"text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
},
}
# --- UI Helpers ---
def default_audio_for_ui(lang: str) -> str | None:
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
def default_text_for_ui(lang: str) -> str:
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
def get_supported_languages_display() -> str:
"""Generate a formatted display of all supported languages."""
language_items = []
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
language_items.append(f"**{name}** (`{code}`)")
# Split into 2 lines
mid = len(language_items) // 2
line1 = " • ".join(language_items[:mid])
line2 = " • ".join(language_items[mid:])
return f"""
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
{line1}
{line2}
"""
def get_or_load_model():
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
"""
Decide which audio prompt to use:
- If user provided a path (upload/mic/url), use it.
- Else, fall back to language-specific default (if any).
"""
if provided_path and str(provided_path).strip():
return provided_path
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
def generate_tts_audio(
text_input: str,
language_id: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5
) -> tuple[int, np.ndarray]:
"""
Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
This tool synthesizes natural-sounding speech from input text. When a reference audio file
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
Args:
text_input (str): The text to synthesize into speech (maximum 300 characters)
language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
Returns:
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
# Handle optional audio prompt
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
}
if chosen_prompt:
generate_kwargs["audio_prompt_path"] = chosen_prompt
print(f"Using audio prompt: {chosen_prompt}")
else:
print("No audio prompt provided; using default voice.")
wav = current_model.generate(
text_input[:300], # Truncate text to max chars
language_id=language_id,
**generate_kwargs
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox Multilingual Demo
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
"""
)
# Display supported languages
gr.Markdown(get_supported_languages_display())
with gr.Row():
with gr.Column():
initial_lang = "fr"
text = gr.Textbox(
value=default_text_for_ui(initial_lang),
label="Text to synthesize (max chars 300)",
max_lines=5
)
language_id = gr.Dropdown(
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
value=initial_lang,
label="Language",
info="Select the language for text-to-speech synthesis"
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value=default_audio_for_ui(initial_lang)
)
gr.Markdown(
"💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
elem_classes=["audio-note"]
)
exaggeration = gr.Slider(
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05, label="CFG/Pace", value=0.5
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
def on_language_change(lang, current_ref, current_text):
return default_audio_for_ui(lang), default_text_for_ui(lang)
language_id.change(
fn=on_language_change,
inputs=[language_id, ref_wav, text],
outputs=[ref_wav, text],
show_progress=False
)
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
language_id,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output],
)
demo.launch(mcp_server=True)
================================================
FILE: pyproject.toml
================================================
[project]
name = "chatterbox-tts"
version = "0.1.6"
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "resemble-ai", email = "engineering@resemble.ai"}
]
dependencies = [
"numpy>=1.24.0,<1.26.0",
"librosa==0.11.0",
"s3tokenizer",
"torch==2.6.0",
"torchaudio==2.6.0",
"transformers==5.2.0",
"diffusers==0.29.0",
"resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master",
"conformer==0.3.2",
"safetensors==0.5.3",
"spacy-pkuseg",
"pykakasi==2.3.0",
"gradio==6.8.0",
"pyloudnorm",
"omegaconf"
]
[project.urls]
Homepage = "https://github.com/resemble-ai/chatterbox"
Repository = "https://github.com/resemble-ai/chatterbox"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
================================================
FILE: src/chatterbox/__init__.py
================================================
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version # For Python <3.8
__version__ = version("chatterbox-tts")
from .tts import ChatterboxTTS
from .vc import ChatterboxVC
from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
================================================
FILE: src/chatterbox/models/__init__.py
================================================
================================================
FILE: src/chatterbox/models/s3gen/__init__.py
================================================
from .s3gen import S3Token2Wav as S3Gen
from .const import S3GEN_SR
================================================
FILE: src/chatterbox/models/s3gen/configs.py
================================================
from ..utils import AttrDict
CFM_PARAMS = AttrDict({
"sigma_min": 1e-06,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1"
})
================================================
FILE: src/chatterbox/models/s3gen/const.py
================================================
S3GEN_SR = 24000
S3GEN_SIL = 4299
================================================
FILE: src/chatterbox/models/s3gen/decoder.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat
from .utils.mask import add_optional_chunk_mask
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
TimestepEmbedding, Upsample1D
from .matcha.transformer import BasicTransformerBlock
from .utils.intmeanflow import get_intmeanflow_time_mixer
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
assert mask.dtype == torch.bool
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
mask = mask.to(dtype)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * -1.0e+10
return mask
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels=320,
out_channels=80,
causal=True,
channels=[256],
dropout=0.0,
attention_head_dim=64,
n_blocks=4,
num_mid_blocks=12,
num_heads=8,
act_fn="gelu",
meanflow=False,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.meanflow = meanflow
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
# NOTE jrm: `static_chunk_size` is missing?
self.static_chunk_size = 0
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
self.time_embed_mixer = None
if self.meanflow:
self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim)
@property
def dtype(self):
return self.final_proj.weight.dtype
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None, r=None):
"""Forward pass of the UNet1DConditional model.
Args:
x: (B, 80, T)
mask (_type_)
t (_type_): shape (batch_size)
spks (_type_, optional) Defaults to None.
cond (_type_, optional)
r: end time for meanflow mode (shape (1,) tensor)
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
if self.meanflow:
r = self.time_embeddings(r).to(t.dtype)
r = self.time_mlp(r)
concat_embed = torch.cat([t, r], dim=1)
t = self.time_embed_mixer(concat_embed)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
================================================
FILE: src/chatterbox/models/s3gen/f0_predictor.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
================================================
FILE: src/chatterbox/models/s3gen/flow.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
from torch.nn import functional as F
from .utils.mask import make_pad_mask
from .configs import CFM_PARAMS
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
def _repeat_batch_dim(tnsr, B, ndim):
"repeat batch dimension if it's equal to 1"
if tnsr is not None:
# add missing batch dim if needed
while tnsr.ndim < ndim:
tnsr = tnsr[None]
# repeat batch dim as needed
if B > 1 and tnsr.size(0) == 1:
tnsr = tnsr.repeat(B, *([1] * (ndim - 1)))
assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}"
return tnsr
class CausalMaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 6561,
input_frame_rate: int = 25,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig(
{'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7,
'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8,
'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
# NOTE: copied in from cosyvoice repo
def compute_loss(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) # (B, 80, T)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
# streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1)
token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb)
# text encode
h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C)
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :, :index] = feat[i, :, :index]
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
# streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
finalize,
n_timesteps=10,
noised_mels=None,
meanflow=False):
# token: (B, n_toks)
# token_len: (B,)
B = token.size(0)
# xvec projection
embedding = torch.atleast_2d(embedding)
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim)
# adjust shapes (batching logic)
prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt)
prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,)
prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80)
prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None
embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
if (token >= self.vocab_size).any():
logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!")
token = self.input_embedding(token.long()) * mask
# text encode
h, h_masks = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# # get conditions
conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h)
if mask.shape[0] != B:
mask = mask.repeat(B, 1, 1)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask,
spks=embedding,
cond=conds,
n_timesteps=n_timesteps,
noised_mels=noised_mels,
meanflow=meanflow,
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, None # NOTE jrm: why are they returning None here?
================================================
FILE: src/chatterbox/models/s3gen/flow_matching.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import torch
import torch.nn.functional as F
from .matcha.flow_matching import BASECFM
from .configs import CFM_PARAMS
from tqdm import tqdm
def cast_all(*args, dtype):
return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args]
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
raise NotImplementedError("unused, needs updating for meanflow model")
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
meanflow: meanflow mode
"""
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
# Duplicated batch dims are for CFG
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
B, T = mu.size(0), x.size(2)
x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow)
for t, r in zip(t_span[:-1], t_span[1:]):
t = t.unsqueeze(dim=0)
r = r.unsqueeze(dim=0)
# Shapes:
# x_in ( 2B, 80, T )
# mask_in ( 2B, 1, T )
# mu_in ( 2B, 80, T )
# t_in ( 2B, )
# spks_in ( 2B, 80, )
# cond_in ( 2B, 80, T )
# r_in ( 2B, )
# x ( B, 80, T )
# mask ( B, 1, T )
# mu ( B, 80, T )
# t ( B, )
# spks ( B, 80, )
# cond ( B, 80, T )
# r ( B, )
x_in[:B] = x_in[B:] = x
mask_in[:B] = mask_in[B:] = mask
mu_in[:B] = mu
t_in[:B] = t_in[B:] = t
spks_in[:B] = spks
cond_in[:B] = cond
r_in[:B] = r_in[B:] = r # (only used for meanflow)
dxdt = self.estimator.forward(
x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in,
r=r_in if meanflow else None,
)
dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0)
dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt)
dt = r - t
x = x + dt * dxdt
return x.to(in_dtype)
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
# TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE
self.rand_noise = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
noised_mels: gt mels noised a time t
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
B = mu.size(0)
z = torch.randn_like(mu)
if noised_mels is not None:
prompt_len = mu.size(2) - noised_mels.size(2)
z[..., prompt_len:] = noised_mels
# time steps for reverse diffusion
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if (not meanflow) and (self.t_scheduler == 'cosine'):
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
# NOTE: right now, the only meanflow models are also distilled models, which don't need CFG
# because they were distilled with CFG outputs. We would need to add another hparam and
# change the conditional logic here if we want to use CFG inference with a meanflow model.
if meanflow:
return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None
def basic_euler(self, x, t_span, mu, mask, spks, cond):
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
print("S3 Token -> Mel Inference...")
for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1):
t, r = t[None], r[None]
dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r)
dt = r - t
x = x + dt * dxdt
return x.to(in_dtype)
================================================
FILE: src/chatterbox/models/s3gen/hifigan.py
================================================
# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
# most modules should be reusable, but I found their SineGen changed a git.
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HIFI-GAN"""
from typing import Dict, Optional, List
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.distributions.uniform import Uniform
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
"""hifigan based generator implementation.
This code is modified from https://github.com/jik876/hifi-gan
,https://github.com/kan-bayashi/ParallelWaveGAN and
https://github.com/NVIDIA/BigVGAN
"""
class ResBlock(torch.nn.Module):
"""Residual block module in HiFiGAN/BigVGAN."""
def __init__(
self,
channels: int = 512,
kernel_size: int = 3,
dilations: List[int] = [1, 3, 5],
):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
for dilation in dilations:
self.convs1.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
)
)
)
self.convs2.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
)
)
)
self.convs1.apply(init_weights)
self.convs2.apply(init_weights)
self.activations1 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs1))
])
self.activations2 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs2))
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for idx in range(len(self.convs1)):
xt = self.activations1[idx](x)
xt = self.convs1[idx](xt)
xt = self.activations2[idx](xt)
xt = self.convs2[idx](xt)
x = xt + x
return x
def remove_weight_norm(self):
for idx in range(len(self.convs1)):
remove_weight_norm(self.convs1[idx])
remove_weight_norm(self.convs2[idx])
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
@torch.no_grad()
def forward(self, f0):
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
# generate sine waveforms
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class HiFTGenerator(nn.Module):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
f0_predictor: torch.nn.Module = None,
):
super(HiFTGenerator, self).__init__()
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.m_source.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def _stft(self, x):
spec = torch.stft(
x,
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# mel+source->speech
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, f0
@torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# use cache_source to avoid glitch
if cache_source.shape[2] != 0:
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s
================================================
FILE: src/chatterbox/models/s3gen/matcha/decoder.py
================================================
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation
from einops import pack, rearrange, repeat
from .transformer import BasicTransformerBlock
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class ConformerWrapper(ConformerBlock):
def __init__( # pylint: disable=useless-super-delegation
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0,
ff_dropout=0,
conv_dropout=0,
conv_causal=False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
):
return super().forward(x=hidden_states, mask=attention_mask.bool())
class Decoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
down_block_type="transformer",
mid_block_type="transformer",
up_block_type="transformer",
):
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
down_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
mid_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i]
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=2 * input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
self.get_block(
up_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
# nn.init.normal_(self.final_proj.weight)
@staticmethod
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
if block_type == "conformer":
block = ConformerWrapper(
dim=dim,
dim_head=attention_head_dim,
heads=num_heads,
ff_mult=1,
conv_expansion_factor=2,
ff_dropout=dropout,
attn_dropout=dropout,
conv_dropout=dropout,
conv_kernel_size=31,
)
elif block_type == "transformer":
block = BasicTransformerBlock(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
else:
raise ValueError(f"Unknown block type {block_type}")
return block
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c")
mask_down = rearrange(mask_down, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_down,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_down = rearrange(mask_down, "b t -> b 1 t")
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c")
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_mid,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
x = rearrange(x, "b c t -> b t c")
mask_up = rearrange(mask_up, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_up,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_up = rearrange(mask_up, "b t -> b 1 t")
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
================================================
FILE: src/chatterbox/models/s3gen/matcha/flow_matching.py
================================================
from abc import ABC
import torch
import torch.nn.functional as F
from .decoder import Decoder
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats,
cfm_params,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.solver
if hasattr(cfm_params, "sigma_min"):
self.sigma_min = cfm_params.sigma_min
else:
self.sigma_min = 1e-4
self.estimator = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
class CFM(BASECFM):
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
# Just change the architecture of the estimator here
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
================================================
FILE: src/chatterbox/models/s3gen/matcha/text_encoder.py
================================================
""" from https://github.com/jaywalnut310/glow-tts """
import math
import torch
import torch.nn as nn
from einops import rearrange
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
x = rearrange(x, "b h t d -> t b h d")
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
query = self.query_rotary_pe(query)
key = self.key_rotary_pe(key)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
@staticmethod
def _attention_bias_proximal(length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.n_layers):
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class TextEncoder(nn.Module):
def __init__(
self,
encoder_type,
encoder_params,
duration_predictor_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.encoder_type = encoder_type
self.n_vocab = n_vocab
self.n_feats = encoder_params.n_feats
self.n_channels = encoder_params.n_channels
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
if encoder_params.prenet:
self.prenet = ConvReluNorm(
self.n_channels,
self.n_channels,
self.n_channels,
kernel_size=5,
n_layers=3,
p_dropout=0.5,
)
else:
self.prenet = lambda x, x_mask: x
self.encoder = Encoder(
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
encoder_params.filter_channels,
encoder_params.n_heads,
encoder_params.n_layers,
encoder_params.kernel_size,
encoder_params.p_dropout,
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
Args:
x (torch.Tensor): text input
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): text input lengths
shape: (batch_size,)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size,)
Returns:
mu (torch.Tensor): average output of the encoder
shape: (batch_size, n_feats, max_text_length)
logw (torch.Tensor): log duration predicted by the duration predictor
shape: (batch_size, 1, max_text_length)
x_mask (torch.Tensor): mask for the text input
shape: (batch_size, 1, max_text_length)
"""
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.prenet(x, x_mask)
if self.n_spks > 1:
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
================================================
FILE: src/chatterbox/models/s3gen/matcha/transformer.py
================================================
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features]
self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta ∶= x + 1/b * sin^2 (xa)
"""
x = self.proj(x)
if self.alpha_logscale:
alpha = torch.exp(self.alpha)
beta = torch.exp(self.beta)
else:
alpha = self.alpha
beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "snakebeta":
act_fn = SnakeBeta(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
================================================
FILE: src/chatterbox/models/s3gen/s3gen.py
================================================
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
import torchaudio as ta
from functools import lru_cache
from typing import Optional
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
from .const import S3GEN_SR
from .flow import CausalMaskedDiffWithXvec
from .xvector import CAMPPlus
from .utils.mel import mel_spectrogram
from .f0_predictor import ConvRNNF0Predictor
from .hifigan import HiFTGenerator
from .transformer.upsample_encoder import UpsampleConformerEncoder
from .flow_matching import CausalConditionalCFM
from .decoder import ConditionalDecoder
from .configs import CFM_PARAMS
def drop_invalid_tokens(x):
assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
return x[x < SPEECH_VOCAB_SIZE]
# TODO: global resampler cache
@lru_cache(100)
def get_resampler(src_sr, dst_sr, device):
return ta.transforms.Resample(src_sr, dst_sr).to(device)
class S3Token2Mel(torch.nn.Module):
"""
S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms.
TODO: make these modules configurable?
"""
def __init__(self, meanflow=False):
super().__init__()
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
self.speaker_encoder = CAMPPlus(
# NOTE: This doesn't affect inference. It turns off activation checkpointing
# (a training optimization), which causes a crazy DDP error with accelerate
memory_efficient=False,
)
self.meanflow = meanflow
encoder = UpsampleConformerEncoder(
output_size=512,
attention_heads=8,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.1,
normalize_before=True,
input_layer='linear',
pos_enc_layer_type='rel_pos_espnet',
selfattention_layer_type='rel_selfattn',
input_size=512,
use_cnn_module=False,
macaron_style=False,
)
estimator = ConditionalDecoder(
in_channels=320,
out_channels=80,
causal=True,
channels=[256],
dropout=0.0,
attention_head_dim=64,
n_blocks=4,
num_mid_blocks=12,
num_heads=8,
act_fn='gelu',
meanflow=self.meanflow,
)
cfm_params = CFM_PARAMS
decoder = CausalConditionalCFM(
spk_emb_dim=80,
cfm_params=cfm_params,
estimator=estimator,
)
self.flow = CausalMaskedDiffWithXvec(
encoder=encoder,
decoder=decoder
)
self.resamplers = {}
@property
def device(self):
params = self.tokenizer.parameters()
return next(params).device
@property
def dtype(self):
params = self.flow.parameters()
return next(params).dtype
def embed_ref(
self,
ref_wav: torch.Tensor,
ref_sr: int,
device="auto",
ref_fade_out=True,
):
device = self.device if device == "auto" else device
if isinstance(ref_wav, np.ndarray):
ref_wav = torch.from_numpy(ref_wav).float()
if ref_wav.device != device:
ref_wav = ref_wav.to(device)
if len(ref_wav.shape) == 1:
ref_wav = ref_wav.unsqueeze(0) # (B, L)
if ref_wav.size(1) > 10 * ref_sr:
print("WARNING: s3gen received ref longer than 10s")
ref_wav_24 = ref_wav
if ref_sr != S3GEN_SR:
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype)
ref_mels_24_len = None
# Resample to 16kHz
ref_wav_16 = ref_wav
if ref_sr != S3_SR:
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav)
# Speaker embedding
ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype))
# Tokenize 16khz reference
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float())
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
logging.warning(
"Reference mel length is not equal to 2 * reference token length.\n"
)
ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
return dict(
prompt_token=ref_speech_tokens.to(device),
prompt_token_len=ref_speech_token_lens,
prompt_feat=ref_mels_24,
prompt_feat_len=ref_mels_24_len,
embedding=ref_x_vector,
)
def forward(
self,
speech_tokens: torch.LongTensor,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
noised_mels=None,
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
NOTE:
- The speaker encoder accepts 16 kHz waveform.
- S3TokenizerV2 accepts 16 kHz waveform.
- The mel-spectrogram for the reference assumes 24 kHz input signal.
- This function is designed for batch_size=1 only.
Args
----
- `speech_tokens`: S3 speech tokens [B=1, T]
- `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
- `ref_sr`: reference sample rate
- `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
"""
assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
if ref_dict is None:
ref_dict = self.embed_ref(ref_wav, ref_sr)
else:
# type/device casting (all values will be numpy if it's from a prod API call)
for rk in list(ref_dict):
if isinstance(ref_dict[rk], np.ndarray):
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
if torch.is_tensor(ref_dict[rk]):
ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype)
speech_tokens = torch.atleast_2d(speech_tokens)
# backcompat
if speech_token_lens is None:
speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device)
output_mels, _ = self.flow.inference(
token=speech_tokens,
token_len=speech_token_lens,
finalize=finalize,
noised_mels=noised_mels,
n_timesteps=n_cfm_timesteps,
meanflow=self.meanflow,
**ref_dict,
)
return output_mels
class S3Token2Wav(S3Token2Mel):
"""
The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
TODO: make these modules configurable?
"""
ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window")
def __init__(self, meanflow=False):
super().__init__(meanflow)
f0_predictor = ConvRNNF0Predictor()
self.mel2wav = HiFTGenerator(
sampling_rate=S3GEN_SR,
upsample_rates=[8, 5, 3],
upsample_kernel_sizes=[16, 11, 7],
source_resblock_kernel_sizes=[7, 7, 11],
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
f0_predictor=f0_predictor,
)
# silence out a few ms and fade audio in to reduce artifacts
n_trim = S3GEN_SR // 50 # 20ms = half of a frame
trim_fade = torch.zeros(2 * n_trim)
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
self.estimator_dtype = "fp32"
def forward(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False,
speech_token_lens=None,
skip_vocoder=False,
n_cfm_timesteps=None,
noised_mels=None,
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis.
"""
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav,
ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize,
n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels,
)
if skip_vocoder:
return output_mels
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
if not self.training:
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
return output_wavs
@torch.inference_mode()
def flow_inference(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
):
n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10)
noise = None
if self.meanflow:
noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device)
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise,
)
return output_mels
@torch.inference_mode()
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
if cache_source is None:
cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype)
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
@torch.inference_mode()
def inference(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
# left as a kwarg because this can change input/output size ratio
drop_invalid_tokens=True,
n_cfm_timesteps=None,
speech_token_lens=None,
):
# hallucination prevention, drop special tokens
# if drop_invalid_tokens:
# speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD)
output_mels = self.flow_inference(
speech_tokens,
speech_token_lens=speech_token_lens,
ref_wav=ref_wav,
ref_sr=ref_sr,
ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps,
finalize=True,
)
output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed?
output_wavs, output_sources = self.hift_inference(output_mels, None)
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
return output_wavs, output_sources
================================================
FILE: src/chatterbox/models/s3gen/transformer/__init__.py
================================================
================================================
FILE: src/chatterbox/models/s3gen/transformer/activation.py
================================================
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swish() activation function for Conformer."""
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
================================================
FILE: src/chatterbox/models/s3gen/transformer/attention.py
================================================
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
CosyVoice.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
return x
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
if matrix_ac.shape != matrix_bd.shape:
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
================================================
FILE: src/chatterbox/models/s3gen/transformer/convolution.py
================================================
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from typing import Tuple
import torch
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
================================================
FILE: src/chatterbox/models/s3gen/transformer/embedding.py
================================================
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Positonal Encoding Module."""
import math
from typing import Tuple, Union
import torch
import torch.nn.functional as F
import numpy as np
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000,
reverse: bool = False):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
self.pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len,
dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int, torch.tensor): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
self.pe = self.pe.to(x.device)
pos_emb = self.position_encoding(offset, x.size(1), False)
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size <= self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
# remove negative offset
index = index * flag
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
if apply_dropout:
pos_emb = self.dropout(pos_emb)
return pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Initialize class."""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.position_encoding(offset, x.size(1), False)
return self.dropout(x), self.dropout(pos_emb)
class WhisperPositionalEncoding(PositionalEncoding):
""" Sinusoids position encoding used in openai-whisper.encoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
super().__init__(d_model, dropout_rate, max_len)
self.xscale = 1.0
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(d_model // 2))
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
inv_timescales[np.newaxis, :]
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
delattr(self, "pe")
self.register_buffer("pe", pe.unsqueeze(0))
class LearnablePositionalEncoding(PositionalEncoding):
""" Learnable position encoding used in openai-whisper.decoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
super().__init__(d_model, dropout_rate, max_len)
# NOTE(xcsong): overwrite self.pe & self.xscale
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
self.xscale = 1.0
class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
def __init__(self, d_model: int, dropout_rate: float):
super().__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
""" Just return zero vector for interface compatibility
"""
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
return self.dropout(x), pos_emb
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return torch.zeros(1, size, self.d_model)
class EspnetRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: torch.Tensor):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
]
return pos_emb
================================================
FILE: src/chatterbox/models/s3gen/transformer/encoder_layer.py
================================================
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-12)
self.norm2 = nn.LayerNorm(size, eps=1e-12)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (torch.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
return x, mask, new_att_cache, fake_cnn_cache
class ConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-12) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
================================================
FILE: src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
================================================
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class MoEFFNLayer(torch.nn.Module):
"""
Mixture of expert with Positionwise feed forward layer
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
The output dim is same with the input dim.
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
Args:
n_expert: number of expert.
n_expert_per_token: The actual number of experts used for each frame
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(
self,
n_expert: int,
n_expert_per_token: int,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
self.experts = torch.nn.ModuleList(
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
activation) for _ in range(n_expert))
self.n_expert_per_token = n_expert_per_token
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
B, L, D = xs.size(
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
logits, indices = torch.topk(
router, self.n_expert_per_token
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
weights = torch.nn.functional.softmax(
logits, dim=1,
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
output = torch.zeros_like(xs) # (B*L, D)
for i, expert in enumerate(self.experts):
mask = indices == i
batch_idx, ith_expert = torch.where(mask)
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
xs[batch_idx])
return output.view(B, L, D)
================================================
FILE: src/chatterbox/models/s3gen/transformer/subsampling.py
================================================
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class EmbedinigNoSubsampling(BaseSubsampling):
"""Embedding input without subsampling
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
super().__init__()
self.embed = torch.nn.Embedding(idim, odim)
self.pos_enc = pos_enc_class
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.embed(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv1dSubsampling2(BaseSubsampling):
"""Convolutional 1D subsampling (to 1/2 length).
It is designed for Whisper, ref:
https://github.com/openai/whisper/blob/main/whisper/model.py
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv1dSubsampling2 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
torch.nn.GELU(),
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
torch.nn.GELU(),
)
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 2
# 4 = (3 - 1) * 1 + (3 - 1) * 1
self.right_context = 4
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
torch.Tensor: positional encoding
"""
time = x.size(1)
x = x.transpose(1, 2) # (b, f, t)
x = self.conv(x)
x = x.transpose(1, 2) # (b, t, f)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling6 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.pos_enc = pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self.subsampling_rate = 6
self.right_context = 10
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling8 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
self.right_context = 14
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
class LegacyLinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
================================================
FILE: src/chatterbox/models/s3gen/transformer/upsample_encoder.py
================================================
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
from .convolution import ConvolutionModule
from .encoder_layer import ConformerEncoderLayer
from .positionwise_feed_forward import PositionwiseFeedForward
from ..utils.class_utils import (
COSYVOICE_EMB_CLASSES,
COSYVOICE_SUBSAMPLE_CLASSES,
COSYVOICE_ATTENTION_CLASSES,
COSYVOICE_ACTIVATION_CLASSES,
)
from ..utils.mask import make_pad_mask
from ..utils.mask import add_optional_chunk_mask
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.stride = stride
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
outputs = self.conv(outputs)
return outputs, input_lengths * self.stride
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
kernel_size=3, stride=1, padding=0,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
inputs: (batch_size, seq_len, channels)
"""
outputs = inputs.transpose(1, 2).contiguous()
# look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
outputs = F.leaky_relu(self.conv1(outputs))
# outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
# residual connection
outputs = outputs + inputs
return outputs
class UpsampleConformerEncoder(torch.nn.Module):
def __init__(
self,
input_size: int = 512,
output_size: int = 512,
attention_heads: int = 8,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.1,
input_layer: str = "linear",
pos_enc_layer_type: str = "rel_pos_espnet",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
positionwise_conv_kernel_size: int = 1,
macaron_style: bool = False,
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = False,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
key_bias: bool = True,
gradient_checkpointing: bool = False,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
"""
super().__init__()
self._output_size = output_size
self.global_cmvn = global_cmvn
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.gradient_checkpointing = gradient_checkpointing
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
# self-attention module definition
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
key_bias,
)
# feed-forward module definition
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
PositionwiseFeedForward(*positionwise_layer_args),
PositionwiseFeedForward(
*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
) for _ in range(num_blocks)
])
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
self.up_encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
PositionwiseFeedForward(*positionwise_layer_args),
PositionwiseFeedForward(
*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
) for _ in range(4)
])
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
# lookahead + conformer encoder
xs = self.pre_lookahead_layer(xs)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
# upsample + conformer encoder
xs = xs.transpose(1, 2).contiguous()
xs, xs_lens = self.up_layer(xs, xs_lens)
xs = xs.transpose(1, 2).contiguous()
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size * self.up_layer.stride,
num_decoding_left_chunks)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.up_encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
================================================
FILE: src/chatterbox/models/s3gen/utils/class_utils.py
================================================
# Copyright [2023-11-28]
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..transformer.activation import Swish
from ..transformer.subsampling import (
LinearNoSubsampling,
EmbedinigNoSubsampling,
Conv1dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
)
from ..transformer.embedding import (
PositionalEncoding,
RelPositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from ..transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from ..transformer.embedding import EspnetRelPositionalEncoding
from ..transformer.subsampling import LegacyLinearNoSubsampling
COSYVOICE_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
COSYVOICE_SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
"linear_legacy": LegacyLinearNoSubsampling,
"embed": EmbedinigNoSubsampling,
"conv1d2": Conv1dSubsampling2,
"conv2d": Conv2dSubsampling4,
"conv2d6": Conv2dSubsampling6,
"conv2d8": Conv2dSubsampling8,
'paraformer_dummy': torch.nn.Identity
}
COSYVOICE_EMB_CLASSES = {
"embed": PositionalEncoding,
"abs_pos": PositionalEncoding,
"rel_pos": RelPositionalEncoding,
"rel_pos_espnet": EspnetRelPositionalEncoding,
"no_pos": NoPositionalEncoding,
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
}
COSYVOICE_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}
================================================
FILE: src/chatterbox/models/s3gen/utils/intmeanflow.py
================================================
import torch
import torch.nn as nn
def get_intmeanflow_time_mixer(dims):
""""
Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979
"""
layer = nn.Linear(dims * 2, dims, bias=False)
with torch.no_grad():
target_weight = torch.zeros(dims, 2 * dims)
target_weight[:, 0:dims] = torch.eye(dims)
layer.weight.data = target_weight
return layer
if __name__ == '__main__':
D_example = 6
W_layer = get_intmeanflow_time_mixer(D_example)
print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n")
e_t = torch.tensor([0., 1., 2., 3., 4., 5.])
e_r = torch.tensor([6., 7., 8., 9., 10., 11.])
e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12)
output = W_layer(e_concat)
print(f"Test Input e_t: \n{e_t}")
print(f"Test Input e_r: \n{e_r}")
print(f"Test Input concat: \n{e_concat}")
print(f"Forward Pass Output: \n{output.squeeze(0)}")
================================================
FILE: src/chatterbox/models/s3gen/utils/mask.py
================================================
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret)
'''
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
# actually this is not needed after we have inference cache implemented, will remove it later
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
return chunk_masks
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
lengths = lengths.long()
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
================================================
FILE: src/chatterbox/models/s3gen/utils/mel.py
================================================
"""mel-spectrogram extraction in Matcha-TTS"""
import logging
from librosa.filters import mel as librosa_mel_fn
import torch
import numpy as np
logger = logging.getLogger(__name__)
# NOTE: they decalred these global vars
mel_basis = {}
hann_window = {}
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
"""
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920
num_mels: 80
sampling_rate: 24000
hop_size: 480
win_size: 1920
fmin: 0
fmax: 8000
center: False
"""
def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920,
fmin=0, fmax=8000, center=False):
"""Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
Set default values according to Cosyvoice's config.
"""
if isinstance(y, np.ndarray):
y = torch.tensor(y).float()
if len(y.shape) == 1:
y = y[None, ]
# Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
min_val = torch.min(y)
max_val = torch.max(y)
if min_val < -1.0 or max_val > 1.0:
logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
================================================
FILE: src/chatterbox/models/s3gen/xvector.py
================================================
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torchaudio.compliance.kaldi as Kaldi
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def extract_feature(audio):
features = []
feature_times = []
feature_lengths = []
for au in audio:
feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature)
feature_times.append(au.shape[0])
feature_lengths.append(feature.shape[0])
# padding for batch inference
features_padded = pad_list(features, pad_value=0)
# features = torch.cat(features)
return features_padded, feature_lengths, feature_times
class BasicResBlock(torch.nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(
in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
)
self.bn1 = torch.nn.BatchNorm2d(planes)
self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = torch.nn.BatchNorm2d(planes)
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = torch.nn.Sequential(
torch.nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False,
),
torch.nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class FCM(torch.nn.Module):
def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.conv2 = torch.nn.Conv2d(
m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
)
self.bn2 = torch.nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
def get_nonlinear(config_str, channels):
nonlinear = torch.nn.Sequential()
for name in config_str.split("-"):
if name == "relu":
nonlinear.add_module("relu", torch.nn.ReLU(inplace=True))
elif name == "prelu":
nonlinear.add_module("prelu", torch.nn.PReLU(channels))
elif name == "batchnorm":
nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels))
elif name == "batchnorm_":
nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError("Unexpected module ({}).".format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(torch.nn.Module):
def forward(self, x):
return statistics_pooling(x)
class TDNNLayer(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str="batchnorm-relu",
):
super(TDNNLayer, self).__init__()
if padding < 0:
assert (
kernel_size % 2 == 1
), "Expect equal paddings, but got even kernel size ({})".format(kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
x = self.nonlinear(x)
return x
class CAMLayer(torch.nn.Module):
def __init__(
self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2
):
super(CAMLayer, self).__init__()
self.linear_local = torch.nn.Conv1d(
bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = torch.nn.ReLU(inplace=True)
self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
y = self.linear_local(x)
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
context = self.relu(self.linear1(context))
m = self.sigmoid(self.linear2(context))
return y * m
def seg_pooling(self, x, seg_len=100, stype="avg"):
if stype == "avg":
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
elif stype == "max":
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
else:
raise ValueError("Wrong segment pooling type.")
shape = seg.shape
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
seg = seg[..., : x.shape[-1]]
return seg
class CAMDenseTDNNLayer(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str="batchnorm-relu",
memory_efficient=False,
):
super(CAMDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format(
kernel_size
)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(
bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
x = self.cam_layer(self.nonlinear2(x))
return x
class CAMDenseTDNNBlock(torch.nn.ModuleList):
def __init__(
self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str="batchnorm-relu",
memory_efficient=False,
):
super(CAMDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = CAMDenseTDNNLayer(
in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient,
)
self.add_module("tdnnd%d" % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class TransitLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
x = self.linear(x)
return x
class DenseLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"):
super(DenseLayer, self).__init__()
self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
# @tables.register("model_classes", "CAMPPlus")
class CAMPPlus(torch.nn.Module):
def __init__(
self,
feat_dim=80,
embedding_size=192,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str="batchnorm-relu",
memory_efficient=True,
output_level="segment",
**kwargs,
):
super().__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
self.output_level = output_level
self.xvector = torch.nn.Sequential(
OrderedDict(
[
(
"tdnn",
TDNNLayer(
channels,
init_channels,
5,
stride=2,
dilation=1,
padding=-1,
config_str=config_str,
),
),
]
)
)
channels = init_channels
for i, (num_layers, kernel_size, dilation) in enumerate(
zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
):
block = CAMDenseTDNNBlock(
num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient,
)
self.xvector.add_module("block%d" % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
"transit%d" % (i + 1),
TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
)
channels //= 2
self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
if self.output_level == "segment":
self.xvector.add_module("stats", StatsPool())
self.xvector.add_module(
"dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
)
else:
assert (
self.output_level == "frame"
), "`output_level` should be set to 'segment' or 'frame'. "
for m in self.modules():
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
if self.output_level == "frame":
x = x.transpose(1, 2)
return x
def inference(self, audio_list):
speech, speech_lengths, speech_times = extract_feature(audio_list)
results = self.forward(speech.to(torch.float32))
return results
================================================
FILE: src/chatterbox/models/s3tokenizer/__init__.py
================================================
from .s3tokenizer import (
S3_SR,
S3_HOP,
S3_TOKEN_HOP,
S3_TOKEN_RATE,
SPEECH_VOCAB_SIZE,
S3Tokenizer,
)
SOS = SPEECH_VOCAB_SIZE
EOS = SPEECH_VOCAB_SIZE + 1
def drop_invalid_tokens(x):
"""Drop SoS and EoS"""
assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
if SOS in x:
s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
else:
s = 0
if EOS in x:
e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
else:
e = None
x = x[s: e]
return x
================================================
FILE: src/chatterbox/models/s3tokenizer/s3tokenizer.py
================================================
from typing import List, Tuple
import numpy as np
import librosa
import torch
import torch.nn.functional as F
from s3tokenizer.utils import padding
from s3tokenizer.model_v2 import (
S3TokenizerV2,
ModelConfig,
)
# Sampling rate of the inputs to S3TokenizerV2
S3_SR = 16_000
S3_HOP = 160 # 100 frames/sec
S3_TOKEN_HOP = 640 # 25 tokens/sec
S3_TOKEN_RATE = 25
SPEECH_VOCAB_SIZE = 6561
class S3Tokenizer(S3TokenizerV2):
"""
s3tokenizer.S3TokenizerV2 with the following changes:
- a more integrated `forward`
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
"""
ignore_state_dict_missing = ("_mel_filters", "window")
def __init__(
self,
name: str="speech_tokenizer_v2_25hz",
config: ModelConfig = ModelConfig()
):
super().__init__(name)
self.n_fft = 400
_mel_filters = librosa.filters.mel(
sr=S3_SR,
n_fft=self.n_fft,
n_mels=config.n_mels
)
self.register_buffer(
"_mel_filters",
torch.FloatTensor(_mel_filters),
)
self.register_buffer(
"window",
torch.hann_window(self.n_fft),
)
def pad(self, wavs, sr) -> List[torch.Tensor]:
"""
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
"""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
n_tokens = np.ceil(n_tokens)
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
intended_wav_len = int(intended_wav_len)
wav = torch.nn.functional.pad(
wav,
(0, intended_wav_len - wav.shape[-1]),
mode="constant",
value=0
)
processed_wavs.append(wav)
return processed_wavs
def _prepare_audio(self, wavs):
"""Prepare a list of audios for s3tokenizer processing."""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
processed_wavs.append(wav)
return processed_wavs
@torch.no_grad()
def forward(
self,
wavs: torch.Tensor,
accelerator: 'Accelerator'=None,
max_len: int=None,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""
NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
Args
----
- `wavs`: 16 kHz speech audio
- `max_len` max length to truncate the output sequence to (25 token/sec).
NOTE: please pad the waveform if longer sequence is needed.
"""
processed_wavs = self._prepare_audio(wavs)
mels, mel_lens = [], []
for wav in processed_wavs:
wav = wav.to(self.device)
mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
if max_len is not None:
mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
mels.append(mel.squeeze(0))
mels, mel_lens = padding(mels)
if accelerator is None:
tokenizer = self
else:
tokenizer = accelerator.unwrap_model(self)
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
return (
speech_tokens.long().detach(),
speech_token_lens.long().detach(),
)
def log_mel_spectrogram(
self,
audio: torch.Tensor,
padding: int = 0,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: torch.Tensor, shape = (*)
The path to audio or either a NumPy array or Tensor containing the
audio waveform in 16 kHz
padding: int
Number of zero samples to pad to the right
Returns
-------
torch.Tensor, shape = (128, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
audio = audio.to(self.device)
if padding > 0:
audio = F.pad(audio, (0, padding))
stft = torch.stft(
audio, self.n_fft, S3_HOP,
window=self.window.to(self.device),
return_complex=True
)
magnitudes = stft[..., :-1].abs()**2
mel_spec = self._mel_filters.to(self.device) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
================================================
FILE: src/chatterbox/models/t3/__init__.py
================================================
from .t3 import T3
================================================
FILE: src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
================================================
# Copyright (c) 2025 Resemble AI
# Author: John Meade, Jeremy Hsu
# MIT License
import logging
import torch
from dataclasses import dataclass
from types import MethodType
logger = logging.getLogger(__name__)
LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
@dataclass
class AlignmentAnalysisResult:
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
false_start: bool
# was this frame detected as being part of a long tail with potential hallucinations?
long_tail: bool
# was this frame detected as repeating existing text content?
repetition: bool
# was the alignment position of this frame too far from the previous frame?
discontinuity: bool
# has inference reached the end of the text tokens? eg, this remains false if inference stops early
complete: bool
# approximate position in the text token sequence. Can be used for generating online timestamps.
position: int
class AlignmentStreamAnalyzer:
def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
"""
Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
activation maps. This module exploits this to perform online integrity checks which streaming.
A hook is injected into the specified attention layer, and heuristics are used to determine alignment
position, repetition, etc.
NOTE: currently requires no queues.
"""
# self.queue = queue
self.text_tokens_slice = (i, j) = text_tokens_slice
self.eos_idx = eos_idx
self.alignment = torch.zeros(0, j-i)
# self.alignment_bin = torch.zeros(0, j-i)
self.curr_frame_pos = 0
self.text_position = 0
self.started = False
self.started_at = None
self.complete = False
self.completed_at = None
# Track generated tokens for repetition detection
self.generated_tokens = []
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
# using it for all layers slows things down too much. We can apply it to just one layer
# by intercepting the kwargs and adding a forward hook (credit: jrm)
self.last_aligned_attns = []
for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
self.last_aligned_attns += [None]
self._add_attention_spy(tfmr, i, layer_idx, head_idx)
def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
"""
Adds a forward hook to a specific attention layer to collect outputs.
"""
def attention_forward_hook(module, input, output):
"""
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
NOTE:
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
"""
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
target_layer = tfmr.layers[layer_idx].self_attn
# Register hook and store the handle
target_layer.register_forward_hook(attention_forward_hook)
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
self.original_output_attentions = tfmr.config.output_attentions
self.original_attn_implementation = getattr(tfmr.config, '_attn_implementation', None)
if getattr(tfmr.config, '_attn_implementation', None) == 'sdpa':
tfmr.config._attn_implementation = 'eager'
tfmr.config.output_attentions = True
def step(self, logits, next_token=None):
"""
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
"""
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
i, j = self.text_tokens_slice
if self.curr_frame_pos == 0:
# first chunk has conditioning info, text tokens, and BOS token
A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
else:
# subsequent chunks have 1 frame due to KV-caching
A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
# TODO: monotonic masking; could have issue b/c spaces are often skipped.
A_chunk[:, self.curr_frame_pos + 1:] = 0
self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
A = self.alignment
T, S = A.shape
# update position
cur_text_posn = A_chunk[-1].argmax()
discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
if not discontinuity:
self.text_position = cur_text_posn
# Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
# To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
# and there are some strong activations in the first few tokens.
false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
self.started = not false_start
if self.started and self.started_at is None:
self.started_at = T
# Is generation likely complete?
self.complete = self.complete or self.text_position >= S - 3
if self.complete and self.completed_at is None:
self.completed_at = T
# NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
# NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
last_text_token_duration = A[15:, -3:].sum()
# Activations for the final token that last too long are likely hallucinations.
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
# Track generated tokens for repetition detection
if next_token is not None:
# Convert tensor to scalar if needed
if isinstance(next_token, torch.Tensor):
token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
else:
token_id = next_token
self.generated_tokens.append(token_id)
# Keep only last 8 tokens to prevent memory issues
if len(self.generated_tokens) > 8:
self.generated_tokens = self.generated_tokens[-8:]
# Check for excessive token repetition (3x same token in a row)
token_repetition = (
# self.complete and
len(self.generated_tokens) >= 3 and
len(set(self.generated_tokens[-2:])) == 1
)
if token_repetition:
repeated_token = self.generated_tokens[-1]
logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
# Suppress EoS to prevent early termination
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
logits[..., self.eos_idx] = -2**15
# If a bad ending is detected, force emit EOS by modifying logits
# NOTE: this means logits may be inconsistent with latents!
if long_tail or alignment_repetition or token_repetition:
logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
# (±2**15 is safe for all dtypes >= 16bit)
logits = -(2**15) * torch.ones_like(logits)
logits[..., self.eos_idx] = 2**15
self.curr_frame_pos += 1
return logits
================================================
FILE: src/chatterbox/models/t3/inference/t3_hf_backend.py
================================================
from typing import Optional
import torch
from torch import nn as nn
from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
"""
Override some HuggingFace interface methods so we can use the standard `generate` method with our
custom embedding / logit layers.
NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
"""
def __init__(
self,
config: LlamaConfig,
llama: LlamaModel,
*,
speech_enc,
speech_head,
latents_queue=None,
logits_queue=None,
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
):
super().__init__(config)
self.model = llama
self.speech_enc = speech_enc
self.speech_head = speech_head
self._added_cond = False
self.alignment_stream_analyzer = alignment_stream_analyzer
@torch.inference_mode()
def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
# This argument was introduced in some recent version of transformers (>=4.29.1)
cache_position=None
):
"""
This is a method used by huggingface's generate() method.
Overridden here to apply our custom speech token embedding layer.
:param input_ids: (B, S) int64 tensors of input tokens.
:param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to )
"""
# Make use of the kv cache: only the last input ID is new, we trim away all the ones before
if not use_cache:
past_key_values = None
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# custom speech token embedding layer
inputs_embeds = self.speech_enc(input_ids)
# prefix decoder conditioning if applicable
if not self._added_cond:
assert past_key_values is not None # should be first step
if decoder_cond.size(0) != inputs_embeds.size(0):
decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
self._added_cond = True
return {
"inputs_embeds": inputs_embeds,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@torch.inference_mode()
def forward(
self,
inputs_embeds: torch.Tensor,
past_key_values: Optional[torch.Tensor]=None,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
):
"""
This is a method used by huggingface's generate() method.
Overridden here to apply our custom layer norm and speech logit projection layers.
:param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
S should be 1.
"""
is_large_input = inputs_embeds.size(1) != 1
has_cache = past_key_values is not None and len(past_key_values) > 0
assert not (is_large_input and has_cache)
assert return_dict
assert output_hidden_states
tfmr_out = self.model(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
logits = self.speech_head(hidden_states)
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
# NOTE: hallucination handler may modify logits to force emit an EOS token
# logits = self.alignment_stream_analyzer.step(logits)
return CausalLMOutputWithCrossAttentions(
logits=logits,
past_key_values=tfmr_out.past_key_values,
hidden_states=tfmr_out.hidden_states,
attentions=tfmr_out.attentions,
)
================================================
FILE: src/chatterbox/models/t3/llama_configs.py
================================================
LLAMA_520M_CONFIG_DICT = dict(
# Arbitrary small number that won't cause problems when loading.
# These param are unused due to custom input layers.
vocab_size=8,
# default params needed for loading most pretrained 1B weights
max_position_embeddings=131072,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=30,
num_attention_heads=16,
attn_implementation="sdpa",
head_dim=64,
tie_word_embeddings=False,
hidden_act="silu",
attention_bias=False,
attention_dropout=0.0,
initializer_range=0.02,
mlp_bias=False,
model_type="llama",
num_key_value_heads=16,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=dict(
factor=8.0,
high_freq_factor=4.0,
low_freq_factor=1.0,
original_max_position_embeddings=8192,
rope_type="llama3"
),
rope_theta=500000.0,
torch_dtype="bfloat16",
use_cache=True,
)
GPT2_MEDIUM_CONFIG = {
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 8196,
"n_embd": 1024,
"hidden_size": 1024,
"n_head": 16,
"n_layer": 24,
"n_positions": 8196,
"n_special": 0,
"predict_special_tokens": True,
"resid_pdrop": 0.1,
"summary_activation": None,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": True,
"summary_type": "cls_index",
"summary_use_proj": True,
"task_specific_params": {
"text-generation": {
"do_sample": True,
"max_length": 50
}
},
"vocab_size": 50276,
}
LLAMA_CONFIGS = {
"Llama_520M": LLAMA_520M_CONFIG_DICT,
"GPT2_medium": GPT2_MEDIUM_CONFIG,
}
================================================
FILE: src/chatterbox/models/t3/modules/cond_enc.py
================================================
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn, Tensor
from .perceiver import Perceiver
from .t3_config import T3Config
@dataclass
class T3Cond:
"""
Dataclass container for most / all conditioning info.
TODO: serialization methods aren't used, keeping them around for convenience
"""
speaker_emb: Tensor
clap_emb: Optional[Tensor] = None
cond_prompt_speech_tokens: Optional[Tensor] = None
cond_prompt_speech_emb: Optional[Tensor] = None
emotion_adv: Optional[Tensor] = 0.5
def to(self, *, device=None, dtype=None):
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
for k, v in self.__dict__.items():
if torch.is_tensor(v):
is_fp = type(v.view(-1)[0].item()) is not int
setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
return self
def save(self, fpath):
torch.save(self.__dict__, fpath)
@staticmethod
def load(fpath, map_location="cpu"):
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return T3Cond(**kwargs)
class T3CondEnc(nn.Module):
"""
Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
"""
def __init__(self, hp: T3Config):
super().__init__()
self.hp = hp
if hp.encoder_type == "voice_encoder":
self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
else:
raise NotImplementedError(str(hp.encoder_type))
# emotion adv
self.emotion_adv_fc = None
if hp.emotion_adv:
self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
# perceiver resampler
self.perceiver = None
if hp.use_perceiver_resampler:
self.perceiver = Perceiver()
def forward(self, cond: T3Cond):
# Validate
assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
"no embeddings for cond_prompt_speech_tokens"
# Speaker embedding projection
cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
# TODO CLAP
assert cond.clap_emb is None, "clap_embed not implemented"
cond_clap = empty # (B, 0, dim)
# Cond prompt
cond_prompt_speech_emb = cond.cond_prompt_speech_emb
if cond_prompt_speech_emb is None:
cond_prompt_speech_emb = empty # (B, 0, dim)
elif self.hp.use_perceiver_resampler:
cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
# Emotion Adv: must provide a value if this model uses emotion conditioning
cond_emotion_adv = empty # (B, 0, dim)
if self.hp.emotion_adv:
assert cond.emotion_adv is not None
cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
# Concat and return
cond_embeds = torch.cat((
cond_spkr,
cond_clap,
cond_prompt_speech_emb,
cond_emotion_adv,
), dim=1)
return cond_embeds
================================================
FILE: src/chatterbox/models/t3/modules/learned_pos_emb.py
================================================
from typing import Union
import torch
from torch import nn, Tensor
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
"""
Returns positional embeddings for index 0 up to the length of x
"""
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
"""
Args:
idx: scalar int or an integer tensor of shape (T,) or (B, T)
Returns:
positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
"""
device = self.emb.weight.device
idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
idx = torch.atleast_2d(idx)
assert idx.ndim == 2
return self.emb(idx) # (B, T, dim)
================================================
FILE: src/chatterbox/models/t3/modules/perceiver.py
================================================
# Copyright (c) 2025 Resemble AI
# Author: Manmay Nakhashi
# MIT License
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
class RelativePositionBias(nn.Module):
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)
class AttentionQKV(nn.Module):
def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = scale if scale is not None else head_dim ** -0.5
self.flash = flash
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(dropout_rate)
self.flash_config = self.setup_flash_config() if flash else None
def setup_flash_config(self):
# Setup flash attention configuration
flash_config = {
'enable_flash': True,
'enable_math': True,
'enable_mem_efficient': True
}
return flash_config
def forward(self, q, k, v, mask=None):
q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
if self.flash:
out = self.flash_attention(q, k, v, mask=mask)
else:
out = self.scaled_dot_product_attention(q, k, v, mask=mask)
return self.combine_heads(out)
def scaled_dot_product_attention(self, q, k, v, mask=None):
sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
if mask is not None:
sim = sim.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(sim, dim=-1)
attn = self.dropout(attn)
return torch.einsum("bhts,bhls->bhlt", attn, v)
def flash_attention(self, q, k, v, mask=None):
config = self.flash_config if self.flash_config else {}
with torch.backends.cuda.sdp_kernel(**config):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout_rate if self.training else 0.
)
return out
def split_heads(self, x):
bs, length, _ = x.shape
x = x.view(bs, length, self.n_heads, self.head_dim)
return x.permute(0, 2, 1, 3)
def combine_heads(self, x):
bs, _, length, _ = x.shape
x = x.permute(0, 2, 1, 3).contiguous()
return x.view(bs, length, -1)
class AttentionBlock2(nn.Module):
"""
An attention block that allows spatial positions to attend to each other,
using AttentionQKV and separate linear transformations for Q, K, and V.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
relative_pos_embeddings=False,
flash_attention=True,
dropout_rate=0.2,
scale=None
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.LayerNorm(channels)
# Separate linear layers for Q, K, and V
self.to_q = nn.Linear(channels, channels)
self.to_k = nn.Linear(channels, channels)
self.to_v = nn.Linear(channels, channels)
self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
self.proj_out = nn.Linear(channels, channels)
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
else:
self.relative_pos_embeddings = None
def forward(self, x1, x2, mask=None):
b1, c1, *spatial1 = x1.shape
b2, c2, *spatial2 = x2.shape
x1_norm = self.norm(x1)
x2_norm = self.norm(x2)
q = self.to_q(x1_norm)
k = self.to_k(x2_norm)
v = self.to_v(x2_norm)
h = self.attention(q, k, v, mask=mask)
h = self.proj_out(h)
return (x1 + h).reshape(b1, c1, *spatial1)
class Perceiver(nn.Module):
"""Inspired by https://arxiv.org/abs/2103.03206"""
def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
"""
Initialize the perceiver module.
:param pre_attention_query_token: Number of query tokens for pre-attention
:param pre_attention_query_size: Size of each query token
:param embedding_dim: Dimension of the embedding space
:param num_attn_heads: Number of attention heads
"""
super().__init__()
# Initialize the pre-attention query parameter
self.pre_attention_query = torch.nn.Parameter(
torch.empty(1, pre_attention_query_token, pre_attention_query_size)
)
# Calculate the variance for uniform initialization
query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
# Initialize the pre-attention query with uniform distribution
self.pre_attention_query.data.uniform_(-query_variance, query_variance)
# Initialize the attention block
self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
def forward(self, h):
"""
Forward pass of the perceiver module.
:param h: Input tensor
:return: Output after applying attention mechanisms
"""
# Expand the pre-attention query to match the batch size of the input
query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
# Apply the first attention mechanism (cross-attention)
pre_att = self.attn(query_, h)
# Apply the second attention mechanism (self-attention)
attn = self.attn(pre_att, pre_att)
return attn
================================================
FILE: src/chatterbox/models/t3/modules/t3_config.py
================================================
from ..llama_configs import LLAMA_CONFIGS
class T3Config:
def __init__(self, text_tokens_dict_size=704):
self.start_text_token = 255
self.stop_text_token = 0
self.text_tokens_dict_size = text_tokens_dict_size
self.max_text_tokens = 2048
self.start_speech_token = 6561
self.stop_speech_token = 6562
self.speech_tokens_dict_size = 8194
self.max_speech_tokens = 4096
self.llama_config_name = "Llama_520M"
self.input_pos_emb = "learned"
self.speech_cond_prompt_len = 150
self.encoder_type = "voice_encoder"
self.speaker_embed_size = 256
self.use_perceiver_resampler = True
self.emotion_adv = True
@property
def n_channels(self):
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
@property
def is_multilingual(self):
return self.text_tokens_dict_size == 2454
@classmethod
def english_only(cls):
"""Create configuration for English-only TTS model."""
return cls(text_tokens_dict_size=704)
@classmethod
def multilingual(cls):
"""Create configuration for multilingual TTS model."""
return cls(text_tokens_dict_size=2454)
================================================
FILE: src/chatterbox/models/t3/t3.py
================================================
# Copyright (c) 2025 Resemble AI
# MIT License
import logging
from typing import Union, Optional, List
logger = logging.getLogger(__name__)
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
)
from .modules.learned_pos_emb import LearnedPositionEmbeddings
from .modules.cond_enc import T3CondEnc, T3Cond
from .modules.t3_config import T3Config
from .llama_configs import LLAMA_CONFIGS
from .inference.t3_hf_backend import T3HuggingfaceBackend
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
from ..utils import AttrDict
logger = logging.getLogger(__name__)
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
B = text_tokens.size(0)
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
class T3(nn.Module):
"""
Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
* tokenization, including start / stop tokens are always added externally to this class
* conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
* careful! this class assumes relative positional encoding -- with absolute PE, we would at
least want to reset the position to 0 when speech tokens begin, and optionally use a
different PE embedding space for speech.
"""
def __init__(self, hp=None):
if hp is None:
hp = T3Config.english_only()
super().__init__()
self.hp = hp
config_dict = LLAMA_CONFIGS[hp.llama_config_name]
self.is_gpt = config_dict.get("model_type") == "gpt2"
if self.is_gpt:
self.cfg = GPT2Config(**config_dict)
self.tfmr = GPT2Model(self.cfg)
else:
self.cfg = LlamaConfig(**config_dict)
self.tfmr = LlamaModel(self.cfg)
self.dim = self.cfg.hidden_size
self.deepspeed_patch_applied = False
# conditioning / embedding
self.cond_enc = T3CondEnc(hp)
self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
# custom position embedding
self.text_pos_emb = None
self.speech_pos_emb = None
if hp.input_pos_emb == "learned":
max_text_seq_len = hp.max_text_tokens + 2
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
max_mel_seq_len = hp.max_speech_tokens + 2 + 2
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
# logit projection
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
self.compiled = False
@property
def device(self):
return self.speech_head.weight.device
def prepare_conditioning(self, t3_cond: T3Cond):
"""
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
"""
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
if not self.is_gpt:
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
return self.cond_enc(t3_cond) # (B, len_cond, dim)
def prepare_input_embeds(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
speech_tokens: torch.LongTensor,
cfg_weight: float = 0.0,
):
# prepare input embeddings (skip backbone tranformer embeddings)
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
if cfg_weight > 0.0 and not self.is_gpt:
text_emb[1].zero_() # CFG uncond
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
if self.hp.input_pos_emb == "learned":
text_emb = text_emb + self.text_pos_emb(text_tokens)
speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
len_cond = cond_emb.size(1)
if cond_emb.size(0) != text_emb.size(0):
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
# concat
embeds = torch.stack([
torch.cat((ce, te, se))
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
]) # (B, length, dim)
return embeds, len_cond
def forward(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
text_token_lens: torch.LongTensor,
speech_tokens: torch.LongTensor,
speech_token_lens: torch.LongTensor,
training=False,
):
_ensure_BOT_EOT(text_tokens, self.hp)
# prepare custom input embeds
embeds, len_cond = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=speech_tokens,
)
# backbone tranformer forward
tfmr_out = self.tfmr.forward(
input_ids=None,
# position_ids=position_ids, # TODO? ROPE should be fine?
inputs_embeds=embeds,
output_hidden_states=True,
return_dict=True,
use_cache=(not training),
)
hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
# post-processing: splice out text and speech parts of hidden states
len_text = text_tokens.size(1)
len_speech = speech_tokens.size(1)
B, _, dim = hidden_states.shape
device, dtype = hidden_states.device, hidden_states.dtype
text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
ttl, stl = text_token_lens, speech_token_lens
for i in range(B):
text_end = len_cond + ttl[i].item()
speech_start = len_cond + text_tokens.size(1)
speech_end = speech_start + stl[i].item()
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
# logit projection
text_logits = self.text_head(text_latents)
speech_logits = self.speech_head(speech_latents)
return AttrDict(
text_logits=text_logits,
text_latents=text_latents,
speech_logits=speech_logits,
speech_latents=speech_latents,
hidden_states=hidden_states,
)
def loss(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
text_token_lens: torch.LongTensor,
speech_tokens: torch.LongTensor,
speech_token_lens: torch.LongTensor,
):
"training method"
len_text = text_tokens.size(1)
len_speech = speech_tokens.size(1)
assert len_text == text_token_lens.max()
assert len_speech == speech_token_lens.max()
out = self.forward(
t3_cond=t3_cond,
text_tokens=text_tokens,
text_token_lens=text_token_lens,
speech_tokens=speech_tokens,
speech_token_lens=speech_token_lens,
training=True,
) # (B, seq, vocab_size)
# Calc CCE losses
IGNORE_ID = -100
device = out.text_logits.device
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
return loss_text, loss_speech
@torch.inference_mode()
def inference(
self,
*,
t3_cond: T3Cond,
text_tokens: Tensor,
initial_speech_tokens: Optional[Tensor]=None,
# misc conditioning
prepend_prompt_speech_tokens: Optional[Tensor]=None,
# HF generate args
num_return_sequences=1,
max_new_tokens=None,
stop_on_eos=True,
do_sample=True,
temperature=0.8,
top_p=0.95,
min_p=0.05,
length_penalty=1.0,
repetition_penalty=1.2,
cfg_weight=0.5,
):
"""
Args:
text_tokens: a 1D (unbatched) or 2D (batched) tensor.
"""
# Validate / sanitize inputs
assert prepend_prompt_speech_tokens is None, "not implemented"
_ensure_BOT_EOT(text_tokens, self.hp)
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
# Default initial speech to a single start-of-speech token
if initial_speech_tokens is None:
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
# Prepare custom input embeds
embeds, len_cond = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=initial_speech_tokens,
cfg_weight=cfg_weight,
)
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
# Note the llama-specific logic. Other tfmr types can be added later.
self.compiled = False
# TODO? synchronize the expensive compile function
# with self.compile_lock:
if not self.compiled:
# Default to None for English models, only create for multilingual
alignment_stream_analyzer = None
if self.hp.is_multilingual:
alignment_stream_analyzer = AlignmentStreamAnalyzer(
self.tfmr,
None,
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
alignment_layer_idx=9, # TODO: hparam or something?
eos_idx=self.hp.stop_speech_token,
)
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
patched_model = T3HuggingfaceBackend(
config=self.cfg,
llama=self.tfmr,
speech_enc=self.speech_emb,
speech_head=self.speech_head,
alignment_stream_analyzer=alignment_stream_analyzer,
)
self.patched_model = patched_model
self.compiled = True
# # Run normal generate method, which calls our custom extended methods
# return self.patched_model.generate(
# inputs=initial_speech_tokens,
# decoder_cond=embeds,
# bos_token_id=self.hp.start_speech_token,
# eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
# pad_token_id=self.hp.stop_speech_token,
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
# num_return_sequences=num_return_sequences,
# temperature=temperature,
# min_p=min_p,
# length_penalty=length_penalty,
# repetition_penalty=repetition_penalty,
# do_sample=do_sample,
# # cache_implementation=None if not self.compiled else "static",
# )
device = embeds.device
bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
# batch_size=2 for CFG
bos_embed = torch.cat([bos_embed, bos_embed])
# Combine condition and BOS token for the initial input
inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
# Track generated token ids; start with the BOS token.
generated_ids = bos_token.clone()
predicted = [] # To store the predicted tokens
# Instantiate the logits processors.
top_p_warper = TopPLogitsWarper(top_p=top_p)
min_p_warper = MinPLogitsWarper(min_p=min_p)
top_p_warper = TopPLogitsWarper(top_p=top_p)
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
# ---- Initial Forward Pass (no kv_cache yet) ----
output = self.patched_model(
inputs_embeds=inputs_embeds,
past_key_values=None,
use_cache=True,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
)
# Initialize kv_cache with the full context.
past = output.past_key_values
# ---- Generation Loop using kv_cache ----
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
logits_step = output.logits[:, -1, :]
# CFG combine → (1, V)
cond = logits_step[0:1, :]
uncond = logits_step[1:2, :]
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
logits = cond + cfg * (cond - uncond)
# Apply alignment stream analyzer integrity checks
if self.patched_model.alignment_stream_analyzer is not None:
if logits.dim() == 1: # guard in case something upstream squeezed
logits = logits.unsqueeze(0) # (1, V)
# Pass the last generated token for repetition tracking
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
# Apply repetition penalty
ids_for_proc = generated_ids[:1, ...] # batch = 1
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
# Apply temperature scaling.
if temperature != 1.0:
logits = logits / temperature
# Apply min_p and top_p filtering
logits = min_p_warper(ids_for_proc, logits)
logits = top_p_warper(ids_for_proc, logits)
# Convert logits to probabilities and sample the next token.
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
predicted.append(next_token)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Check for EOS token.
if next_token.view(-1) == self.hp.stop_speech_token:
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
break
# Get embedding for the new token.
next_token_embed = self.speech_emb(next_token)
next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
# For CFG
next_token_embed = torch.cat([next_token_embed, next_token_embed])
# Forward pass with only the new token and the cached past.
output = self.patched_model(
inputs_embeds=next_token_embed,
past_key_values=past,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
)
# Update the kv_cache.
past = output.past_key_values
# Concatenate all predicted tokens along the sequence dimension.
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
return predicted_tokens
@torch.inference_mode()
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
max_gen_len=1000):
logits_processors = LogitsProcessorList()
if temperature > 0 and temperature != 1.0:
logits_processors.append(TemperatureLogitsWarper(temperature))
if top_k > 0:
logits_processors.append(TopKLogitsWarper(top_k))
if top_p < 1.0:
logits_processors.append(TopPLogitsWarper(top_p))
if repetition_penalty != 1.0:
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
embeds, _ = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=speech_start_token,
cfg_weight=0.0,
)
generated_speech_tokens = []
llm_outputs = self.tfmr(
inputs_embeds=embeds,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_hidden = hidden_states[:, -1:]
speech_logits = self.speech_head(speech_hidden)
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
for _ in tqdm(range(max_gen_len)):
current_speech_embed = self.speech_emb(current_speech_token)
llm_outputs = self.tfmr(
inputs_embeds=current_speech_embed,
past_key_values=past_key_values,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_logits = self.speech_head(hidden_states)
input_ids = torch.cat(generated_speech_tokens, dim=1)
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
if torch.all(processed_logits == -float("inf")):
print("Warning: All logits are -inf")
break
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
if torch.all(next_speech_token == self.hp.stop_speech_token):
break
all_tokens = torch.cat(generated_speech_tokens, dim=1)
# Remove EOS token if present
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
all_tokens = all_tokens[:, :-1]
return all_tokens
================================================
FILE: src/chatterbox/models/tokenizers/__init__.py
================================================
from .tokenizer import EnTokenizer, MTLTokenizer
================================================
FILE: src/chatterbox/models/tokenizers/tokenizer.py
================================================
import logging
import json
import torch
from pathlib import Path
from unicodedata import category, normalize
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Special tokens
SOT = "[START]"
EOT = "[STOP]"
UNK = "[UNK]"
SPACE = "[SPACE]"
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
logger = logging.getLogger(__name__)
class EnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str):
text_tokens = self.encode(text)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
txt = txt.replace(' ', SPACE)
code = self.tokenizer.encode(txt)
ids = code.ids
return ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
# Model repository
REPO_ID = "ResembleAI/chatterbox"
# Global instances for optional dependencies
_kakasi = None
_dicta = None
_russian_stresser = None
def is_kanji(c: str) -> bool:
"""Check if character is kanji."""
return 19968 <= ord(c) <= 40959
def is_katakana(c: str) -> bool:
"""Check if character is katakana."""
return 12449 <= ord(c) <= 12538
def hiragana_normalize(text: str) -> str:
"""Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
global _kakasi
try:
if _kakasi is None:
import pykakasi
_kakasi = pykakasi.kakasi()
result = _kakasi.convert(text)
out = []
for r in result:
inp = r['orig']
hira = r["hira"]
# Any kanji in the phrase
if any([is_kanji(c) for c in inp]):
if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
hira = " " + hira
out.append(hira)
# All katakana
elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
out.append(r['orig'])
else:
out.append(inp)
normalized_text = "".join(out)
# Decompose Japanese characters for tokenizer compatibility
import unicodedata
normalized_text = unicodedata.normalize('NFKD', normalized_text)
return normalized_text
except ImportError:
logger.warning("pykakasi not available - Japanese text processing skipped")
return text
def add_hebrew_diacritics(text: str) -> str:
"""Hebrew text normalization: adds diacritics to Hebrew text."""
global _dicta
try:
if _dicta is None:
from dicta_onnx import Dicta
_dicta = Dicta()
return _dicta.add_diacritics(text)
except ImportError:
logger.warning("dicta_onnx not available - Hebrew text processing skipped")
return text
except Exception as e:
logger.warning(f"Hebrew diacritization failed: {e}")
return text
def korean_normalize(text: str) -> str:
"""Korean text normalization: decompose syllables into Jamo for tokenization."""
def decompose_hangul(char):
"""Decompose Korean syllable into Jamo components."""
if not ('\uac00' <= char <= '\ud7af'):
return char
# Hangul decomposition formula
base = ord(char) - 0xAC00
initial = chr(0x1100 + base // (21 * 28))
medial = chr(0x1161 + (base % (21 * 28)) // 28)
final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
return initial + medial + final
# Decompose syllables and normalize punctuation
result = ''.join(decompose_hangul(char) for char in text)
return result.strip()
class ChineseCangjieConverter:
"""Converts Chinese characters to Cangjie codes for tokenization."""
def __init__(self, model_dir=None):
self.word2cj = {}
self.cj2word = {}
self.segmenter = None
self._load_cangjie_mapping(model_dir)
self._init_segmenter()
def _load_cangjie_mapping(self, model_dir=None):
"""Load Cangjie mapping from HuggingFace model repository."""
try:
cangjie_file = hf_hub_download(
repo_id=REPO_ID,
filename="Cangjie5_TC.json",
cache_dir=model_dir
)
with open(cangjie_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
for entry in data:
word, code = entry.split("\t")[:2]
self.word2cj[word] = code
if code not in self.cj2word:
self.cj2word[code] = [word]
else:
self.cj2word[code].append(word)
except Exception as e:
logger.warning(f"Could not load Cangjie mapping: {e}")
def _init_segmenter(self):
"""Initialize pkuseg segmenter."""
try:
from spacy_pkuseg import pkuseg
self.segmenter = pkuseg()
except ImportError:
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
self.segmenter = None
def _cangjie_encode(self, glyph: str):
"""Encode a single Chinese glyph to Cangjie code."""
normed_glyph = glyph
code = self.word2cj.get(normed_glyph, None)
if code is None: # e.g. Japanese hiragana
return None
index = self.cj2word[code].index(normed_glyph)
index = str(index) if index > 0 else ""
return code + str(index)
def __call__(self, text):
"""Convert Chinese characters in text to Cangjie tokens."""
output = []
if self.segmenter is not None:
segmented_words = self.segmenter.cut(text)
full_text = " ".join(segmented_words)
else:
full_text = text
for t in full_text:
if category(t) == "Lo":
cangjie = self._cangjie_encode(t)
if cangjie is None:
output.append(t)
continue
code = []
for c in cangjie:
code.append(f"[cj_{c}]")
code.append("[cj_.]")
code = "".join(code)
output.append(code)
else:
output.append(t)
return "".join(output)
def add_russian_stress(text: str) -> str:
"""Russian text normalization: adds stress marks to Russian text."""
global _russian_stresser
try:
if _russian_stresser is None:
from russian_text_stresser.text_stresser import RussianTextStresser
_russian_stresser = RussianTextStresser()
return _russian_stresser.stress_text(text)
except ImportError:
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
return text
except Exception as e:
logger.warning(f"Russian stress labeling failed: {e}")
return text
class MTLTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
model_dir = Path(vocab_file_path).parent
self.cangjie_converter = ChineseCangjieConverter(model_dir)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
"""
Text preprocessor that handles lowercase conversion and NFKD normalization.
"""
preprocessed_text = raw_text
if lowercase:
preprocessed_text = preprocessed_text.lower()
if nfkd_normalize:
preprocessed_text = normalize("NFKD", preprocessed_text)
return preprocessed_text
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
# Language-specific text processing
if language_id == 'zh':
txt = self.cangjie_converter(txt)
elif language_id == 'ja':
txt = hiragana_normalize(txt)
elif language_id == 'he':
txt = add_hebrew_diacritics(txt)
elif language_id == 'ko':
txt = korean_normalize(txt)
elif language_id == 'ru':
txt = add_russian_stress(txt)
# Prepend language token
if language_id:
txt = f"[{language_id.lower()}]{txt}"
txt = txt.replace(' ', SPACE)
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
return txt
================================================
FILE: src/chatterbox/models/utils.py
================================================
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
================================================
FILE: src/chatterbox/models/voice_encoder/__init__.py
================================================
from .voice_encoder import VoiceEncoder, VoiceEncConfig
================================================
FILE: src/chatterbox/models/voice_encoder/config.py
================================================
class VoiceEncConfig:
num_mels = 40
sample_rate = 16000
speaker_embed_size = 256
ve_hidden_size = 256
flatten_lstm_params = False
n_fft = 400
hop_size = 160
win_size = 400
fmax = 8000
fmin = 0
preemphasis = 0.
mel_power = 2.0
mel_type = "amp"
normalized_mels = False
ve_partial_frames = 160
ve_final_relu = True
stft_magnitude_min = 1e-4
================================================
FILE: src/chatterbox/models/voice_encoder/melspec.py
================================================
from functools import lru_cache
from scipy import signal
import numpy as np
import librosa
@lru_cache()
def mel_basis(hp):
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(
sr=hp.sample_rate,
n_fft=hp.n_fft,
n_mels=hp.num_mels,
fmin=hp.fmin,
fmax=hp.fmax) # -> (nmel, nfreq)
def preemphasis(wav, hp):
assert hp.preemphasis != 0
wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
wav = np.clip(wav, -1, 1)
return wav
def melspectrogram(wav, hp, pad=True):
# Run through pre-emphasis
if hp.preemphasis > 0:
wav = preemphasis(wav, hp)
assert np.abs(wav).max() - 1 < 1e-07
# Do the stft
spec_complex = _stft(wav, hp, pad=pad)
# Get the magnitudes
spec_magnitudes = np.abs(spec_complex)
if hp.mel_power != 1.0:
spec_magnitudes **= hp.mel_power
# Get the mel and convert magnitudes->db
mel = np.dot(mel_basis(hp), spec_magnitudes)
if hp.mel_type == "db":
mel = _amp_to_db(mel, hp)
# Normalise the mel from db to 0,1
if hp.normalized_mels:
mel = _normalize(mel, hp).astype(np.float32)
assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
return mel # (M, T)
def _stft(y, hp, pad=True):
# NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
# historical consistency and streaming-version consistency
return librosa.stft(
y,
n_fft=hp.n_fft,
hop_length=hp.hop_size,
win_length=hp.win_size,
center=pad,
pad_mode="reflect",
)
def _amp_to_db(x, hp):
return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _normalize(s, hp, headroom_db=15):
min_level_db = 20 * np.log10(hp.stft_magnitude_min)
s = (s - min_level_db) / (-min_level_db + headroom_db)
return s
================================================
FILE: src/chatterbox/models/voice_encoder/voice_encoder.py
================================================
# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
# MIT License
from typing import List, Union, Optional
import numpy as np
from numpy.lib.stride_tricks import as_strided
import librosa
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from .config import VoiceEncConfig
from .melspec import melspectrogram
def pack(arrays, seq_len: int=None, pad_value=0):
"""
Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
shape (B, T, ...) by padding each individual array on the right.
:param arrays: a list of array-like objects of matching shapes except for the first axis.
:param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
minimum. Will default to that value if None.
:param pad_value: the value to pad the arrays with.
:return: a (B, T, ...) tensor
"""
if seq_len is None:
seq_len = max(len(array) for array in arrays)
else:
assert seq_len >= max(len(array) for array in arrays)
# Convert lists to np.array
if isinstance(arrays[0], list):
arrays = [np.array(array) for array in arrays]
# Convert to tensor and handle device
device = None
if isinstance(arrays[0], torch.Tensor):
tensors = arrays
device = tensors[0].device
else:
tensors = [torch.as_tensor(array) for array in arrays]
# Fill the packed tensor with the array data
packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)
for i, tensor in enumerate(tensors):
packed_tensor[i, :tensor.size(0)] = tensor
return packed_tensor
def get_num_wins(
n_frames: int,
step: int,
min_coverage: float,
hp: VoiceEncConfig,
):
assert n_frames > 0
win_size = hp.ve_partial_frames
n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
n_wins += 1
target_n = win_size + step * (n_wins - 1)
return n_wins, target_n
def get_frame_step(
overlap: float,
rate: float,
hp: VoiceEncConfig,
):
# Compute how many frames separate two partial utterances
assert 0 <= overlap < 1
if rate is None:
frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
else:
frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
assert 0 < frame_step <= hp.ve_partial_frames
return frame_step
def stride_as_partials(
mel: np.ndarray,
hp: VoiceEncConfig,
overlap=0.5,
rate: float=None,
min_coverage=0.8,
):
"""
Takes unscaled mels in (T, M) format
TODO: doc
"""
assert 0 < min_coverage <= 1
frame_step = get_frame_step(overlap, rate, hp)
# Compute how many partials can fit in the mel
n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)
# Trim or pad the mel spectrogram to match the number of partials
if target_len > len(mel):
mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
elif target_len < len(mel):
mel = mel[:target_len]
# Ensure the numpy array data is float32 and contiguous in memory
mel = mel.astype(np.float32, order="C")
# Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
# where N is the number of partials, P is the number of frames of each partial and M the
# number of channels of the mel spectrograms.
shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
partials = as_strided(mel, shape, strides)
return partials
class VoiceEncoder(nn.Module):
def __init__(self, hp=VoiceEncConfig()):
super().__init__()
self.hp = hp
# Network definition
self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
if hp.flatten_lstm_params:
self.lstm.flatten_parameters()
self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)
# Cosine similarity scaling (fixed initial parameter values)
self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)
@property
def device(self):
return next(self.parameters()).device
def forward(self, mels: torch.FloatTensor):
"""
Computes the embeddings of a batch of partial utterances.
:param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
of shape (B, T, M) where T is hp.ve_partial_frames
:return: the embeddings as a float32 tensor of shape (B, E) where E is
hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
"""
if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")
# Pass the input through the LSTM layers
_, (hidden, _) = self.lstm(mels)
# Project the final hidden state
raw_embeds = self.proj(hidden[-1])
if self.hp.ve_final_relu:
raw_embeds = F.relu(raw_embeds)
# L2 normalize the embeddings.
return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
"""
Computes the embeddings of a batch of full utterances with gradients.
:param mels: (B, T, M) unscaled mels
:return: (B, E) embeddings on CPU
"""
mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens
# Compute where to split the utterances into partials
frame_step = get_frame_step(overlap, rate, self.hp)
n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))
# Possibly pad the mels to reach the target lengths
len_diff = max(target_lens) - mels.size(1)
if len_diff > 0:
pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
mels = torch.cat((mels, pad.to(mels.device)), dim=1)
# Group all partials together so that we can batch them easily
partials = [
mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
]
assert all(partials[0].shape == partial.shape for partial in partials)
partials = torch.stack(partials)
# Forward the partials
n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()
# Reduce the partial embeds into full embeds and L2-normalize them
slices = np.concatenate(([0], np.cumsum(n_partials)))
raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
raw_embeds = torch.stack(raw_embeds)
embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
return embeds
@staticmethod
def utt_to_spk_embed(utt_embeds: np.ndarray):
"""
Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
speaker embedding.
"""
assert utt_embeds.ndim == 2
utt_embeds = np.mean(utt_embeds, axis=0)
return utt_embeds / np.linalg.norm(utt_embeds, 2)
@staticmethod
def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
"""
Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
"""
embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
return embeds_x @ embeds_y
def embeds_from_mels(
self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
):
"""
Convenience function for deriving utterance or speaker embeddings from mel spectrograms.
:param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
:param mel_lens: if passing mels as a tensor, individual mel lengths
:param as_spk: whether to return utterance embeddings or a single speaker embedding
:param kwargs: args for inference()
:returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array
"""
# Load mels in memory and pack them
if isinstance(mels, List):
mels = [np.asarray(mel) for mel in mels]
assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
mel_lens = [mel.shape[0] for mel in mels]
mels = pack(mels)
# Embed them
with torch.inference_mode():
utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()
return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds
def embeds_from_wavs(
self,
wavs: List[np.ndarray],
sample_rate,
as_spk=False,
batch_size=32,
trim_top_db: Optional[float]=20,
**kwargs
):
"""
Wrapper around embeds_from_mels
:param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
"""
if sample_rate != self.hp.sample_rate:
wavs = [
librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
for wav in wavs
]
if trim_top_db:
wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]
if "rate" not in kwargs:
kwargs["rate"] = 1.3 # Resemble's default value.
mels = [melspectrogram(w, self.hp).T for w in wavs]
return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
================================================
FILE: src/chatterbox/mtl_tts.py
================================================
from dataclasses import dataclass
from pathlib import Path
import os
import librosa
import torch
import perth
import torch.nn.functional as F
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import snapshot_download
from .models.t3 import T3
from .models.t3.modules.t3_config import T3Config
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import MTLTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
REPO_ID = "ResembleAI/chatterbox"
# Supported languages for the multilingual model
SUPPORTED_LANGUAGES = {
"ar": "Arabic",
"da": "Danish",
"de": "German",
"el": "Greek",
"en": "English",
"es": "Spanish",
"fi": "Finnish",
"fr": "French",
"he": "Hebrew",
"hi": "Hindi",
"it": "Italian",
"ja": "Japanese",
"ko": "Korean",
"ms": "Malay",
"nl": "Dutch",
"no": "Norwegian",
"pl": "Polish",
"pt": "Portuguese",
"ru": "Russian",
"sv": "Swedish",
"sw": "Swahili",
"tr": "Turkish",
"zh": "Chinese",
}
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("...", ", "),
("…", ", "),
(":", ","),
(" - ", ", "),
(";", ", "),
("—", "-"),
("–", "-"),
(" ,", ","),
("“", "\""),
("”", "\""),
("‘", "'"),
("’", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
if isinstance(map_location, str):
map_location = torch.device(map_location)
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxMultilingualTTS:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: MTLTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def get_supported_languages(cls):
"""Return dictionary of supported language codes and names."""
return SUPPORTED_LANGUAGES.copy()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ve = VoiceEncoder()
ve.load_state_dict(
torch.load(ckpt_dir / "ve.pt", map_location=map_location, weights_only=True)
)
ve.to(device).eval()
t3 = T3(T3Config.multilingual())
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
t3.to(device).eval()
s3gen = S3Gen()
s3gen.load_state_dict(
torch.load(ckpt_dir / "s3gen.pt", map_location=map_location, weights_only=True)
)
s3gen.to(device).eval()
tokenizer = MTLTokenizer(
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
)
conds = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
ckpt_dir = Path(
snapshot_download(
repo_id=REPO_ID,
repo_type="model",
revision="main",
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
token=os.getenv("HF_TOKEN"),
)
)
return cls.from_local(ckpt_dir, device)
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
## Load reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
t3_cond_prompt_tokens = None
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
language_id,
audio_prompt_path=None,
exaggeration=0.5,
cfg_weight=0.5,
temperature=0.8,
repetition_penalty=2.0,
min_p=0.05,
top_p=1.0,
):
# Validate language_id
if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
raise ValueError(
f"Unsupported language_id '{language_id}'. "
f"Supported languages: {supported_langs}"
)
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
# Update exaggeration if needed
if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
_cond: T3Cond = self.conds.t3
self.conds.t3 = T3Cond(
speaker_emb=_cond.speaker_emb,
cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
sot = self.t3.hp.start_text_token
eot = self.t3.hp.stop_text_token
text_tokens = F.pad(text_tokens, (1, 0), value=sot)
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
with torch.inference_mode():
speech_tokens = self.t3.inference(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
max_new_tokens=1000, # TODO: use the value in config
temperature=temperature,
cfg_weight=cfg_weight,
repetition_penalty=repetition_penalty,
min_p=min_p,
top_p=top_p,
)
# Extract only the conditional batch.
speech_tokens = speech_tokens[0]
# TODO: output becomes 1D
speech_tokens = drop_invalid_tokens(speech_tokens)
speech_tokens = speech_tokens.to(self.device)
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)
================================================
FILE: src/chatterbox/tts.py
================================================
from dataclasses import dataclass
from pathlib import Path
import librosa
import torch
import perth
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .models.t3 import T3
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
REPO_ID = "ResembleAI/chatterbox"
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("...", ", "),
("…", ", "),
(":", ","),
(" - ", ", "),
(";", ", "),
("—", "-"),
("–", "-"),
(" ,", ","),
("“", "\""),
("”", "\""),
("‘", "'"),
("’", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ","}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
if isinstance(map_location, str):
map_location = torch.device(map_location)
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxTTS:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: EnTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ve = VoiceEncoder()
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
)
ve.to(device).eval()
t3 = T3()
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
t3.to(device).eval()
s3gen = S3Gen()
s3gen.load_state_dict(
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
)
s3gen.to(device).eval()
tokenizer = EnTokenizer(
str(ckpt_dir / "tokenizer.json")
)
conds = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device) -> 'ChatterboxTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
return cls.from_local(Path(local_path).parent, device)
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
## Load reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
repetition_penalty=1.2,
min_p=0.05,
top_p=1.0,
audio_prompt_path=None,
exaggeration=0.5,
cfg_weight=0.5,
temperature=0.8,
):
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
# Update exaggeration if needed
if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]:
_cond: T3Cond = self.conds.t3
self.conds.t3 = T3Cond(
speaker_emb=_cond.speaker_emb,
cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
if cfg_weight > 0.0:
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
sot = self.t3.hp.start_text_token
eot = self.t3.hp.stop_text_token
text_tokens = F.pad(text_tokens, (1, 0), value=sot)
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
with torch.inference_mode():
speech_tokens = self.t3.inference(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
max_new_tokens=1000, # TODO: use the value in config
temperature=temperature,
cfg_weight=cfg_weight,
repetition_penalty=repetition_penalty,
min_p=min_p,
top_p=top_p,
)
# Extract only the conditional batch.
speech_tokens = speech_tokens[0]
# TODO: output becomes 1D
speech_tokens = drop_invalid_tokens(speech_tokens)
speech_tokens = speech_tokens[speech_tokens < 6561]
speech_tokens = speech_tokens.to(self.device)
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)
================================================
FILE: src/chatterbox/tts_turbo.py
================================================
import os
import math
from dataclasses import dataclass
from pathlib import Path
import librosa
import torch
import perth
import pyloudnorm as ln
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from .models.t3 import T3
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
from .models.t3.modules.t3_config import T3Config
from .models.s3gen.const import S3GEN_SIL
import logging
logger = logging.getLogger(__name__)
REPO_ID = "ResembleAI/chatterbox-turbo"
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("…", ", "),
(":", ","),
("—", "-"),
("–", "-"),
(" ,", ","),
("“", "\""),
("”", "\""),
("‘", "'"),
("’", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ","}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
if isinstance(map_location, str):
map_location = torch.device(map_location)
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxTurboTTS:
ENC_COND_LEN = 15 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: EnTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ve = VoiceEncoder()
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
)
ve.to(device).eval()
# Turbo specific hp
hp = T3Config(text_tokens_dict_size=50276)
hp.llama_config_name = "GPT2_medium"
hp.speech_tokens_dict_size = 6563
hp.input_pos_emb = None
hp.speech_cond_prompt_len = 375
hp.use_perceiver_resampler = False
hp.emotion_adv = False
t3 = T3(hp)
t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
del t3.tfmr.wte
t3.to(device).eval()
s3gen = S3Gen(meanflow=True)
weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
s3gen.load_state_dict(
weights, strict=True
)
s3gen.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if len(tokenizer) != 50276:
print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
conds = None
builtin_voice = ckpt_dir / "conds.pt"
if builtin_voice.exists():
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
local_path = snapshot_download(
repo_id=REPO_ID,
token=os.getenv("HF_TOKEN") or None,
# Optional: Filter to download only what you need
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
)
return cls.from_local(local_path, device)
def norm_loudness(self, wav, sr, target_lufs=-27):
try:
meter = ln.Meter(sr)
loudness = meter.integrated_loudness(wav)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
if math.isfinite(gain_linear) and gain_linear > 0.0:
wav = wav * gain_linear
except Exception as e:
print(f"Warning: Error in norm_loudness, skipping: {e}")
return wav
def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
## Load and norm reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
if norm_loudness:
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
repetition_penalty=1.2,
min_p=0.00,
top_p=0.95,
audio_prompt_path=None,
exaggeration=0.0,
cfg_weight=0.0,
temperature=0.8,
top_k=1000,
norm_loudness=True,
):
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
text_tokens = text_tokens.input_ids.to(self.device)
speech_tokens = self.t3.inference_turbo(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
# Remove OOV tokens and add silence to end
speech_tokens = speech_tokens[speech_tokens < 6561]
speech_tokens = speech_tokens.to(self.device)
silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
speech_tokens = torch.cat([speech_tokens, silence])
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
n_cfm_timesteps=2,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)
================================================
FILE: src/chatterbox/vc.py
================================================
from pathlib import Path
import librosa
import torch
import perth
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
REPO_ID = "ResembleAI/chatterbox"
class ChatterboxVC:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
s3gen: S3Gen,
device: str,
ref_dict: dict=None,
):
self.sr = S3GEN_SR
self.s3gen = s3gen
self.device = device
self.watermarker = perth.PerthImplicitWatermarker()
if ref_dict is None:
self.ref_dict = None
else:
self.ref_dict = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in ref_dict.items()
}
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ref_dict = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
states = torch.load(builtin_voice, map_location=map_location)
ref_dict = states['gen']
s3gen = S3Gen()
s3gen.load_state_dict(
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
)
s3gen.to(device).eval()
return cls(s3gen, device, ref_dict=ref_dict)
@classmethod
def from_pretrained(cls, device) -> 'ChatterboxVC':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
for fpath in ["s3gen.safetensors", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
return cls.from_local(Path(local_path).parent, device)
def set_target_voice(self, wav_fpath):
## Load reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
def generate(
self,
audio,
target_voice_path=None,
):
if target_voice_path:
self.set_target_voice(target_voice_path)
else:
assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`"
with torch.inference_mode():
audio_16, _ = librosa.load(audio, sr=S3_SR)
audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ]
s3_tokens, _ = self.s3gen.tokenizer(audio_16)
wav, _ = self.s3gen.inference(
speech_tokens=s3_tokens,
ref_dict=self.ref_dict,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)