Repository: resemble-ai/chatterbox Branch: master Commit: eaf93540a1dd Files: 64 Total size: 354.7 KB Directory structure: gitextract_lols3utf/ ├── .github/ │ └── workflows/ │ └── install_check.yml ├── .gitignore ├── LICENSE ├── README.md ├── example_for_mac.py ├── example_tts.py ├── example_tts_turbo.py ├── example_vc.py ├── gradio_tts_app.py ├── gradio_tts_turbo_app.py ├── gradio_vc_app.py ├── multilingual_app.py ├── pyproject.toml └── src/ └── chatterbox/ ├── __init__.py ├── models/ │ ├── __init__.py │ ├── s3gen/ │ │ ├── __init__.py │ │ ├── configs.py │ │ ├── const.py │ │ ├── decoder.py │ │ ├── f0_predictor.py │ │ ├── flow.py │ │ ├── flow_matching.py │ │ ├── hifigan.py │ │ ├── matcha/ │ │ │ ├── decoder.py │ │ │ ├── flow_matching.py │ │ │ ├── text_encoder.py │ │ │ └── transformer.py │ │ ├── s3gen.py │ │ ├── transformer/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── attention.py │ │ │ ├── convolution.py │ │ │ ├── embedding.py │ │ │ ├── encoder_layer.py │ │ │ ├── positionwise_feed_forward.py │ │ │ ├── subsampling.py │ │ │ └── upsample_encoder.py │ │ ├── utils/ │ │ │ ├── class_utils.py │ │ │ ├── intmeanflow.py │ │ │ ├── mask.py │ │ │ └── mel.py │ │ └── xvector.py │ ├── s3tokenizer/ │ │ ├── __init__.py │ │ └── s3tokenizer.py │ ├── t3/ │ │ ├── __init__.py │ │ ├── inference/ │ │ │ ├── alignment_stream_analyzer.py │ │ │ └── t3_hf_backend.py │ │ ├── llama_configs.py │ │ ├── modules/ │ │ │ ├── cond_enc.py │ │ │ ├── learned_pos_emb.py │ │ │ ├── perceiver.py │ │ │ └── t3_config.py │ │ └── t3.py │ ├── tokenizers/ │ │ ├── __init__.py │ │ └── tokenizer.py │ ├── utils.py │ └── voice_encoder/ │ ├── __init__.py │ ├── config.py │ ├── melspec.py │ └── voice_encoder.py ├── mtl_tts.py ├── tts.py ├── tts_turbo.py └── vc.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/install_check.yml ================================================ name: Test Installation on: push: branches: [ "master" ] pull_request: branches: [ "master" ] jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python 3.10 uses: actions/setup-python@v4 with: python-version: "3.10" - name: Test Standard Install run: | pip install -e . ================================================ FILE: .gitignore ================================================ .vscode # Pylance pyrightconfig.json # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt syn_out/ checkpoints/ .gradio # Ignore generated sample .wav files **/*.wav ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Resemble AI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ![Chatterbox Turbo Image](./Chatterbox-Turbo.jpg) # Chatterbox TTS [![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_turbo_demopage/) [![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) [![Alt Text](https://static-public.podonos.com/badges/insight-on-pdns-sm-dark.svg)](https://podonos.com/resembleai/chatterbox) [![Discord](https://img.shields.io/discord/1377773249798344776?label=join%20discord&logo=discord&style=flat)](https://discord.gg/rJq9cRJBJ6) *Made with ♥️ by* resemble-logo-horizontal **Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI. We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output. **Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows. If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (link). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media. Podonos Turbo Eval ### ⚡ Model Zoo Choose the right model for your application. | Model | Size | Languages | Key Features | Best For | 🤗 | Examples | |:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- | | **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) | | Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) | | Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) | ## Installation ```shell pip install chatterbox-tts ``` Alternatively, you can install from source: ```shell # conda create -yn chatterbox python=3.11 # conda activate chatterbox git clone https://github.com/resemble-ai/chatterbox.git cd chatterbox pip install -e . ``` We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode. ## Usage ##### Chatterbox-Turbo ```python import torchaudio as ta import torch from chatterbox.tts_turbo import ChatterboxTurboTTS # Load the Turbo model model = ChatterboxTurboTTS.from_pretrained(device="cuda") # Generate with Paralinguistic Tags text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?" # Generate audio (requires a reference clip for voice cloning) wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav") ta.save("test-turbo.wav", wav, model.sr) ``` ##### Chatterbox and Chatterbox-Multilingual ```python import torchaudio as ta from chatterbox.tts import ChatterboxTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS # English example model = ChatterboxTTS.from_pretrained(device="cuda") text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill." wav = model.generate(text) ta.save("test-english.wav", wav, model.sr) # Multilingual examples multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device) french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues." wav_french = multilingual_model.generate(french_text, language_id="fr") ta.save("test-french.wav", wav_french, model.sr) chinese_text = "你好,今天天气真不错,希望你有一个愉快的周末。" wav_chinese = multilingual_model.generate(chinese_text, language_id="zh") ta.save("test-chinese.wav", wav_chinese, model.sr) # If you want to synthesize with a different voice, specify the audio prompt AUDIO_PROMPT_PATH = "YOUR_FILE.wav" wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH) ta.save("test-2.wav", wav, model.sr) ``` See `example_tts.py` and `example_vc.py` for more examples. ## Supported Languages Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh) ## Original Chatterbox Tips - **General Use (TTS and Voice Agents):** - Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`. - The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages. - If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing. - **Expressive or Dramatic Speech:** - Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher. - Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing. ## Built-in PerTh Watermarking for Responsible AI Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy. ## Watermark extraction You can look for the watermark using the following script. ```python import perth import librosa AUDIO_PATH = "YOUR_FILE.wav" # Load the watermarked audio watermarked_audio, sr = librosa.load(AUDIO_PATH, sr=None) # Initialize watermarker (same as used for embedding) watermarker = perth.PerthImplicitWatermarker() # Extract watermark watermark = watermarker.get_watermark(watermarked_audio, sample_rate=sr) print(f"Extracted watermark: {watermark}") # Output: 0.0 (no watermark) or 1.0 (watermarked) ``` ## Official Discord 👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together! ## Evaluation Chatterbox Turbo was evaluated using Podonos, a platform for reproducible subjective speech evaluation. We compared Chatterbox Turbo to competitive TTS systems using Podonos' standardized evaluation suite, focusing on overall preference, naturalness, and expressiveness. Evaluation reports: - [Chatterbox Turbo vs ElevenLabs Turbo v2.5](https://podonos.com/resembleai/chatterbox-turbo-vs-elevenlabs-turbo) - [Chatterbox Turbo vs Cartesia Sonic 3](https://podonos.com/resembleai/chatterbox-turbo-vs-cartesia-sonic3) - [Chatterbox Turbo vs VibeVoice 7B](https://podonos.com/resembleai/chatterbox-turbo-vs-vibevoice7b) These evaluations were conducted under identical conditions and are publicly accessible via Podonos. ## Acknowledgements - [Podonos](https://podonos.com) — for supporting reproducible subjective speech evaluation - [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice) - [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) - [HiFT-GAN](https://github.com/yl4579/HiFTNet) - [Llama 3](https://github.com/meta-llama/llama3) - [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer) ## Citation If you find this model useful, please consider citing. ``` @misc{chatterboxtts2025, author = {{Resemble AI}}, title = {{Chatterbox-TTS}}, year = {2025}, howpublished = {\url{https://github.com/resemble-ai/chatterbox}}, note = {GitHub repository} } ``` ## Disclaimer Don't use this model to do bad things. Prompts are sourced from freely available data on the internet. ================================================ FILE: example_for_mac.py ================================================ import torch import torchaudio as ta from chatterbox.tts import ChatterboxTTS # Detect device (Mac with M1/M2/M3/M4) device = "mps" if torch.backends.mps.is_available() else "cpu" map_location = torch.device(device) torch_load_original = torch.load def patched_torch_load(*args, **kwargs): if 'map_location' not in kwargs: kwargs['map_location'] = map_location return torch_load_original(*args, **kwargs) torch.load = patched_torch_load model = ChatterboxTTS.from_pretrained(device=device) text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day." # If you want to synthesize with a different voice, specify the audio prompt AUDIO_PROMPT_PATH = "YOUR_FILE.wav" wav = model.generate( text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration=2.0, cfg_weight=0.5 ) ta.save("test-2.wav", wav, model.sr) ================================================ FILE: example_tts.py ================================================ import torchaudio as ta import torch from pathlib import Path from chatterbox.tts import ChatterboxTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS # Automatically detect the best available device if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Using device: {device}") model = ChatterboxTTS.from_pretrained(device=device) text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill." wav = model.generate(text) ta.save("test-1.wav", wav, model.sr) multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device) text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues." wav = multilingual_model.generate(text, language_id="fr") ta.save("test-2.wav", wav, multilingual_model.sr) # If you want to synthesize with a different voice, specify the audio prompt AUDIO_PROMPT_PATH = "YOUR_FILE.wav" if Path(AUDIO_PROMPT_PATH).exists(): wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH) ta.save("test-3.wav", wav, model.sr) else: print(f"Warning: audio prompt file '{AUDIO_PROMPT_PATH}' not found, skipping voice cloning example.") ================================================ FILE: example_tts_turbo.py ================================================ import torchaudio as ta import torch from chatterbox.tts_turbo import ChatterboxTurboTTS # Load the Turbo model model = ChatterboxTurboTTS.from_pretrained(device="cuda") # Generate with Paralinguistic Tags text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?" # Generate audio (requires a reference clip for voice cloning) # wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav") wav = model.generate(text) ta.save("test-turbo.wav", wav, model.sr) ================================================ FILE: example_vc.py ================================================ import torch import torchaudio as ta from chatterbox.vc import ChatterboxVC # Automatically detect the best available device if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Using device: {device}") AUDIO_PATH = "YOUR_FILE.wav" TARGET_VOICE_PATH = "YOUR_FILE.wav" model = ChatterboxVC.from_pretrained(device) wav = model.generate( audio=AUDIO_PATH, target_voice_path=TARGET_VOICE_PATH, ) ta.save("testvc.wav", wav, model.sr) ================================================ FILE: gradio_tts_app.py ================================================ import random import numpy as np import torch import gradio as gr from chatterbox.tts import ChatterboxTTS DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def set_seed(seed: int): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def load_model(): model = ChatterboxTTS.from_pretrained(DEVICE) return model def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty): if model is None: model = ChatterboxTTS.from_pretrained(DEVICE) if seed_num != 0: set_seed(int(seed_num)) wav = model.generate( text, audio_prompt_path=audio_prompt_path, exaggeration=exaggeration, temperature=temperature, cfg_weight=cfgw, min_p=min_p, top_p=top_p, repetition_penalty=repetition_penalty, ) return (model.sr, wav.squeeze(0).numpy()) with gr.Blocks() as demo: model_state = gr.State(None) # Loaded once per session/user with gr.Row(): with gr.Column(): text = gr.Textbox( value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", label="Text to synthesize (max chars 300)", max_lines=5 ) ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None) exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5) cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.5) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8) min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05) top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00) repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2) run_btn = gr.Button("Generate", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output Audio") demo.load(fn=load_model, inputs=[], outputs=model_state) run_btn.click( fn=generate, inputs=[ model_state, text, ref_wav, exaggeration, temp, seed_num, cfg_weight, min_p, top_p, repetition_penalty, ], outputs=audio_output, ) if __name__ == "__main__": demo.queue( max_size=50, default_concurrency_limit=1, ).launch(share=True) ================================================ FILE: gradio_tts_turbo_app.py ================================================ import random import numpy as np import torch import gradio as gr from chatterbox.tts_turbo import ChatterboxTurboTTS DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EVENT_TAGS = [ "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]", "[sniff]", "[gasp]", "[chuckle]", "[laugh]" ] # --- REFINED CSS --- # 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds. # 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching. CUSTOM_CSS = """ .tag-container { display: flex !important; flex-wrap: wrap !important; /* This fixes the one-per-line issue */ gap: 8px !important; margin-top: 5px !important; margin-bottom: 10px !important; border: none !important; background: transparent !important; } .tag-btn { min-width: fit-content !important; width: auto !important; height: 32px !important; font-size: 13px !important; background: #eef2ff !important; border: 1px solid #c7d2fe !important; color: #3730a3 !important; border-radius: 6px !important; padding: 0 10px !important; margin: 0 !important; box-shadow: none !important; } .tag-btn:hover { background: #c7d2fe !important; transform: translateY(-1px); } """ INSERT_TAG_JS = """ (tag_val, current_text) => { const textarea = document.querySelector('#main_textbox textarea'); if (!textarea) return current_text + " " + tag_val; const start = textarea.selectionStart; const end = textarea.selectionEnd; let prefix = " "; let suffix = " "; if (start === 0) prefix = ""; else if (current_text[start - 1] === ' ') prefix = ""; if (end < current_text.length && current_text[end] === ' ') suffix = ""; return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end); } """ def set_seed(seed: int): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def load_model(): print(f"Loading Chatterbox-Turbo on {DEVICE}...") model = ChatterboxTurboTTS.from_pretrained(DEVICE) return model def generate( model, text, audio_prompt_path, temperature, seed_num, min_p, top_p, top_k, repetition_penalty, norm_loudness ): if model is None: model = ChatterboxTurboTTS.from_pretrained(DEVICE) if seed_num != 0: set_seed(int(seed_num)) wav = model.generate( text, audio_prompt_path=audio_prompt_path, temperature=temperature, min_p=min_p, top_p=top_p, top_k=int(top_k), repetition_penalty=repetition_penalty, norm_loudness=norm_loudness, ) return (model.sr, wav.squeeze(0).numpy()) with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo: gr.Markdown("# ⚡ Chatterbox Turbo") model_state = gr.State(None) with gr.Row(): with gr.Column(): text = gr.Textbox( value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?", label="Text to synthesize (max chars 300)", max_lines=5, elem_id="main_textbox" ) # --- Event Tags --- # Switched back to Row, but applied specific CSS to force wrapping with gr.Row(elem_classes=["tag-container"]): for tag in EVENT_TAGS: # elem_classes targets the button specifically btn = gr.Button(tag, elem_classes=["tag-btn"]) btn.click( fn=None, inputs=[btn, text], outputs=text, js=INSERT_TAG_JS ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav" ) run_btn = gr.Button("Generate ⚡", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output Audio") with gr.Accordion("Advanced Options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8) top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95) top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000) repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2) min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00) norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)") demo.load(fn=load_model, inputs=[], outputs=model_state) run_btn.click( fn=generate, inputs=[ model_state, text, ref_wav, temp, seed_num, min_p, top_p, top_k, repetition_penalty, norm_loudness, ], outputs=audio_output, ) if __name__ == "__main__": demo.queue( max_size=50, default_concurrency_limit=1, ).launch(share=True) ================================================ FILE: gradio_vc_app.py ================================================ import torch import gradio as gr from chatterbox.vc import ChatterboxVC DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = ChatterboxVC.from_pretrained(DEVICE) def generate(audio, target_voice_path): wav = model.generate( audio, target_voice_path=target_voice_path, ) return model.sr, wav.squeeze(0).numpy() demo = gr.Interface( generate, [ gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input audio file"), gr.Audio(sources=["upload", "microphone"], type="filepath", label="Target voice audio file (if none, the default voice is used)", value=None), ], "audio", ) if __name__ == "__main__": demo.launch() ================================================ FILE: multilingual_app.py ================================================ import random import numpy as np import torch from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES import gradio as gr DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {DEVICE}") # --- Global Model Initialization --- MODEL = None LANGUAGE_CONFIG = { "ar": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac", "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب." }, "da": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac", "text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal." }, "de": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac", "text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal." }, "el": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac", "text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube." }, "en": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac", "text": "Last month, we reached a new milestone with two billion views on our YouTube channel." }, "es": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac", "text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube." }, "fi": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac", "text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme." }, "fr": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac", "text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube." }, "he": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac", "text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו." }, "hi": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac", "text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।" }, "it": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac", "text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube." }, "ja": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac", "text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。" }, "ko": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac", "text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다." }, "ms": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac", "text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami." }, "nl": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac", "text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal." }, "no": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac", "text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår." }, "pl": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac", "text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube." }, "pt": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac", "text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube." }, "ru": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac", "text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале." }, "sv": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac", "text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal." }, "sw": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac", "text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube." }, "tr": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac", "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık." }, "zh": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac", "text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。" }, } # --- UI Helpers --- def default_audio_for_ui(lang: str) -> str | None: return LANGUAGE_CONFIG.get(lang, {}).get("audio") def default_text_for_ui(lang: str) -> str: return LANGUAGE_CONFIG.get(lang, {}).get("text", "") def get_supported_languages_display() -> str: """Generate a formatted display of all supported languages.""" language_items = [] for code, name in sorted(SUPPORTED_LANGUAGES.items()): language_items.append(f"**{name}** (`{code}`)") # Split into 2 lines mid = len(language_items) // 2 line1 = " • ".join(language_items[:mid]) line2 = " • ".join(language_items[mid:]) return f""" ### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total) {line1} {line2} """ def get_or_load_model(): """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already, and ensures it's on the correct device.""" global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: MODEL.to(DEVICE) print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") except Exception as e: print(f"Error loading model: {e}") raise return MODEL # Attempt to load the model at startup. try: get_or_load_model() except Exception as e: print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") def set_seed(seed: int): """Sets the random seed for reproducibility across torch, numpy, and random.""" torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None: """ Decide which audio prompt to use: - If user provided a path (upload/mic/url), use it. - Else, fall back to language-specific default (if any). """ if provided_path and str(provided_path).strip(): return provided_path return LANGUAGE_CONFIG.get(language_id, {}).get("audio") def generate_tts_audio( text_input: str, language_id: str, audio_prompt_path_input: str = None, exaggeration_input: float = 0.5, temperature_input: float = 0.8, seed_num_input: int = 0, cfgw_input: float = 0.5 ) -> tuple[int, np.ndarray]: """ Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling. Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi. This tool synthesizes natural-sounding speech from input text. When a reference audio file is provided, it captures the speaker's voice characteristics and speaking style. The generated audio maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided. Args: text_input (str): The text to synthesize into speech (maximum 300 characters) language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi) audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None. exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5. temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8. seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0. cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer. Returns: tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray) """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") if seed_num_input != 0: set_seed(int(seed_num_input)) print(f"Generating audio for text: '{text_input[:50]}...'") # Handle optional audio prompt chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id) generate_kwargs = { "exaggeration": exaggeration_input, "temperature": temperature_input, "cfg_weight": cfgw_input, } if chosen_prompt: generate_kwargs["audio_prompt_path"] = chosen_prompt print(f"Using audio prompt: {chosen_prompt}") else: print("No audio prompt provided; using default voice.") wav = current_model.generate( text_input[:300], # Truncate text to max chars language_id=language_id, **generate_kwargs ) print("Audio generation complete.") return (current_model.sr, wav.squeeze(0).numpy()) with gr.Blocks() as demo: gr.Markdown( """ # Chatterbox Multilingual Demo Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages. """ ) # Display supported languages gr.Markdown(get_supported_languages_display()) with gr.Row(): with gr.Column(): initial_lang = "fr" text = gr.Textbox( value=default_text_for_ui(initial_lang), label="Text to synthesize (max chars 300)", max_lines=5 ) language_id = gr.Dropdown( choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()), value=initial_lang, label="Language", info="Select the language for text-to-speech synthesis" ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File (Optional)", value=default_audio_for_ui(initial_lang) ) gr.Markdown( "💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.", elem_classes=["audio-note"] ) exaggeration = gr.Slider( 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5 ) cfg_weight = gr.Slider( 0.2, 1, step=.05, label="CFG/Pace", value=0.5 ) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) run_btn = gr.Button("Generate", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output Audio") def on_language_change(lang, current_ref, current_text): return default_audio_for_ui(lang), default_text_for_ui(lang) language_id.change( fn=on_language_change, inputs=[language_id, ref_wav, text], outputs=[ref_wav, text], show_progress=False ) run_btn.click( fn=generate_tts_audio, inputs=[ text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight, ], outputs=[audio_output], ) demo.launch(mcp_server=True) ================================================ FILE: pyproject.toml ================================================ [project] name = "chatterbox-tts" version = "0.1.6" description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI" readme = "README.md" requires-python = ">=3.10" license = {file = "LICENSE"} authors = [ {name = "resemble-ai", email = "engineering@resemble.ai"} ] dependencies = [ "numpy>=1.24.0,<1.26.0", "librosa==0.11.0", "s3tokenizer", "torch==2.6.0", "torchaudio==2.6.0", "transformers==5.2.0", "diffusers==0.29.0", "resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master", "conformer==0.3.2", "safetensors==0.5.3", "spacy-pkuseg", "pykakasi==2.3.0", "gradio==6.8.0", "pyloudnorm", "omegaconf" ] [project.urls] Homepage = "https://github.com/resemble-ai/chatterbox" Repository = "https://github.com/resemble-ai/chatterbox" [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["src"] ================================================ FILE: src/chatterbox/__init__.py ================================================ try: from importlib.metadata import version except ImportError: from importlib_metadata import version # For Python <3.8 __version__ = version("chatterbox-tts") from .tts import ChatterboxTTS from .vc import ChatterboxVC from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES ================================================ FILE: src/chatterbox/models/__init__.py ================================================ ================================================ FILE: src/chatterbox/models/s3gen/__init__.py ================================================ from .s3gen import S3Token2Wav as S3Gen from .const import S3GEN_SR ================================================ FILE: src/chatterbox/models/s3gen/configs.py ================================================ from ..utils import AttrDict CFM_PARAMS = AttrDict({ "sigma_min": 1e-06, "solver": "euler", "t_scheduler": "cosine", "training_cfg_rate": 0.2, "inference_cfg_rate": 0.7, "reg_loss_type": "l1" }) ================================================ FILE: src/chatterbox/models/s3gen/const.py ================================================ S3GEN_SR = 24000 S3GEN_SIL = 4299 ================================================ FILE: src/chatterbox/models/s3gen/decoder.py ================================================ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat from .utils.mask import add_optional_chunk_mask from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \ TimestepEmbedding, Upsample1D from .matcha.transformer import BasicTransformerBlock from .utils.intmeanflow import get_intmeanflow_time_mixer def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16] mask = mask.to(dtype) # attention mask bias # NOTE(Mddct): torch.finfo jit issues # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min mask = (1.0 - mask) * -1.0e+10 return mask class Transpose(torch.nn.Module): def __init__(self, dim0: int, dim1: int): super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor): x = torch.transpose(x, self.dim0, self.dim1) return x class CausalBlock1D(Block1D): def __init__(self, dim: int, dim_out: int): super(CausalBlock1D, self).__init__(dim, dim_out) self.block = torch.nn.Sequential( CausalConv1d(dim, dim_out, 3), Transpose(1, 2), nn.LayerNorm(dim_out), Transpose(1, 2), nn.Mish(), ) def forward(self, x: torch.Tensor, mask: torch.Tensor): output = self.block(x * mask) return output * mask class CausalResnetBlock1D(ResnetBlock1D): def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) self.block1 = CausalBlock1D(dim, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out) class CausalConv1d(torch.nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None ) -> None: super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) assert stride == 1 self.causal_padding = (kernel_size - 1, 0) def forward(self, x: torch.Tensor): x = F.pad(x, self.causal_padding) x = super(CausalConv1d, self).forward(x) return x class ConditionalDecoder(nn.Module): def __init__( self, in_channels=320, out_channels=80, causal=True, channels=[256], dropout=0.0, attention_head_dim=64, n_blocks=4, num_mid_blocks=12, num_heads=8, act_fn="gelu", meanflow=False, ): """ This decoder requires an input with the same shape of the target. So, if your text content is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. """ super().__init__() channels = tuple(channels) self.meanflow = meanflow self.in_channels = in_channels self.out_channels = out_channels self.causal = causal self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=in_channels, time_embed_dim=time_embed_dim, act_fn="silu", ) self.down_blocks = nn.ModuleList([]) self.mid_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) # NOTE jrm: `static_chunk_size` is missing? self.static_chunk_size = 0 output_channel = in_channels for i in range(len(channels)): # pylint: disable=consider-using-enumerate input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) downsample = ( Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for _ in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) channels = channels[::-1] + (channels[0],) for i in range(len(channels) - 1): input_channel = channels[i] * 2 output_channel = channels[i + 1] is_last = i == len(channels) - 2 resnet = CausalResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, ) if self.causal else ResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, ) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() self.time_embed_mixer = None if self.meanflow: self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim) @property def dtype(self): return self.final_proj.weight.dtype def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, mask, mu, t, spks=None, cond=None, r=None): """Forward pass of the UNet1DConditional model. Args: x: (B, 80, T) mask (_type_) t (_type_): shape (batch_size) spks (_type_, optional) Defaults to None. cond (_type_, optional) r: end time for meanflow mode (shape (1,) tensor) Raises: ValueError: _description_ ValueError: _description_ Returns: _type_: _description_ """ t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) if self.meanflow: r = self.time_embeddings(r).to(t.dtype) r = self.time_mlp(r) concat_embed = torch.cat([t, r], dim=1) t = self.time_embed_mixer(concat_embed) x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) x = pack([x, spks], "b * t")[0] if cond is not None: x = pack([x, cond], "b * t")[0] hiddens = [] masks = [mask] for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections x = downsample(x * mask_down) masks.append(mask_down[:, :, ::2]) masks = masks[:-1] mask_mid = masks[-1] for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() for resnet, transformer_blocks, upsample in self.up_blocks: mask_up = masks.pop() skip = hiddens.pop() x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() x = upsample(x * mask_up) x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask ================================================ FILE: src/chatterbox/models/s3gen/f0_predictor.py ================================================ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from torch.nn.utils.parametrizations import weight_norm class ConvRNNF0Predictor(nn.Module): def __init__(self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512 ): super().__init__() self.num_class = num_class self.condnet = nn.Sequential( weight_norm( nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) ), nn.ELU(), weight_norm( nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) ), nn.ELU(), weight_norm( nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) ), nn.ELU(), weight_norm( nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) ), nn.ELU(), weight_norm( nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) ), nn.ELU(), ) self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.condnet(x) x = x.transpose(1, 2) return torch.abs(self.classifier(x).squeeze(-1)) ================================================ FILE: src/chatterbox/models/s3gen/flow.py ================================================ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import random from typing import Dict, Optional logger = logging.getLogger(__name__) import torch import torch.nn as nn from torch.nn import functional as F from .utils.mask import make_pad_mask from .configs import CFM_PARAMS from omegaconf import DictConfig logger = logging.getLogger(__name__) def _repeat_batch_dim(tnsr, B, ndim): "repeat batch dimension if it's equal to 1" if tnsr is not None: # add missing batch dim if needed while tnsr.ndim < ndim: tnsr = tnsr[None] # repeat batch dim as needed if B > 1 and tnsr.size(0) == 1: tnsr = tnsr.repeat(B, *([1] * (ndim - 1))) assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}" return tnsr class CausalMaskedDiffWithXvec(torch.nn.Module): def __init__(self, input_size: int = 512, output_size: int = 80, spk_embed_dim: int = 192, output_type: str = "mel", vocab_size: int = 6561, input_frame_rate: int = 25, only_mask_loss: bool = True, token_mel_ratio: int = 2, pre_lookahead_len: int = 3, encoder: torch.nn.Module = None, decoder: torch.nn.Module = None, decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig( {'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): super().__init__() self.input_size = input_size self.output_size = output_size self.decoder_conf = decoder_conf self.mel_feat_conf = mel_feat_conf self.vocab_size = vocab_size self.output_type = output_type self.input_frame_rate = input_frame_rate logging.info(f"input frame rate={self.input_frame_rate}") self.input_embedding = nn.Embedding(vocab_size, input_size) self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) self.encoder = encoder self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) self.decoder = decoder self.only_mask_loss = only_mask_loss self.token_mel_ratio = token_mel_ratio self.pre_lookahead_len = pre_lookahead_len # NOTE: copied in from cosyvoice repo def compute_loss( self, batch: dict, device: torch.device, ) -> Dict[str, Optional[torch.Tensor]]: token = batch['speech_token'].to(device) token_len = batch['speech_token_len'].to(device) feat = batch['speech_feat'].to(device) # (B, 80, T) feat_len = batch['speech_feat_len'].to(device) embedding = batch['embedding'].to(device) # NOTE unified training, static_chunk_size > 0 or = 0 # streaming = True if random.random() < 0.5 else False # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) # concat text and prompt_text mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1) token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb) # text encode h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C) h = self.encoder_proj(h) # get conditions conds = torch.zeros(feat.shape, device=token.device) for i, j in enumerate(feat_len): if random.random() < 0.5: continue index = random.randint(0, int(0.3 * j)) conds[i, :, :index] = feat[i, :, :index] mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h) loss, _ = self.decoder.compute_loss( feat.contiguous(), mask.unsqueeze(1), h.transpose(1, 2).contiguous(), embedding, cond=conds, # streaming=streaming, ) return {'loss': loss} @torch.inference_mode() def inference(self, token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding, finalize, n_timesteps=10, noised_mels=None, meanflow=False): # token: (B, n_toks) # token_len: (B,) B = token.size(0) # xvec projection embedding = torch.atleast_2d(embedding) embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim) # adjust shapes (batching logic) prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt) prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,) prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80) prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim) # concat text and prompt_text token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) if (token >= self.vocab_size).any(): logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!") token = self.input_embedding(token.long()) * mask # text encode h, h_masks = self.encoder(token, token_len) if finalize is False: h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1) mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) # # get conditions conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) conds[:, :mel_len1] = prompt_feat conds = conds.transpose(1, 2) mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h) if mask.shape[0] != B: mask = mask.repeat(B, 1, 1) feat, _ = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask, spks=embedding, cond=conds, n_timesteps=n_timesteps, noised_mels=noised_mels, meanflow=meanflow, ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 return feat, None # NOTE jrm: why are they returning None here? ================================================ FILE: src/chatterbox/models/s3gen/flow_matching.py ================================================ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import torch import torch.nn.functional as F from .matcha.flow_matching import BASECFM from .configs import CFM_PARAMS from tqdm import tqdm def cast_all(*args, dtype): return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args] class ConditionalCFM(BASECFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): super().__init__( n_feats=in_channels, cfm_params=cfm_params, n_spks=n_spks, spk_emb_dim=spk_emb_dim, ) self.t_scheduler = cfm_params.t_scheduler self.training_cfg_rate = cfm_params.training_cfg_rate self.inference_cfg_rate = cfm_params.inference_cfg_rate in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ raise NotImplementedError("unused, needs updating for meanflow model") z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature cache_size = flow_cache.shape[2] # fix prompt and overlap part mu and z if cache_size != 0: z[:, :, :cache_size] = flow_cache[:, :, :, 0] mu[:, :, :cache_size] = flow_cache[:, :, :, 1] z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) flow_cache = torch.stack([z_cache, mu_cache], dim=-1) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes meanflow: meanflow mode """ in_dtype = x.dtype x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype) # Duplicated batch dims are for CFG # Do not use concat, it may cause memory format changed and trt infer with wrong results! B, T = mu.size(0), x.size(2) x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype) mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype) mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype) t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype) cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype) r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow) for t, r in zip(t_span[:-1], t_span[1:]): t = t.unsqueeze(dim=0) r = r.unsqueeze(dim=0) # Shapes: # x_in ( 2B, 80, T ) # mask_in ( 2B, 1, T ) # mu_in ( 2B, 80, T ) # t_in ( 2B, ) # spks_in ( 2B, 80, ) # cond_in ( 2B, 80, T ) # r_in ( 2B, ) # x ( B, 80, T ) # mask ( B, 1, T ) # mu ( B, 80, T ) # t ( B, ) # spks ( B, 80, ) # cond ( B, 80, T ) # r ( B, ) x_in[:B] = x_in[B:] = x mask_in[:B] = mask_in[B:] = mask mu_in[:B] = mu t_in[:B] = t_in[B:] = t spks_in[:B] = spks cond_in[:B] = cond r_in[:B] = r_in[B:] = r # (only used for meanflow) dxdt = self.estimator.forward( x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in, r=r_in if meanflow else None, ) dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0) dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt) dt = r - t x = x + dt * dxdt return x.to(in_dtype) def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): target mask shape: (batch_size, 1, mel_timesteps) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) spks (torch.Tensor, optional): speaker embedding. Defaults to None. shape: (batch_size, spk_emb_dim) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_feats, mel_timesteps) """ b, _, t = mu.shape # random timestep t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t = 1 - torch.cos(t * 0.5 * torch.pi) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z # during training, we randomly drop condition to trade off mode coverage and sample fidelity if self.training_cfg_rate > 0: cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate mu = mu * cfg_mask.view(-1, 1, 1) spks = spks * cfg_mask.view(-1, 1) cond = cond * cfg_mask.view(-1, 1, 1) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y class CausalConditionalCFM(ConditionalCFM): def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None): super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) # TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE self.rand_noise = None @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes noised_mels: gt mels noised a time t Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ B = mu.size(0) z = torch.randn_like(mu) if noised_mels is not None: prompt_len = mu.size(2) - noised_mels.size(2) z[..., prompt_len:] = noised_mels # time steps for reverse diffusion t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if (not meanflow) and (self.t_scheduler == 'cosine'): t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) # NOTE: right now, the only meanflow models are also distilled models, which don't need CFG # because they were distilled with CFG outputs. We would need to add another hparam and # change the conditional logic here if we want to use CFG inference with a meanflow model. if meanflow: return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None def basic_euler(self, x, t_span, mu, mask, spks, cond): in_dtype = x.dtype x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype) print("S3 Token -> Mel Inference...") for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1): t, r = t[None], r[None] dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r) dt = r - t x = x + dt * dxdt return x.to(in_dtype) ================================================ FILE: src/chatterbox/models/s3gen/hifigan.py ================================================ # jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py # most modules should be reusable, but I found their SineGen changed a git. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """HIFI-GAN""" from typing import Dict, Optional, List import numpy as np from scipy.signal import get_window import torch import torch.nn.functional as F from torch.nn import Conv1d from torch.nn import ConvTranspose1d from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm from torch.distributions.uniform import Uniform from torch import nn, sin, pow from torch.nn import Parameter class Snake(nn.Module): ''' Implementation of a sine-based periodic activation function Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha: trainable parameter alpha is initialized to 1 by default, higher values = higher-frequency. alpha will be trained along with the rest of your model. ''' super(Snake, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) """hifigan based generator implementation. This code is modified from https://github.com/jik876/hifi-gan ,https://github.com/kan-bayashi/ParallelWaveGAN and https://github.com/NVIDIA/BigVGAN """ class ResBlock(torch.nn.Module): """Residual block module in HiFiGAN/BigVGAN.""" def __init__( self, channels: int = 512, kernel_size: int = 3, dilations: List[int] = [1, 3, 5], ): super(ResBlock, self).__init__() self.convs1 = nn.ModuleList() self.convs2 = nn.ModuleList() for dilation in dilations: self.convs1.append( weight_norm( Conv1d( channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation) ) ) ) self.convs2.append( weight_norm( Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1) ) ) ) self.convs1.apply(init_weights) self.convs2.apply(init_weights) self.activations1 = nn.ModuleList([ Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1)) ]) self.activations2 = nn.ModuleList([ Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2)) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: for idx in range(len(self.convs1)): xt = self.activations1[idx](x) xt = self.convs1[idx](xt) xt = self.activations2[idx](xt) xt = self.convs2[idx](xt) x = xt + x return x def remove_weight_norm(self): for idx in range(len(self.convs1)): remove_weight_norm(self.convs1[idx]) remove_weight_norm(self.convs2[idx]) class SineGen(torch.nn.Module): """ Definition of sine generator SineGen(samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0, flag_for_pulse=False) samp_rate: sampling rate in Hz harmonic_num: number of harmonic overtones (default 0) sine_amp: amplitude of sine-wavefrom (default 0.1) noise_std: std of Gaussian noise (default 0.003) voiced_thoreshold: F0 threshold for U/V classification (default 0) flag_for_pulse: this SinGen is used inside PulseGen (default False) Note: when flag_for_pulse is True, the first time step of a voiced segment is always sin(np.pi) or cos(0) """ def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): super(SineGen, self).__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold def _f02uv(self, f0): # generate uv signal uv = (f0 > self.voiced_threshold).type(torch.float32) return uv @torch.no_grad() def forward(self, f0): """ :param f0: [B, 1, sample_len], Hz :return: [B, 1, sample_len] """ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) for i in range(self.harmonic_num + 1): F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) u_dist = Uniform(low=-np.pi, high=np.pi) phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) phase_vec[:, 0, :] = 0 # generate sine waveforms sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) # generate uv signal uv = self._f02uv(f0) # noise: for unvoiced should be similar to sine_amp # std = self.sine_amp/3 -> max value ~ self.sine_amp # . for voiced regions is self.noise_std noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) # first: set the unvoiced part to 0 by uv # then: additive noise sine_waves = sine_waves * uv + noise return sine_waves, uv, noise class SourceModuleHnNSF(torch.nn.Module): """ SourceModule for hn-nsf SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0) sampling_rate: sampling_rate in Hz harmonic_num: number of harmonic above F0 (default: 0) sine_amp: amplitude of sine source signal (default: 0.1) add_noise_std: std of additive Gaussian noise (default: 0.003) note that amplitude of noise in unvoiced is decided by sine_amp voiced_threshold: threhold to set U/V given F0 (default: 0) Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) Sine_source (batchsize, length, 1) noise_source (batchsize, length 1) uv (batchsize, length, 1) """ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0): super(SourceModuleHnNSF, self).__init__() self.sine_amp = sine_amp self.noise_std = add_noise_std # to produce sine waveforms self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) # to merge source harmonics into a single excitation self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() def forward(self, x): """ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) Sine_source (batchsize, length, 1) noise_source (batchsize, length 1) """ # source for harmonic branch with torch.no_grad(): sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) sine_wavs = sine_wavs.transpose(1, 2) uv = uv.transpose(1, 2) sine_merge = self.l_tanh(self.l_linear(sine_wavs)) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 return sine_merge, noise, uv class HiFTGenerator(nn.Module): """ HiFTNet Generator: Neural Source Filter + ISTFTNet https://arxiv.org/abs/2309.09493 """ def __init__( self, in_channels: int = 80, base_channels: int = 512, nb_harmonics: int = 8, sampling_rate: int = 22050, nsf_alpha: float = 0.1, nsf_sigma: float = 0.003, nsf_voiced_threshold: float = 10, upsample_rates: List[int] = [8, 8], upsample_kernel_sizes: List[int] = [16, 16], istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, resblock_kernel_sizes: List[int] = [3, 7, 11], resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], source_resblock_kernel_sizes: List[int] = [7, 11], source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], lrelu_slope: float = 0.1, audio_limit: float = 0.99, f0_predictor: torch.nn.Module = None, ): super(HiFTGenerator, self).__init__() self.out_channels = 1 self.nb_harmonics = nb_harmonics self.sampling_rate = sampling_rate self.istft_params = istft_params self.lrelu_slope = lrelu_slope self.audio_limit = audio_limit self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.m_source = SourceModuleHnNSF( sampling_rate=sampling_rate, upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], harmonic_num=nb_harmonics, sine_amp=nsf_alpha, add_noise_std=nsf_sigma, voiced_threshod=nsf_voiced_threshold) self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) self.conv_pre = weight_norm( Conv1d(in_channels, base_channels, 7, 1, padding=3) ) # Up self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups.append( weight_norm( ConvTranspose1d( base_channels // (2**i), base_channels // (2**(i + 1)), k, u, padding=(k - u) // 2, ) ) ) # Down self.source_downs = nn.ModuleList() self.source_resblocks = nn.ModuleList() downsample_rates = [1] + upsample_rates[::-1][:-1] downsample_cum_rates = np.cumprod(downsample_rates) for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): if u == 1: self.source_downs.append( Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) ) else: self.source_downs.append( Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) ) self.source_resblocks.append( ResBlock(base_channels // (2 ** (i + 1)), k, d) ) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = base_channels // (2**(i + 1)) for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(ResBlock(ch, k, d)) self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.reflection_pad = nn.ReflectionPad1d((1, 0)) self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) self.f0_predictor = f0_predictor def remove_weight_norm(self): print('Removing weight norm...') for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) self.m_source.remove_weight_norm() for l in self.source_downs: remove_weight_norm(l) for l in self.source_resblocks: l.remove_weight_norm() def _stft(self, x): spec = torch.stft( x, self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), return_complex=True) spec = torch.view_as_real(spec) # [B, F, TT, 2] return spec[..., 0], spec[..., 1] def _istft(self, magnitude, phase): magnitude = torch.clip(magnitude, max=1e2) real = magnitude * torch.cos(phase) img = magnitude * torch.sin(phase) inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) return inverse_transform def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) x = self.conv_pre(x) for i in range(self.num_upsamples): x = F.leaky_relu(x, self.lrelu_slope) x = self.ups[i](x) if i == self.num_upsamples - 1: x = self.reflection_pad(x) # fusion si = self.source_downs[i](s_stft) si = self.source_resblocks[i](si) x = x + si xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy x = self._istft(magnitude, phase) x = torch.clamp(x, -self.audio_limit, self.audio_limit) return x def forward( self, batch: dict, device: torch.device, ) -> Dict[str, Optional[torch.Tensor]]: speech_feat = batch['speech_feat'].transpose(1, 2).to(device) # mel->f0 f0 = self.f0_predictor(speech_feat) # f0->source s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t s, _, _ = self.m_source(s) s = s.transpose(1, 2) # mel+source->speech generated_speech = self.decode(x=speech_feat, s=s) return generated_speech, f0 @torch.inference_mode() def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: # mel->f0 f0 = self.f0_predictor(speech_feat) # f0->source s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t s, _, _ = self.m_source(s) s = s.transpose(1, 2) # use cache_source to avoid glitch if cache_source.shape[2] != 0: s[:, :, :cache_source.shape[2]] = cache_source generated_speech = self.decode(x=speech_feat, s=s) return generated_speech, s ================================================ FILE: src/chatterbox/models/s3gen/matcha/decoder.py ================================================ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat from .transformer import BasicTransformerBlock class SinusoidalPosEmb(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" def forward(self, x, scale=1000): if x.ndim < 1: x = x.unsqueeze(0) device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Block1D(torch.nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.block = torch.nn.Sequential( torch.nn.Conv1d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), nn.Mish(), ) def forward(self, x, mask): output = self.block(x * mask) return output * mask class ResnetBlock1D(torch.nn.Module): def __init__(self, dim, dim_out, time_emb_dim, groups=8): super().__init__() self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) self.block1 = Block1D(dim, dim_out, groups=groups) self.block2 = Block1D(dim_out, dim_out, groups=groups) self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) def forward(self, x, mask, time_emb): h = self.block1(x, mask) h += self.mlp(time_emb).unsqueeze(-1) h = self.block2(h, mask) output = h + self.res_conv(x * mask) return output class Downsample1D(nn.Module): def __init__(self, dim): super().__init__() self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs): assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class ConformerWrapper(ConformerBlock): def __init__( # pylint: disable=useless-super-delegation self, *, dim, dim_head=64, heads=8, ff_mult=4, conv_expansion_factor=2, conv_kernel_size=31, attn_dropout=0, ff_dropout=0, conv_dropout=0, conv_causal=False, ): super().__init__( dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, conv_expansion_factor=conv_expansion_factor, conv_kernel_size=conv_kernel_size, attn_dropout=attn_dropout, ff_dropout=ff_dropout, conv_dropout=conv_dropout, conv_causal=conv_causal, ) def forward( self, hidden_states, attention_mask, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, ): return super().forward(x=hidden_states, mask=attention_mask.bool()) class Decoder(nn.Module): def __init__( self, in_channels, out_channels, channels=(256, 256), dropout=0.05, attention_head_dim=64, n_blocks=1, num_mid_blocks=2, num_heads=4, act_fn="snake", down_block_type="transformer", mid_block_type="transformer", up_block_type="transformer", ): super().__init__() channels = tuple(channels) self.in_channels = in_channels self.out_channels = out_channels self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=in_channels, time_embed_dim=time_embed_dim, act_fn="silu", ) self.down_blocks = nn.ModuleList([]) self.mid_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) output_channel = in_channels for i in range(len(channels)): # pylint: disable=consider-using-enumerate input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ self.get_block( down_block_type, output_channel, attention_head_dim, num_heads, dropout, act_fn, ) for _ in range(n_blocks) ] ) downsample = ( Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for i in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ self.get_block( mid_block_type, output_channel, attention_head_dim, num_heads, dropout, act_fn, ) for _ in range(n_blocks) ] ) self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) channels = channels[::-1] + (channels[0],) for i in range(len(channels) - 1): input_channel = channels[i] output_channel = channels[i + 1] is_last = i == len(channels) - 2 resnet = ResnetBlock1D( dim=2 * input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, ) transformer_blocks = nn.ModuleList( [ self.get_block( up_block_type, output_channel, attention_head_dim, num_heads, dropout, act_fn, ) for _ in range(n_blocks) ] ) upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) self.final_block = Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() # nn.init.normal_(self.final_proj.weight) @staticmethod def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): if block_type == "conformer": block = ConformerWrapper( dim=dim, dim_head=attention_head_dim, heads=num_heads, ff_mult=1, conv_expansion_factor=2, ff_dropout=dropout, attn_dropout=dropout, conv_dropout=dropout, conv_kernel_size=31, ) elif block_type == "transformer": block = BasicTransformerBlock( dim=dim, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) else: raise ValueError(f"Unknown block type {block_type}") return block def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, mask, mu, t, spks=None, cond=None): """Forward pass of the UNet1DConditional model. Args: x (torch.Tensor): shape (batch_size, in_channels, time) mask (_type_): shape (batch_size, 1, time) t (_type_): shape (batch_size) spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. cond (_type_, optional): placeholder for future use. Defaults to None. Raises: ValueError: _description_ ValueError: _description_ Returns: _type_: _description_ """ t = self.time_embeddings(t) t = self.time_mlp(t) x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) x = pack([x, spks], "b * t")[0] hiddens = [] masks = [mask] for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c") mask_down = rearrange(mask_down, "b 1 t -> b t") for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=mask_down, timestep=t, ) x = rearrange(x, "b t c -> b c t") mask_down = rearrange(mask_down, "b t -> b 1 t") hiddens.append(x) # Save hidden states for skip connections x = downsample(x * mask_down) masks.append(mask_down[:, :, ::2]) masks = masks[:-1] mask_mid = masks[-1] for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c") mask_mid = rearrange(mask_mid, "b 1 t -> b t") for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=mask_mid, timestep=t, ) x = rearrange(x, "b t c -> b c t") mask_mid = rearrange(mask_mid, "b t -> b 1 t") for resnet, transformer_blocks, upsample in self.up_blocks: mask_up = masks.pop() x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) x = rearrange(x, "b c t -> b t c") mask_up = rearrange(mask_up, "b 1 t -> b t") for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, attention_mask=mask_up, timestep=t, ) x = rearrange(x, "b t c -> b c t") mask_up = rearrange(mask_up, "b t -> b 1 t") x = upsample(x * mask_up) x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask ================================================ FILE: src/chatterbox/models/s3gen/matcha/flow_matching.py ================================================ from abc import ABC import torch import torch.nn.functional as F from .decoder import Decoder class BASECFM(torch.nn.Module, ABC): def __init__( self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128, ): super().__init__() self.n_feats = n_feats self.n_spks = n_spks self.spk_emb_dim = spk_emb_dim self.solver = cfm_params.solver if hasattr(cfm_params, "sigma_min"): self.sigma_min = cfm_params.sigma_min else: self.sigma_min = 1e-4 self.estimator = None @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) def solve_euler(self, x, t_span, mu, mask, spks, cond): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag sol = [] for step in range(1, len(t_span)): dphi_dt = self.estimator(x, mask, mu, t, spks, cond) x = x + dt * dphi_dt t = t + dt sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t return sol[-1] def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): target mask shape: (batch_size, 1, mel_timesteps) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) spks (torch.Tensor, optional): speaker embedding. Defaults to None. shape: (batch_size, spk_emb_dim) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_feats, mel_timesteps) """ b, _, t = mu.shape # random timestep t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( torch.sum(mask) * u.shape[1] ) return loss, y class CFM(BASECFM): def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): super().__init__( n_feats=in_channels, cfm_params=cfm_params, n_spks=n_spks, spk_emb_dim=spk_emb_dim, ) in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) # Just change the architecture of the estimator here self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) ================================================ FILE: src/chatterbox/models/s3gen/matcha/text_encoder.py ================================================ """ from https://github.com/jaywalnut310/glow-tts """ import math import torch import torch.nn as nn from einops import rearrange def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-4): super().__init__() self.channels = channels self.eps = eps self.gamma = torch.nn.Parameter(torch.ones(channels)) self.beta = torch.nn.Parameter(torch.zeros(channels)) def forward(self, x): n_dims = len(x.shape) mean = torch.mean(x, 1, keepdim=True) variance = torch.mean((x - mean) ** 2, 1, keepdim=True) x = (x - mean) * torch.rsqrt(variance + self.eps) shape = [1, -1] + [1] * (n_dims - 2) x = x * self.gamma.view(*shape) + self.beta.view(*shape) return x class ConvReluNorm(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.kernel_size = kernel_size self.n_layers = n_layers self.p_dropout = p_dropout self.conv_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList() self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) for _ in range(n_layers - 1): self.conv_layers.append( torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() def forward(self, x, x_mask): x_org = x for i in range(self.n_layers): x = self.conv_layers[i](x * x_mask) x = self.norm_layers[i](x) x = self.relu_drop(x) x = x_org + self.proj(x) return x * x_mask class DurationPredictor(nn.Module): def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(filter_channels) self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) def forward(self, x, x_mask): x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) x = self.drop(x) x = self.conv_2(x * x_mask) x = torch.relu(x) x = self.norm_2(x) x = self.drop(x) x = self.proj(x * x_mask) return x * x_mask class RotaryPositionalEmbeddings(nn.Module): """ ## RoPE module Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token. """ def __init__(self, d: int, base: int = 10_000): r""" * `d` is the number of features $d$ * `base` is the constant used for calculating $\Theta$ """ super().__init__() self.base = base self.d = int(d) self.cos_cached = None self.sin_cached = None def _build_cache(self, x: torch.Tensor): r""" Cache $\cos$ and $\sin$ values """ # Return if cache is already built if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: return # Get sequence length seq_len = x.shape[0] # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.einsum("n,d->nd", seq_idx, theta) # Concatenate so that for row $m$ we have # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # Cache them self.cos_cached = idx_theta2.cos()[:, None, None, :] self.sin_cached = idx_theta2.sin()[:, None, None, :] def _neg_half(self, x: torch.Tensor): # $\frac{d}{2}$ d_2 = self.d // 2 # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) def forward(self, x: torch.Tensor): """ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` """ # Cache $\cos$ and $\sin$ values x = rearrange(x, "b h t d -> t b h d") self._build_cache(x) # Split the features, we can choose to apply rotary embeddings only to a partial set of features. x_rope, x_pass = x[..., : self.d], x[..., self.d :] # Calculate # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ neg_half_x = self._neg_half(x_rope) x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") class MultiHeadAttention(nn.Module): def __init__( self, channels, out_channels, n_heads, heads_share=True, p_dropout=0.0, proximal_bias=False, proximal_init=False, ): super().__init__() assert channels % n_heads == 0 self.channels = channels self.out_channels = out_channels self.n_heads = n_heads self.heads_share = heads_share self.proximal_bias = proximal_bias self.p_dropout = p_dropout self.attn = None self.k_channels = channels // n_heads self.conv_q = torch.nn.Conv1d(channels, channels, 1) self.conv_k = torch.nn.Conv1d(channels, channels, 1) self.conv_v = torch.nn.Conv1d(channels, channels, 1) # from https://nn.labml.ai/transformers/rope/index.html self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) self.drop = torch.nn.Dropout(p_dropout) torch.nn.init.xavier_uniform_(self.conv_q.weight) torch.nn.init.xavier_uniform_(self.conv_k.weight) if proximal_init: self.conv_k.weight.data.copy_(self.conv_q.weight.data) self.conv_k.bias.data.copy_(self.conv_q.bias.data) torch.nn.init.xavier_uniform_(self.conv_v.weight) def forward(self, x, c, attn_mask=None): q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) x, self.attn = self.attention(q, k, v, mask=attn_mask) x = self.conv_o(x) return x def attention(self, query, key, value, mask=None): b, d, t_s, t_t = (*key.size(), query.size(2)) query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) query = self.query_rotary_pe(query) key = self.key_rotary_pe(key) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) p_attn = torch.nn.functional.softmax(scores, dim=-1) p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) output = output.transpose(2, 3).contiguous().view(b, d, t_t) return output, p_attn @staticmethod def _attention_bias_proximal(length): r = torch.arange(length, dtype=torch.float32) diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) class FFN(nn.Module): def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) self.drop = torch.nn.Dropout(p_dropout) def forward(self, x, x_mask): x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.drop(x) x = self.conv_2(x * x_mask) return x * x_mask class Encoder(nn.Module): def __init__( self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, **kwargs, ): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) self.attn_layers = torch.nn.ModuleList() self.norm_layers_1 = torch.nn.ModuleList() self.ffn_layers = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList() for _ in range(self.n_layers): self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, ) ) self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask): attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) for i in range(self.n_layers): x = x * x_mask y = self.attn_layers[i](x, x, attn_mask) y = self.drop(y) x = self.norm_layers_1[i](x + y) y = self.ffn_layers[i](x, x_mask) y = self.drop(y) x = self.norm_layers_2[i](x + y) x = x * x_mask return x class TextEncoder(nn.Module): def __init__( self, encoder_type, encoder_params, duration_predictor_params, n_vocab, n_spks=1, spk_emb_dim=128, ): super().__init__() self.encoder_type = encoder_type self.n_vocab = n_vocab self.n_feats = encoder_params.n_feats self.n_channels = encoder_params.n_channels self.spk_emb_dim = spk_emb_dim self.n_spks = n_spks self.emb = torch.nn.Embedding(n_vocab, self.n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) if encoder_params.prenet: self.prenet = ConvReluNorm( self.n_channels, self.n_channels, self.n_channels, kernel_size=5, n_layers=3, p_dropout=0.5, ) else: self.prenet = lambda x, x_mask: x self.encoder = Encoder( encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), encoder_params.filter_channels, encoder_params.n_heads, encoder_params.n_layers, encoder_params.kernel_size, encoder_params.p_dropout, ) self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) self.proj_w = DurationPredictor( self.n_channels + (spk_emb_dim if n_spks > 1 else 0), duration_predictor_params.filter_channels_dp, duration_predictor_params.kernel_size, duration_predictor_params.p_dropout, ) def forward(self, x, x_lengths, spks=None): """Run forward pass to the transformer based encoder and duration predictor Args: x (torch.Tensor): text input shape: (batch_size, max_text_length) x_lengths (torch.Tensor): text input lengths shape: (batch_size,) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size,) Returns: mu (torch.Tensor): average output of the encoder shape: (batch_size, n_feats, max_text_length) logw (torch.Tensor): log duration predicted by the duration predictor shape: (batch_size, 1, max_text_length) x_mask (torch.Tensor): mask for the text input shape: (batch_size, 1, max_text_length) """ x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.prenet(x, x_mask) if self.n_spks > 1: x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) x = self.encoder(x, x_mask) mu = self.proj_m(x) * x_mask x_dp = torch.detach(x) logw = self.proj_w(x_dp, x_mask) return mu, logw, x_mask ================================================ FILE: src/chatterbox/models/s3gen/matcha/transformer.py ================================================ from typing import Any, Dict, Optional import torch import torch.nn as nn from diffusers.models.attention import ( GEGLU, GELU, AdaLayerNorm, AdaLayerNormZero, ApproximateGELU, ) from diffusers.models.attention_processor import Attention from diffusers.models.lora import LoRACompatibleLinear from diffusers.utils.torch_utils import maybe_allow_in_graph class SnakeBeta(nn.Module): """ A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256) >>> x = torch.randn(256) >>> x = a1(x) """ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): """ Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. """ super().__init__() self.in_features = out_features if isinstance(out_features, list) else [out_features] self.proj = LoRACompatibleLinear(in_features, out_features) # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): """ Forward pass of the function. Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ x = self.proj(x) if self.alpha_logscale: alpha = torch.exp(self.alpha) beta = torch.exp(self.beta) else: alpha = self.alpha beta = self.beta x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) return x class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) elif activation_fn == "snakebeta": act_fn = SnakeBeta(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, # scale_qk=False, # uncomment this to not to use flash attention ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim, ) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states ================================================ FILE: src/chatterbox/models/s3gen/s3gen.py ================================================ # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy as np import torch import torchaudio as ta from functools import lru_cache from typing import Optional from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer from .const import S3GEN_SR from .flow import CausalMaskedDiffWithXvec from .xvector import CAMPPlus from .utils.mel import mel_spectrogram from .f0_predictor import ConvRNNF0Predictor from .hifigan import HiFTGenerator from .transformer.upsample_encoder import UpsampleConformerEncoder from .flow_matching import CausalConditionalCFM from .decoder import ConditionalDecoder from .configs import CFM_PARAMS def drop_invalid_tokens(x): assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" return x[x < SPEECH_VOCAB_SIZE] # TODO: global resampler cache @lru_cache(100) def get_resampler(src_sr, dst_sr, device): return ta.transforms.Resample(src_sr, dst_sr).to(device) class S3Token2Mel(torch.nn.Module): """ S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms. TODO: make these modules configurable? """ def __init__(self, meanflow=False): super().__init__() self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") self.mel_extractor = mel_spectrogram # TODO: make it a torch module? self.speaker_encoder = CAMPPlus( # NOTE: This doesn't affect inference. It turns off activation checkpointing # (a training optimization), which causes a crazy DDP error with accelerate memory_efficient=False, ) self.meanflow = meanflow encoder = UpsampleConformerEncoder( output_size=512, attention_heads=8, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.1, normalize_before=True, input_layer='linear', pos_enc_layer_type='rel_pos_espnet', selfattention_layer_type='rel_selfattn', input_size=512, use_cnn_module=False, macaron_style=False, ) estimator = ConditionalDecoder( in_channels=320, out_channels=80, causal=True, channels=[256], dropout=0.0, attention_head_dim=64, n_blocks=4, num_mid_blocks=12, num_heads=8, act_fn='gelu', meanflow=self.meanflow, ) cfm_params = CFM_PARAMS decoder = CausalConditionalCFM( spk_emb_dim=80, cfm_params=cfm_params, estimator=estimator, ) self.flow = CausalMaskedDiffWithXvec( encoder=encoder, decoder=decoder ) self.resamplers = {} @property def device(self): params = self.tokenizer.parameters() return next(params).device @property def dtype(self): params = self.flow.parameters() return next(params).dtype def embed_ref( self, ref_wav: torch.Tensor, ref_sr: int, device="auto", ref_fade_out=True, ): device = self.device if device == "auto" else device if isinstance(ref_wav, np.ndarray): ref_wav = torch.from_numpy(ref_wav).float() if ref_wav.device != device: ref_wav = ref_wav.to(device) if len(ref_wav.shape) == 1: ref_wav = ref_wav.unsqueeze(0) # (B, L) if ref_wav.size(1) > 10 * ref_sr: print("WARNING: s3gen received ref longer than 10s") ref_wav_24 = ref_wav if ref_sr != S3GEN_SR: ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype) ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype) ref_mels_24_len = None # Resample to 16kHz ref_wav_16 = ref_wav if ref_sr != S3_SR: ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav) # Speaker embedding ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype)) # Tokenize 16khz reference ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float()) # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: logging.warning( "Reference mel length is not equal to 2 * reference token length.\n" ) ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] ref_speech_token_lens[0] = ref_speech_tokens.shape[1] return dict( prompt_token=ref_speech_tokens.to(device), prompt_token_len=ref_speech_token_lens, prompt_feat=ref_mels_24, prompt_feat_len=ref_mels_24_len, embedding=ref_x_vector, ) def forward( self, speech_tokens: torch.LongTensor, # locally-computed ref embedding (mutex with ref_dict) ref_wav: Optional[torch.Tensor], ref_sr: Optional[int], # pre-computed ref embedding (prod API) ref_dict: Optional[dict] = None, n_cfm_timesteps = None, finalize: bool = False, speech_token_lens=None, noised_mels=None, ): """ Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. NOTE: - The speaker encoder accepts 16 kHz waveform. - S3TokenizerV2 accepts 16 kHz waveform. - The mel-spectrogram for the reference assumes 24 kHz input signal. - This function is designed for batch_size=1 only. Args ---- - `speech_tokens`: S3 speech tokens [B=1, T] - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) - `ref_sr`: reference sample rate - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. """ assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" if ref_dict is None: ref_dict = self.embed_ref(ref_wav, ref_sr) else: # type/device casting (all values will be numpy if it's from a prod API call) for rk in list(ref_dict): if isinstance(ref_dict[rk], np.ndarray): ref_dict[rk] = torch.from_numpy(ref_dict[rk]) if torch.is_tensor(ref_dict[rk]): ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype) speech_tokens = torch.atleast_2d(speech_tokens) # backcompat if speech_token_lens is None: speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device) output_mels, _ = self.flow.inference( token=speech_tokens, token_len=speech_token_lens, finalize=finalize, noised_mels=noised_mels, n_timesteps=n_cfm_timesteps, meanflow=self.meanflow, **ref_dict, ) return output_mels class S3Token2Wav(S3Token2Mel): """ The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. TODO: make these modules configurable? """ ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window") def __init__(self, meanflow=False): super().__init__(meanflow) f0_predictor = ConvRNNF0Predictor() self.mel2wav = HiFTGenerator( sampling_rate=S3GEN_SR, upsample_rates=[8, 5, 3], upsample_kernel_sizes=[16, 11, 7], source_resblock_kernel_sizes=[7, 7, 11], source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], f0_predictor=f0_predictor, ) # silence out a few ms and fade audio in to reduce artifacts n_trim = S3GEN_SR // 50 # 20ms = half of a frame trim_fade = torch.zeros(2 * n_trim) trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) self.estimator_dtype = "fp32" def forward( self, speech_tokens, # locally-computed ref embedding (mutex with ref_dict) ref_wav: Optional[torch.Tensor], ref_sr: Optional[int], # pre-computed ref embedding (prod API) ref_dict: Optional[dict] = None, finalize: bool = False, speech_token_lens=None, skip_vocoder=False, n_cfm_timesteps=None, noised_mels=None, ): """ Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis. """ output_mels = super().forward( speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize, n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels, ) if skip_vocoder: return output_mels # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. hift_cache_source = torch.zeros(1, 1, 0).to(self.device) output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) if not self.training: # NOTE: ad-hoc method to reduce "spillover" from the reference clip. output_wavs[:, :len(self.trim_fade)] *= self.trim_fade return output_wavs @torch.inference_mode() def flow_inference( self, speech_tokens, # locally-computed ref embedding (mutex with ref_dict) ref_wav: Optional[torch.Tensor] = None, ref_sr: Optional[int] = None, # pre-computed ref embedding (prod API) ref_dict: Optional[dict] = None, n_cfm_timesteps = None, finalize: bool = False, speech_token_lens=None, ): n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10) noise = None if self.meanflow: noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device) output_mels = super().forward( speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise, ) return output_mels @torch.inference_mode() def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): if cache_source is None: cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype) return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) @torch.inference_mode() def inference( self, speech_tokens, # locally-computed ref embedding (mutex with ref_dict) ref_wav: Optional[torch.Tensor] = None, ref_sr: Optional[int] = None, # pre-computed ref embedding (prod API) ref_dict: Optional[dict] = None, # left as a kwarg because this can change input/output size ratio drop_invalid_tokens=True, n_cfm_timesteps=None, speech_token_lens=None, ): # hallucination prevention, drop special tokens # if drop_invalid_tokens: # speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD) output_mels = self.flow_inference( speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, n_cfm_timesteps=n_cfm_timesteps, finalize=True, ) output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed? output_wavs, output_sources = self.hift_inference(output_mels, None) # NOTE: ad-hoc method to reduce "spillover" from the reference clip. output_wavs[:, :len(self.trim_fade)] *= self.trim_fade return output_wavs, output_sources ================================================ FILE: src/chatterbox/models/s3gen/transformer/__init__.py ================================================ ================================================ FILE: src/chatterbox/models/s3gen/transformer/activation.py ================================================ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) # 2020 Northwestern Polytechnical University (Pengcheng Guo) # 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Swish() activation function for Conformer.""" import torch from torch import nn, sin, pow from torch.nn import Parameter class Swish(torch.nn.Module): """Construct an Swish object.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """Return Swish activation function.""" return x * torch.sigmoid(x) # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. # LICENSE is in incl_licenses directory. class Snake(nn.Module): ''' Implementation of a sine-based periodic activation function Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha: trainable parameter alpha is initialized to 1 by default, higher values = higher-frequency. alpha will be trained along with the rest of your model. ''' super(Snake, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x ================================================ FILE: src/chatterbox/models/s3gen/transformer/attention.py ================================================ # Copyright (c) 2019 Shigeki Karita # 2020 Mobvoi Inc (Binbin Zhang) # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Multi-Head Attention layer definition.""" import math from typing import Tuple import torch from torch import nn class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True): """Construct an MultiHeadedAttention object.""" super().__init__() assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head self.linear_q = nn.Linear(n_feat, n_feat) self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) self.linear_v = nn.Linear(n_feat, n_feat) self.linear_out = nn.Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transform query, key and value. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). Returns: torch.Tensor: Transformed query tensor, size (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor, size (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor, size (#batch, n_head, time2, d_k). """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) q = q.transpose(1, 2) # (batch, head, time1, d_k) k = k.transpose(1, 2) # (batch, head, time2, d_k) v = v.transpose(1, 2) # (batch, head, time2, d_k) return q, k, v def forward_attention( self, value: torch.Tensor, scores: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) ) -> torch.Tensor: """Compute attention context vector. Args: value (torch.Tensor): Transformed value, size (#batch, n_head, time2, d_k). scores (torch.Tensor): Attention score, size (#batch, n_head, time1, time2). mask (torch.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2), (0, 0, 0) means fake mask. Returns: torch.Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2). """ n_batch = value.size(0) # NOTE(xcsong): When will `if mask.size(2) > 0` be True? # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the # 1st chunk to ease the onnx export.] # 2. pytorch training if mask.size(2) > 0: # time2 > 0 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) # For last chunk, time2 might be larger than scores.size(-1) mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) attn = torch.softmax(scores, dim=-1).masked_fill( mask, 0.0) # (batch, head, time1, time2) # NOTE(xcsong): When will `if mask.size(2) > 0` be False? # 1. onnx(16/-1, -1/-1, 16/0) # 2. jit (16/-1, -1/-1, 16/0, 16/4) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) p_attn = self.dropout(attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute scaled dot product attention. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). 1.When applying cross attention between decoder and encoder, the batch padding mask for input is in (#batch, 1, T) shape. 2.When applying self attention of encoder, the mask is in (#batch, T, T) shape. 3.When applying self attention of decoder, the mask is in (#batch, L, L) shape. 4.If the different position in decoder see different block of the encoder, such as Mocha, the passed in mask could be in (#batch, L, T) shape. But there is no such case in current CosyVoice. cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` Returns: torch.Tensor: Output tensor (#batch, time1, d_model). torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` """ q, k, v = self.forward_qkv(query, key, value) # NOTE(xcsong): # when export onnx model, for 1st chunk, we feed # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). # In all modes, `if cache.size(0) > 0` will alwayse be `True` # and we will always do splitting and # concatnation(this will simplify onnx export). Note that # it's OK to concat & split zero-shaped tensors(see code below). # when export jit model, for 1st chunk, we always feed # cache(0, 0, 0, 0) since jit supports dynamic if-branch. # >>> a = torch.ones((1, 2, 0, 4)) # >>> b = torch.ones((1, 2, 3, 4)) # >>> c = torch.cat((a, b), dim=2) # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if cache.size(0) > 0: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = torch.cat((k, v), dim=-1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache class RelPositionMultiHeadedAttention(MultiHeadedAttention): """Multi-Head Attention layer with relative position encoding. Paper: https://arxiv.org/abs/1901.02860 Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate, key_bias) # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) torch.nn.init.xavier_uniform_(self.pos_bias_u) torch.nn.init.xavier_uniform_(self.pos_bias_v) def rel_shift(self, x: torch.Tensor) -> torch.Tensor: """Compute relative positional encoding. Args: x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. Returns: torch.Tensor: Output tensor. """ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype) x_padded = torch.cat([zero_pad, x], dim=-1) x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2)) x = x_padded[:, :, 1:].view_as(x)[ :, :, :, : x.size(-1) // 2 + 1 ] # only keep the positions from 0 to time2 return x def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size). cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` Returns: torch.Tensor: Output tensor (#batch, time1, d_model). torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` """ q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) # NOTE(xcsong): # when export onnx model, for 1st chunk, we feed # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). # In all modes, `if cache.size(0) > 0` will alwayse be `True` # and we will always do splitting and # concatnation(this will simplify onnx export). Note that # it's OK to concat & split zero-shaped tensors(see code below). # when export jit model, for 1st chunk, we always feed # cache(0, 0, 0, 0) since jit supports dynamic if-branch. # >>> a = torch.ones((1, 2, 0, 4)) # >>> b = torch.ones((1, 2, 3, 4)) # >>> c = torch.cat((a, b), dim=2) # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if cache.size(0) > 0: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = torch.cat((k, v), dim=-1) n_batch_pos = pos_emb.size(0) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose(1, 2) # (batch, head, time1, d_k) # (batch, head, time1, d_k) q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2) # (batch, head, time1, d_k) q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2) # compute attention score # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) # compute matrix b and matrix d # (batch, head, time1, time2) matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used if matrix_ac.shape != matrix_bd.shape: matrix_bd = self.rel_shift(matrix_bd) scores = (matrix_ac + matrix_bd) / math.sqrt( self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache ================================================ FILE: src/chatterbox/models/s3gen/transformer/convolution.py ================================================ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """ConvolutionModule definition.""" from typing import Tuple import torch from torch import nn class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" def __init__(self, channels: int, kernel_size: int = 15, activation: nn.Module = nn.ReLU(), norm: str = "batch_norm", causal: bool = False, bias: bool = True): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. kernel_size (int): Kernel size of conv layers. causal (int): Whether use causal convolution or not """ super().__init__() self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) # self.lorder is used to distinguish if it's a causal convolution, # if self.lorder > 0: it's a causal convolution, the input will be # padded with self.lorder frames on the left in forward. # else: it's a symmetrical convolution if causal: padding = 0 self.lorder = kernel_size - 1 else: # kernel_size should be an odd number for none causal convolution assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 self.lorder = 0 self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, padding=padding, groups=channels, bias=bias, ) assert norm in ['batch_norm', 'layer_norm'] if norm == "batch_norm": self.use_layer_norm = False self.norm = nn.BatchNorm1d(channels) else: self.use_layer_norm = True self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = activation def forward( self, x: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), cache: torch.Tensor = torch.zeros((0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, time, channels). mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), (0, 0, 0) means fake mask. cache (torch.Tensor): left context cache, it is only used in causal convolution (#batch, channels, cache_t), (0, 0, 0) meas fake cache. Returns: torch.Tensor: Output tensor (#batch, time, channels). """ # exchange the temporal dimension and the feature dimension x = x.transpose(1, 2) # (#batch, channels, time) # mask batch padding if mask_pad.size(2) > 0: # time > 0 x.masked_fill_(~mask_pad, 0.0) if self.lorder > 0: if cache.size(2) == 0: # cache_t == 0 x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) else: assert cache.size(0) == x.size(0) # equal batch assert cache.size(1) == x.size(1) # equal channel x = torch.cat((cache, x), dim=2) assert (x.size(2) > self.lorder) new_cache = x[:, :, -self.lorder:] else: # It's better we just return None if no cache is required, # However, for JIT export, here we just fake one tensor instead of # None. new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = nn.functional.glu(x, dim=1) # (batch, channel, dim) # 1D Depthwise Conv x = self.depthwise_conv(x) if self.use_layer_norm: x = x.transpose(1, 2) x = self.activation(self.norm(x)) if self.use_layer_norm: x = x.transpose(1, 2) x = self.pointwise_conv2(x) # mask batch padding if mask_pad.size(2) > 0: # time > 0 x.masked_fill_(~mask_pad, 0.0) return x.transpose(1, 2), new_cache ================================================ FILE: src/chatterbox/models/s3gen/transformer/embedding.py ================================================ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Positonal Encoding Module.""" import math from typing import Tuple, Union import torch import torch.nn.functional as F import numpy as np class PositionalEncoding(torch.nn.Module): """Positional encoding. :param int d_model: embedding dim :param float dropout_rate: dropout rate :param int max_len: maximum input length PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000, reverse: bool = False): """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.max_len = max_len self.pe = torch.zeros(self.max_len, self.d_model) position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)) self.pe[:, 0::2] = torch.sin(position * div_term) self.pe[:, 1::2] = torch.cos(position * div_term) self.pe = self.pe.unsqueeze(0) def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \ -> Tuple[torch.Tensor, torch.Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input. Its shape is (batch, time, ...) offset (int, torch.tensor): position offset Returns: torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) torch.Tensor: for compatibility to RelPositionalEncoding """ self.pe = self.pe.to(x.device) pos_emb = self.position_encoding(offset, x.size(1), False) x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) def position_encoding(self, offset: Union[int, torch.Tensor], size: int, apply_dropout: bool = True) -> torch.Tensor: """ For getting encoding in a streaming fashion Attention!!!!! we apply dropout only once at the whole utterance level in a none streaming way, but will call this function several times with increasing input size in a streaming scenario, so the dropout will be applied several times. Args: offset (int or torch.tensor): start offset size (int): required size of position encoding Returns: torch.Tensor: Corresponding encoding """ # How to subscript a Union type: # https://github.com/pytorch/pytorch/issues/69434 if isinstance(offset, int): assert offset + size <= self.max_len pos_emb = self.pe[:, offset:offset + size] elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar assert offset + size <= self.max_len pos_emb = self.pe[:, offset:offset + size] else: # for batched streaming decoding on GPU assert torch.max(offset) + size <= self.max_len index = offset.unsqueeze(1) + \ torch.arange(0, size).to(offset.device) # B X T flag = index > 0 # remove negative offset index = index * flag pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model if apply_dropout: pos_emb = self.dropout(pos_emb) return pos_emb class RelPositionalEncoding(PositionalEncoding): """Relative positional encoding module. See : Appendix B in https://arxiv.org/abs/1901.02860 Args: d_model (int): Embedding dimension. dropout_rate (float): Dropout rate. max_len (int): Maximum input length. """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): """Initialize class.""" super().__init__(d_model, dropout_rate, max_len, reverse=True) def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \ -> Tuple[torch.Tensor, torch.Tensor]: """Compute positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Positional embedding tensor (1, time, `*`). """ self.pe = self.pe.to(x.device) x = x * self.xscale pos_emb = self.position_encoding(offset, x.size(1), False) return self.dropout(x), self.dropout(pos_emb) class WhisperPositionalEncoding(PositionalEncoding): """ Sinusoids position encoding used in openai-whisper.encoder """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): super().__init__(d_model, dropout_rate, max_len) self.xscale = 1.0 log_timescale_increment = np.log(10000) / (d_model // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(d_model // 2)) scaled_time = torch.arange(max_len)[:, np.newaxis] * \ inv_timescales[np.newaxis, :] pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) delattr(self, "pe") self.register_buffer("pe", pe.unsqueeze(0)) class LearnablePositionalEncoding(PositionalEncoding): """ Learnable position encoding used in openai-whisper.decoder """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): super().__init__(d_model, dropout_rate, max_len) # NOTE(xcsong): overwrite self.pe & self.xscale self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) self.xscale = 1.0 class NoPositionalEncoding(torch.nn.Module): """ No position encoding """ def __init__(self, d_model: int, dropout_rate: float): super().__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(p=dropout_rate) def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \ -> Tuple[torch.Tensor, torch.Tensor]: """ Just return zero vector for interface compatibility """ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) return self.dropout(x), pos_emb def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: return torch.zeros(1, size, self.d_model) class EspnetRelPositionalEncoding(torch.nn.Module): """Relative positional encoding module (new implementation). Details can be found in https://github.com/espnet/espnet/pull/2816. See : Appendix B in https://arxiv.org/abs/1901.02860 Args: d_model (int): Embedding dimension. dropout_rate (float): Dropout rate. max_len (int): Maximum input length. """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): """Construct an PositionalEncoding object.""" super(EspnetRelPositionalEncoding, self).__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) def extend_pe(self, x: torch.Tensor): """Reset the positional encodings.""" if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). Returns: torch.Tensor: Encoded tensor (batch, time, `*`). """ self.extend_pe(x) x = x * self.xscale pos_emb = self.position_encoding(size=x.size(1), offset=offset) return self.dropout(x), self.dropout(pos_emb) def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: """ For getting encoding in a streaming fashion Attention!!!!! we apply dropout only once at the whole utterance level in a none streaming way, but will call this function several times with increasing input size in a streaming scenario, so the dropout will be applied several times. Args: offset (int or torch.tensor): start offset size (int): required size of position encoding Returns: torch.Tensor: Corresponding encoding """ pos_emb = self.pe[ :, self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, ] return pos_emb ================================================ FILE: src/chatterbox/models/s3gen/transformer/encoder_layer.py ================================================ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder self-attention layer definition.""" from typing import Optional, Tuple import torch from torch import nn class TransformerEncoderLayer(nn.Module): """Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): True: use layer_norm before each sub-block. False: to use layer_norm after each sub-block. """ def __init__( self, size: int, self_attn: torch.nn.Module, feed_forward: torch.nn.Module, dropout_rate: float, normalize_before: bool = True, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.norm1 = nn.LayerNorm(size, eps=1e-12) self.norm2 = nn.LayerNorm(size, eps=1e-12) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before def forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. Args: x (torch.Tensor): (#batch, time, size) mask (torch.Tensor): Mask tensor for the input (#batch, time,time), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): just for interface compatibility to ConformerEncoderLayer mask_pad (torch.Tensor): does not used in transformer layer, just for unified api with conformer. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, cache_t2), not used here, it's for interface compatibility to ConformerEncoderLayer. Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time, time). torch.Tensor: att_cache tensor, (#batch=1, head, cache_t1 + time, d_k * 2). torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). """ residual = x if self.normalize_before: x = self.norm1(x) x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) x = residual + self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) return x, mask, new_att_cache, fake_cnn_cache class ConformerEncoderLayer(nn.Module): """Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. conv_module (torch.nn.Module): Convolution module instance. `ConvlutionModule` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): True: use layer_norm before each sub-block. False: use layer_norm after each sub-block. """ def __init__( self, size: int, self_attn: torch.nn.Module, feed_forward: Optional[nn.Module] = None, feed_forward_macaron: Optional[nn.Module] = None, conv_module: Optional[nn.Module] = None, dropout_rate: float = 0.1, normalize_before: bool = True, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module if feed_forward_macaron is not None: self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module self.norm_final = nn.LayerNorm( size, eps=1e-12) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before def forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. Args: x (torch.Tensor): (#batch, time, size) mask (torch.Tensor): Mask tensor for the input (#batch, time,time), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): positional encoding, must not be None for ConformerEncoderLayer. mask_pad (torch.Tensor): batch padding mask used for conv module. (#batch, 1,time), (0, 0, 0) means fake mask. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, cache_t2) Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time, time). torch.Tensor: att_cache tensor, (#batch=1, head, cache_t1 + time, d_k * 2). torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). """ # whether to use macaron style if self.feed_forward_macaron is not None: residual = x if self.normalize_before: x = self.norm_ff_macaron(x) x = residual + self.ff_scale * self.dropout( self.feed_forward_macaron(x)) if not self.normalize_before: x = self.norm_ff_macaron(x) # multi-headed self-attention module residual = x if self.normalize_before: x = self.norm_mha(x) x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x) # convolution module # Fake new cnn cache here, and then change it in conv_module new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) if self.conv_module is not None: residual = x if self.normalize_before: x = self.norm_conv(x) x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) x = residual + self.dropout(x) if not self.normalize_before: x = self.norm_conv(x) # feed forward module residual = x if self.normalize_before: x = self.norm_ff(x) x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm_ff(x) if self.conv_module is not None: x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache ================================================ FILE: src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py ================================================ # Copyright (c) 2019 Shigeki Karita # 2020 Mobvoi Inc (Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Positionwise feed forward layer definition.""" import torch class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward layer. FeedForward are appied on each position of the sequence. The output dim is same with the input dim. Args: idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function """ def __init__( self, idim: int, hidden_units: int, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU(), ): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() self.w_1 = torch.nn.Linear(idim, hidden_units) self.activation = activation self.dropout = torch.nn.Dropout(dropout_rate) self.w_2 = torch.nn.Linear(hidden_units, idim) def forward(self, xs: torch.Tensor) -> torch.Tensor: """Forward function. Args: xs: input tensor (B, L, D) Returns: output tensor, (B, L, D) """ return self.w_2(self.dropout(self.activation(self.w_1(xs)))) class MoEFFNLayer(torch.nn.Module): """ Mixture of expert with Positionwise feed forward layer See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf The output dim is same with the input dim. Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 Args: n_expert: number of expert. n_expert_per_token: The actual number of experts used for each frame idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function """ def __init__( self, n_expert: int, n_expert_per_token: int, idim: int, hidden_units: int, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU(), ): super(MoEFFNLayer, self).__init__() self.gate = torch.nn.Linear(idim, n_expert, bias=False) self.experts = torch.nn.ModuleList( PositionwiseFeedForward(idim, hidden_units, dropout_rate, activation) for _ in range(n_expert)) self.n_expert_per_token = n_expert_per_token def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. Args: xs: input tensor (B, L, D) Returns: output tensor, (B, L, D) """ B, L, D = xs.size( ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) logits, indices = torch.topk( router, self.n_expert_per_token ) # probs:(B*L, n_expert), indices: (B*L, n_expert) weights = torch.nn.functional.softmax( logits, dim=1, dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) output = torch.zeros_like(xs) # (B*L, D) for i, expert in enumerate(self.experts): mask = indices == i batch_idx, ith_expert = torch.where(mask) output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( xs[batch_idx]) return output.view(B, L, D) ================================================ FILE: src/chatterbox/models/s3gen/transformer/subsampling.py ================================================ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Subsampling layer definition.""" from typing import Tuple, Union import torch class BaseSubsampling(torch.nn.Module): def __init__(self): super().__init__() self.right_context = 0 self.subsampling_rate = 1 def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: return self.pos_enc.position_encoding(offset, size) class EmbedinigNoSubsampling(BaseSubsampling): """Embedding input without subsampling """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): super().__init__() self.embed = torch.nn.Embedding(idim, odim) self.pos_enc = pos_enc_class def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: linear input tensor (#batch, time', odim), where time' = time . torch.Tensor: linear input mask (#batch, 1, time'), where time' = time . """ x = self.embed(x) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask class LinearNoSubsampling(BaseSubsampling): """Linear transform the input without subsampling Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an linear object.""" super().__init__() self.out = torch.nn.Sequential( torch.nn.Linear(idim, odim), torch.nn.LayerNorm(odim, eps=1e-5), torch.nn.Dropout(dropout_rate), ) self.pos_enc = pos_enc_class self.right_context = 0 self.subsampling_rate = 1 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: linear input tensor (#batch, time', odim), where time' = time . torch.Tensor: linear input mask (#batch, 1, time'), where time' = time . """ x = self.out(x) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask class Conv1dSubsampling2(BaseSubsampling): """Convolutional 1D subsampling (to 1/2 length). It is designed for Whisper, ref: https://github.com/openai/whisper/blob/main/whisper/model.py Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an Conv1dSubsampling2 object.""" super().__init__() self.conv = torch.nn.Sequential( torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), torch.nn.GELU(), torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), torch.nn.GELU(), ) self.pos_enc = pos_enc_class # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer self.subsampling_rate = 2 # 4 = (3 - 1) * 1 + (3 - 1) * 1 self.right_context = 4 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 2. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 2. torch.Tensor: positional encoding """ time = x.size(1) x = x.transpose(1, 2) # (b, f, t) x = self.conv(x) x = x.transpose(1, 2) # (b, t, f) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] class Conv2dSubsampling4(BaseSubsampling): """Convolutional 2D subsampling (to 1/4 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an Conv2dSubsampling4 object.""" super().__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) self.pos_enc = pos_enc_class # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer self.subsampling_rate = 4 # 6 = (3 - 1) * 1 + (3 - 1) * 2 self.right_context = 6 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 4. torch.Tensor: positional encoding """ x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] class Conv2dSubsampling6(BaseSubsampling): """Convolutional 2D subsampling (to 1/6 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an Conv2dSubsampling6 object.""" super().__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 5, 3), torch.nn.ReLU(), ) self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) self.pos_enc = pos_enc_class # 10 = (3 - 1) * 1 + (5 - 1) * 2 self.subsampling_rate = 6 self.right_context = 10 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 6. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 6. torch.Tensor: positional encoding """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] class Conv2dSubsampling8(BaseSubsampling): """Convolutional 2D subsampling (to 1/8 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an Conv2dSubsampling8 object.""" super().__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), ) self.linear = torch.nn.Linear( odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) self.pos_enc = pos_enc_class self.subsampling_rate = 8 # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 self.right_context = 14 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 8. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 8. torch.Tensor: positional encoding """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] class LegacyLinearNoSubsampling(BaseSubsampling): """Linear transform the input without subsampling Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. """ def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an linear object.""" super().__init__() self.out = torch.nn.Sequential( torch.nn.Linear(idim, odim), torch.nn.LayerNorm(odim, eps=1e-5), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), ) self.pos_enc = pos_enc_class self.right_context = 0 self.subsampling_rate = 1 def forward( self, x: torch.Tensor, x_mask: torch.Tensor, offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: linear input tensor (#batch, time', odim), where time' = time . torch.Tensor: linear input mask (#batch, 1, time'), where time' = time . """ x = self.out(x) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask ================================================ FILE: src/chatterbox/models/s3gen/transformer/upsample_encoder.py ================================================ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" from typing import Tuple import torch from torch import nn from torch.nn import functional as F from .convolution import ConvolutionModule from .encoder_layer import ConformerEncoderLayer from .positionwise_feed_forward import PositionwiseFeedForward from ..utils.class_utils import ( COSYVOICE_EMB_CLASSES, COSYVOICE_SUBSAMPLE_CLASSES, COSYVOICE_ATTENTION_CLASSES, COSYVOICE_ACTIVATION_CLASSES, ) from ..utils.mask import make_pad_mask from ..utils.mask import add_optional_chunk_mask class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels: int, out_channels: int, stride: int = 2): super().__init__() self.channels = channels self.out_channels = out_channels self.stride = stride # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) outputs = self.conv(outputs) return outputs, input_lengths * self.stride class PreLookaheadLayer(nn.Module): def __init__(self, channels: int, pre_lookahead_len: int = 1): super().__init__() self.channels = channels self.pre_lookahead_len = pre_lookahead_len self.conv1 = nn.Conv1d( channels, channels, kernel_size=pre_lookahead_len + 1, stride=1, padding=0, ) self.conv2 = nn.Conv1d( channels, channels, kernel_size=3, stride=1, padding=0, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ inputs: (batch_size, seq_len, channels) """ outputs = inputs.transpose(1, 2).contiguous() # look ahead outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs return outputs class UpsampleConformerEncoder(torch.nn.Module): def __init__( self, input_size: int = 512, output_size: int = 512, attention_heads: int = 8, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1, input_layer: str = "linear", pos_enc_layer_type: str = "rel_pos_espnet", normalize_before: bool = True, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, positionwise_conv_kernel_size: int = 1, macaron_style: bool = False, selfattention_layer_type: str = "rel_selfattn", activation_type: str = "swish", use_cnn_module: bool = False, cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", key_bias: bool = True, gradient_checkpointing: bool = False, ): """ Args: input_size (int): input dim output_size (int): dimension of attention attention_heads (int): the number of heads of multi head attention linear_units (int): the hidden units number of position-wise feed forward num_blocks (int): the number of decoder blocks dropout_rate (float): dropout rate attention_dropout_rate (float): dropout rate in attention positional_dropout_rate (float): dropout rate after adding positional encoding input_layer (str): input layer type. optional [linear, conv2d, conv2d6, conv2d8] pos_enc_layer_type (str): Encoder positional encoding layer type. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] normalize_before (bool): True: use layer_norm before each sub-block of a layer. False: use layer_norm after each sub-block of a layer. static_chunk_size (int): chunk size for static chunk training and decoding use_dynamic_chunk (bool): whether use dynamic chunk size for training or not, You can only use fixed chunk(chunk_size > 0) or dyanmic chunk size(use_dynamic_chunk = True) global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training key_bias: whether use bias in attention.linear_k, False for whisper models. gradient_checkpointing: rerunning a forward-pass segment for each checkpointed segment during backward. """ super().__init__() self._output_size = output_size self.global_cmvn = global_cmvn self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( input_size, output_size, dropout_rate, COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, positional_dropout_rate), ) self.normalize_before = normalize_before self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk self.gradient_checkpointing = gradient_checkpointing activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() # self-attention module definition encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, key_bias, ) # feed-forward module definition positionwise_layer_args = ( output_size, linear_units, dropout_rate, activation, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, cnn_module_norm, causal) self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args), PositionwiseFeedForward(*positionwise_layer_args), PositionwiseFeedForward( *positionwise_layer_args) if macaron_style else None, ConvolutionModule( *convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, ) for _ in range(num_blocks) ]) self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( input_size, output_size, dropout_rate, COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, positional_dropout_rate), ) self.up_encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args), PositionwiseFeedForward(*positionwise_layer_args), PositionwiseFeedForward( *positionwise_layer_args) if macaron_style else None, ConvolutionModule( *convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, ) for _ in range(4) ]) def output_size(self) -> int: return self._output_size def forward( self, xs: torch.Tensor, xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. Args: xs: padded input tensor (B, T, D) xs_lens: input length (B) decoding_chunk_size: decoding chunk size for dynamic chunk 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. >=0: use num_decoding_left_chunks <0: use all left chunks Returns: encoder output tensor xs, and subsampled masks xs: padded output tensor (B, T' ~= T/subsample_rate, D) masks: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) NOTE(xcsong): We pass the `__call__` method of the modules instead of `forward` to the checkpointing API because `__call__` attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) # lookahead + conformer encoder xs = self.pre_lookahead_layer(xs) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) # upsample + conformer encoder xs = xs.transpose(1, 2).contiguous() xs, xs_lens = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) xs, pos_emb, masks = self.up_embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size * self.up_layer.stride, num_decoding_left_chunks) xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just # return the masks before encoder layers, and the masks will be used # for cross attention with decoder later return xs, masks def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.up_encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs ================================================ FILE: src/chatterbox/models/s3gen/utils/class_utils.py ================================================ # Copyright [2023-11-28] # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from ..transformer.activation import Swish from ..transformer.subsampling import ( LinearNoSubsampling, EmbedinigNoSubsampling, Conv1dSubsampling2, Conv2dSubsampling4, Conv2dSubsampling6, Conv2dSubsampling8, ) from ..transformer.embedding import ( PositionalEncoding, RelPositionalEncoding, WhisperPositionalEncoding, LearnablePositionalEncoding, NoPositionalEncoding) from ..transformer.attention import (MultiHeadedAttention, RelPositionMultiHeadedAttention) from ..transformer.embedding import EspnetRelPositionalEncoding from ..transformer.subsampling import LegacyLinearNoSubsampling COSYVOICE_ACTIVATION_CLASSES = { "hardtanh": torch.nn.Hardtanh, "tanh": torch.nn.Tanh, "relu": torch.nn.ReLU, "selu": torch.nn.SELU, "swish": getattr(torch.nn, "SiLU", Swish), "gelu": torch.nn.GELU, } COSYVOICE_SUBSAMPLE_CLASSES = { "linear": LinearNoSubsampling, "linear_legacy": LegacyLinearNoSubsampling, "embed": EmbedinigNoSubsampling, "conv1d2": Conv1dSubsampling2, "conv2d": Conv2dSubsampling4, "conv2d6": Conv2dSubsampling6, "conv2d8": Conv2dSubsampling8, 'paraformer_dummy': torch.nn.Identity } COSYVOICE_EMB_CLASSES = { "embed": PositionalEncoding, "abs_pos": PositionalEncoding, "rel_pos": RelPositionalEncoding, "rel_pos_espnet": EspnetRelPositionalEncoding, "no_pos": NoPositionalEncoding, "abs_pos_whisper": WhisperPositionalEncoding, "embed_learnable_pe": LearnablePositionalEncoding, } COSYVOICE_ATTENTION_CLASSES = { "selfattn": MultiHeadedAttention, "rel_selfattn": RelPositionMultiHeadedAttention, } ================================================ FILE: src/chatterbox/models/s3gen/utils/intmeanflow.py ================================================ import torch import torch.nn as nn def get_intmeanflow_time_mixer(dims): """" Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979 """ layer = nn.Linear(dims * 2, dims, bias=False) with torch.no_grad(): target_weight = torch.zeros(dims, 2 * dims) target_weight[:, 0:dims] = torch.eye(dims) layer.weight.data = target_weight return layer if __name__ == '__main__': D_example = 6 W_layer = get_intmeanflow_time_mixer(D_example) print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n") e_t = torch.tensor([0., 1., 2., 3., 4., 5.]) e_r = torch.tensor([6., 7., 8., 9., 10., 11.]) e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12) output = W_layer(e_concat) print(f"Test Input e_t: \n{e_t}") print(f"Test Input e_r: \n{e_r}") print(f"Test Input concat: \n{e_concat}") print(f"Forward Pass Output: \n{output.squeeze(0)}") ================================================ FILE: src/chatterbox/models/s3gen/utils/mask.py ================================================ # Copyright (c) 2019 Shigeki Karita # 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch ''' def subsequent_mask( size: int, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Create mask for subsequent steps (size, size). This mask is used only in decoder which works in an auto-regressive mode. This means the current step could only do attention with its left steps. In encoder, fully attention is used when streaming is not necessary and the sequence is not long. In this case, no attention mask is needed. When streaming is need, chunk-based attention is used in encoder. See subsequent_chunk_mask for the chunk-based attention mask. Args: size (int): size of mask str device (str): "cpu" or "cuda" or torch.Tensor.device dtype (torch.device): result dtype Returns: torch.Tensor: mask Examples: >>> subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]] """ ret = torch.ones(size, size, device=device, dtype=torch.bool) return torch.tril(ret) ''' def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device Returns: torch.Tensor: mask Examples: >>> subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] """ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks # actually this is not needed after we have inference cache implemented, will remove it later pos_idx = torch.arange(size, device=device) block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) return ret def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, use_dynamic_left_chunk: bool, decoding_chunk_size: int, static_chunk_size: int, num_decoding_left_chunks: int, enable_full_context: bool = True): """ Apply optional mask for encoder. Args: xs (torch.Tensor): padded input, (B, L, D), L for max length mask (torch.Tensor): mask for xs, (B, 1, L) use_dynamic_chunk (bool): whether to use dynamic chunk or not use_dynamic_left_chunk (bool): whether to use dynamic left chunk for training. decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. static_chunk_size (int): chunk size for static chunk training/decoding if it's greater than 0, if use_dynamic_chunk is true, this parameter will be ignored num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. >=0: use num_decoding_left_chunks <0: use all left chunks enable_full_context (bool): True: chunk size is either [1, 25] or full context(max_len) False: chunk size ~ U[1, 25] Returns: torch.Tensor: chunk mask of the input xs. """ # Whether to use chunk mask or not if use_dynamic_chunk: max_len = xs.size(1) if decoding_chunk_size < 0: chunk_size = max_len num_left_chunks = -1 elif decoding_chunk_size > 0: chunk_size = decoding_chunk_size num_left_chunks = num_decoding_left_chunks else: # chunk size is either [1, 25] or full context(max_len). # Since we use 4 times subsampling and allow up to 1s(100 frames) # delay, the maximum frame is 100 / 4 = 25. chunk_size = torch.randint(1, max_len, (1, )).item() num_left_chunks = -1 if chunk_size > max_len // 2 and enable_full_context: chunk_size = max_len else: chunk_size = chunk_size % 25 + 1 if use_dynamic_left_chunk: max_left_chunks = (max_len - 1) // chunk_size num_left_chunks = torch.randint(0, max_left_chunks, (1, )).item() chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) else: chunk_masks = masks assert chunk_masks.dtype == torch.bool if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') chunk_masks[chunk_masks.sum(dim=-1)==0] = True return chunk_masks def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ lengths = lengths.long() batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask ================================================ FILE: src/chatterbox/models/s3gen/utils/mel.py ================================================ """mel-spectrogram extraction in Matcha-TTS""" import logging from librosa.filters import mel as librosa_mel_fn import torch import numpy as np logger = logging.getLogger(__name__) # NOTE: they decalred these global vars mel_basis = {} hann_window = {} def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def spectral_normalize_torch(magnitudes): output = dynamic_range_compression_torch(magnitudes) return output """ feat_extractor: !name:matcha.utils.audio.mel_spectrogram n_fft: 1920 num_mels: 80 sampling_rate: 24000 hop_size: 480 win_size: 1920 fmin: 0 fmax: 8000 center: False """ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000, center=False): """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py Set default values according to Cosyvoice's config. """ if isinstance(y, np.ndarray): y = torch.tensor(y).float() if len(y.shape) == 1: y = y[None, ] # Debug: Check for audio clipping (values outside [-1.0, 1.0] range) min_val = torch.min(y) max_val = torch.max(y) if min_val < -1.0 or max_val > 1.0: logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}") global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned if f"{str(fmax)}_{str(y.device)}" not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad( y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" ) y = y.squeeze(1) spec = torch.view_as_real( torch.stft( y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) spec = spectral_normalize_torch(spec) return spec ================================================ FILE: src/chatterbox/models/s3gen/xvector.py ================================================ #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) from collections import OrderedDict import torch import torch.nn.functional as F import torch.utils.checkpoint as cp import torchaudio.compliance.kaldi as Kaldi def pad_list(xs, pad_value): """Perform padding for the list of tensors. Args: xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. pad_value (float): Value for padding. Returns: Tensor: Padded tensor (B, Tmax, `*`). Examples: >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] >>> x [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] >>> pad_list(x, 0) tensor([[1., 1., 1., 1.], [1., 1., 0., 0.], [1., 0., 0., 0.]]) """ n_batch = len(xs) max_len = max(x.size(0) for x in xs) pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) for i in range(n_batch): pad[i, : xs[i].size(0)] = xs[i] return pad def extract_feature(audio): features = [] feature_times = [] feature_lengths = [] for au in audio: feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80) feature = feature - feature.mean(dim=0, keepdim=True) features.append(feature) feature_times.append(au.shape[0]) feature_lengths.append(feature.shape[0]) # padding for batch inference features_padded = pad_list(features, pad_value=0) # features = torch.cat(features) return features_padded, feature_lengths, feature_times class BasicResBlock(torch.nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicResBlock, self).__init__() self.conv1 = torch.nn.Conv2d( in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False ) self.bn1 = torch.nn.BatchNorm2d(planes) self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = torch.nn.BatchNorm2d(planes) self.shortcut = torch.nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = torch.nn.Sequential( torch.nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=(stride, 1), bias=False, ), torch.nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class FCM(torch.nn.Module): def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80): super(FCM, self).__init__() self.in_planes = m_channels self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = torch.nn.BatchNorm2d(m_channels) self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) self.conv2 = torch.nn.Conv2d( m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False ) self.bn2 = torch.nn.BatchNorm2d(m_channels) self.out_channels = m_channels * (feat_dim // 8) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return torch.nn.Sequential(*layers) def forward(self, x): x = x.unsqueeze(1) out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = F.relu(self.bn2(self.conv2(out))) shape = out.shape out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) return out def get_nonlinear(config_str, channels): nonlinear = torch.nn.Sequential() for name in config_str.split("-"): if name == "relu": nonlinear.add_module("relu", torch.nn.ReLU(inplace=True)) elif name == "prelu": nonlinear.add_module("prelu", torch.nn.PReLU(channels)) elif name == "batchnorm": nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels)) elif name == "batchnorm_": nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False)) else: raise ValueError("Unexpected module ({}).".format(name)) return nonlinear def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): mean = x.mean(dim=dim) std = x.std(dim=dim, unbiased=unbiased) stats = torch.cat([mean, std], dim=-1) if keepdim: stats = stats.unsqueeze(dim=dim) return stats class StatsPool(torch.nn.Module): def forward(self, x): return statistics_pooling(x) class TDNNLayer(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False, config_str="batchnorm-relu", ): super(TDNNLayer, self).__init__() if padding < 0: assert ( kernel_size % 2 == 1 ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size) padding = (kernel_size - 1) // 2 * dilation self.linear = torch.nn.Conv1d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): x = self.linear(x) x = self.nonlinear(x) return x class CAMLayer(torch.nn.Module): def __init__( self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2 ): super(CAMLayer, self).__init__() self.linear_local = torch.nn.Conv1d( bn_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ) self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1) self.relu = torch.nn.ReLU(inplace=True) self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): y = self.linear_local(x) context = x.mean(-1, keepdim=True) + self.seg_pooling(x) context = self.relu(self.linear1(context)) m = self.sigmoid(self.linear2(context)) return y * m def seg_pooling(self, x, seg_len=100, stype="avg"): if stype == "avg": seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) elif stype == "max": seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) else: raise ValueError("Wrong segment pooling type.") shape = seg.shape seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) seg = seg[..., : x.shape[-1]] return seg class CAMDenseTDNNLayer(torch.nn.Module): def __init__( self, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str="batchnorm-relu", memory_efficient=False, ): super(CAMDenseTDNNLayer, self).__init__() assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format( kernel_size ) padding = (kernel_size - 1) // 2 * dilation self.memory_efficient = memory_efficient self.nonlinear1 = get_nonlinear(config_str, in_channels) self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False) self.nonlinear2 = get_nonlinear(config_str, bn_channels) self.cam_layer = CAMLayer( bn_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ) def bn_function(self, x): return self.linear1(self.nonlinear1(x)) def forward(self, x): if self.training and self.memory_efficient: x = cp.checkpoint(self.bn_function, x) else: x = self.bn_function(x) x = self.cam_layer(self.nonlinear2(x)) return x class CAMDenseTDNNBlock(torch.nn.ModuleList): def __init__( self, num_layers, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str="batchnorm-relu", memory_efficient=False, ): super(CAMDenseTDNNBlock, self).__init__() for i in range(num_layers): layer = CAMDenseTDNNLayer( in_channels=in_channels + i * out_channels, out_channels=out_channels, bn_channels=bn_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias, config_str=config_str, memory_efficient=memory_efficient, ) self.add_module("tdnnd%d" % (i + 1), layer) def forward(self, x): for layer in self: x = torch.cat([x, layer(x)], dim=1) return x class TransitLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"): super(TransitLayer, self).__init__() self.nonlinear = get_nonlinear(config_str, in_channels) self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) def forward(self, x): x = self.nonlinear(x) x = self.linear(x) return x class DenseLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"): super(DenseLayer, self).__init__() self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): if len(x.shape) == 2: x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) else: x = self.linear(x) x = self.nonlinear(x) return x # @tables.register("model_classes", "CAMPPlus") class CAMPPlus(torch.nn.Module): def __init__( self, feat_dim=80, embedding_size=192, growth_rate=32, bn_size=4, init_channels=128, config_str="batchnorm-relu", memory_efficient=True, output_level="segment", **kwargs, ): super().__init__() self.head = FCM(feat_dim=feat_dim) channels = self.head.out_channels self.output_level = output_level self.xvector = torch.nn.Sequential( OrderedDict( [ ( "tdnn", TDNNLayer( channels, init_channels, 5, stride=2, dilation=1, padding=-1, config_str=config_str, ), ), ] ) ) channels = init_channels for i, (num_layers, kernel_size, dilation) in enumerate( zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) ): block = CAMDenseTDNNBlock( num_layers=num_layers, in_channels=channels, out_channels=growth_rate, bn_channels=bn_size * growth_rate, kernel_size=kernel_size, dilation=dilation, config_str=config_str, memory_efficient=memory_efficient, ) self.xvector.add_module("block%d" % (i + 1), block) channels = channels + num_layers * growth_rate self.xvector.add_module( "transit%d" % (i + 1), TransitLayer(channels, channels // 2, bias=False, config_str=config_str), ) channels //= 2 self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) if self.output_level == "segment": self.xvector.add_module("stats", StatsPool()) self.xvector.add_module( "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_") ) else: assert ( self.output_level == "frame" ), "`output_level` should be set to 'segment' or 'frame'. " for m in self.modules(): if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): torch.nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: torch.nn.init.zeros_(m.bias) def forward(self, x): x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = self.head(x) x = self.xvector(x) if self.output_level == "frame": x = x.transpose(1, 2) return x def inference(self, audio_list): speech, speech_lengths, speech_times = extract_feature(audio_list) results = self.forward(speech.to(torch.float32)) return results ================================================ FILE: src/chatterbox/models/s3tokenizer/__init__.py ================================================ from .s3tokenizer import ( S3_SR, S3_HOP, S3_TOKEN_HOP, S3_TOKEN_RATE, SPEECH_VOCAB_SIZE, S3Tokenizer, ) SOS = SPEECH_VOCAB_SIZE EOS = SPEECH_VOCAB_SIZE + 1 def drop_invalid_tokens(x): """Drop SoS and EoS""" assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now" if SOS in x: s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1 else: s = 0 if EOS in x: e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0) else: e = None x = x[s: e] return x ================================================ FILE: src/chatterbox/models/s3tokenizer/s3tokenizer.py ================================================ from typing import List, Tuple import numpy as np import librosa import torch import torch.nn.functional as F from s3tokenizer.utils import padding from s3tokenizer.model_v2 import ( S3TokenizerV2, ModelConfig, ) # Sampling rate of the inputs to S3TokenizerV2 S3_SR = 16_000 S3_HOP = 160 # 100 frames/sec S3_TOKEN_HOP = 640 # 25 tokens/sec S3_TOKEN_RATE = 25 SPEECH_VOCAB_SIZE = 6561 class S3Tokenizer(S3TokenizerV2): """ s3tokenizer.S3TokenizerV2 with the following changes: - a more integrated `forward` - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` """ ignore_state_dict_missing = ("_mel_filters", "window") def __init__( self, name: str="speech_tokenizer_v2_25hz", config: ModelConfig = ModelConfig() ): super().__init__(name) self.n_fft = 400 _mel_filters = librosa.filters.mel( sr=S3_SR, n_fft=self.n_fft, n_mels=config.n_mels ) self.register_buffer( "_mel_filters", torch.FloatTensor(_mel_filters), ) self.register_buffer( "window", torch.hann_window(self.n_fft), ) def pad(self, wavs, sr) -> List[torch.Tensor]: """ Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). """ processed_wavs = [] for wav in wavs: if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav) if wav.dim() == 1: wav = wav.unsqueeze(0) n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE n_tokens = np.ceil(n_tokens) intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) intended_wav_len = int(intended_wav_len) wav = torch.nn.functional.pad( wav, (0, intended_wav_len - wav.shape[-1]), mode="constant", value=0 ) processed_wavs.append(wav) return processed_wavs def _prepare_audio(self, wavs): """Prepare a list of audios for s3tokenizer processing.""" processed_wavs = [] for wav in wavs: if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav) if wav.dim() == 1: wav = wav.unsqueeze(0) processed_wavs.append(wav) return processed_wavs @torch.no_grad() def forward( self, wavs: torch.Tensor, accelerator: 'Accelerator'=None, max_len: int=None, ) -> Tuple[torch.Tensor, torch.LongTensor]: """ NOTE: mel-spec has a hop size of 160 points (100 frame/sec). FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. Args ---- - `wavs`: 16 kHz speech audio - `max_len` max length to truncate the output sequence to (25 token/sec). NOTE: please pad the waveform if longer sequence is needed. """ processed_wavs = self._prepare_audio(wavs) mels, mel_lens = [], [] for wav in processed_wavs: wav = wav.to(self.device) mel = self.log_mel_spectrogram(wav) # [B=1, F, T] if max_len is not None: mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens mels.append(mel.squeeze(0)) mels, mel_lens = padding(mels) if accelerator is None: tokenizer = self else: tokenizer = accelerator.unwrap_model(self) speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) return ( speech_tokens.long().detach(), speech_token_lens.long().detach(), ) def log_mel_spectrogram( self, audio: torch.Tensor, padding: int = 0, ): """ Compute the log-Mel spectrogram of Parameters ---------- audio: torch.Tensor, shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz padding: int Number of zero samples to pad to the right Returns ------- torch.Tensor, shape = (128, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): audio = torch.from_numpy(audio) audio = audio.to(self.device) if padding > 0: audio = F.pad(audio, (0, padding)) stft = torch.stft( audio, self.n_fft, S3_HOP, window=self.window.to(self.device), return_complex=True ) magnitudes = stft[..., :-1].abs()**2 mel_spec = self._mel_filters.to(self.device) @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec ================================================ FILE: src/chatterbox/models/t3/__init__.py ================================================ from .t3 import T3 ================================================ FILE: src/chatterbox/models/t3/inference/alignment_stream_analyzer.py ================================================ # Copyright (c) 2025 Resemble AI # Author: John Meade, Jeremy Hsu # MIT License import logging import torch from dataclasses import dataclass from types import MethodType logger = logging.getLogger(__name__) LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)] @dataclass class AlignmentAnalysisResult: # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? false_start: bool # was this frame detected as being part of a long tail with potential hallucinations? long_tail: bool # was this frame detected as repeating existing text content? repetition: bool # was the alignment position of this frame too far from the previous frame? discontinuity: bool # has inference reached the end of the text tokens? eg, this remains false if inference stops early complete: bool # approximate position in the text token sequence. Can be used for generating online timestamps. position: int class AlignmentStreamAnalyzer: def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0): """ Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention activation maps. This module exploits this to perform online integrity checks which streaming. A hook is injected into the specified attention layer, and heuristics are used to determine alignment position, repetition, etc. NOTE: currently requires no queues. """ # self.queue = queue self.text_tokens_slice = (i, j) = text_tokens_slice self.eos_idx = eos_idx self.alignment = torch.zeros(0, j-i) # self.alignment_bin = torch.zeros(0, j-i) self.curr_frame_pos = 0 self.text_position = 0 self.started = False self.started_at = None self.complete = False self.completed_at = None # Track generated tokens for repetition detection self.generated_tokens = [] # Using `output_attentions=True` is incompatible with optimized attention kernels, so # using it for all layers slows things down too much. We can apply it to just one layer # by intercepting the kwargs and adding a forward hook (credit: jrm) self.last_aligned_attns = [] for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS): self.last_aligned_attns += [None] self._add_attention_spy(tfmr, i, layer_idx, head_idx) def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx): """ Adds a forward hook to a specific attention layer to collect outputs. """ def attention_forward_hook(module, input, output): """ See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. NOTE: - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. """ if isinstance(output, tuple) and len(output) > 1 and output[1] is not None: step_attention = output[1].cpu() # (B, n_heads, T0, Ti) self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti) target_layer = tfmr.layers[layer_idx].self_attn # Register hook and store the handle target_layer.register_forward_hook(attention_forward_hook) if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'): self.original_output_attentions = tfmr.config.output_attentions self.original_attn_implementation = getattr(tfmr.config, '_attn_implementation', None) if getattr(tfmr.config, '_attn_implementation', None) == 'sdpa': tfmr.config._attn_implementation = 'eager' tfmr.config.output_attentions = True def step(self, logits, next_token=None): """ Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. """ # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N) i, j = self.text_tokens_slice if self.curr_frame_pos == 0: # first chunk has conditioning info, text tokens, and BOS token A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S) else: # subsequent chunks have 1 frame due to KV-caching A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S) # TODO: monotonic masking; could have issue b/c spaces are often skipped. A_chunk[:, self.curr_frame_pos + 1:] = 0 self.alignment = torch.cat((self.alignment, A_chunk), dim=0) A = self.alignment T, S = A.shape # update position cur_text_posn = A_chunk[-1].argmax() discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient! if not discontinuity: self.text_position = cur_text_posn # Hallucinations at the start of speech show up as activations at the bottom of the attention maps! # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens, # and there are some strong activations in the first few tokens. false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5) self.started = not false_start if self.started and self.started_at is None: self.started_at = T # Is generation likely complete? self.complete = self.complete or self.text_position >= S - 3 if self.complete and self.completed_at is None: self.completed_at = T # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens. # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens. last_text_token_duration = A[15:, -3:].sum() # Activations for the final token that last too long are likely hallucinations. long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms # If there are activations in previous tokens after generation has completed, assume this is a repetition error. alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) # Track generated tokens for repetition detection if next_token is not None: # Convert tensor to scalar if needed if isinstance(next_token, torch.Tensor): token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item() else: token_id = next_token self.generated_tokens.append(token_id) # Keep only last 8 tokens to prevent memory issues if len(self.generated_tokens) > 8: self.generated_tokens = self.generated_tokens[-8:] # Check for excessive token repetition (3x same token in a row) token_repetition = ( # self.complete and len(self.generated_tokens) >= 3 and len(set(self.generated_tokens[-2:])) == 1 ) if token_repetition: repeated_token = self.generated_tokens[-1] logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}") # Suppress EoS to prevent early termination if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens logits[..., self.eos_idx] = -2**15 # If a bad ending is detected, force emit EOS by modifying logits # NOTE: this means logits may be inconsistent with latents! if long_tail or alignment_repetition or token_repetition: logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}") # (±2**15 is safe for all dtypes >= 16bit) logits = -(2**15) * torch.ones_like(logits) logits[..., self.eos_idx] = 2**15 self.curr_frame_pos += 1 return logits ================================================ FILE: src/chatterbox/models/t3/inference/t3_hf_backend.py ================================================ from typing import Optional import torch from torch import nn as nn from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): """ Override some HuggingFace interface methods so we can use the standard `generate` method with our custom embedding / logit layers. NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights! """ def __init__( self, config: LlamaConfig, llama: LlamaModel, *, speech_enc, speech_head, latents_queue=None, logits_queue=None, alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None, ): super().__init__(config) self.model = llama self.speech_enc = speech_enc self.speech_head = speech_head self._added_cond = False self.alignment_stream_analyzer = alignment_stream_analyzer @torch.inference_mode() def prepare_inputs_for_generation( self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None, # This argument was introduced in some recent version of transformers (>=4.29.1) cache_position=None ): """ This is a method used by huggingface's generate() method. Overridden here to apply our custom speech token embedding layer. :param input_ids: (B, S) int64 tensors of input tokens. :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to ) """ # Make use of the kv cache: only the last input ID is new, we trim away all the ones before if not use_cache: past_key_values = None if past_key_values is not None: input_ids = input_ids[:, -1:] # custom speech token embedding layer inputs_embeds = self.speech_enc(input_ids) # prefix decoder conditioning if applicable if not self._added_cond: assert past_key_values is not None # should be first step if decoder_cond.size(0) != inputs_embeds.size(0): decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1) inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1) self._added_cond = True return { "inputs_embeds": inputs_embeds, "past_key_values": past_key_values, "use_cache": use_cache, } @torch.inference_mode() def forward( self, inputs_embeds: torch.Tensor, past_key_values: Optional[torch.Tensor]=None, use_cache=True, output_attentions=False, output_hidden_states=True, return_dict=True, ): """ This is a method used by huggingface's generate() method. Overridden here to apply our custom layer norm and speech logit projection layers. :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given, S should be 1. """ is_large_input = inputs_embeds.size(1) != 1 has_cache = past_key_values is not None and len(past_key_values) > 0 assert not (is_large_input and has_cache) assert return_dict assert output_hidden_states tfmr_out = self.model( inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) logits = self.speech_head(hidden_states) # assert inputs_embeds.size(0) == 1 # (disabled for CFG) # NOTE: hallucination handler may modify logits to force emit an EOS token # logits = self.alignment_stream_analyzer.step(logits) return CausalLMOutputWithCrossAttentions( logits=logits, past_key_values=tfmr_out.past_key_values, hidden_states=tfmr_out.hidden_states, attentions=tfmr_out.attentions, ) ================================================ FILE: src/chatterbox/models/t3/llama_configs.py ================================================ LLAMA_520M_CONFIG_DICT = dict( # Arbitrary small number that won't cause problems when loading. # These param are unused due to custom input layers. vocab_size=8, # default params needed for loading most pretrained 1B weights max_position_embeddings=131072, hidden_size=1024, intermediate_size=4096, num_hidden_layers=30, num_attention_heads=16, attn_implementation="sdpa", head_dim=64, tie_word_embeddings=False, hidden_act="silu", attention_bias=False, attention_dropout=0.0, initializer_range=0.02, mlp_bias=False, model_type="llama", num_key_value_heads=16, pretraining_tp=1, rms_norm_eps=1e-05, rope_scaling=dict( factor=8.0, high_freq_factor=4.0, low_freq_factor=1.0, original_max_position_embeddings=8192, rope_type="llama3" ), rope_theta=500000.0, torch_dtype="bfloat16", use_cache=True, ) GPT2_MEDIUM_CONFIG = { "activation_function": "gelu_new", "architectures": [ "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, "embd_pdrop": 0.1, "eos_token_id": 50256, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "model_type": "gpt2", "n_ctx": 8196, "n_embd": 1024, "hidden_size": 1024, "n_head": 16, "n_layer": 24, "n_positions": 8196, "n_special": 0, "predict_special_tokens": True, "resid_pdrop": 0.1, "summary_activation": None, "summary_first_dropout": 0.1, "summary_proj_to_labels": True, "summary_type": "cls_index", "summary_use_proj": True, "task_specific_params": { "text-generation": { "do_sample": True, "max_length": 50 } }, "vocab_size": 50276, } LLAMA_CONFIGS = { "Llama_520M": LLAMA_520M_CONFIG_DICT, "GPT2_medium": GPT2_MEDIUM_CONFIG, } ================================================ FILE: src/chatterbox/models/t3/modules/cond_enc.py ================================================ from dataclasses import dataclass from typing import Optional import torch from torch import nn, Tensor from .perceiver import Perceiver from .t3_config import T3Config @dataclass class T3Cond: """ Dataclass container for most / all conditioning info. TODO: serialization methods aren't used, keeping them around for convenience """ speaker_emb: Tensor clap_emb: Optional[Tensor] = None cond_prompt_speech_tokens: Optional[Tensor] = None cond_prompt_speech_emb: Optional[Tensor] = None emotion_adv: Optional[Tensor] = 0.5 def to(self, *, device=None, dtype=None): "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." for k, v in self.__dict__.items(): if torch.is_tensor(v): is_fp = type(v.view(-1)[0].item()) is not int setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) return self def save(self, fpath): torch.save(self.__dict__, fpath) @staticmethod def load(fpath, map_location="cpu"): kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return T3Cond(**kwargs) class T3CondEnc(nn.Module): """ Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. """ def __init__(self, hp: T3Config): super().__init__() self.hp = hp if hp.encoder_type == "voice_encoder": self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) else: raise NotImplementedError(str(hp.encoder_type)) # emotion adv self.emotion_adv_fc = None if hp.emotion_adv: self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) # perceiver resampler self.perceiver = None if hp.use_perceiver_resampler: self.perceiver = Perceiver() def forward(self, cond: T3Cond): # Validate assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ "no embeddings for cond_prompt_speech_tokens" # Speaker embedding projection cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) # TODO CLAP assert cond.clap_emb is None, "clap_embed not implemented" cond_clap = empty # (B, 0, dim) # Cond prompt cond_prompt_speech_emb = cond.cond_prompt_speech_emb if cond_prompt_speech_emb is None: cond_prompt_speech_emb = empty # (B, 0, dim) elif self.hp.use_perceiver_resampler: cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) # Emotion Adv: must provide a value if this model uses emotion conditioning cond_emotion_adv = empty # (B, 0, dim) if self.hp.emotion_adv: assert cond.emotion_adv is not None cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) # Concat and return cond_embeds = torch.cat(( cond_spkr, cond_clap, cond_prompt_speech_emb, cond_emotion_adv, ), dim=1) return cond_embeds ================================================ FILE: src/chatterbox/models/t3/modules/learned_pos_emb.py ================================================ from typing import Union import torch from torch import nn, Tensor class LearnedPositionEmbeddings(nn.Module): def __init__(self, seq_len, model_dim, init=.02): super().__init__() self.emb = nn.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) def forward(self, x): """ Returns positional embeddings for index 0 up to the length of x """ sl = x.shape[1] return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, idx: 'Union[int, Tensor]'): """ Args: idx: scalar int or an integer tensor of shape (T,) or (B, T) Returns: positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input """ device = self.emb.weight.device idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device) idx = torch.atleast_2d(idx) assert idx.ndim == 2 return self.emb(idx) # (B, T, dim) ================================================ FILE: src/chatterbox/models/t3/modules/perceiver.py ================================================ # Copyright (c) 2025 Resemble AI # Author: Manmay Nakhashi # MIT License import math import torch from torch import nn import torch.nn.functional as F from einops import rearrange class RelativePositionBias(nn.Module): def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): super().__init__() self.scale = scale self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): ret = 0 n = -relative_position if not causal: num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, qk_dots): i, j, device = *qk_dots.shape[-2:], qk_dots.device q_pos = torch.arange(i, dtype=torch.long, device=device) k_pos = torch.arange(j, dtype=torch.long, device=device) rel_pos = k_pos[None, :] - q_pos[:, None] rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j h -> () h i j') return qk_dots + (bias * self.scale) class AttentionQKV(nn.Module): def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False): super().__init__() self.n_heads = n_heads self.head_dim = head_dim self.scale = scale if scale is not None else head_dim ** -0.5 self.flash = flash self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate) self.flash_config = self.setup_flash_config() if flash else None def setup_flash_config(self): # Setup flash attention configuration flash_config = { 'enable_flash': True, 'enable_math': True, 'enable_mem_efficient': True } return flash_config def forward(self, q, k, v, mask=None): q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]] if self.flash: out = self.flash_attention(q, k, v, mask=mask) else: out = self.scaled_dot_product_attention(q, k, v, mask=mask) return self.combine_heads(out) def scaled_dot_product_attention(self, q, k, v, mask=None): sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale if mask is not None: sim = sim.masked_fill(mask == 0, float('-inf')) attn = torch.softmax(sim, dim=-1) attn = self.dropout(attn) return torch.einsum("bhts,bhls->bhlt", attn, v) def flash_attention(self, q, k, v, mask=None): config = self.flash_config if self.flash_config else {} with torch.backends.cuda.sdp_kernel(**config): out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.dropout_rate if self.training else 0. ) return out def split_heads(self, x): bs, length, _ = x.shape x = x.view(bs, length, self.n_heads, self.head_dim) return x.permute(0, 2, 1, 3) def combine_heads(self, x): bs, _, length, _ = x.shape x = x.permute(0, 2, 1, 3).contiguous() return x.view(bs, length, -1) class AttentionBlock2(nn.Module): """ An attention block that allows spatial positions to attend to each other, using AttentionQKV and separate linear transformations for Q, K, and V. """ def __init__( self, channels, num_heads=1, num_head_channels=-1, relative_pos_embeddings=False, flash_attention=True, dropout_rate=0.2, scale=None ): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.norm = nn.LayerNorm(channels) # Separate linear layers for Q, K, and V self.to_q = nn.Linear(channels, channels) self.to_k = nn.Linear(channels, channels) self.to_v = nn.Linear(channels, channels) self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale) self.proj_out = nn.Linear(channels, channels) if relative_pos_embeddings: self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) else: self.relative_pos_embeddings = None def forward(self, x1, x2, mask=None): b1, c1, *spatial1 = x1.shape b2, c2, *spatial2 = x2.shape x1_norm = self.norm(x1) x2_norm = self.norm(x2) q = self.to_q(x1_norm) k = self.to_k(x2_norm) v = self.to_v(x2_norm) h = self.attention(q, k, v, mask=mask) h = self.proj_out(h) return (x1 + h).reshape(b1, c1, *spatial1) class Perceiver(nn.Module): """Inspired by https://arxiv.org/abs/2103.03206""" def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): """ Initialize the perceiver module. :param pre_attention_query_token: Number of query tokens for pre-attention :param pre_attention_query_size: Size of each query token :param embedding_dim: Dimension of the embedding space :param num_attn_heads: Number of attention heads """ super().__init__() # Initialize the pre-attention query parameter self.pre_attention_query = torch.nn.Parameter( torch.empty(1, pre_attention_query_token, pre_attention_query_size) ) # Calculate the variance for uniform initialization query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token)) # Initialize the pre-attention query with uniform distribution self.pre_attention_query.data.uniform_(-query_variance, query_variance) # Initialize the attention block self.attn = AttentionBlock2(embedding_dim, num_attn_heads) def forward(self, h): """ Forward pass of the perceiver module. :param h: Input tensor :return: Output after applying attention mechanisms """ # Expand the pre-attention query to match the batch size of the input query_ = self.pre_attention_query.expand(h.shape[0], -1, -1) # Apply the first attention mechanism (cross-attention) pre_att = self.attn(query_, h) # Apply the second attention mechanism (self-attention) attn = self.attn(pre_att, pre_att) return attn ================================================ FILE: src/chatterbox/models/t3/modules/t3_config.py ================================================ from ..llama_configs import LLAMA_CONFIGS class T3Config: def __init__(self, text_tokens_dict_size=704): self.start_text_token = 255 self.stop_text_token = 0 self.text_tokens_dict_size = text_tokens_dict_size self.max_text_tokens = 2048 self.start_speech_token = 6561 self.stop_speech_token = 6562 self.speech_tokens_dict_size = 8194 self.max_speech_tokens = 4096 self.llama_config_name = "Llama_520M" self.input_pos_emb = "learned" self.speech_cond_prompt_len = 150 self.encoder_type = "voice_encoder" self.speaker_embed_size = 256 self.use_perceiver_resampler = True self.emotion_adv = True @property def n_channels(self): return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"] @property def is_multilingual(self): return self.text_tokens_dict_size == 2454 @classmethod def english_only(cls): """Create configuration for English-only TTS model.""" return cls(text_tokens_dict_size=704) @classmethod def multilingual(cls): """Create configuration for multilingual TTS model.""" return cls(text_tokens_dict_size=2454) ================================================ FILE: src/chatterbox/models/t3/t3.py ================================================ # Copyright (c) 2025 Resemble AI # MIT License import logging from typing import Union, Optional, List logger = logging.getLogger(__name__) from tqdm import tqdm import torch import torch.nn.functional as F from torch import nn, Tensor from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, MinPLogitsWarper, ) from .modules.learned_pos_emb import LearnedPositionEmbeddings from .modules.cond_enc import T3CondEnc, T3Cond from .modules.t3_config import T3Config from .llama_configs import LLAMA_CONFIGS from .inference.t3_hf_backend import T3HuggingfaceBackend from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer from ..utils import AttrDict logger = logging.getLogger(__name__) def _ensure_BOT_EOT(text_tokens: Tensor, hp): B = text_tokens.size(0) assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token" assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token" class T3(nn.Module): """ Token-To-Token (T3) TTS model using huggingface transformer models as backbones, * tokenization, including start / stop tokens are always added externally to this class * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity * careful! this class assumes relative positional encoding -- with absolute PE, we would at least want to reset the position to 0 when speech tokens begin, and optionally use a different PE embedding space for speech. """ def __init__(self, hp=None): if hp is None: hp = T3Config.english_only() super().__init__() self.hp = hp config_dict = LLAMA_CONFIGS[hp.llama_config_name] self.is_gpt = config_dict.get("model_type") == "gpt2" if self.is_gpt: self.cfg = GPT2Config(**config_dict) self.tfmr = GPT2Model(self.cfg) else: self.cfg = LlamaConfig(**config_dict) self.tfmr = LlamaModel(self.cfg) self.dim = self.cfg.hidden_size self.deepspeed_patch_applied = False # conditioning / embedding self.cond_enc = T3CondEnc(hp) self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim) self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim) # custom position embedding self.text_pos_emb = None self.speech_pos_emb = None if hp.input_pos_emb == "learned": max_text_seq_len = hp.max_text_tokens + 2 self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim) max_mel_seq_len = hp.max_speech_tokens + 2 + 2 self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim) # logit projection self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False) self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt) self.compiled = False @property def device(self): return self.speech_head.weight.device def prepare_conditioning(self, t3_cond: T3Cond): """ Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`. """ if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None: t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) if not self.is_gpt: t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens) return self.cond_enc(t3_cond) # (B, len_cond, dim) def prepare_input_embeds( self, *, t3_cond: T3Cond, text_tokens: torch.LongTensor, speech_tokens: torch.LongTensor, cfg_weight: float = 0.0, ): # prepare input embeddings (skip backbone tranformer embeddings) cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim) text_emb = self.text_emb(text_tokens) # (B, len_text, dim) if cfg_weight > 0.0 and not self.is_gpt: text_emb[1].zero_() # CFG uncond speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim) if self.hp.input_pos_emb == "learned": text_emb = text_emb + self.text_pos_emb(text_tokens) speech_emb = speech_emb + self.speech_pos_emb(speech_tokens) len_cond = cond_emb.size(1) if cond_emb.size(0) != text_emb.size(0): cond_emb = cond_emb.expand(text_emb.size(0), -1, -1) # concat embeds = torch.stack([ torch.cat((ce, te, se)) for ce, te, se in zip(cond_emb, text_emb, speech_emb) ]) # (B, length, dim) return embeds, len_cond def forward( self, *, t3_cond: T3Cond, text_tokens: torch.LongTensor, text_token_lens: torch.LongTensor, speech_tokens: torch.LongTensor, speech_token_lens: torch.LongTensor, training=False, ): _ensure_BOT_EOT(text_tokens, self.hp) # prepare custom input embeds embeds, len_cond = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=speech_tokens, ) # backbone tranformer forward tfmr_out = self.tfmr.forward( input_ids=None, # position_ids=position_ids, # TODO? ROPE should be fine? inputs_embeds=embeds, output_hidden_states=True, return_dict=True, use_cache=(not training), ) hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim) # post-processing: splice out text and speech parts of hidden states len_text = text_tokens.size(1) len_speech = speech_tokens.size(1) B, _, dim = hidden_states.shape device, dtype = hidden_states.device, hidden_states.dtype text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device) speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device) ttl, stl = text_token_lens, speech_token_lens for i in range(B): text_end = len_cond + ttl[i].item() speech_start = len_cond + text_tokens.size(1) speech_end = speech_start + stl[i].item() text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end] speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end] # logit projection text_logits = self.text_head(text_latents) speech_logits = self.speech_head(speech_latents) return AttrDict( text_logits=text_logits, text_latents=text_latents, speech_logits=speech_logits, speech_latents=speech_latents, hidden_states=hidden_states, ) def loss( self, *, t3_cond: T3Cond, text_tokens: torch.LongTensor, text_token_lens: torch.LongTensor, speech_tokens: torch.LongTensor, speech_token_lens: torch.LongTensor, ): "training method" len_text = text_tokens.size(1) len_speech = speech_tokens.size(1) assert len_text == text_token_lens.max() assert len_speech == speech_token_lens.max() out = self.forward( t3_cond=t3_cond, text_tokens=text_tokens, text_token_lens=text_token_lens, speech_tokens=speech_tokens, speech_token_lens=speech_token_lens, training=True, ) # (B, seq, vocab_size) # Calc CCE losses IGNORE_ID = -100 device = out.text_logits.device mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text) mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech) masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID) masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID) loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID) loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID) return loss_text, loss_speech @torch.inference_mode() def inference( self, *, t3_cond: T3Cond, text_tokens: Tensor, initial_speech_tokens: Optional[Tensor]=None, # misc conditioning prepend_prompt_speech_tokens: Optional[Tensor]=None, # HF generate args num_return_sequences=1, max_new_tokens=None, stop_on_eos=True, do_sample=True, temperature=0.8, top_p=0.95, min_p=0.05, length_penalty=1.0, repetition_penalty=1.2, cfg_weight=0.5, ): """ Args: text_tokens: a 1D (unbatched) or 2D (batched) tensor. """ # Validate / sanitize inputs assert prepend_prompt_speech_tokens is None, "not implemented" _ensure_BOT_EOT(text_tokens, self.hp) text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device) # Default initial speech to a single start-of-speech token if initial_speech_tokens is None: initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) # Prepare custom input embeds embeds, len_cond = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=initial_speech_tokens, cfg_weight=cfg_weight, ) # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic # Note the llama-specific logic. Other tfmr types can be added later. self.compiled = False # TODO? synchronize the expensive compile function # with self.compile_lock: if not self.compiled: # Default to None for English models, only create for multilingual alignment_stream_analyzer = None if self.hp.is_multilingual: alignment_stream_analyzer = AlignmentStreamAnalyzer( self.tfmr, None, text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), alignment_layer_idx=9, # TODO: hparam or something? eos_idx=self.hp.stop_speech_token, ) assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token patched_model = T3HuggingfaceBackend( config=self.cfg, llama=self.tfmr, speech_enc=self.speech_emb, speech_head=self.speech_head, alignment_stream_analyzer=alignment_stream_analyzer, ) self.patched_model = patched_model self.compiled = True # # Run normal generate method, which calls our custom extended methods # return self.patched_model.generate( # inputs=initial_speech_tokens, # decoder_cond=embeds, # bos_token_id=self.hp.start_speech_token, # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1), # pad_token_id=self.hp.stop_speech_token, # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens, # num_return_sequences=num_return_sequences, # temperature=temperature, # min_p=min_p, # length_penalty=length_penalty, # repetition_penalty=repetition_penalty, # do_sample=do_sample, # # cache_implementation=None if not self.compiled else "static", # ) device = embeds.device bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device) bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim) bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0) # batch_size=2 for CFG bos_embed = torch.cat([bos_embed, bos_embed]) # Combine condition and BOS token for the initial input inputs_embeds = torch.cat([embeds, bos_embed], dim=1) # Track generated token ids; start with the BOS token. generated_ids = bos_token.clone() predicted = [] # To store the predicted tokens # Instantiate the logits processors. top_p_warper = TopPLogitsWarper(top_p=top_p) min_p_warper = MinPLogitsWarper(min_p=min_p) top_p_warper = TopPLogitsWarper(top_p=top_p) repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)) # ---- Initial Forward Pass (no kv_cache yet) ---- output = self.patched_model( inputs_embeds=inputs_embeds, past_key_values=None, use_cache=True, output_attentions=True, output_hidden_states=True, return_dict=True, ) # Initialize kv_cache with the full context. past = output.past_key_values # ---- Generation Loop using kv_cache ---- for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True): logits_step = output.logits[:, -1, :] # CFG combine → (1, V) cond = logits_step[0:1, :] uncond = logits_step[1:2, :] cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype) logits = cond + cfg * (cond - uncond) # Apply alignment stream analyzer integrity checks if self.patched_model.alignment_stream_analyzer is not None: if logits.dim() == 1: # guard in case something upstream squeezed logits = logits.unsqueeze(0) # (1, V) # Pass the last generated token for repetition tracking last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V) # Apply repetition penalty ids_for_proc = generated_ids[:1, ...] # batch = 1 logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V) # Apply temperature scaling. if temperature != 1.0: logits = logits / temperature # Apply min_p and top_p filtering logits = min_p_warper(ids_for_proc, logits) logits = top_p_warper(ids_for_proc, logits) # Convert logits to probabilities and sample the next token. probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1) predicted.append(next_token) generated_ids = torch.cat([generated_ids, next_token], dim=1) # Check for EOS token. if next_token.view(-1) == self.hp.stop_speech_token: logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}") break # Get embedding for the new token. next_token_embed = self.speech_emb(next_token) next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1) # For CFG next_token_embed = torch.cat([next_token_embed, next_token_embed]) # Forward pass with only the new token and the cached past. output = self.patched_model( inputs_embeds=next_token_embed, past_key_values=past, output_attentions=True, output_hidden_states=True, return_dict=True, ) # Update the kv_cache. past = output.past_key_values # Concatenate all predicted tokens along the sequence dimension. predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens) return predicted_tokens @torch.inference_mode() def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2, max_gen_len=1000): logits_processors = LogitsProcessorList() if temperature > 0 and temperature != 1.0: logits_processors.append(TemperatureLogitsWarper(temperature)) if top_k > 0: logits_processors.append(TopKLogitsWarper(top_k)) if top_p < 1.0: logits_processors.append(TopPLogitsWarper(top_p)) if repetition_penalty != 1.0: logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) embeds, _ = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=speech_start_token, cfg_weight=0.0, ) generated_speech_tokens = [] llm_outputs = self.tfmr( inputs_embeds=embeds, use_cache=True ) hidden_states = llm_outputs[0] past_key_values = llm_outputs.past_key_values speech_hidden = hidden_states[:, -1:] speech_logits = self.speech_head(speech_hidden) processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :]) probs = F.softmax(processed_logits, dim=-1) next_speech_token = torch.multinomial(probs, num_samples=1) generated_speech_tokens.append(next_speech_token) current_speech_token = next_speech_token for _ in tqdm(range(max_gen_len)): current_speech_embed = self.speech_emb(current_speech_token) llm_outputs = self.tfmr( inputs_embeds=current_speech_embed, past_key_values=past_key_values, use_cache=True ) hidden_states = llm_outputs[0] past_key_values = llm_outputs.past_key_values speech_logits = self.speech_head(hidden_states) input_ids = torch.cat(generated_speech_tokens, dim=1) processed_logits = logits_processors(input_ids, speech_logits[:, -1, :]) if torch.all(processed_logits == -float("inf")): print("Warning: All logits are -inf") break probs = F.softmax(processed_logits, dim=-1) next_speech_token = torch.multinomial(probs, num_samples=1) generated_speech_tokens.append(next_speech_token) current_speech_token = next_speech_token if torch.all(next_speech_token == self.hp.stop_speech_token): break all_tokens = torch.cat(generated_speech_tokens, dim=1) # Remove EOS token if present if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token: all_tokens = all_tokens[:, :-1] return all_tokens ================================================ FILE: src/chatterbox/models/tokenizers/__init__.py ================================================ from .tokenizer import EnTokenizer, MTLTokenizer ================================================ FILE: src/chatterbox/models/tokenizers/tokenizer.py ================================================ import logging import json import torch from pathlib import Path from unicodedata import category, normalize from tokenizers import Tokenizer from huggingface_hub import hf_hub_download # Special tokens SOT = "[START]" EOT = "[STOP]" UNK = "[UNK]" SPACE = "[SPACE]" SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] logger = logging.getLogger(__name__) class EnTokenizer: def __init__(self, vocab_file_path): self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) self.check_vocabset_sot_eot() def check_vocabset_sot_eot(self): voc = self.tokenizer.get_vocab() assert SOT in voc assert EOT in voc def text_to_tokens(self, text: str): text_tokens = self.encode(text) text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) return text_tokens def encode(self, txt: str): """ clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer """ txt = txt.replace(' ', SPACE) code = self.tokenizer.encode(txt) ids = code.ids return ids def decode(self, seq): if isinstance(seq, torch.Tensor): seq = seq.cpu().numpy() txt: str = self.tokenizer.decode(seq, skip_special_tokens=False) txt = txt.replace(' ', '') txt = txt.replace(SPACE, ' ') txt = txt.replace(EOT, '') txt = txt.replace(UNK, '') return txt # Model repository REPO_ID = "ResembleAI/chatterbox" # Global instances for optional dependencies _kakasi = None _dicta = None _russian_stresser = None def is_kanji(c: str) -> bool: """Check if character is kanji.""" return 19968 <= ord(c) <= 40959 def is_katakana(c: str) -> bool: """Check if character is katakana.""" return 12449 <= ord(c) <= 12538 def hiragana_normalize(text: str) -> str: """Japanese text normalization: converts kanji to hiragana; katakana remains the same.""" global _kakasi try: if _kakasi is None: import pykakasi _kakasi = pykakasi.kakasi() result = _kakasi.convert(text) out = [] for r in result: inp = r['orig'] hira = r["hira"] # Any kanji in the phrase if any([is_kanji(c) for c in inp]): if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira hira = " " + hira out.append(hira) # All katakana elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp out.append(r['orig']) else: out.append(inp) normalized_text = "".join(out) # Decompose Japanese characters for tokenizer compatibility import unicodedata normalized_text = unicodedata.normalize('NFKD', normalized_text) return normalized_text except ImportError: logger.warning("pykakasi not available - Japanese text processing skipped") return text def add_hebrew_diacritics(text: str) -> str: """Hebrew text normalization: adds diacritics to Hebrew text.""" global _dicta try: if _dicta is None: from dicta_onnx import Dicta _dicta = Dicta() return _dicta.add_diacritics(text) except ImportError: logger.warning("dicta_onnx not available - Hebrew text processing skipped") return text except Exception as e: logger.warning(f"Hebrew diacritization failed: {e}") return text def korean_normalize(text: str) -> str: """Korean text normalization: decompose syllables into Jamo for tokenization.""" def decompose_hangul(char): """Decompose Korean syllable into Jamo components.""" if not ('\uac00' <= char <= '\ud7af'): return char # Hangul decomposition formula base = ord(char) - 0xAC00 initial = chr(0x1100 + base // (21 * 28)) medial = chr(0x1161 + (base % (21 * 28)) // 28) final = chr(0x11A7 + base % 28) if base % 28 > 0 else '' return initial + medial + final # Decompose syllables and normalize punctuation result = ''.join(decompose_hangul(char) for char in text) return result.strip() class ChineseCangjieConverter: """Converts Chinese characters to Cangjie codes for tokenization.""" def __init__(self, model_dir=None): self.word2cj = {} self.cj2word = {} self.segmenter = None self._load_cangjie_mapping(model_dir) self._init_segmenter() def _load_cangjie_mapping(self, model_dir=None): """Load Cangjie mapping from HuggingFace model repository.""" try: cangjie_file = hf_hub_download( repo_id=REPO_ID, filename="Cangjie5_TC.json", cache_dir=model_dir ) with open(cangjie_file, "r", encoding="utf-8") as fp: data = json.load(fp) for entry in data: word, code = entry.split("\t")[:2] self.word2cj[word] = code if code not in self.cj2word: self.cj2word[code] = [word] else: self.cj2word[code].append(word) except Exception as e: logger.warning(f"Could not load Cangjie mapping: {e}") def _init_segmenter(self): """Initialize pkuseg segmenter.""" try: from spacy_pkuseg import pkuseg self.segmenter = pkuseg() except ImportError: logger.warning("pkuseg not available - Chinese segmentation will be skipped") self.segmenter = None def _cangjie_encode(self, glyph: str): """Encode a single Chinese glyph to Cangjie code.""" normed_glyph = glyph code = self.word2cj.get(normed_glyph, None) if code is None: # e.g. Japanese hiragana return None index = self.cj2word[code].index(normed_glyph) index = str(index) if index > 0 else "" return code + str(index) def __call__(self, text): """Convert Chinese characters in text to Cangjie tokens.""" output = [] if self.segmenter is not None: segmented_words = self.segmenter.cut(text) full_text = " ".join(segmented_words) else: full_text = text for t in full_text: if category(t) == "Lo": cangjie = self._cangjie_encode(t) if cangjie is None: output.append(t) continue code = [] for c in cangjie: code.append(f"[cj_{c}]") code.append("[cj_.]") code = "".join(code) output.append(code) else: output.append(t) return "".join(output) def add_russian_stress(text: str) -> str: """Russian text normalization: adds stress marks to Russian text.""" global _russian_stresser try: if _russian_stresser is None: from russian_text_stresser.text_stresser import RussianTextStresser _russian_stresser = RussianTextStresser() return _russian_stresser.stress_text(text) except ImportError: logger.warning("russian_text_stresser not available - Russian stress labeling skipped") return text except Exception as e: logger.warning(f"Russian stress labeling failed: {e}") return text class MTLTokenizer: def __init__(self, vocab_file_path): self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) model_dir = Path(vocab_file_path).parent self.cangjie_converter = ChineseCangjieConverter(model_dir) self.check_vocabset_sot_eot() def check_vocabset_sot_eot(self): voc = self.tokenizer.get_vocab() assert SOT in voc assert EOT in voc def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): """ Text preprocessor that handles lowercase conversion and NFKD normalization. """ preprocessed_text = raw_text if lowercase: preprocessed_text = preprocessed_text.lower() if nfkd_normalize: preprocessed_text = normalize("NFKD", preprocessed_text) return preprocessed_text def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) return text_tokens def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) # Language-specific text processing if language_id == 'zh': txt = self.cangjie_converter(txt) elif language_id == 'ja': txt = hiragana_normalize(txt) elif language_id == 'he': txt = add_hebrew_diacritics(txt) elif language_id == 'ko': txt = korean_normalize(txt) elif language_id == 'ru': txt = add_russian_stress(txt) # Prepend language token if language_id: txt = f"[{language_id.lower()}]{txt}" txt = txt.replace(' ', SPACE) return self.tokenizer.encode(txt).ids def decode(self, seq): if isinstance(seq, torch.Tensor): seq = seq.cpu().numpy() txt = self.tokenizer.decode(seq, skip_special_tokens=False) txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '') return txt ================================================ FILE: src/chatterbox/models/utils.py ================================================ class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self ================================================ FILE: src/chatterbox/models/voice_encoder/__init__.py ================================================ from .voice_encoder import VoiceEncoder, VoiceEncConfig ================================================ FILE: src/chatterbox/models/voice_encoder/config.py ================================================ class VoiceEncConfig: num_mels = 40 sample_rate = 16000 speaker_embed_size = 256 ve_hidden_size = 256 flatten_lstm_params = False n_fft = 400 hop_size = 160 win_size = 400 fmax = 8000 fmin = 0 preemphasis = 0. mel_power = 2.0 mel_type = "amp" normalized_mels = False ve_partial_frames = 160 ve_final_relu = True stft_magnitude_min = 1e-4 ================================================ FILE: src/chatterbox/models/voice_encoder/melspec.py ================================================ from functools import lru_cache from scipy import signal import numpy as np import librosa @lru_cache() def mel_basis(hp): assert hp.fmax <= hp.sample_rate // 2 return librosa.filters.mel( sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax) # -> (nmel, nfreq) def preemphasis(wav, hp): assert hp.preemphasis != 0 wav = signal.lfilter([1, -hp.preemphasis], [1], wav) wav = np.clip(wav, -1, 1) return wav def melspectrogram(wav, hp, pad=True): # Run through pre-emphasis if hp.preemphasis > 0: wav = preemphasis(wav, hp) assert np.abs(wav).max() - 1 < 1e-07 # Do the stft spec_complex = _stft(wav, hp, pad=pad) # Get the magnitudes spec_magnitudes = np.abs(spec_complex) if hp.mel_power != 1.0: spec_magnitudes **= hp.mel_power # Get the mel and convert magnitudes->db mel = np.dot(mel_basis(hp), spec_magnitudes) if hp.mel_type == "db": mel = _amp_to_db(mel, hp) # Normalise the mel from db to 0,1 if hp.normalized_mels: mel = _normalize(mel, hp).astype(np.float32) assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check return mel # (M, T) def _stft(y, hp, pad=True): # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for # historical consistency and streaming-version consistency return librosa.stft( y, n_fft=hp.n_fft, hop_length=hp.hop_size, win_length=hp.win_size, center=pad, pad_mode="reflect", ) def _amp_to_db(x, hp): return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x)) def _db_to_amp(x): return np.power(10.0, x * 0.05) def _normalize(s, hp, headroom_db=15): min_level_db = 20 * np.log10(hp.stft_magnitude_min) s = (s - min_level_db) / (-min_level_db + headroom_db) return s ================================================ FILE: src/chatterbox/models/voice_encoder/voice_encoder.py ================================================ # Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning # MIT License from typing import List, Union, Optional import numpy as np from numpy.lib.stride_tricks import as_strided import librosa import torch import torch.nn.functional as F from torch import nn, Tensor from .config import VoiceEncConfig from .melspec import melspectrogram def pack(arrays, seq_len: int=None, pad_value=0): """ Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of shape (B, T, ...) by padding each individual array on the right. :param arrays: a list of array-like objects of matching shapes except for the first axis. :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at minimum. Will default to that value if None. :param pad_value: the value to pad the arrays with. :return: a (B, T, ...) tensor """ if seq_len is None: seq_len = max(len(array) for array in arrays) else: assert seq_len >= max(len(array) for array in arrays) # Convert lists to np.array if isinstance(arrays[0], list): arrays = [np.array(array) for array in arrays] # Convert to tensor and handle device device = None if isinstance(arrays[0], torch.Tensor): tensors = arrays device = tensors[0].device else: tensors = [torch.as_tensor(array) for array in arrays] # Fill the packed tensor with the array data packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:]) packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device) for i, tensor in enumerate(tensors): packed_tensor[i, :tensor.size(0)] = tensor return packed_tensor def get_num_wins( n_frames: int, step: int, min_coverage: float, hp: VoiceEncConfig, ): assert n_frames > 0 win_size = hp.ve_partial_frames n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step) if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage: n_wins += 1 target_n = win_size + step * (n_wins - 1) return n_wins, target_n def get_frame_step( overlap: float, rate: float, hp: VoiceEncConfig, ): # Compute how many frames separate two partial utterances assert 0 <= overlap < 1 if rate is None: frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap))) else: frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames)) assert 0 < frame_step <= hp.ve_partial_frames return frame_step def stride_as_partials( mel: np.ndarray, hp: VoiceEncConfig, overlap=0.5, rate: float=None, min_coverage=0.8, ): """ Takes unscaled mels in (T, M) format TODO: doc """ assert 0 < min_coverage <= 1 frame_step = get_frame_step(overlap, rate, hp) # Compute how many partials can fit in the mel n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp) # Trim or pad the mel spectrogram to match the number of partials if target_len > len(mel): mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0))) elif target_len < len(mel): mel = mel[:target_len] # Ensure the numpy array data is float32 and contiguous in memory mel = mel.astype(np.float32, order="C") # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother, # where N is the number of partials, P is the number of frames of each partial and M the # number of channels of the mel spectrograms. shape = (n_partials, hp.ve_partial_frames, hp.num_mels) strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1]) partials = as_strided(mel, shape, strides) return partials class VoiceEncoder(nn.Module): def __init__(self, hp=VoiceEncConfig()): super().__init__() self.hp = hp # Network definition self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True) if hp.flatten_lstm_params: self.lstm.flatten_parameters() self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size) # Cosine similarity scaling (fixed initial parameter values) self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True) self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True) @property def device(self): return next(self.parameters()).device def forward(self, mels: torch.FloatTensor): """ Computes the embeddings of a batch of partial utterances. :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor of shape (B, T, M) where T is hp.ve_partial_frames :return: the embeddings as a float32 tensor of shape (B, E) where E is hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1]. """ if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1): raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}") # Pass the input through the LSTM layers _, (hidden, _) = self.lstm(mels) # Project the final hidden state raw_embeds = self.proj(hidden[-1]) if self.hp.ve_final_relu: raw_embeds = F.relu(raw_embeds) # L2 normalize the embeddings. return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None): """ Computes the embeddings of a batch of full utterances with gradients. :param mels: (B, T, M) unscaled mels :return: (B, E) embeddings on CPU """ mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens # Compute where to split the utterances into partials frame_step = get_frame_step(overlap, rate, self.hp) n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens)) # Possibly pad the mels to reach the target lengths len_diff = max(target_lens) - mels.size(1) if len_diff > 0: pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32) mels = torch.cat((mels, pad.to(mels.device)), dim=1) # Group all partials together so that we can batch them easily partials = [ mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames] for mel, n_partial in zip(mels, n_partials) for i in range(n_partial) ] assert all(partials[0].shape == partial.shape for partial in partials) partials = torch.stack(partials) # Forward the partials n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials)))) partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu() # Reduce the partial embeds into full embeds and L2-normalize them slices = np.concatenate(([0], np.cumsum(n_partials))) raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])] raw_embeds = torch.stack(raw_embeds) embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) return embeds @staticmethod def utt_to_spk_embed(utt_embeds: np.ndarray): """ Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a speaker embedding. """ assert utt_embeds.ndim == 2 utt_embeds = np.mean(utt_embeds, axis=0) return utt_embeds / np.linalg.norm(utt_embeds, 2) @staticmethod def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray): """ Cosine similarity for L2-normalized utterance embeddings or speaker embeddings """ embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x) embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y) return embeds_x @ embeds_y def embeds_from_mels( self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs ): """ Convenience function for deriving utterance or speaker embeddings from mel spectrograms. :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays. :param mel_lens: if passing mels as a tensor, individual mel lengths :param as_spk: whether to return utterance embeddings or a single speaker embedding :param kwargs: args for inference() :returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array """ # Load mels in memory and pack them if isinstance(mels, List): mels = [np.asarray(mel) for mel in mels] assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format" mel_lens = [mel.shape[0] for mel in mels] mels = pack(mels) # Embed them with torch.inference_mode(): utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy() return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds def embeds_from_wavs( self, wavs: List[np.ndarray], sample_rate, as_spk=False, batch_size=32, trim_top_db: Optional[float]=20, **kwargs ): """ Wrapper around embeds_from_mels :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation """ if sample_rate != self.hp.sample_rate: wavs = [ librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast") for wav in wavs ] if trim_top_db: wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs] if "rate" not in kwargs: kwargs["rate"] = 1.3 # Resemble's default value. mels = [melspectrogram(w, self.hp).T for w in wavs] return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs) ================================================ FILE: src/chatterbox/mtl_tts.py ================================================ from dataclasses import dataclass from pathlib import Path import os import librosa import torch import perth import torch.nn.functional as F from safetensors.torch import load_file as load_safetensors from huggingface_hub import snapshot_download from .models.t3 import T3 from .models.t3.modules.t3_config import T3Config from .models.s3tokenizer import S3_SR, drop_invalid_tokens from .models.s3gen import S3GEN_SR, S3Gen from .models.tokenizers import MTLTokenizer from .models.voice_encoder import VoiceEncoder from .models.t3.modules.cond_enc import T3Cond REPO_ID = "ResembleAI/chatterbox" # Supported languages for the multilingual model SUPPORTED_LANGUAGES = { "ar": "Arabic", "da": "Danish", "de": "German", "el": "Greek", "en": "English", "es": "Spanish", "fi": "Finnish", "fr": "French", "he": "Hebrew", "hi": "Hindi", "it": "Italian", "ja": "Japanese", "ko": "Korean", "ms": "Malay", "nl": "Dutch", "no": "Norwegian", "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sv": "Swedish", "sw": "Swahili", "tr": "Turkish", "zh": "Chinese", } def punc_norm(text: str) -> str: """ Quick cleanup func for punctuation from LLMs or containing chars not seen often in the dataset """ if len(text) == 0: return "You need to add some text for me to talk." # Capitalise first letter if text[0].islower(): text = text[0].upper() + text[1:] # Remove multiple space chars text = " ".join(text.split()) # Replace uncommon/llm punc punc_to_replace = [ ("...", ", "), ("…", ", "), (":", ","), (" - ", ", "), (";", ", "), ("—", "-"), ("–", "-"), (" ,", ","), ("“", "\""), ("”", "\""), ("‘", "'"), ("’", "'"), ] for old_char_sequence, new_char in punc_to_replace: text = text.replace(old_char_sequence, new_char) # Add full stop if no ending punc text = text.rstrip(" ") sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"} if not any(text.endswith(p) for p in sentence_enders): text += "." return text @dataclass class Conditionals: """ Conditionals for T3 and S3Gen - T3 conditionals: - speaker_emb - clap_emb - cond_prompt_speech_tokens - cond_prompt_speech_emb - emotion_adv - S3Gen conditionals: - prompt_token - prompt_token_len - prompt_feat - prompt_feat_len - embedding """ t3: T3Cond gen: dict def to(self, device): self.t3 = self.t3.to(device=device) for k, v in self.gen.items(): if torch.is_tensor(v): self.gen[k] = v.to(device=device) return self def save(self, fpath: Path): arg_dict = dict( t3=self.t3.__dict__, gen=self.gen ) torch.save(arg_dict, fpath) @classmethod def load(cls, fpath, map_location="cpu"): if isinstance(map_location, str): map_location = torch.device(map_location) kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return cls(T3Cond(**kwargs['t3']), kwargs['gen']) class ChatterboxMultilingualTTS: ENC_COND_LEN = 6 * S3_SR DEC_COND_LEN = 10 * S3GEN_SR def __init__( self, t3: T3, s3gen: S3Gen, ve: VoiceEncoder, tokenizer: MTLTokenizer, device: str, conds: Conditionals = None, ): self.sr = S3GEN_SR # sample rate of synthesized audio self.t3 = t3 self.s3gen = s3gen self.ve = ve self.tokenizer = tokenizer self.device = device self.conds = conds self.watermarker = perth.PerthImplicitWatermarker() @classmethod def get_supported_languages(cls): """Return dictionary of supported language codes and names.""" return SUPPORTED_LANGUAGES.copy() @classmethod def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS': ckpt_dir = Path(ckpt_dir) # Always load to CPU first for non-CUDA devices to handle CUDA-saved models if device in ["cpu", "mps"]: map_location = torch.device('cpu') else: map_location = None ve = VoiceEncoder() ve.load_state_dict( torch.load(ckpt_dir / "ve.pt", map_location=map_location, weights_only=True) ) ve.to(device).eval() t3 = T3(T3Config.multilingual()) t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors") if "model" in t3_state.keys(): t3_state = t3_state["model"][0] t3.load_state_dict(t3_state) t3.to(device).eval() s3gen = S3Gen() s3gen.load_state_dict( torch.load(ckpt_dir / "s3gen.pt", map_location=map_location, weights_only=True) ) s3gen.to(device).eval() tokenizer = MTLTokenizer( str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json") ) conds = None if (builtin_voice := ckpt_dir / "conds.pt").exists(): conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): print("MPS not available because the current PyTorch install was not built with MPS enabled.") else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" ckpt_dir = Path( snapshot_download( repo_id=REPO_ID, repo_type="model", revision="main", allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"], token=os.getenv("HF_TOKEN"), ) ) return cls.from_local(ckpt_dir, device) def prepare_conditionals(self, wav_fpath, exaggeration=0.5): ## Load reference wav s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) # Speech cond prompt tokens t3_cond_prompt_tokens = None if plen := self.t3.hp.speech_cond_prompt_len: s3_tokzr = self.s3gen.tokenizer t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) # Voice-encoder speaker embedding ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) t3_cond = T3Cond( speaker_emb=ve_embed, cond_prompt_speech_tokens=t3_cond_prompt_tokens, emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) self.conds = Conditionals(t3_cond, s3gen_ref_dict) def generate( self, text, language_id, audio_prompt_path=None, exaggeration=0.5, cfg_weight=0.5, temperature=0.8, repetition_penalty=2.0, min_p=0.05, top_p=1.0, ): # Validate language_id if language_id and language_id.lower() not in SUPPORTED_LANGUAGES: supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys()) raise ValueError( f"Unsupported language_id '{language_id}'. " f"Supported languages: {supported_langs}" ) if audio_prompt_path: self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) else: assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" # Update exaggeration if needed if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()): _cond: T3Cond = self.conds.t3 self.conds.t3 = T3Cond( speaker_emb=_cond.speaker_emb, cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) # Norm and tokenize text text = punc_norm(text) text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device) text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG sot = self.t3.hp.start_text_token eot = self.t3.hp.stop_text_token text_tokens = F.pad(text_tokens, (1, 0), value=sot) text_tokens = F.pad(text_tokens, (0, 1), value=eot) with torch.inference_mode(): speech_tokens = self.t3.inference( t3_cond=self.conds.t3, text_tokens=text_tokens, max_new_tokens=1000, # TODO: use the value in config temperature=temperature, cfg_weight=cfg_weight, repetition_penalty=repetition_penalty, min_p=min_p, top_p=top_p, ) # Extract only the conditional batch. speech_tokens = speech_tokens[0] # TODO: output becomes 1D speech_tokens = drop_invalid_tokens(speech_tokens) speech_tokens = speech_tokens.to(self.device) wav, _ = self.s3gen.inference( speech_tokens=speech_tokens, ref_dict=self.conds.gen, ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) return torch.from_numpy(watermarked_wav).unsqueeze(0) ================================================ FILE: src/chatterbox/tts.py ================================================ from dataclasses import dataclass from pathlib import Path import librosa import torch import perth import torch.nn.functional as F from huggingface_hub import hf_hub_download from safetensors.torch import load_file from .models.t3 import T3 from .models.s3tokenizer import S3_SR, drop_invalid_tokens from .models.s3gen import S3GEN_SR, S3Gen from .models.tokenizers import EnTokenizer from .models.voice_encoder import VoiceEncoder from .models.t3.modules.cond_enc import T3Cond REPO_ID = "ResembleAI/chatterbox" def punc_norm(text: str) -> str: """ Quick cleanup func for punctuation from LLMs or containing chars not seen often in the dataset """ if len(text) == 0: return "You need to add some text for me to talk." # Capitalise first letter if text[0].islower(): text = text[0].upper() + text[1:] # Remove multiple space chars text = " ".join(text.split()) # Replace uncommon/llm punc punc_to_replace = [ ("...", ", "), ("…", ", "), (":", ","), (" - ", ", "), (";", ", "), ("—", "-"), ("–", "-"), (" ,", ","), ("“", "\""), ("”", "\""), ("‘", "'"), ("’", "'"), ] for old_char_sequence, new_char in punc_to_replace: text = text.replace(old_char_sequence, new_char) # Add full stop if no ending punc text = text.rstrip(" ") sentence_enders = {".", "!", "?", "-", ","} if not any(text.endswith(p) for p in sentence_enders): text += "." return text @dataclass class Conditionals: """ Conditionals for T3 and S3Gen - T3 conditionals: - speaker_emb - clap_emb - cond_prompt_speech_tokens - cond_prompt_speech_emb - emotion_adv - S3Gen conditionals: - prompt_token - prompt_token_len - prompt_feat - prompt_feat_len - embedding """ t3: T3Cond gen: dict def to(self, device): self.t3 = self.t3.to(device=device) for k, v in self.gen.items(): if torch.is_tensor(v): self.gen[k] = v.to(device=device) return self def save(self, fpath: Path): arg_dict = dict( t3=self.t3.__dict__, gen=self.gen ) torch.save(arg_dict, fpath) @classmethod def load(cls, fpath, map_location="cpu"): if isinstance(map_location, str): map_location = torch.device(map_location) kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return cls(T3Cond(**kwargs['t3']), kwargs['gen']) class ChatterboxTTS: ENC_COND_LEN = 6 * S3_SR DEC_COND_LEN = 10 * S3GEN_SR def __init__( self, t3: T3, s3gen: S3Gen, ve: VoiceEncoder, tokenizer: EnTokenizer, device: str, conds: Conditionals = None, ): self.sr = S3GEN_SR # sample rate of synthesized audio self.t3 = t3 self.s3gen = s3gen self.ve = ve self.tokenizer = tokenizer self.device = device self.conds = conds self.watermarker = perth.PerthImplicitWatermarker() @classmethod def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS': ckpt_dir = Path(ckpt_dir) # Always load to CPU first for non-CUDA devices to handle CUDA-saved models if device in ["cpu", "mps"]: map_location = torch.device('cpu') else: map_location = None ve = VoiceEncoder() ve.load_state_dict( load_file(ckpt_dir / "ve.safetensors") ) ve.to(device).eval() t3 = T3() t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") if "model" in t3_state.keys(): t3_state = t3_state["model"][0] t3.load_state_dict(t3_state) t3.to(device).eval() s3gen = S3Gen() s3gen.load_state_dict( load_file(ckpt_dir / "s3gen.safetensors"), strict=False ) s3gen.to(device).eval() tokenizer = EnTokenizer( str(ckpt_dir / "tokenizer.json") ) conds = None if (builtin_voice := ckpt_dir / "conds.pt").exists(): conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod def from_pretrained(cls, device) -> 'ChatterboxTTS': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): print("MPS not available because the current PyTorch install was not built with MPS enabled.") else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) return cls.from_local(Path(local_path).parent, device) def prepare_conditionals(self, wav_fpath, exaggeration=0.5): ## Load reference wav s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) # Speech cond prompt tokens if plen := self.t3.hp.speech_cond_prompt_len: s3_tokzr = self.s3gen.tokenizer t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) # Voice-encoder speaker embedding ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) t3_cond = T3Cond( speaker_emb=ve_embed, cond_prompt_speech_tokens=t3_cond_prompt_tokens, emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) self.conds = Conditionals(t3_cond, s3gen_ref_dict) def generate( self, text, repetition_penalty=1.2, min_p=0.05, top_p=1.0, audio_prompt_path=None, exaggeration=0.5, cfg_weight=0.5, temperature=0.8, ): if audio_prompt_path: self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) else: assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" # Update exaggeration if needed if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]: _cond: T3Cond = self.conds.t3 self.conds.t3 = T3Cond( speaker_emb=_cond.speaker_emb, cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) # Norm and tokenize text text = punc_norm(text) text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) if cfg_weight > 0.0: text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG sot = self.t3.hp.start_text_token eot = self.t3.hp.stop_text_token text_tokens = F.pad(text_tokens, (1, 0), value=sot) text_tokens = F.pad(text_tokens, (0, 1), value=eot) with torch.inference_mode(): speech_tokens = self.t3.inference( t3_cond=self.conds.t3, text_tokens=text_tokens, max_new_tokens=1000, # TODO: use the value in config temperature=temperature, cfg_weight=cfg_weight, repetition_penalty=repetition_penalty, min_p=min_p, top_p=top_p, ) # Extract only the conditional batch. speech_tokens = speech_tokens[0] # TODO: output becomes 1D speech_tokens = drop_invalid_tokens(speech_tokens) speech_tokens = speech_tokens[speech_tokens < 6561] speech_tokens = speech_tokens.to(self.device) wav, _ = self.s3gen.inference( speech_tokens=speech_tokens, ref_dict=self.conds.gen, ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) return torch.from_numpy(watermarked_wav).unsqueeze(0) ================================================ FILE: src/chatterbox/tts_turbo.py ================================================ import os import math from dataclasses import dataclass from pathlib import Path import librosa import torch import perth import pyloudnorm as ln from safetensors.torch import load_file from huggingface_hub import snapshot_download from transformers import AutoTokenizer from .models.t3 import T3 from .models.s3tokenizer import S3_SR from .models.s3gen import S3GEN_SR, S3Gen from .models.tokenizers import EnTokenizer from .models.voice_encoder import VoiceEncoder from .models.t3.modules.cond_enc import T3Cond from .models.t3.modules.t3_config import T3Config from .models.s3gen.const import S3GEN_SIL import logging logger = logging.getLogger(__name__) REPO_ID = "ResembleAI/chatterbox-turbo" def punc_norm(text: str) -> str: """ Quick cleanup func for punctuation from LLMs or containing chars not seen often in the dataset """ if len(text) == 0: return "You need to add some text for me to talk." # Capitalise first letter if text[0].islower(): text = text[0].upper() + text[1:] # Remove multiple space chars text = " ".join(text.split()) # Replace uncommon/llm punc punc_to_replace = [ ("…", ", "), (":", ","), ("—", "-"), ("–", "-"), (" ,", ","), ("“", "\""), ("”", "\""), ("‘", "'"), ("’", "'"), ] for old_char_sequence, new_char in punc_to_replace: text = text.replace(old_char_sequence, new_char) # Add full stop if no ending punc text = text.rstrip(" ") sentence_enders = {".", "!", "?", "-", ","} if not any(text.endswith(p) for p in sentence_enders): text += "." return text @dataclass class Conditionals: """ Conditionals for T3 and S3Gen - T3 conditionals: - speaker_emb - clap_emb - cond_prompt_speech_tokens - cond_prompt_speech_emb - emotion_adv - S3Gen conditionals: - prompt_token - prompt_token_len - prompt_feat - prompt_feat_len - embedding """ t3: T3Cond gen: dict def to(self, device): self.t3 = self.t3.to(device=device) for k, v in self.gen.items(): if torch.is_tensor(v): self.gen[k] = v.to(device=device) return self def save(self, fpath: Path): arg_dict = dict( t3=self.t3.__dict__, gen=self.gen ) torch.save(arg_dict, fpath) @classmethod def load(cls, fpath, map_location="cpu"): if isinstance(map_location, str): map_location = torch.device(map_location) kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return cls(T3Cond(**kwargs['t3']), kwargs['gen']) class ChatterboxTurboTTS: ENC_COND_LEN = 15 * S3_SR DEC_COND_LEN = 10 * S3GEN_SR def __init__( self, t3: T3, s3gen: S3Gen, ve: VoiceEncoder, tokenizer: EnTokenizer, device: str, conds: Conditionals = None, ): self.sr = S3GEN_SR # sample rate of synthesized audio self.t3 = t3 self.s3gen = s3gen self.ve = ve self.tokenizer = tokenizer self.device = device self.conds = conds self.watermarker = perth.PerthImplicitWatermarker() @classmethod def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS': ckpt_dir = Path(ckpt_dir) # Always load to CPU first for non-CUDA devices to handle CUDA-saved models if device in ["cpu", "mps"]: map_location = torch.device('cpu') else: map_location = None ve = VoiceEncoder() ve.load_state_dict( load_file(ckpt_dir / "ve.safetensors") ) ve.to(device).eval() # Turbo specific hp hp = T3Config(text_tokens_dict_size=50276) hp.llama_config_name = "GPT2_medium" hp.speech_tokens_dict_size = 6563 hp.input_pos_emb = None hp.speech_cond_prompt_len = 375 hp.use_perceiver_resampler = False hp.emotion_adv = False t3 = T3(hp) t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors") if "model" in t3_state.keys(): t3_state = t3_state["model"][0] t3.load_state_dict(t3_state) del t3.tfmr.wte t3.to(device).eval() s3gen = S3Gen(meanflow=True) weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors") s3gen.load_state_dict( weights, strict=True ) s3gen.to(device).eval() tokenizer = AutoTokenizer.from_pretrained(ckpt_dir) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if len(tokenizer) != 50276: print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276") conds = None builtin_voice = ckpt_dir / "conds.pt" if builtin_voice.exists(): conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod def from_pretrained(cls, device) -> 'ChatterboxTurboTTS': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): print("MPS not available because the current PyTorch install was not built with MPS enabled.") else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" local_path = snapshot_download( repo_id=REPO_ID, token=os.getenv("HF_TOKEN") or None, # Optional: Filter to download only what you need allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"] ) return cls.from_local(local_path, device) def norm_loudness(self, wav, sr, target_lufs=-27): try: meter = ln.Meter(sr) loudness = meter.integrated_loudness(wav) gain_db = target_lufs - loudness gain_linear = 10.0 ** (gain_db / 20.0) if math.isfinite(gain_linear) and gain_linear > 0.0: wav = wav * gain_linear except Exception as e: print(f"Warning: Error in norm_loudness, skipping: {e}") return wav def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True): ## Load and norm reference wav s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!" if norm_loudness: s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr) ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) # Speech cond prompt tokens if plen := self.t3.hp.speech_cond_prompt_len: s3_tokzr = self.s3gen.tokenizer t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) # Voice-encoder speaker embedding ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) t3_cond = T3Cond( speaker_emb=ve_embed, cond_prompt_speech_tokens=t3_cond_prompt_tokens, emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) self.conds = Conditionals(t3_cond, s3gen_ref_dict) def generate( self, text, repetition_penalty=1.2, min_p=0.00, top_p=0.95, audio_prompt_path=None, exaggeration=0.0, cfg_weight=0.0, temperature=0.8, top_k=1000, norm_loudness=True, ): if audio_prompt_path: self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness) else: assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0: logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.") # Norm and tokenize text text = punc_norm(text) text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) text_tokens = text_tokens.input_ids.to(self.device) speech_tokens = self.t3.inference_turbo( t3_cond=self.conds.t3, text_tokens=text_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, ) # Remove OOV tokens and add silence to end speech_tokens = speech_tokens[speech_tokens < 6561] speech_tokens = speech_tokens.to(self.device) silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device) speech_tokens = torch.cat([speech_tokens, silence]) wav, _ = self.s3gen.inference( speech_tokens=speech_tokens, ref_dict=self.conds.gen, n_cfm_timesteps=2, ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) return torch.from_numpy(watermarked_wav).unsqueeze(0) ================================================ FILE: src/chatterbox/vc.py ================================================ from pathlib import Path import librosa import torch import perth from huggingface_hub import hf_hub_download from safetensors.torch import load_file from .models.s3tokenizer import S3_SR from .models.s3gen import S3GEN_SR, S3Gen REPO_ID = "ResembleAI/chatterbox" class ChatterboxVC: ENC_COND_LEN = 6 * S3_SR DEC_COND_LEN = 10 * S3GEN_SR def __init__( self, s3gen: S3Gen, device: str, ref_dict: dict=None, ): self.sr = S3GEN_SR self.s3gen = s3gen self.device = device self.watermarker = perth.PerthImplicitWatermarker() if ref_dict is None: self.ref_dict = None else: self.ref_dict = { k: v.to(device) if torch.is_tensor(v) else v for k, v in ref_dict.items() } @classmethod def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC': ckpt_dir = Path(ckpt_dir) # Always load to CPU first for non-CUDA devices to handle CUDA-saved models if device in ["cpu", "mps"]: map_location = torch.device('cpu') else: map_location = None ref_dict = None if (builtin_voice := ckpt_dir / "conds.pt").exists(): states = torch.load(builtin_voice, map_location=map_location) ref_dict = states['gen'] s3gen = S3Gen() s3gen.load_state_dict( load_file(ckpt_dir / "s3gen.safetensors"), strict=False ) s3gen.to(device).eval() return cls(s3gen, device, ref_dict=ref_dict) @classmethod def from_pretrained(cls, device) -> 'ChatterboxVC': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): print("MPS not available because the current PyTorch install was not built with MPS enabled.") else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" for fpath in ["s3gen.safetensors", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) return cls.from_local(Path(local_path).parent, device) def set_target_voice(self, wav_fpath): ## Load reference wav s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) def generate( self, audio, target_voice_path=None, ): if target_voice_path: self.set_target_voice(target_voice_path) else: assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`" with torch.inference_mode(): audio_16, _ = librosa.load(audio, sr=S3_SR) audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ] s3_tokens, _ = self.s3gen.tokenizer(audio_16) wav, _ = self.s3gen.inference( speech_tokens=s3_tokens, ref_dict=self.ref_dict, ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) return torch.from_numpy(watermarked_wav).unsqueeze(0)