Showing preview only (272K chars total). Download the full file or copy to clipboard to get everything.
Repository: davidbrowne17/csm-streaming
Branch: main
Commit: 121552fcd68e
Files: 25
Total size: 262.0 KB
Directory structure:
gitextract_5x1lml4p/
├── .github/
│ └── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── finetuned_model/
│ └── config.json
├── generator.py
├── llm_interface.py
├── loadandmergecheckpoint.py
├── lora.py
├── main.py
├── models.py
├── rag_system.py
├── requirements.txt
├── run_csm.py
├── setup.py
├── static/
│ ├── app.js
│ ├── chat.js
│ └── crud.js
├── templates/
│ ├── chat.html
│ ├── crud.html
│ ├── index.html
│ └── setup.html
├── test.py
└── vad.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: davidbrowne17
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
================================================
FILE: .gitignore
================================================
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
.env
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# Project specific
.python-version
*.wav
output_*/
basic_audio.wav
full_conversation.wav
context_audio.wav
# Model files
*.pt
*.ckpt
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# CSM - Optimized Streaming/Finetuning Edition
---
CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
Our fork adds **streaming audio generation**, **real-time playback**, and **performance optimizations** to the original implementation.
## Requirements
* A CUDA-compatible GPU
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
* Similarly, Python 3.10 is recommended, but newer versions may be fine
* For some audio operations, `ffmpeg` may be required
* For real-time audio playback: `pip install sounddevice`
* Access to the following Hugging Face models:
* [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
* [CSM-1B](https://huggingface.co/sesame/csm-1b)
### Setup
```bash
sudo apt-get update && sudo apt-get install -y libportaudio2 libportaudio-dev
git clone git@github.com:davidbrowne17/csm-streaming.git
cd csm-streaming
python3.10 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
# Optional speedup
pip install flash-attn
# You will need access to CSM-1B and Llama-3.2-1B
huggingface-cli login
```
### Windows Setup
The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
The realtime demo uses VLLM for inference speed. This is currently not supported for windows but you can try with https://github.com/SystemPanic/vllm-windows until support is added.
## Quickstart
Generate a sentence with streaming (chunks are processed and output as they're generated):
```python
import time
from huggingface_hub import hf_hub_download
from generator import Generator, Segment, load_csm_1b, generate_streaming_audio
import torchaudio
# Load the model
generator = load_csm_1b("cuda")
# Generate audio with streaming and real-time playback
generate_streaming_audio(
generator=generator,
text="Hello, this is streaming audio generation in action!",
speaker=0,
context=[], # No context needed for basic generation
output_file="streaming_audio.wav",
play_audio=True # Enable real-time playback
)
```
## Finetuning
To finetune CSM all you need are some wav audio files with the speaker voice you want to train, just the raw wavs. Place them in a folder called audio_data and run lora.py.
You can configure the exact training params such as batch size, number of epochs and learning rate by modifying the values at the top of lora.py.
You will need a CUDA gpu with at least 12gb of vram depending on your dataset size and training params. You can monitor the training metrics via the dynamic png created in /finetuned_model/ folder. This contains various graphs to help you track the training progress. If you want to try a checkpoint you can use the loadandmergecheckpoint.py (make sure to set the same R and Alpha values as you used in the training)
## Realtime chat demo
To use the realtime demo run the setup.py to download the required models, and then run main.py. This will open up a setup page at http://localhost:8000 in which you can set the paths for your chosen LLM and setup the CSM paths and reference audio as well as select your headset and mic. When loaded you will be able to chat in realtime with the AI just like the CSM demo. Our demo includes a dynamic RAG system so the AI can remember previous conversations. The demo by default uses whisper-large-v3-turbo for STT and includes Automatic Voice Detection using Silero VAD.
## Usage
Our optimized version offers several ways to use CSM with streaming capabilities:
### 1. Basic Streaming Generation
Generate audio with streaming and save to a file:
```python
from generator import load_csm_1b, generate_streaming_audio
generator = load_csm_1b("cuda")
# Generate with streaming (writes to file as it generates)
generate_streaming_audio(
generator=generator,
text="This audio will be generated in chunks for faster response times.",
speaker=0,
context=[],
output_file="streaming_output.wav"
)
```
### 2. Real-time Audio Playback
Generate and play audio in real-time as it's being generated:
```python
from generator import load_csm_1b, generate_streaming_audio
generator = load_csm_1b("cuda")
# Generate with streaming and play in real-time
generate_streaming_audio(
generator=generator,
text="You'll hear me speaking as I'm being generated!",
speaker=0,
context=[],
output_file="streaming_output.wav",
play_audio=True # Enable real-time playback
)
```
### 3. Low-level Streaming API
For more control, use the low-level streaming API:
```python
from generator import load_csm_1b, Segment
import torchaudio
generator = load_csm_1b("cuda")
# Process audio chunks as they're generated
for audio_chunk in generator.generate_stream(
text="This is generated chunk by chunk.",
speaker=0,
context=[]
):
# Do something with each chunk as it's generated
print(f"Received chunk of size: {audio_chunk.shape}")
# You could process or play each chunk here
# For example, write to a file incrementally
# Or send over a network connection
```
### 4. Generate with Context
For best results, provide reference audio context:
```python
from generator import load_csm_1b, Segment, generate_streaming_audio
import torchaudio
generator = load_csm_1b("cuda")
# Load reference audio
def load_audio(audio_path):
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
)
return audio_tensor
# Create context segments
segments = [
Segment(
text="I knew I could trust you.",
speaker=0,
audio=load_audio("reference.wav")
)
]
# Generate with streaming using the context
generate_streaming_audio(
generator=generator,
text="Me too, this is some cool stuff huh?",
speaker=0,
context=segments,
output_file="contextual_streaming.wav",
play_audio=True
)
```
### 5. Regular Generation with Internal Streaming
Use the original API with streaming enabled internally:
```python
from generator import load_csm_1b, Segment
import torchaudio
generator = load_csm_1b("cuda")
# Regular generation but with internal streaming optimization
audio = generator.generate(
text="This uses internal streaming for faster processing.",
speaker=0,
context=[],
max_audio_length_ms=10_000,
stream=True # Enable internal streaming optimization
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
## Performance Optimizations
Our optimized version includes several performance enhancements:
- **Streaming Generation**: Processes and outputs audio in chunks instead of waiting for the entire generation achieving a Real-time factor (RTF): 0.28x (target: <1.0) on a 4090 (10 seconds of audio takes 2.8 seconds to generate)
- **Frame Batching**: Processes multiple frames at once for better GPU utilization
- **Half-precision Inference**: Uses bfloat16/float16 for faster processing
- **CUDA Optimizations**: Enables cuDNN benchmarking and Flash Attention where available
- **Memory Management**: Clears GPU cache before generation to reduce memory pressure
## FAQ
**How much faster is the streaming version?**
The perceived response time is significantly faster since you get the first audio chunks in milliseconds instead of waiting for the entire generation to complete. The actual total generation time is also improved by 40-60% depending on your hardware.
**Does this model come with any voices?**
The model is a base generation model capable of producing a variety of voices but hasn't been fine-tuned on any specific voice. Provide reference audio for best results.
**Can I converse with the model?**
CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. Using a seperate LLM you can converse with the realtime demo via the web ui.
**Does it support other languages?**
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
## Misuse and abuse ⚠️
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
---
## Original Authors
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.
## Streaming, Realtime Demo and Finetuning Implementation
David Browne
## Support me
Support this project on Ko-fi: https://ko-fi.com/davidbrowne17
## Transformers streaming
If you want to use streaming with the Transformers implementation you can find it here: https://github.com/davidbrowne17/csm-streaming-tf
================================================
FILE: config.py
================================================
import os
import json
import logging
from typing import Dict, Any, Optional
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class ConfigManager:
"""
Manages configuration persistence for the AI Companion app.
Saves and loads configuration to avoid re-entering model paths.
"""
def __init__(self, config_path: str = "config/app_config.json"):
"""
Initialize the configuration manager.
Args:
config_path: Path to store the configuration file
"""
self.config_path = config_path
self.config_dir = os.path.dirname(config_path)
# Create config directory if it doesn't exist
if not os.path.exists(self.config_dir):
os.makedirs(self.config_dir, exist_ok=True)
logger.info(f"Created configuration directory: {self.config_dir}")
def save_config(self, config_data: Dict[str, Any]) -> bool:
"""
Save configuration data to the config file.
Args:
config_data: Configuration data to save
Returns:
bool: True if successful, False otherwise
"""
try:
# Ensure directory exists
os.makedirs(self.config_dir, exist_ok=True)
print(config_data)
# Verify all reference paths are included
ref_paths = [
"reference_audio_path",
"reference_audio_path2",
"reference_audio_path3"
]
# Log which references are being saved
for path_key in ref_paths:
if path_key in config_data and config_data[path_key]:
logger.info(f"Saving reference path: {path_key}={config_data[path_key]}")
else:
logger.info(f"No {path_key} provided in configuration")
# Save configuration
with open(self.config_path, 'w') as f:
json.dump(config_data, f, indent=2)
logger.info(f"Configuration saved to {self.config_path}")
return True
except Exception as e:
logger.error(f"Failed to save configuration: {e}")
return False
def load_config(self) -> Optional[Dict[str, Any]]:
"""
Load configuration data from the config file.
Returns:
Dict or None: Configuration data if successful, None otherwise
"""
if not os.path.exists(self.config_path):
logger.info(f"Configuration file does not exist at {self.config_path}")
return None
try:
with open(self.config_path, 'r') as f:
config_data = json.load(f)
# Log which references are being loaded
ref_paths = [
"reference_audio_path",
"reference_audio_path2",
"reference_audio_path3"
]
for path_key in ref_paths:
if path_key in config_data and config_data[path_key]:
logger.info(f"Loaded reference path: {path_key}={config_data[path_key]}")
logger.info(f"Configuration loaded from {self.config_path}")
return config_data
except Exception as e:
logger.error(f"Failed to load configuration: {e}")
return None
# Helper function to convert Pydantic model to dict
def model_to_dict(model: BaseModel) -> Dict[str, Any]:
"""Convert a Pydantic model to a dictionary suitable for JSON serialization"""
return json.loads(model.json())
================================================
FILE: finetuned_model/config.json
================================================
{
"audio_num_codebooks": 32,
"audio_vocab_size": 2051,
"backbone_flavor": "llama-1B",
"decoder_flavor": "llama-100M",
"text_vocab_size": 128256
}
================================================
FILE: generator.py
================================================
from dataclasses import dataclass
import math
import os
from typing import List, Tuple, Generator as PyGenerator, Optional, Callable
import time
import queue
import threading
import platform
from typing_extensions import OrderedDict
import wave
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from models import Model, ModelArgs
from moshi.models import loaders
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer
import logging
logger = logging.getLogger(__name__)
@dataclass
class Segment:
speaker: int
text: str
sample_rate = 24_000
audio: torch.Tensor
def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "unsloth/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
class Generator:
def __init__(self, model: Model):
self._model = model
self._model.setup_caches(1)
self._text_tokenizer = load_llama3_tokenizer()
device = next(model.parameters()).device
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
num_codebooks = model.config.audio_num_codebooks
mimi.set_num_codebooks(num_codebooks)
self._num_codebooks = num_codebooks
self._audio_tokenizer = mimi
self.sample_rate = mimi.sample_rate
self.device = device
self._stream_buffer_size = 20
self.max_seq_len = 2048
self._cache = OrderedDict()
self._text_token_cache = {}
torch.set_num_threads(16)
torch.cuda.set_per_process_memory_fraction(0.95)
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize text segment with caching optimization for reduced latency.
"""
# Check cache first
cache_key = f"{speaker}:{text}"
if not hasattr(self, '_text_token_cache'):
self._text_token_cache = {}
if cache_key in self._text_token_cache:
return self._text_token_cache[cache_key]
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.long, device=self.device)
text_frame_mask = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.bool, device=self.device)
text_frame[:, -1] = torch.tensor(text_tokens, device=self.device)
text_frame_mask[:, -1] = True
frame_tokens = [text_frame]
frame_masks = [text_frame_mask]
result = (torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0))
self._text_token_cache[cache_key] = result
return result
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
frame_tokens = []
frame_masks = []
# (K, T)
audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# Limit to the number of codebooks set in MIMI
audio_tokens = audio_tokens[:self._num_codebooks, :]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).bool().to(self.device)
audio_frame[:, :self._num_codebooks] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :self._num_codebooks] = True
frame_tokens.append(audio_frame)
frame_masks.append(audio_frame_mask)
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
total_len = text_tokens.size(0) + audio_tokens.size(0)
if total_len > self.max_seq_len:
overflow = total_len - self.max_seq_len
if text_tokens.size(0) > overflow:
text_tokens = text_tokens[overflow:]
text_masks = text_masks[overflow:]
else:
audio_overflow = overflow - text_tokens.size(0)
text_tokens = text_tokens[0:0]
text_masks = text_masks[0:0]
audio_tokens = audio_tokens[audio_overflow:]
audio_masks = audio_masks[audio_overflow:]
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
@torch.inference_mode()
def _decode_frames(self, frames):
if not frames:
return torch.tensor([])
# Only use first N codebooks for faster decoding
frames_reduced = [frame[:, :self._num_codebooks//2] for frame in frames]
audio = self._audio_tokenizer.decode(torch.stack(frames_reduced).permute(1, 2, 0)).squeeze(0).squeeze(0)
return audio
@torch.inference_mode()
def generate_stream(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.7,
topk: int = 30,
on_chunk_generated: Optional[Callable[[torch.Tensor], None]] = None,
):
"""
Generate audio in a streaming fashion, optimized for lower latency to first chunk.
"""
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
torch.cuda.synchronize()
self._model.reset_caches()
max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
initial_batch_size = 20
normal_batch_size = 20
initial_buffer_size = 20
normal_buffer_size = 20
batch_size = initial_batch_size
buffer_size = initial_buffer_size
first_chunk_delivered = False
context_start = time.time()
if context:
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
max_seq_len = 2048
if prompt_tokens.size(0) > max_seq_len:
prompt_tokens = prompt_tokens[-max_seq_len:]
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
expected_frame_count = buffer_size
frame_buffer = []
zeros_1_1 = torch.zeros(1, 1).long().to(self.device)
zeros_mask_1_1 = torch.zeros(1, 1).bool().to(self.device)
def update_tokens(sample):
nonlocal curr_tokens, curr_tokens_mask, curr_pos
ones = torch.ones_like(sample).bool()
curr_tokens = torch.cat([sample, zeros_1_1], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat([ones, zeros_mask_1_1], dim=1).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
with self._audio_tokenizer.streaming(1):
i = 0
generation_start = time.time()
while i < max_generation_len:
batch_end = min(i + batch_size, max_generation_len)
batch_size_actual = batch_end - i
batch_samples = []
for _ in range(batch_size_actual):
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
if torch.cuda.is_available() and hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available"):
try:
torch.cuda.synchronize() # Force sync before checking
if sample.numel() == 0 or torch.isnan(sample).any():
print("Warning: Generated empty or NaN sample, stopping generation")
break
except:
print("Error checking tensor, stopping generation")
break
if torch.all(sample == 0):
break
batch_samples.append(sample)
update_tokens(sample)
if not batch_samples:
break
frame_buffer.extend(batch_samples)
i += len(batch_samples)
if len(frame_buffer) >= buffer_size:
frames_to_process = frame_buffer[:expected_frame_count]
# If we don't have enough frames, pad with zeros to match expected shape
if len(frames_to_process) < expected_frame_count:
# Create padding frames (zeros)
padding_frames = [
torch.zeros_like(frames_to_process[0])
for _ in range(expected_frame_count - len(frames_to_process))
]
# Combine actual frames with padding
frames_to_process = frames_to_process + padding_frames
frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)
audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)
# Keep remaining frames for next iteration
frame_buffer = frame_buffer[expected_frame_count:]
# Process and yield the chunk
cpu_chunk = audio_chunk.cpu()
if on_chunk_generated:
on_chunk_generated(cpu_chunk)
# After first chunk is delivered, switch to normal batch and buffer sizes
if not first_chunk_delivered:
batch_size = normal_batch_size
buffer_size = normal_buffer_size
expected_frame_count = buffer_size
first_chunk_delivered = True
yield cpu_chunk
# Occasionally print progress and sync GPU
if i >= 100 and (i % 100 == 0):
if torch.cuda.is_available():
torch.cuda.synchronize()
print(f"Generated {i} frames ({i * 0.08:.2f}s of audio)")
# Process any remaining frames
if frame_buffer:
# Pad frame buffer if necessary
if len(frame_buffer) < expected_frame_count:
padding_frames = [
torch.zeros_like(frame_buffer[0])
for _ in range(expected_frame_count - len(frame_buffer))
]
frames_to_process = frame_buffer + padding_frames
else:
# Otherwise take as many frames as possible that are a multiple of expected_frame_count
frames_multiple = (len(frame_buffer) // expected_frame_count) * expected_frame_count
frames_to_process = frame_buffer[:frames_multiple]
frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)
audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)
# Determine actual audio length (before padding)
actual_frames_percentage = min(len(frame_buffer), expected_frame_count) / expected_frame_count
actual_samples = int(audio_chunk.shape[0] * actual_frames_percentage)
# Return only the non-padded portion of audio if we added padding
if len(frame_buffer) < expected_frame_count:
audio_chunk = audio_chunk[:actual_samples]
cpu_chunk = audio_chunk.cpu()
if on_chunk_generated:
on_chunk_generated(cpu_chunk)
yield cpu_chunk
# Print final performance metrics
if torch.cuda.is_available():
torch.cuda.synchronize()
total_time = time.time() - generation_start
frames_generated = i
audio_seconds = frames_generated * 0.08
rtf = total_time / audio_seconds if audio_seconds > 0 else float('inf')
print(f"Total time: {total_time:.2f}s")
print(f"Generated {frames_generated} frames ({audio_seconds:.2f}s of audio)")
print(f"Real-time factor: {rtf:.3f}x (target: <1.0)")
@torch.inference_mode()
def generate(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.8,
topk: int = 40,
stream: bool = False,
output_file: Optional[str] = None,
):
"""
Generate audio with optional streaming and file output.
Args:
text: Text to generate audio for
speaker: Speaker ID
context: List of context segments
max_audio_length_ms: Maximum audio length in milliseconds
temperature: Sampling temperature
topk: Top-k sampling parameter
stream: Whether to use streaming generation
output_file: If provided and stream=True, output will be saved to this file
Returns:
torch.Tensor: Generated audio tensor
"""
if stream:
if output_file:
# Setup streaming to file
write_chunk, close_wav = stream_audio_to_wav(output_file, self.sample_rate)
# Collect chunks while streaming to file
audio_chunks = []
t1 = time.time()
for i, chunk in enumerate(self.generate_stream(
text, speaker, context, max_audio_length_ms, temperature, topk
)):
# Write to file
write_chunk(chunk)
# Store for return value
audio_chunks.append(chunk)
# Occasionally print progress
if i % 5 == 0:
print(f"Part {i+1} available after {time.time() - t1:.4f}s")
t1 = time.time()
# Close file
close_wav()
print(f"Streaming complete, WAV file saved to {output_file}")
else:
# Just collect chunks without file output
audio_chunks = []
for chunk in self.generate_stream(text, speaker, context, max_audio_length_ms, temperature, topk):
audio_chunks.append(chunk)
if not audio_chunks:
return torch.tensor([])
return torch.cat(audio_chunks)
# Non-streaming generation remains unchanged
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._model.reset_caches()
max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
max_seq_len = 2048
if prompt_tokens.size(0) > max_seq_len:
prompt_tokens = prompt_tokens[-max_seq_len:]
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
samples = []
with self._audio_tokenizer.streaming(1):
for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
if torch.all(sample == 0):
break
samples.append(sample)
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
if not samples:
return torch.tensor([])
return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
class AudioStreamWriter:
"""
Helper class for writing streaming audio to a file.
"""
def __init__(self, filename, sample_rate):
self.filename = filename
self.sample_rate = sample_rate
self.audio_chunks = []
self.lock = threading.Lock()
self.queue = queue.Queue()
self.running = True
# Start background writer thread
self.writer_thread = threading.Thread(target=self._writer_worker, daemon=True)
self.writer_thread.start()
def _writer_worker(self):
"""Background thread that handles audio chunk processing"""
buffer_chunks = []
last_flush_time = time.time()
while self.running or not self.queue.empty():
try:
# Get chunk with timeout to allow for regular checks
chunk = self.queue.get(timeout=0.2)
buffer_chunks.append(chunk)
# Periodically flush the buffer to the main list
current_time = time.time()
if len(buffer_chunks) >= 10 or (current_time - last_flush_time > 2.0 and buffer_chunks):
with self.lock:
self.audio_chunks.extend(buffer_chunks)
buffer_chunks = []
last_flush_time = current_time
except queue.Empty:
# If queue is empty but we have pending chunks, add them
if buffer_chunks:
with self.lock:
self.audio_chunks.extend(buffer_chunks)
buffer_chunks = []
last_flush_time = time.time()
# Final flush of any remaining chunks
if buffer_chunks:
with self.lock:
self.audio_chunks.extend(buffer_chunks)
def add_chunk(self, chunk):
"""Add an audio chunk to the buffer queue without blocking"""
try:
self.queue.put(chunk, timeout=0.1)
except queue.Full:
# If queue is full, add directly to avoid losing data
with self.lock:
self.audio_chunks.append(chunk)
def write_file(self):
"""Write all collected audio chunks to file and clean up"""
# Signal the background thread to stop
self.running = False
# Wait for the thread to finish with a timeout
self.writer_thread.join(timeout=3.0)
with self.lock:
if not self.audio_chunks:
return
# Concatenate all chunks
audio = torch.cat(self.audio_chunks)
# Save to file
torchaudio.save(self.filename, audio.unsqueeze(0).cpu(), self.sample_rate)
from safetensors.torch import load_file
import os
import torch
from models import Model, ModelArgs
from generator import Generator
def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_codebooks: int = 32):
"""
Load the CSM-1B model from a local checkpoint with extreme optimizations and warmup.
"""
import torch
import platform
from functools import lru_cache
from generator import Generator, Model, ModelArgs
# Enable all CUDA optimizations
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch.backends.cuda, 'enable_flash_sdp'):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
print(f"Loading CSM-1B model from local checkpoint '{model_path}' with extreme optimizations...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
config = ModelArgs(
backbone_flavor="llama-1B",
decoder_flavor="llama-100M",
text_vocab_size=128256,
audio_vocab_size=2051,
audio_num_codebooks=audio_num_codebooks,
)
model = Model.from_pretrained(model_path)
model.eval()
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')
model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')
model.to(device=device, dtype=dtype)
print("Model compilation complete. Creating generator...")
generator = Generator(model)
generator._stream_buffer_size = 20
# Setup tokenization caching
generator._tokenization_cache = {}
original_tokenize_text = generator._tokenize_text_segment
@lru_cache(maxsize=2048)
def cached_tokenize_text_segment(text_str, speaker_int):
return original_tokenize_text(text_str, speaker_int)
generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)
# Perform warmup
warmup_generator(generator)
return generator
def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is a comprehensive warmup text that will exercise the model's generation capabilities.", speaker_id: int = 0):
"""
Perform an extremely aggressive warmup to drastically reduce first-generation latency.
"""
print("Starting maximum-intensity warmup sequence...")
# Directly access and optimize the model's internal state
if hasattr(gen._model, 'backbone') and hasattr(gen._model.backbone, 'positional_embedding'):
# Force calculation of position embeddings to ensure they're cached
with torch.inference_mode():
positions = torch.arange(0, 2048).to(gen.device)
_ = gen._model.backbone.positional_embedding(positions)
# Pre-allocate CUDA memory to prevent fragmentation during generation
if torch.cuda.is_available():
print("Optimizing GPU memory allocation...")
# Try to reserve a large chunk of memory
try:
import math
reserved_memory = []
# Reserve multiple blocks of different sizes
for size_mb in [128, 256, 512, 256, 128, 64]:
size = int(size_mb * 1024 * 1024 / 4) # Convert MB to float32 elements
tensor_size = int(math.sqrt(size))
tensor = torch.ones((tensor_size, tensor_size), device=gen.device, dtype=torch.float32)
tensor = tensor * 1.0 # Force allocation
reserved_memory.append(tensor)
torch.cuda.synchronize()
# Now free the memory
for tensor in reserved_memory:
del tensor
reserved_memory = []
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception as e:
print(f"Memory pre-allocation: {e}")
# Create multiple dummy audio segments with varying characteristics
print("Creating diverse audio contexts...")
audio_segments = []
# Create 3 different audio patterns
for i in range(3):
length = 24000 * (i + 1) # 1s, 2s, 3s
audio = torch.zeros(length).to(gen.device)
# Add different patterns to each segment
if i == 0:
# Sine wave pattern
import math
t = torch.linspace(0, 8 * math.pi, length).to(gen.device)
audio = torch.sin(t) * 0.1
elif i == 1:
# Random noise pattern
audio = torch.randn(length).to(gen.device) * 0.05
else:
# Pulse pattern
audio[::800] = 0.2
audio[::801] = -0.2
segment = Segment(
speaker=speaker_id,
text=f"Warmup segment {i+1} with {length/24000:.1f}s of audio.",
audio=audio
)
audio_segments.append(segment)
# Force compilation of critical model components
print("Forcing compilation of critical components...")
# Directly exercise the audio tokenizer with real data
with torch.inference_mode():
for segment in audio_segments:
# Force tokenization of both text and audio
gen._tokenize_segment(segment)
# Exercise the model's generation capabilities directly
with torch.inference_mode():
# Generate some sample frames to ensure model is compiled
dummy_tokens = torch.ones(1, 10, gen._num_codebooks+1).long().to(gen.device)
dummy_mask = torch.ones(1, 10, gen._num_codebooks+1).bool().to(gen.device)
dummy_pos = torch.arange(0, 10).unsqueeze(0).to(gen.device)
# Generate multiple frames with different parameters
for temp in [0.6, 0.7, 0.8]:
for topk in [20, 30, 40]:
_ = gen._model.generate_frame(dummy_tokens, dummy_mask, dummy_pos, temp, topk)
gen._text_token_cache.clear()
print("Running final generation with exact same setup as a real request...")
final_text = "This is the final warmup that exactly matches a real generation request."
# First tokenize the text - to fill the cache
gen._tokenize_text_segment(final_text, speaker_id)
try:
# Now run a complete generation with a single context segment
generate_streaming_audio(
generator=gen,
text=final_text,
speaker=speaker_id,
context=[audio_segments[0]], # Just one context segment
output_file="warmup_final.wav",
max_audio_length_ms=6000,
temperature=0.7,
topk=30,
play_audio=False
)
except Exception as e:
print(f"Final warmup run exception (ignorable): {e}")
# Force final synchronization and memory optimization
if torch.cuda.is_available():
print("Final GPU optimization...")
torch.cuda.synchronize()
torch.cuda.empty_cache()
try:
# Allocate a large tensor to force compaction
large_tensor = torch.empty(int(1e9//4), dtype=torch.float, device=gen.device)
# Immediately delete it
del large_tensor
except RuntimeError:
# Expected if there's not enough memory
pass
# Final cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("Maximum-intensity warmup complete. First generation should now be MUCH faster.")
def load_csm_1b(device: str = "cuda") -> Generator:
"""
Load the CSM-1B model with extreme optimizations for real-time performance.
"""
# Enable all CUDA optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
print("Loading CSM-1B model with extreme optimizations for real-time performance...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
model = Model.from_pretrained("sesame/csm-1b")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')
model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')
model.to(device=device, dtype=dtype)
print("Model compilation complete. Creating generator...")
generator = Generator(model)
generator._stream_buffer_size = 20
generator._tokenization_cache = {}
from functools import lru_cache
# Patch the tokenize method with caching
original_tokenize_text = generator._tokenize_text_segment
@lru_cache(maxsize=2048)
def cached_tokenize_text_segment(text_str, speaker_int):
return original_tokenize_text(text_str, speaker_int)
generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)
warmup_generator(generator)
return generator
def stream_audio_to_wav(filename, sample_rate):
"""
Initialize a WAV writer for streaming audio chunks.
Args:
filename: Output WAV file path
sample_rate: Audio sample rate in Hz
Returns:
tuple: (write_chunk, close) functions for writing audio data and closing the file
"""
# Create a WAV file with the proper header
wav_file = wave.open(filename, 'wb')
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sample_rate)
def write_chunk(audio_chunk):
# Convert tensor to numpy and then to int16 PCM format
if isinstance(audio_chunk, torch.Tensor):
# Ensure it's on CPU and detached before converting to numpy
audio_np = audio_chunk.detach().cpu().numpy()
else:
audio_np = audio_chunk
# Normalize if needed (assuming audio is in [-1, 1] range)
if audio_np.max() <= 1.0 and audio_np.min() >= -1.0:
audio_int = (audio_np * 32767).astype(np.int16)
else:
audio_int = audio_np.astype(np.int16)
# Write to WAV file
wav_file.writeframes(audio_int.tobytes())
def close():
wav_file.close()
return write_chunk, close
def generate_streaming_audio(
generator: Generator,
text: str,
speaker: int,
context: List[Segment],
output_file: str,
max_audio_length_ms: float = 90_000,
temperature: float = 1.0,
topk: int = 50,
play_audio: bool = False,
):
"""
Generate audio with streaming output and comprehensive timing metrics.
Optimized for reduced first-chunk latency.
"""
# Initialize the streaming WAV writer
write_chunk, close_wav = stream_audio_to_wav(output_file, generator.sample_rate)
# Set up audio playback if requested
audio_queue = queue.Queue(maxsize=100) if play_audio else None
stop_event = threading.Event()
if play_audio:
try:
import sounddevice as sd
# Get available sample rates for default output device to check compatibility
device_info = sd.query_devices(kind='output')
supported_rate = device_info.get('default_samplerate', 44100)
need_resampling = abs(supported_rate - generator.sample_rate) > 100
if need_resampling:
try:
# Use resampling if sample rate doesn't match
import librosa
print(f"Resampling from {generator.sample_rate}Hz to {int(supported_rate)}Hz for playback")
def audio_playback_worker():
while not stop_event.is_set() or not audio_queue.empty():
try:
chunk = audio_queue.get(timeout=0.5)
if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:
audio_queue.task_done()
continue
audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk
# Skip very short chunks (likely noise)
if len(audio_np) < 100:
audio_queue.task_done()
continue
# Resample to device's supported rate
resampled = librosa.resample(
audio_np,
orig_sr=generator.sample_rate,
target_sr=int(supported_rate)
)
sd.play(resampled, supported_rate, blocking=True)
# Add a small delay to ensure audio finishes playing
time.sleep(0.05)
audio_queue.task_done()
except queue.Empty:
# If queue empty but not stopping, keep trying
if not stop_event.is_set():
continue
else:
break
except Exception as e:
print(f"Playback error: {e}")
audio_queue.task_done()
except ImportError:
print("Librosa not found. Using direct playback which may cause sample rate warnings.")
need_resampling = False
if not need_resampling:
def audio_playback_worker():
while not stop_event.is_set() or not audio_queue.empty():
try:
chunk = audio_queue.get(timeout=0.5)
if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:
audio_queue.task_done()
continue
audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk
# Skip very short chunks (likely noise)
if len(audio_np) < 100:
audio_queue.task_done()
continue
sd.play(audio_np, generator.sample_rate, blocking=True)
# Add a small delay to ensure audio finishes playing
time.sleep(0.05)
audio_queue.task_done()
except queue.Empty:
# If queue empty but not stopping, keep trying
if not stop_event.is_set():
continue
else:
break
except Exception as e:
print(f"Playback error: {e}")
audio_queue.task_done()
# Start playback thread
playback_thread = threading.Thread(target=audio_playback_worker, daemon=False)
playback_thread.start()
except ImportError:
print("sounddevice library not found. Install with 'pip install sounddevice' for real-time playback.")
play_audio = False
# Timing metrics
chunk_times = []
latency_to_first_chunk = None
total_audio_duration = 0
chunk_count = 0
# Function to handle each generated chunk
def on_chunk_generated(chunk):
nonlocal chunk_count, latency_to_first_chunk, total_audio_duration
current_time = time.time()
if chunk_count == 0:
latency_to_first_chunk = current_time - start_time
print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms")
# Save chunk to WAV file
write_chunk(chunk)
# Update metrics
chunk_count += 1
chunk_duration = len(chunk) / generator.sample_rate
total_audio_duration += chunk_duration
chunk_times.append(current_time)
# Send to audio player if enabled
if play_audio and audio_queue is not None:
try:
audio_queue.put(chunk, timeout=1.0)
except queue.Full:
pass # Skip if queue is full to avoid blocking
if torch.cuda.is_available():
print("Preparing GPU for low-latency generation...")
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Pre-allocate some GPU memory to avoid allocation during generation
dummy_tensors = []
for i in range(5):
dummy = torch.ones((100, 100), device=generator.device)
dummy = dummy + 1.0 # Force computation
dummy_tensors.append(dummy) # Keep reference to prevent deallocation
torch.cuda.synchronize()
# Set process priority to improve performance - use higher priority
try:
import psutil
process = psutil.Process()
if platform.system() == 'Windows':
process.nice(psutil.HIGH_PRIORITY_CLASS)
else:
process.nice(-1)
except (ImportError, PermissionError, psutil.AccessDenied):
pass
print(f"Starting audio generation for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
start_time = time.time()
# Generate audio in chunks, catching possible errors
frame_count = 0
audio_chunks = [] # Store all chunks for possible use at the end
try:
for audio_chunk in generator.generate_stream(
text=text,
speaker=speaker,
context=context,
max_audio_length_ms=max_audio_length_ms,
temperature=temperature,
topk=topk,
on_chunk_generated=on_chunk_generated
):
frame_count += 1
audio_chunks.append(audio_chunk) # Store the chunk
# Print timing info less frequently to reduce overhead
if frame_count % 10 == 0:
current_time = time.time()
elapsed = current_time - start_time
if total_audio_duration > 0:
rtf = elapsed / total_audio_duration
remaining_time = (max_audio_length_ms/1000 - total_audio_duration) * rtf
print(f"Chunk {chunk_count}: {total_audio_duration:.1f}s audio in {elapsed:.1f}s "
f"(RTF: {rtf:.2f}x, Est. remaining: {remaining_time:.1f}s)")
except Exception as e:
print(f"Error during audio generation: {e}")
import traceback
traceback.print_exc()
# Release dummy tensors to free memory
if 'dummy_tensors' in locals():
del dummy_tensors
# Ensure all chunks are properly processed
if play_audio and audio_queue is not None:
print("Waiting for playback queue to finish...")
try:
timeout_start = time.time()
while not audio_queue.empty() and time.time() - timeout_start < 5.0:
time.sleep(0.1)
except:
pass
# Add a small delay to ensure everything is processed
time.sleep(0.5)
# Signal audio worker that generation is complete
stop_event.set()
# Close WAV file
close_wav()
# Wait for audio playback to complete if enabled
if play_audio and 'playback_thread' in locals():
print("Waiting for audio playback to complete...")
# First, ensure the queue is empty
try:
timeout_start = time.time()
while not audio_queue.empty() and time.time() - timeout_start < 5.0:
time.sleep(0.1)
except:
pass
# Set a flag to indicate complete audio playback is needed
if hasattr(sd, 'wait'):
try:
sd.wait()
except:
pass
# Join the playback thread with timeout
playback_thread.join(timeout=5.0)
# Force sounddevice to stop if it's still playing
try:
sd.stop()
except:
pass
# Calculate and print detailed performance metrics
end_time = time.time()
total_elapsed = end_time - start_time
# Calculate inter-chunk latency
if len(chunk_times) > 1:
inter_chunk_latencies = [chunk_times[i] - chunk_times[i-1] for i in range(1, len(chunk_times))]
avg_inter_chunk_latency = sum(inter_chunk_latencies) / len(inter_chunk_latencies)
max_inter_chunk_latency = max(inter_chunk_latencies) if inter_chunk_latencies else 0
min_inter_chunk_latency = min(inter_chunk_latencies) if inter_chunk_latencies else 0
else:
avg_inter_chunk_latency = max_inter_chunk_latency = min_inter_chunk_latency = 0
rtf = total_elapsed / total_audio_duration if total_audio_duration > 0 else float('inf')
print("\n" + "="*50)
print("AUDIO GENERATION PERFORMANCE METRICS")
print("="*50)
print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms")
print(f"Total generation time: {total_elapsed:.2f}s")
print(f"Audio duration: {total_audio_duration:.2f}s")
print(f"Real-time factor (RTF): {rtf:.3f}x (target: <1.0)")
print(f"Number of chunks: {chunk_count}")
print(f"Average chunk size: {(total_audio_duration/chunk_count)*1000:.1f}ms") if chunk_count > 0 else None
print(f"Average inter-chunk latency: {avg_inter_chunk_latency*1000:.1f}ms")
print(f"Min/Max inter-chunk latency: {min_inter_chunk_latency*1000:.1f}ms / {max_inter_chunk_latency*1000:.1f}ms")
print(f"Chunks per second: {chunk_count/total_elapsed:.2f}")
print(f"Output file: {output_file}")
print("="*50)
================================================
FILE: llm_interface.py
================================================
import re
from typing import List, Dict, Any, Optional
import torch
from vllm import LLM, SamplingParams
class LLMInterface:
def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1):
"""Initialize the LLM interface using VLLM with a given model.
Args:
model_path (str): Path to the model or HuggingFace model name
max_tokens (int, optional): Maximum context length. Defaults to 8192.
n_threads (int, optional): Number of CPU threads. Defaults to 8.
gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility.
"""
# VLLM configuration
self.llm = LLM(
model=model_path,
tensor_parallel_size=1, # Adjust based on number of GPUs available
gpu_memory_utilization=0.6,
max_model_len=max_tokens,
swap_space=0,
trust_remote_code=True,
dtype=torch.float16,
)
# Store configuration for reference
self.config = {
"model_path": model_path,
"max_tokens": max_tokens,
}
def trim_to_last_sentence(self, text: str) -> str:
"""
Return *text* truncated at the final full sentence boundary.
A boundary is considered to be any '.', '!' or '?' followed by
optional quotes/brackets, optional whitespace, and then end-of-string.
If no sentence terminator exists, the original text is returned.
"""
# Regex explanation:
# (.*?[.!?]["')\]]?) any text lazily until a terminator
# \s*$ followed only by whitespace till end-of-string
m = re.match(r"^(.*?[.!?][\"')\]]?)\s*$", text, re.DOTALL)
if m:
return m.group(1).strip()
# Fall back to manual search (handles cases with additional text)
for i in range(len(text) - 1, -1, -1):
if text[i] in ".!?":
return text[: i + 1].strip()
return text.strip()
def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = "") -> str:
"""Generate a response from the LLM using chat-style prompt formatting.
Args:
system_prompt (str): The system prompt/instructions
user_message (str): The user's input message
conversation_history (str, optional): Any prior conversation context. Defaults to "".
Returns:
str: The generated response
"""
# Format prompt following chat template structure
prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>
{conversation_history}
<|start_header_id|>user<|end_header_id|>\n{user_message}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>\n"""
# Define sampling parameters (equivalent to the previous implementation)
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.95,
max_tokens=100,
repetition_penalty=1.2,
top_k=200,
stop=["</s>", "<|endoftext|>", "<<USR>>", "<</USR>>", "<</SYS>>",
"<</USER>>", "<</ASSISTANT>>", "<|end_header_id|>", "<<ASSISTANT>>",
"<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"]
)
# Generate response using VLLM
outputs = self.llm.generate(prompt, sampling_params)
# Extract and return the generated text
if outputs and len(outputs) > 0:
text = outputs[0].outputs[0].text
return self.trim_to_last_sentence(text)
return ""
def tokenize(self, text: str) -> List[int]:
"""Tokenize text using VLLM's tokenizer.
Args:
text (str): Text to tokenize
Returns:
List[int]: List of token IDs
"""
# VLLM doesn't expose tokenizer directly in the same way
# We can access the tokenizer through the LLM instance
tokenizer = self.llm.get_tokenizer()
return tokenizer.encode(text)
def get_token_count(self, text: str) -> int:
"""Return token count of the input text.
Args:
text (str): Text to count tokens for
Returns:
int: Number of tokens
"""
return len(self.tokenize(text))
def batch_generate(self, prompts: List[Dict[str, str]],
max_tokens: int = 512,
temperature: float = 0.7) -> List[str]:
"""Generate responses for multiple prompts in a batch.
Args:
prompts (List[Dict[str, str]]): List of prompt dictionaries, each with
'system', 'user' and optional 'history' keys
max_tokens (int, optional): Maximum tokens to generate per response
temperature (float, optional): Temperature for sampling
Returns:
List[str]: Generated responses
"""
formatted_prompts = []
# Format each prompt according to the chat template
for p in prompts:
system = p.get("system", "")
user = p.get("user", "")
history = p.get("history", "")
formatted_prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system}<|eot_id|>
{history}
<|start_header_id|>user<|end_header_id|>\n{user}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>\n"""
formatted_prompts.append(formatted_prompt)
# Set up batch sampling parameters
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=max_tokens,
repetition_penalty=1.2,
top_k=400,
stop=["</s>", "<|endoftext|>", "<<USR>>", "<</USR>>", "<</SYS>>",
"<</USER>>", "<</ASSISTANT>>", "<|end_header_id|>", "<<ASSISTANT>>",
"<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"]
)
# Generate responses for all prompts in a batch
outputs = self.llm.generate(formatted_prompts, sampling_params)
# Extract and return the generated texts
results = []
for output in outputs:
if output.outputs:
results.append(output.outputs[0].text.strip())
else:
results.append("")
return results
================================================
FILE: loadandmergecheckpoint.py
================================================
import os
import re
import torch
from models import Model
from safetensors.torch import save_file, load_file
from lora import (
remove_lora_modules,
merge_lora_weights,
strip_bias_keys,
DEVICE,
OUTPUT_DIR,
replace_linear_with_lora,
)
MODEL_NAME = "sesame/csm-1b"
R=32
APLHA=32
def find_latest_checkpoint(dir_path):
checkpoints = [
(int(re.search(r"checkpoint-epoch-(\d+)", d).group(1)), os.path.join(dir_path, d))
for d in os.listdir(dir_path)
if os.path.isdir(os.path.join(dir_path, d)) and "checkpoint-epoch" in d
]
if not checkpoints:
raise FileNotFoundError("No checkpoints found.")
latest_epoch, latest_path = max(checkpoints, key=lambda x: x[0])
print(f"Latest checkpoint: epoch {latest_epoch} -> {latest_path}")
return latest_path
def load_checkpoint_and_merge():
print("Loading base model...")
model = Model.from_pretrained(MODEL_NAME).to(DEVICE)
print("Applying LoRA structure to the model...")
target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
model = replace_linear_with_lora(model, r=R, alpha=APLHA, dropout=0.0, target_linear_names = target_layers)
checkpoint_path = find_latest_checkpoint(OUTPUT_DIR)
print(f"Loading state dictionary from safetensors file...")
state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"), device=DEVICE)
print("Loading weights into the model...")
model.load_state_dict(state_dict, strict=False)
print("Merging LoRA weights into base model...")
merge_lora_weights(model)
print("Replacing LoRALinear modules with standard nn.Linear...")
model = remove_lora_modules(model)
print("Stripping bias keys for final clean model...")
merged_state = strip_bias_keys(model.state_dict())
final_path = os.path.join(OUTPUT_DIR, "model.safetensors")
save_file(merged_state, final_path)
print(f"Merged and cleaned model saved to: {final_path}")
if __name__ == "__main__":
load_checkpoint_and_merge()
================================================
FILE: lora.py
================================================
import json
import os
import glob
import torch
import torchaudio
import logging
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_scheduler
import torch.nn.functional as F
from tqdm import tqdm
import wandb
from safetensors.torch import save_file
import csv
from models import Model
from moshi.models import loaders
from huggingface_hub import hf_hub_download
from tokenizers.processors import TemplateProcessing
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import torch.nn as nn
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(), logging.FileHandler("finetune.log")]
)
logger = logging.getLogger(__name__)
AUDIO_DIR = "audio_data"
OUTPUT_DIR = "finetuned_model"
NUM_EPOCHS = 5
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 1e-6
MAX_GRAD_NORM = 0.1
NUM_CYCLES = 1.0
USE_WANDB = False
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIXED_PRECISION = True
WARMUP_STEPS = 50
SPEAKER_ID = 0
MODEL_NAME = "sesame/csm-1b"
TRANSCRIPTION_MODEL = "openai/whisper-large-v3-turbo"
MAX_AUDIO_FILES = 0
R=32
APLHA=32
class TrainingVisualizer:
def __init__(self, output_dir):
self.output_dir = output_dir
self.epochs = []
self.losses = []
self.val_losses = [] # Added validation losses
self.learning_rates = []
self.steps = []
self.fig, self.axes = plt.subplots(3, 1, figsize=(10, 15))
self.fig.suptitle('CSM Finetuning Progress', fontsize=16)
# Setup training loss plot
self.axes[0].set_title('Training Loss')
self.axes[0].set_xlabel('Epoch')
self.axes[0].set_ylabel('Loss')
self.axes[0].grid(True, linestyle='--', alpha=0.7)
# Setup validation loss plot
self.axes[1].set_title('Training vs Validation Loss')
self.axes[1].set_xlabel('Epoch')
self.axes[1].set_ylabel('Loss')
self.axes[1].grid(True, linestyle='--', alpha=0.7)
# Setup learning rate plot
self.axes[2].set_title('Learning Rate')
self.axes[2].set_xlabel('Epoch')
self.axes[2].set_ylabel('Learning Rate')
self.axes[2].grid(True, linestyle='--', alpha=0.7)
def update(self, epoch, step, loss, lr, val_loss=None):
"""Update the metrics and redraw the plot"""
self.epochs.append(epoch)
self.steps.append(step)
self.losses.append(loss)
self.learning_rates.append(lr)
# Add validation loss if provided, otherwise use None
if val_loss is not None:
self.val_losses.append(val_loss)
elif len(self.val_losses) > 0:
# If we have validation losses but none provided this time, use the last one
self.val_losses.append(self.val_losses[-1])
else:
# If we've never had validation losses, use None
self.val_losses.append(None)
# Update training loss plot
self.axes[0].clear()
self.axes[0].plot(self.epochs, self.losses, 'b-')
self.axes[0].set_title('Training Loss')
self.axes[0].set_xlabel('Epoch')
self.axes[0].set_ylabel('Loss')
self.axes[0].grid(True, linestyle='--', alpha=0.7)
# Update validation loss plot
self.axes[1].clear()
self.axes[1].plot(self.epochs, self.losses, 'b-', label='Training')
# If we have validation losses, plot them
if any(x is not None for x in self.val_losses):
# Filter out None values
val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]
val_loss_values = [v for v in self.val_losses if v is not None]
if val_epochs:
self.axes[1].plot(val_epochs, val_loss_values, 'r-', label='Validation')
self.axes[1].legend()
self.axes[1].set_title('Training vs Validation Loss')
self.axes[1].set_xlabel('Epoch')
self.axes[1].set_ylabel('Loss')
self.axes[1].grid(True, linestyle='--', alpha=0.7)
# Update learning rate plot
self.axes[2].clear()
self.axes[2].plot(self.epochs, self.learning_rates, 'g-')
self.axes[2].set_title('Learning Rate')
self.axes[2].set_xlabel('Epoch')
self.axes[2].set_ylabel('Learning Rate')
self.axes[2].grid(True, linestyle='--', alpha=0.7)
# Calculate convergence metrics
min_loss = min(self.losses)
min_loss_epoch = self.epochs[self.losses.index(min_loss)]
# Check for potential convergence stall
recent_window = 10 # Look at last 10 steps
if len(self.losses) > recent_window:
recent_losses = self.losses[-recent_window:]
loss_std = np.std(recent_losses)
loss_change = (recent_losses[0] - recent_losses[-1]) / recent_losses[0] if recent_losses[0] != 0 else 0
convergence_status = ""
if loss_std < 0.001 and loss_change < 0.01:
convergence_status = "STALLED: Loss not improving significantly"
elif min_loss == self.losses[-1]:
convergence_status = "IMPROVING: New best loss!"
elif self.losses[-1] < self.losses[-2]:
convergence_status = "IMPROVING: Loss decreasing"
else:
convergence_status = "FLUCTUATING: Loss increased"
# Add convergence status to title
self.fig.suptitle(f'CSM Finetuning Progress - {convergence_status}\n' +
f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' +
f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)
else:
self.fig.suptitle(f'CSM Finetuning Progress\n' +
f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' +
f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)
plt.tight_layout(rect=[0, 0.03, 1, 0.92]) # Adjust for the larger title
# Save the figure
plot_path = os.path.join(self.output_dir, 'training_progress.png')
self.fig.savefig(plot_path)
def finalize(self):
"""Create a final, more detailed visualization when training completes"""
# Create a new figure for the final plot
final_fig = plt.figure(figsize=(12, 16))
gs = plt.GridSpec(4, 2, figure=final_fig)
# Plot 1: Loss vs Steps
ax1 = final_fig.add_subplot(gs[0, :])
ax1.plot(self.steps, self.losses, 'b-', linewidth=2)
ax1.set_title('Training Loss vs Steps', fontsize=14)
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss')
ax1.grid(True, linestyle='--', alpha=0.7)
# Plot 2: Loss vs Epochs
ax2 = final_fig.add_subplot(gs[1, 0])
ax2.plot(self.epochs, self.losses, 'r-', linewidth=2)
ax2.set_title('Training Loss vs Epochs', fontsize=14)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Loss')
ax2.grid(True, linestyle='--', alpha=0.7)
# Plot 3: Learning Rate vs Steps
ax3 = final_fig.add_subplot(gs[1, 1])
ax3.plot(self.steps, self.learning_rates, 'g-', linewidth=2)
ax3.set_title('Learning Rate Schedule', fontsize=14)
ax3.set_xlabel('Steps')
ax3.set_ylabel('Learning Rate')
ax3.grid(True, linestyle='--', alpha=0.7)
# Plot 4: Training vs Validation Loss
ax4 = final_fig.add_subplot(gs[2, :])
ax4.plot(self.epochs, self.losses, 'b-', linewidth=2, label='Training')
if any(x is not None for x in self.val_losses):
# Filter out None values
val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]
val_loss_values = [v for v in self.val_losses if v is not None]
if val_epochs:
ax4.plot(val_epochs, val_loss_values, 'r-', linewidth=2, label='Validation')
ax4.legend()
ax4.set_title('Training vs Validation Loss', fontsize=14)
ax4.set_xlabel('Epochs')
ax4.set_ylabel('Loss')
ax4.grid(True, linestyle='--', alpha=0.7)
# Plot 5: Combined plot with two y-axes
ax5 = final_fig.add_subplot(gs[3, :])
color1, color2 = 'blue', 'green'
# Plot loss on left axis
line1 = ax5.plot(self.epochs, self.losses, color=color1, linewidth=2.5, label='Loss')
ax5.set_xlabel('Epochs')
ax5.set_ylabel('Loss', color=color1)
ax5.tick_params(axis='y', labelcolor=color1)
# Plot learning rate on right axis
ax6 = ax5.twinx()
line2 = ax6.plot(self.epochs, self.learning_rates, color=color2, linewidth=2.5, label='Learning Rate')
ax6.set_ylabel('Learning Rate', color=color2)
ax6.tick_params(axis='y', labelcolor=color2)
# Combine legends
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax5.legend(lines, labels, loc='upper right')
ax5.set_title('Loss and Learning Rate vs Epochs', fontsize=14)
ax5.grid(True, linestyle='--', alpha=0.7)
# Add training summary
if self.epochs:
epoch_count = max(self.epochs)
step_count = max(self.steps)
min_loss = min(self.losses)
min_loss_epoch = self.epochs[self.losses.index(min_loss)]
min_loss_step = self.steps[self.losses.index(min_loss)]
# Calculate convergence indicators
recent_epochs = min(10, len(self.losses))
recent_losses = self.losses[-recent_epochs:]
loss_change_pct = ((recent_losses[0] - recent_losses[-1]) / recent_losses[0]) * 100 if recent_losses[0] != 0 else 0
summary_text = (
f"Training Summary\n"
f"Total Epochs: {epoch_count:.2f}\n"
f"Total Steps: {step_count}\n"
f"Min Loss: {min_loss:.6f} (Epoch {min_loss_epoch:.2f}, Step {min_loss_step})\n"
f"Recent {recent_epochs} epochs loss change: {loss_change_pct:.2f}%\n"
)
if len(self.losses) > 20:
# Add convergence assessment
last_20_losses = self.losses[-20:]
std_last_20 = np.std(last_20_losses)
converged = std_last_20 < 0.01 and loss_change_pct < 1.0
summary_text += f"Convergence status: {'CONVERGED' if converged else 'NOT CONVERGED'}\n"
if converged:
summary_text += f"Loss stabilized with std dev {std_last_20:.6f}"
else:
summary_text += f"Loss still changing significantly (std dev: {std_last_20:.6f})"
plt.figtext(0.02, 0.02, summary_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))
plt.tight_layout(rect=[0, 0.05, 1, 0.97])
final_fig.suptitle('CSM Model Finetuning Metrics', fontsize=16, fontweight='bold')
plt.subplots_adjust(top=0.93)
# Save the final detailed plot
final_plot_path = os.path.join(self.output_dir, 'training_metrics_final.png')
final_fig.savefig(final_plot_path, dpi=300, bbox_inches='tight')
plt.close(final_fig)
plt.close(self.fig)
logger.info(f"Final training visualization saved to {final_plot_path}")
return final_plot_path
class LoRALinear(nn.Module):
def __init__(self, in_features, out_features, r=32, alpha=64, dropout=0.0, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.r = r
self.alpha = alpha
self.scaling = alpha / r
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# The base linear (frozen).
self.weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False)
nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=bias)
# LoRA trainable matrices
self.lora_A = nn.Parameter(torch.zeros(r, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
nn.init.zeros_(self.lora_B)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# normal forward with frozen weight
result = F.linear(x, self.weight, self.bias)
# LoRA forward with trainable A and B
lora_out = F.linear(self.dropout(x), self.lora_A) # [*, r]
lora_out = F.linear(lora_out, self.lora_B) # [*, out_features]
return result + self.scaling * lora_out
def replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout=0.0, target_linear_names: List[str] = None):
"""
Replaces specified nn.Linear layers with LoRALinear layers within a model, ensuring device consistency.
"""
if target_linear_names is None:
logger.warning("No target layer names specified for LoRA replacement. No layers will be replaced.")
return model
for name, module in list(model.named_modules()):
if isinstance(module, nn.Linear) and any(target_name in name for target_name in target_linear_names):
parent_name, child_name = name.rsplit('.', 1)
parent_module = model
for part in parent_name.split('.'):
parent_module = getattr(parent_module, part)
original_device = module.weight.device
original_dtype = module.weight.dtype
# Create the new LoRA layer
lora_linear = LoRALinear(
in_features=module.in_features,
out_features=module.out_features,
r=r,
alpha=alpha,
dropout=dropout,
bias=(module.bias is not None)
)
# Copy the original weights and bias
with torch.no_grad():
lora_linear.weight.copy_(module.weight.data)
if module.bias is not None:
lora_linear.bias.copy_(module.bias.data)
lora_linear.to(device=original_device, dtype=original_dtype)
setattr(parent_module, child_name, lora_linear)
logger.info(f"Replaced layer: {name} with LoRALinear on device {original_device}")
return model
def load_llama3_tokenizer():
tokenizer_name = "unsloth/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(bos, tokenizer.bos_token_id), (eos, tokenizer.eos_token_id)],
)
return tokenizer
@dataclass
class AudioTextPair:
audio_path: str
text: str
speaker_id: int
processed_audio: Optional[torch.Tensor] = None
def load_audio(self, sample_rate=24000) -> torch.Tensor:
if self.processed_audio is not None:
return self.processed_audio
waveform, sr = torchaudio.load(self.audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sr != sample_rate:
resampler = torchaudio.transforms.Resample(sr, sample_rate)
waveform = resampler(waveform)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
self.processed_audio = waveform.squeeze(0)
return self.processed_audio
class CSMDataset(Dataset):
def __init__(self, data_items, text_tokenizer, audio_tokenizer, device):
self.data_items = data_items
self.text_tokenizer = text_tokenizer
self.audio_tokenizer = audio_tokenizer
self.device = device
self.sample_rate = audio_tokenizer.sample_rate
def __len__(self):
return len(self.data_items)
def tokenize_text_segment(self, text: str, speaker: int):
text_tokens = self.text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), 33).long()
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True
return text_frame, text_frame_mask
def tokenize_audio(self, audio: torch.Tensor):
assert audio.ndim == 1, "Audio must be single channel"
audio_device = next(self.audio_tokenizer.parameters()).device
audio = audio.to(audio_device)
try:
audio_tokens = self.audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
eos_frame = torch.zeros(audio_tokens.size(0), 1, device=audio_device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), 33, device=audio_device).long()
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33, device=audio_device).bool()
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True
except RuntimeError as e:
logger.warning(f"Error encoding audio: {e}, using empty frames")
audio_frame = torch.zeros(1, 33, device=audio_device).long()
audio_frame_mask = torch.zeros(1, 33, device=audio_device).bool()
return audio_frame, audio_frame_mask
def __getitem__(self, idx: int):
item = self.data_items[idx]
audio = item.load_audio(self.sample_rate)
text_tokens, text_masks = self.tokenize_text_segment(item.text, item.speaker_id)
audio_tokens, audio_masks = self.tokenize_audio(audio)
device = audio_tokens.device
text_tokens = text_tokens.to(device)
text_masks = text_masks.to(device)
input_tokens = text_tokens
input_masks = text_masks
target_tokens = torch.cat([text_tokens, audio_tokens], dim=0)
target_masks = torch.cat([text_masks, audio_masks], dim=0)
if device != self.device:
input_tokens = input_tokens.to(self.device)
input_masks = input_masks.to(self.device)
target_tokens = target_tokens.to(self.device)
target_masks = target_masks.to(self.device)
return {
"input_tokens": input_tokens,
"input_masks": input_masks,
"target_tokens": target_tokens,
"target_masks": target_masks,
}
def collate_fn(batch):
max_seq_len = 1024
device = batch[0]["input_tokens"].device
max_input_len = min(max(item["input_tokens"].size(0) for item in batch), max_seq_len)
max_target_len = min(max(item["target_tokens"].size(0) for item in batch), max_seq_len)
batch_input_tokens = []
batch_input_masks = []
batch_target_tokens = []
batch_target_masks = []
for item in batch:
input_tokens = item["input_tokens"][:max_input_len]
input_masks = item["input_masks"][:max_input_len]
target_tokens = item["target_tokens"][:max_target_len]
target_masks = item["target_masks"][:max_target_len]
input_tokens = F.pad(input_tokens, (0,0,0, max_input_len - input_tokens.size(0)), "constant", 0)
input_masks = F.pad(input_masks, (0,0,0, max_input_len - input_masks.size(0)), "constant", False)
target_tokens = F.pad(target_tokens, (0,0,0, max_target_len - target_tokens.size(0)), "constant", 0)
target_masks = F.pad(target_masks, (0,0,0, max_target_len - target_masks.size(0)), "constant", False)
batch_input_tokens.append(input_tokens)
batch_input_masks.append(input_masks)
batch_target_tokens.append(target_tokens)
batch_target_masks.append(target_masks)
return {
"input_tokens": torch.stack(batch_input_tokens),
"input_masks": torch.stack(batch_input_masks),
"target_tokens": torch.stack(batch_target_tokens),
"target_masks": torch.stack(batch_target_masks),
"positions": torch.arange(0, max_target_len).unsqueeze(0).repeat(len(batch), 1).to(device)
}
def transcribe_audio_files():
from transformers import pipeline
# Cache file path
cache_file = os.path.join(AUDIO_DIR, "transcription_cache.json")
# Load existing cache
cache = {}
if os.path.exists(cache_file):
try:
with open(cache_file, 'r', encoding='utf-8') as f:
cache = json.load(f)
logger.info(f"Loaded transcription cache with {len(cache)} entries")
except Exception as e:
logger.warning(f"Could not load cache file: {e}")
cache = {}
logger.info(f"Transcribing audio files in: {AUDIO_DIR}")
transcriber = pipeline("automatic-speech-recognition", model=TRANSCRIPTION_MODEL)
audio_text_pairs = []
audio_files = glob.glob(os.path.join(AUDIO_DIR, "*.wav")) \
+ glob.glob(os.path.join(AUDIO_DIR, "*.mp3")) \
+ glob.glob(os.path.join(AUDIO_DIR, "*.flac"))
if MAX_AUDIO_FILES > 0 and len(audio_files) > MAX_AUDIO_FILES:
logger.info(f"Found {len(audio_files)} files, limiting to {MAX_AUDIO_FILES}")
audio_files = audio_files[:MAX_AUDIO_FILES]
cache_hits = 0
cache_misses = 0
for audio_file in tqdm(audio_files, desc="Processing audio files"):
try:
# Create cache key using file path and modification time
file_stat = os.stat(audio_file)
cache_key = f"{audio_file}_{file_stat.st_mtime}_{file_stat.st_size}"
# Check if transcription exists in cache
if cache_key in cache:
transcription = cache[cache_key]
cache_hits += 1
logger.debug(f"Cache hit: {os.path.basename(audio_file)}")
else:
# Transcribe the file
result = transcriber(audio_file, return_timestamps=True, chunk_length_s=30,
stride_length_s=[6, 0], batch_size=32,
generate_kwargs={"language": "<|en|>", "task": "transcribe"})
transcription = result["text"].strip()
# Save to cache
cache[cache_key] = transcription
cache_misses += 1
logger.info(f"Transcribed: {os.path.basename(audio_file)} -> {transcription}")
audio_text_pairs.append(
AudioTextPair(audio_path=audio_file, text=transcription, speaker_id=0)
)
except Exception as e:
logger.error(f"Error processing {audio_file}: {e}")
# Save updated cache
try:
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
logger.info(f"Saved transcription cache with {len(cache)} entries")
except Exception as e:
logger.error(f"Could not save cache file: {e}")
logger.info(f"Processed {len(audio_text_pairs)} audio files (Cache hits: {cache_hits}, Cache misses: {cache_misses})")
return audio_text_pairs
def prepare_csm_model_for_training():
logger.info(f"Loading CSM model: {MODEL_NAME}")
model = Model.from_pretrained(MODEL_NAME).to(DEVICE)
text_tokenizer = load_llama3_tokenizer()
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=DEVICE)
mimi.set_num_codebooks(32)
audio_tokenizer = mimi
try:
codebook_0_centroids = mimi.quantizer.rvq_first.layers[0].codebook.weight.data
num_codebook_0_tokens, embedding_dim = codebook_0_centroids.shape
model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
model.codebook_embedding.weight.data.copy_(codebook_0_centroids)
logger.info(f"Successfully initialized codebook_embedding with shape: {codebook_0_centroids.shape}")
except AttributeError:
num_codebook_0_tokens, embedding_dim = 1024, 1024
model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
nn.init.xavier_uniform_(model.codebook_embedding.weight)
except Exception as e:
num_codebook_0_tokens, embedding_dim = 1024, 1024
model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
nn.init.xavier_uniform_(model.codebook_embedding.weight)
# Some fallback logic for config
if not hasattr(model.config, 'get'):
def get_method(self, key, default=None):
if hasattr(self, key):
return getattr(self, key)
return default
model.config.__class__.get = get_method
if not hasattr(model.config, 'tie_word_embeddings'):
model.config.tie_word_embeddings = False
target_layers = [
"q_proj",
"k_proj",
"v_proj",
"output_proj",
"w1",
"w2",
"w3"
]
logger.info("Applying LoRA to model...")
model = replace_linear_with_lora(
model,
r=R,
alpha=APLHA,
dropout=0.01,
target_linear_names=target_layers
)
model.cuda()
# First, freeze all parameters of the base model
for param in model.parameters():
param.requires_grad = False
# Then, unfreeze only the newly added LoRA parameters.
# It is also common practice to train the bias parameters.
for name, param in model.named_parameters():
if "lora_A" in name or "lora_B" in name or "bias" in name:
param.requires_grad = True
return model, text_tokenizer, audio_tokenizer
def setup_model_caches(model, batch_size):
try:
with torch.no_grad():
model.reset_caches()
model.backbone.reset_caches()
model.decoder.reset_caches()
except Exception as e:
logger.debug(f"No caches to reset or error: {e}")
return True
class BridgingModule(nn.Module):
"""For a 2048->1024 bridging if needed."""
def __init__(self, in_dim=2048, out_dim=1024):
super().__init__()
self.bridge = nn.Linear(in_dim, out_dim, bias=False)
nn.init.xavier_uniform_(self.bridge.weight)
def forward(self, x):
return self.bridge(x)
def compute_loss_for_codebooks_single_pass(
backbone_out, # [b, seq_len, 2048]
decoder_out, # [b, seq_len, 1024]
model,
target_tokens, # [b, seq_len, codebooks]
target_masks, # [b, seq_len, codebooks bool]
device
):
bsz, seq_len = target_tokens.size()[:2]
num_codebooks = model.config.audio_num_codebooks
c0_logits = model.codebook0_head(backbone_out)
audio_positions = target_masks[..., :-1].any(dim=-1) # [b, seq_len] for audio
total_loss = torch.tensor(0.0, device=device)
count = 0
# codebook0
for b in range(bsz):
for s in range(seq_len):
if audio_positions[b, s]:
token_logits = c0_logits[b, s]
target_token = target_tokens[b, s, 0]
if target_token > 0:
ce = F.cross_entropy(token_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')
total_loss += ce
count += 1
# codebooks [1..N-1] from decoder_out
for i in range(1, num_codebooks):
weight_i = model.audio_head[i-1]
flat_dec = decoder_out.view(bsz * seq_len, -1)
token_logits_all = flat_dec.mm(weight_i)
for b in range(bsz):
for s in range(seq_len):
if audio_positions[b, s]:
target_token = target_tokens[b, s, i]
if target_token > 0:
row_idx = b*seq_len + s
row_logits = token_logits_all[row_idx]
ce = F.cross_entropy(row_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')
total_loss += ce
count += 1
if count > 0:
total_loss = total_loss / count
return total_loss
def single_pass_forward(model, bridging_module, target_tokens, target_masks, positions):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
embed = model._embed_tokens(target_tokens)
masked_embed = embed * target_masks.unsqueeze(-1)
h = masked_embed.sum(dim=2)
backbone_out = model.backbone(h, input_pos=positions, mask=None).to(dtype)
bridging_out = bridging_module(backbone_out)
codebook0_logits = model.codebook0_head(backbone_out)
codebook0_tokens = torch.argmax(codebook0_logits, dim=-1).clamp(0, model.codebook_embedding.num_embeddings - 1)
c0_embed = model.codebook_embedding(codebook0_tokens)
# Get the last hidden state from bridging module
last_h = bridging_out[:, -1, :].unsqueeze(1)
# Concatenate the last hidden state with the codebook embeddings
decoder_input = torch.cat([last_h, c0_embed], dim=1)
# Process decoder inputs in parallel
B, S, D = decoder_input.shape # Batch, Sequence length, Dimension
# Reshape to (B*S, D) to process all tokens in parallel
decoder_input_flat = decoder_input.view(-1, D).unsqueeze(1) # [B*S, 1, D]
# Run decoder on all inputs in parallel
decoder_out_flat = model.decoder(decoder_input_flat).to(dtype) # [B*S, 1, output_dim]
# Reshape back to original batch and sequence dimensions
decoder_out = decoder_out_flat.view(B, S, -1) # [B, S, output_dim]
# Remove the first token (corresponding to last_h) as in original code
decoder_out = decoder_out[:, 1:, :] # [B, T, 1024]
# Safety check: handle empty sequences
if decoder_out.size(1) == 0:
return torch.tensor(0.0, requires_grad=True, device=device)
loss = compute_loss_for_codebooks_single_pass(
backbone_out=backbone_out,
decoder_out=decoder_out,
model=model,
target_tokens=target_tokens[..., 1:], # Drop codebook 0
target_masks=target_masks[..., 1:],
device=device
)
return loss
def calculate_validation_loss(model, bridging_module, dataset, device, max_samples=50):
"""
Calculate validation loss on a subset of the dataset
"""
# Create a small validation dataloader with a subset of data
val_indices = torch.randperm(len(dataset))[:max_samples].tolist()
val_samples = [dataset[i] for i in val_indices]
val_loader = DataLoader(
val_samples,
batch_size=1,
shuffle=False,
collate_fn=collate_fn,
num_workers=0,
pin_memory=False
)
model.eval()
bridging_module.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in val_loader:
setup_model_caches(model, batch["target_tokens"].size(0))
loss = forward_and_loss(model, bridging_module, batch, device)
total_loss += loss.item()
num_batches += 1
model.train()
bridging_module.train()
# Return average loss
return total_loss / num_batches if num_batches > 0 else 0.0
def strip_bias_keys(state_dict: dict) -> dict:
new_sd = {}
for k, v in state_dict.items():
if k == "codebook_embedding.weight":
print(f"Stripping {k} from checkpoint (training-only layer)")
continue
if not k.endswith(".bias"):
new_sd[k] = v
else:
print(f"Stripping {k} from checkpoint")
return new_sd
def remove_lora_modules(module: nn.Module) -> nn.Module:
for name, child in list(module.named_children()):
new_child = remove_lora_modules(child)
setattr(module, name, new_child)
if isinstance(module, LoRALinear):
out_features, in_features = module.out_features, module.in_features
# Determine if we actually need a bias
has_bias = (module.bias is not None)
new_linear = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=has_bias
)
# Copy over the merged weight
new_linear.weight.data.copy_(module.weight.data)
# If we had a bias in LoRALinear, copy it too
if has_bias:
new_linear.bias.data.copy_(module.bias.data)
return new_linear
return module
def merge_lora_layer(lora_module: LoRALinear):
"""
Merge the LoRA params (lora_A, lora_B) into the base weight in-place.
This transforms the LoRALinear into a standard Linear equivalent.
"""
# W = W + (alpha/r) * (lora_B @ lora_A)
merged_delta = lora_module.scaling * (lora_module.lora_B @ lora_module.lora_A)
lora_module.weight.data += merged_delta
# Optionally zero out LoRA parameters so they no longer affect anything
lora_module.lora_A.data.zero_()
lora_module.lora_B.data.zero_()
def merge_lora_weights(model: nn.Module):
for module in model.modules():
if isinstance(module, LoRALinear):
merge_lora_layer(module)
return model
def finetune(model, dataset):
logger.info("Starting finetuning process")
csv_file = os.path.join(OUTPUT_DIR, "training_metrics.csv")
with open(csv_file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["epoch", "step", "global_step", "loss", "learning_rate", "val_loss"])
def log_metrics(epoch, step, global_step, loss, learning_rate, val_loss=None):
with open(csv_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([epoch, step, global_step, loss, learning_rate, val_loss if val_loss is not None else ""])
visualizer.update(epoch, global_step, loss, learning_rate, val_loss)
visualizer = TrainingVisualizer(OUTPUT_DIR)
bridging_module = BridgingModule(in_dim=2048, out_dim=1024).to(DEVICE)
for param in bridging_module.parameters():
param.requires_grad = True
dataloader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate_fn, num_workers=0, pin_memory=False
)
trainable_params = [p for p in model.parameters() if p.requires_grad] + list(bridging_module.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE)
num_training_steps = len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
lr_scheduler = get_scheduler(
"cosine", optimizer=optimizer,
num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps
)
if USE_WANDB:
wandb.init(project="csm-finetuning", name="csm-lora-finetune-fixed")
scaler = torch.amp.GradScaler() if MIXED_PRECISION else None
global_step = 0
validation_frequency = max(1, len(dataloader) // (2 * GRADIENT_ACCUMULATION_STEPS))
model.train()
bridging_module.train()
logger.info("Calculating initial validation loss...")
initial_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
logger.info(f"Initial validation loss: {initial_val_loss:.6f}")
current_loss = 0.0
current_lr = LEARNING_RATE
for epoch in range(NUM_EPOCHS):
logger.info(f"Starting epoch {epoch+1}/{NUM_EPOCHS}")
progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
for step, batch in enumerate(dataloader):
try:
setup_model_caches(model, batch["target_tokens"].size(0))
with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16, enabled=MIXED_PRECISION):
loss = forward_and_loss(model, bridging_module, batch, DEVICE)
if GRADIENT_ACCUMULATION_STEPS > 1:
loss = loss / GRADIENT_ACCUMULATION_STEPS
if torch.isnan(loss) or torch.isinf(loss):
logger.warning(f"NaN or Inf loss detected at step {step}. Skipping batch.")
optimizer.zero_grad()
progress_bar.update(1)
continue
if MIXED_PRECISION:
scaler.scale(loss).backward()
else:
loss.backward()
if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(dataloader):
if MIXED_PRECISION:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(trainable_params, MAX_GRAD_NORM)
if MIXED_PRECISION:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
current_lr = optimizer.param_groups[0]["lr"]
current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS if GRADIENT_ACCUMULATION_STEPS > 1 else loss.item()
current_epoch = epoch + (step + 1) / len(dataloader)
current_val_loss = None
if global_step > 0 and global_step % validation_frequency == 0:
logger.info(f"Calculating validation loss at global step {global_step}...")
current_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
logger.info(f"Validation loss: {current_val_loss:.6f}")
log_metrics(current_epoch, step, global_step, current_loss, current_lr, current_val_loss)
global_step += 1
if USE_WANDB:
wandb.log({"loss": current_loss, "learning_rate": current_lr, "epoch": current_epoch, "global_step": global_step, "val_loss": current_val_loss})
progress_bar.set_postfix({"loss": f"{current_loss:.4f}", "lr": f"{current_lr:.2e}"})
progress_bar.update(1)
except Exception as e:
logger.error(f"Error in batch {step}: {e}")
import traceback
logger.error(traceback.format_exc())
try:
optimizer.zero_grad()
torch.cuda.empty_cache()
except: pass
progress_bar.update(1)
continue
logger.info(f"Calculating validation loss at end of epoch {epoch+1}...")
epoch_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
logger.info(f"Epoch {epoch+1} validation loss: {epoch_val_loss:.6f}")
log_metrics(epoch + 1.0, len(dataloader), global_step, current_loss, current_lr, epoch_val_loss)
checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-epoch-{epoch+1}")
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_tensors = {
**model.state_dict(),
**bridging_module.state_dict()
}
save_file(checkpoint_tensors, os.path.join(checkpoint_dir, "model.safetensors"))
logger.info(f"Saved checkpoint to {checkpoint_dir}")
final_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE, max_samples=100)
logger.info(f"Final validation loss: {final_val_loss:.6f}")
logger.info("Merging LoRA weights into the base model...")
merge_lora_weights(model)
model = remove_lora_modules(model)
merged_state = strip_bias_keys(model.state_dict())
final_merged_path = os.path.join(OUTPUT_DIR, "model.safetensors")
save_file(merged_state, final_merged_path)
logger.info(f"LoRA-merged & replaced model saved to {final_merged_path}")
visualizer.finalize()
if USE_WANDB:
wandb.finish()
return model
def forward_and_loss(model, bridging_module, batch, device):
target_tokens = batch["target_tokens"].to(device)
target_masks = batch["target_masks"].to(device)
positions = batch["positions"].to(device)
input_tokens = target_tokens[:, :-1]
input_masks = target_masks[:, :-1]
input_positions = positions[:, :-1]
labels = target_tokens[:, 1:]
label_masks = target_masks[:, 1:]
if input_tokens.size(1) == 0:
return torch.tensor(0.0, requires_grad=True, device=device)
# 1. Embed tokens and apply mask
embed = model._embed_tokens(input_tokens)
masked_embed = embed * input_masks.unsqueeze(-1)
h = masked_embed.sum(dim=2)
# 2. Pass through the backbone
backbone_out = model.backbone(h, input_pos=input_positions, mask=None)
# 3. Calculate loss for all codebooks
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
total_loss = 0.0
num_codebooks_with_loss = 0
c0_logits = model.codebook0_head(backbone_out)
c0_labels = labels[..., 0]
active_mask = label_masks[..., 0].view(-1)
if active_mask.sum() > 0:
active_logits = c0_logits.view(-1, c0_logits.size(-1))[active_mask]
active_labels = c0_labels.view(-1)[active_mask]
c0_loss = loss_fct(active_logits, active_labels)
total_loss += c0_loss
num_codebooks_with_loss += 1
decoder_states = bridging_module(backbone_out)
num_codebooks = model.config.audio_num_codebooks
for i in range(1, num_codebooks):
if hasattr(model, 'audio_head') and len(model.audio_head) >= i:
weight_i = model.audio_head[i-1]
logits_i = decoder_states @ weight_i
labels_i = labels[..., i]
active_mask_i = label_masks[..., i].view(-1)
if active_mask_i.sum() > 0:
active_logits_i = logits_i.view(-1, logits_i.size(-1))[active_mask_i]
active_labels_i = labels_i.view(-1)[active_mask_i]
loss_i = loss_fct(active_logits_i, active_labels_i)
total_loss += loss_i
num_codebooks_with_loss += 1
if num_codebooks_with_loss > 0:
return total_loss / num_codebooks_with_loss
else:
return torch.tensor(0.0, requires_grad=True, device=device)
def main():
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
os.makedirs(OUTPUT_DIR, exist_ok=True)
torch.backends.cuda.enable_flash_sdp(True)
if DEVICE == "cuda":
torch.backends.cudnn.benchmark = True
model, text_tokenizer, audio_tokenizer = prepare_csm_model_for_training()
audio_text_pairs = transcribe_audio_files()
if not audio_text_pairs:
logger.error(f"No audio files found or transcribed in {AUDIO_DIR}")
return
dataset = CSMDataset(
audio_text_pairs,
text_tokenizer=text_tokenizer,
audio_tokenizer=audio_tokenizer,
device=DEVICE
)
logger.info(f"Dataset created with {len(dataset)} samples")
try:
finetune(model, dataset)
logger.info("Finetuning completed successfully!")
except Exception as e:
logger.error(f"Error during finetuning: {e}")
import traceback
logger.error(traceback.format_exc())
try:
# If there's an error, at least save a partial state
partial_path = os.path.join(OUTPUT_DIR, "model_partial.safetensors")
torch.save(model.state_dict(), partial_path)
logger.info(f"Saved partial model to {partial_path} despite errors")
except Exception as save_error:
logger.error(f"Could not save partial model: {save_error}")
if __name__ == "__main__":
main()
================================================
FILE: main.py
================================================
import asyncio
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_DISABLE_CUDA_GRAPHS"] = "1"
import platform
import sqlite3
import time
import threading
import json
import queue
from fastapi.websockets import WebSocketState
import torch
import torchaudio
import sounddevice as sd
import numpy as np
import whisper
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy import create_engine, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from typing import Optional
from generator import Segment, load_csm_1b_local
from llm_interface import LLMInterface
from rag_system import RAGSystem
from vad import AudioStreamProcessor
from pydantic import BaseModel
import logging
from config import ConfigManager
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import re
speaking_start_time = 0.0 # set every time the AI begins a new turn
MIN_BARGE_LATENCY = 0.9
speaker_counters = {
0: 0, # AI
1: 0 # User
}
current_generation_id = 1
pending_user_inputs = []
user_input_lock = threading.Lock()
audio_fade_duration = 0.3 # seconds for fade-out
last_interrupt_time = 0
interrupt_cooldown = 6.0 # seconds between allowed interrupts
audio_chunk_buffer = [] # Buffer to store the most recent audio chunks for fade-out
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
model_thread = None
model_queue = queue.Queue()
model_result_queue = queue.Queue()
model_thread_running = threading.Event()
llm_lock = threading.Lock()
audio_gen_lock = threading.Lock()
# Database
Base = declarative_base()
engine = create_engine("sqlite:///companion.db")
SessionLocal = sessionmaker(bind=engine)
class Conversation(Base):
__tablename__ = "conversations"
id = Column(Integer, primary_key=True, index=True)
session_id = Column(String, index=True)
timestamp = Column(String)
user_message = Column(Text)
ai_message = Column(Text)
audio_path = Column(String)
Base.metadata.create_all(bind=engine)
# Pydantic config schema
class CompanionConfig(BaseModel):
system_prompt: str
reference_audio_path: str
reference_text: str
reference_audio_path2: Optional[str] = None # optional field
reference_text2: Optional[str] = None # optional field
reference_audio_path3: Optional[str] = None # optional field
reference_text3: Optional[str] = None # optional field
model_path: str
llm_path: str
max_tokens: int = 8192
voice_speaker_id: int = 0
vad_enabled: bool = True
vad_threshold: float = 0.5
embedding_model: str = "all-MiniLM-L6-v2"
# Global state
conversation_history = []
config = None
audio_queue = queue.Queue()
is_speaking = False
interrupt_flag = threading.Event()
generator = None
llm = None
rag = None
vad_processor = None
reference_segments = []
active_connections = []
message_queue = asyncio.Queue()
# Async event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# FastAPI
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
config_manager = ConfigManager()
model_id = "openai/whisper-large-v3-turbo"
# Whisper
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True
)
whisper_model.to("cuda")
processor = AutoProcessor.from_pretrained(model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch.float16,
device='cuda',
)
# Background queue
async def process_message_queue():
while True:
message = await message_queue.get()
for client in active_connections[:]:
try:
if client.client_state == WebSocketState.CONNECTED:
await client.send_json(message)
except Exception as e:
logger.error(f"Error in message queue for client: {e}")
if client in active_connections:
active_connections.remove(client)
message_queue.task_done()
def load_reference_segments(config_data: CompanionConfig):
"""Load multiple reference clips for voice‑cloning."""
global reference_segments
reference_segments = []
# Load primary reference (required)
if os.path.isfile(config_data.reference_audio_path):
logger.info(f"Loading primary reference audio: {config_data.reference_audio_path}")
wav, sr = torchaudio.load(config_data.reference_audio_path)
wav = torchaudio.functional.resample(wav.squeeze(0),
orig_freq=sr,
new_freq=24_000)
reference_segments.append(Segment(text=config_data.reference_text,
speaker=config_data.voice_speaker_id,
audio=wav))
else:
logger.warning(f"Primary reference audio '{config_data.reference_audio_path}' not found.")
# Load second reference (optional)
if config_data.reference_audio_path2 and os.path.isfile(config_data.reference_audio_path2):
logger.info(f"Loading second reference audio: {config_data.reference_audio_path2}")
wav, sr = torchaudio.load(config_data.reference_audio_path2)
wav = torchaudio.functional.resample(wav.squeeze(0),
orig_freq=sr,
new_freq=24_000)
reference_segments.append(Segment(text=config_data.reference_text2,
speaker=config_data.voice_speaker_id,
audio=wav))
# Load third reference (optional)
if config_data.reference_audio_path3 and os.path.isfile(config_data.reference_audio_path3):
logger.info(f"Loading third reference audio: {config_data.reference_audio_path3}")
wav, sr = torchaudio.load(config_data.reference_audio_path3)
wav = torchaudio.functional.resample(wav.squeeze(0),
orig_freq=sr,
new_freq=24_000)
reference_segments.append(Segment(text=config_data.reference_text3,
speaker=config_data.voice_speaker_id,
audio=wav))
logger.info(f"Loaded {len(reference_segments)} reference audio segments.")
def transcribe_audio(audio_data, sample_rate):
global whisper_model
audio_np = np.array(audio_data).astype(np.float32)
if sample_rate != 16000:
try:
audio_tensor = torch.tensor(audio_np).unsqueeze(0)
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000)
audio_np = audio_tensor.squeeze(0).numpy()
except: pass
try:
with torch.jit.optimized_execution(False):
result = whisper_pipe(audio_np, generate_kwargs={"language": "english"})
return result["text"]
except:
return "[Transcription error]"
def initialize_models(config_data: CompanionConfig):
global generator, llm, rag, vad_processor, config
config = config_data
logger.info("Loading LLM …")
llm = LLMInterface(config_data.llm_path,
config_data.max_tokens)
logger.info("Loading RAG …")
rag = RAGSystem("companion.db",
model_name=config_data.embedding_model)
vad_model, vad_utils = torch.hub.load('snakers4/silero-vad',
model='silero_vad',
force_reload=False)
vad_processor = AudioStreamProcessor(
model=vad_model,
utils=vad_utils,
sample_rate=16_000,
vad_threshold=config_data.vad_threshold,
callbacks={"on_speech_start": on_speech_start,
"on_speech_end": on_speech_end},
)
load_reference_segments(config_data)
start_model_thread()
logger.info("Compiling / warming‑up voice model …")
t0 = time.time()
# send a dummy request; max 0.5 s of audio, result discarded
model_queue.put((
"warm‑up.", # text
config_data.voice_speaker_id, # speaker
[], # no context
500, # max_ms
0.7, # temperature
40, # top‑k
))
# block until worker signals EOS (None marker)
while True:
r = model_result_queue.get()
if r is None:
break
logger.info(f"Voice model ready in {time.time() - t0:.1f}s")
def on_speech_start():
asyncio.run_coroutine_threadsafe(
message_queue.put(
{
"type": "vad_status",
"status": "speech_started",
"should_interrupt": False, # always False – UI never barges-in here
}
),
loop,
)
def on_speech_end(audio_data, sample_rate):
try:
logger.info("Transcription starting")
user_text = transcribe_audio(audio_data, sample_rate)
logger.info(f"Transcription completed: '{user_text}'")
session_id = "default"
speaker_id = 1
index = speaker_counters[speaker_id]
user_audio_path = f"audio/user/{session_id}_user_{index}.wav"
os.makedirs(os.path.dirname(user_audio_path), exist_ok=True)
audio_tensor = torch.tensor(audio_data).unsqueeze(0)
save_audio_and_trim(user_audio_path, session_id, speaker_id, audio_tensor.squeeze(0), sample_rate)
add_segment(user_text, speaker_id, audio_tensor.squeeze(0))
logger.info(f"User audio saved and segment appended: {user_audio_path}")
speaker_counters[speaker_id] += 1
# Send transcription to clients
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "transcription", "text": user_text}),
loop
)
threading.Thread(target=lambda: process_user_input(user_text, session_id), daemon=True).start()
except Exception as e:
logger.error(f"VAD callback failed: {e}")
def process_pending_inputs():
"""Process only the latest user input after an interruption"""
global pending_user_inputs, is_speaking, interrupt_flag
time.sleep(0.2)
is_speaking = False
interrupt_flag.clear()
with user_input_lock:
if not pending_user_inputs:
logger.info("No pending user inputs to process")
return
# Only take the most recent input and ignore others
latest_input = pending_user_inputs[-1]
logger.info(f"Processing only latest input: '{latest_input[0]}'")
# Clear all pending inputs
pending_user_inputs = []
# Process only the latest input
user_text, session_id = latest_input
process_user_input(user_text, session_id)
def process_user_input(user_text, session_id="default"):
global config, is_speaking, pending_user_inputs, interrupt_flag
# Skip empty messages
if not user_text or user_text.strip() == "":
logger.warning("Empty user input received, ignoring")
return
interrupt_flag.clear()
is_speaking = False
# Check if we're currently supposed to be speaking
if is_speaking:
logger.info(f"AI is currently speaking, adding input to pending queue: '{user_text}'")
with user_input_lock:
# Only keep the most recent input, replacing any existing ones
pending_user_inputs = [(user_text, session_id)]
logger.info(f"Added user input as the only pending input: '{user_text}'")
# Request interruption if not already interrupted
if not interrupt_flag.is_set():
logger.info("Automatically interrupting current speech for new input")
interrupt_flag.set()
# Notify clients of interruption
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "audio_status", "status": "interrupted"}),
loop
)
# Allow a short delay before processing the new input
time.sleep(0.3)
# Process the pending input after interruption
process_pending_inputs()
return
interrupt_flag.clear()
# Normal processing continues...
logger.info(f"Processing user input: '{user_text}'")
context = "\n".join([f"User: {msg['user']}\nAI: {msg['ai']}" for msg in conversation_history[-5:]])
rag_context = rag.query(user_text)
system_prompt = config.system_prompt
if rag_context:
system_prompt += f"\n\nRelevant context:\n{rag_context}"
# Notify clients that we're thinking
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "status", "message": "Thinking..."}),
loop
)
try:
with llm_lock:
ai_response = llm.generate_response(system_prompt, user_text, context)
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
conversation_history.append({
"timestamp": timestamp,
"user": user_text,
"ai": ai_response
})
try:
db = SessionLocal()
conv = Conversation(
session_id=session_id,
timestamp=timestamp,
user_message=user_text,
ai_message=ai_response,
audio_path=""
)
db.add(conv)
db.commit()
index = speaker_counters[0]
output_file = f"audio/ai/{session_id}_response_{index}.wav"
speaker_counters[0] += 1
conv.audio_path = output_file
db.commit()
db.close()
except Exception as e:
logger.error(f"Database error: {e}")
threading.Thread(target=lambda: rag.add_conversation(user_text, ai_response), daemon=True).start()
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "audio_status", "status": "preparing"}),
loop
)
# Small delay to ensure client is ready
time.sleep(0.2)
# Send the response to clients
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "response", "text": ai_response}),
loop
)
time.sleep(0.5)
if is_speaking:
logger.warning("Still speaking when trying to start new audio - forcing interrupt")
interrupt_flag.set()
is_speaking = False
time.sleep(0.5) # Give time for cleanup
interrupt_flag.clear() # Make absolutely sure
is_speaking = False # Reset for audio thread to take over
# Start audio generation in a new thread
threading.Thread(target=audio_generation_thread, args=(ai_response, output_file), daemon=True).start()
except Exception as e:
logger.error(f"Error generating response: {e}")
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "error", "message": "Failed to generate response"}),
loop
)
def model_worker(cfg: CompanionConfig):
global generator, model_thread_running
logger.info("Model worker thread started")
if generator is None:
torch._inductor.config.triton.cudagraphs = False # Disable cudagraphs
torch._inductor.config.fx_graph_cache = False # Disable graph caching
logger.info("Loading voice model inside worker thread …")
generator = load_csm_1b_local(cfg.model_path, "cuda")
logger.info("Voice model ready (compiled with cudagraphs)")
while model_thread_running.is_set():
try:
request = model_queue.get(timeout=0.1)
if request is None:
break
text, speaker_id, context, max_ms, temperature, topk = request
for chunk in generator.generate_stream(
text=text,
speaker=speaker_id,
context=context,
max_audio_length_ms=max_ms,
temperature=temperature,
topk=topk):
model_result_queue.put(chunk)
if not model_thread_running.is_set():
break
model_result_queue.put(None) # EOS marker
except queue.Empty:
continue
except Exception as e:
import traceback
logger.error(f"Error in model worker: {e}\n{traceback.format_exc()}")
model_result_queue.put(Exception(f"Generation error: {e}"))
logger.info("Model worker thread exiting")
def start_model_thread():
global model_thread, model_thread_running
if model_thread is not None and model_thread.is_alive():
return
model_thread_running.set()
model_thread = threading.Thread(target=model_worker,
args=(config,),
daemon=True,
name="model_worker")
model_thread.start()
logger.info("Started dedicated model worker thread")
async def run_audio_generation(text, output_file):
"""Async wrapper for audio generation that runs in the event loop thread"""
audio_generation_thread(text, output_file)
def send_to_all_clients(message: dict):
"""Send a message to all connected WebSocket clients"""
for client in active_connections[:]:
try:
if client.client_state == WebSocketState.CONNECTED:
asyncio.run_coroutine_threadsafe(client.send_json(message), loop)
logger.info(f"Sent message to client: {message}")
else:
logger.warning("Detected non-connected client; removing from active_connections")
active_connections.remove(client)
except Exception as e:
logger.error(f"Error sending message to client: {e}")
if client in active_connections:
active_connections.remove(client)
saved_audio_paths = {
"default": {
0: [], # AI
1: [] # User
}
}
MAX_AUDIO_FILES = 8
def save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate):
"""
Save audio file and trim old audio files for both AI and user to maintain storage limits.
Args:
path: Path to save the audio file
session_id: Conversation session ID
speaker_id: 0 for AI, 1 for user
tensor: Audio tensor to save
sample_rate: Audio sample rate
"""
torchaudio.save(path, tensor.unsqueeze(0), sample_rate)
saved_audio_paths.setdefault(session_id, {}).setdefault(speaker_id, []).append(path)
paths = saved_audio_paths[session_id][speaker_id]
while len(paths) > MAX_AUDIO_FILES:
old_path = paths.pop(0)
if os.path.exists(old_path):
os.remove(old_path)
logger.info(f"Removed old audio file: {old_path}")
other_speaker_id = 1 if speaker_id == 0 else 0
if other_speaker_id in saved_audio_paths[session_id]:
other_paths = saved_audio_paths[session_id][other_speaker_id]
while len(other_paths) > MAX_AUDIO_FILES:
old_path = other_paths.pop(0)
if os.path.exists(old_path):
os.remove(old_path)
logger.info(f"Removed old audio file from other speaker: {old_path}")
MAX_SEGMENTS = 8
def add_segment(text, speaker_id, audio_tensor):
"""
Add a new segment and ensure the total context stays within token limits.
This version correctly separates protected and dynamic segments, performs trimming
on the dynamic list, and rebuilds the global context list at the end.
Args:
text: Text content of the segment
speaker_id: ID of the speaker (0 for AI, 1 for user)
audio_tensor: Audio data as a tensor
"""
global reference_segments, generator, config
# Determine the number of protected, initial reference segments based on what was actually loaded.
num_protected_segments = 0
if config.reference_audio_path and os.path.exists(config.reference_audio_path):
num_protected_segments += 1
if config.reference_audio_path2 and os.path.exists(config.reference_audio_path2):
num_protected_segments += 1
if config.reference_audio_path3 and os.path.exists(config.reference_audio_path3):
num_protected_segments += 1
# Separate protected from dynamic segments from the current global state
protected_segments = reference_segments[:num_protected_segments]
dynamic_segments = reference_segments[num_protected_segments:]
# Add the new segment to the dynamic list
new_segment = Segment(text=text, speaker=speaker_id, audio=audio_tensor)
dynamic_segments.append(new_segment)
# First, trim by MAX_SEGMENTS count. The oldest dynamic segments are removed.
max_dynamic_allowed = MAX_SEGMENTS - len(protected_segments)
if len(dynamic_segments) > max_dynamic_allowed:
# Keep only the most recent dynamic segments
dynamic_segments = dynamic_segments[-max_dynamic_allowed:]
# Then, check and trim by token count if necessary.
# This loop will trim the oldest dynamic segments until the token count is acceptable.
if hasattr(generator, '_text_tokenizer'):
while dynamic_segments:
# Tentatively combine for token calculation
temp_full_list = protected_segments + dynamic_segments
total_tokens = 0
# Calculate total tokens for the current combination
for segment in temp_full_list:
tokens = generator._text_tokenizer.encode(f"[{segment.speaker}]{segment.text}")
total_tokens += len(tokens)
if segment.audio is not None:
# Approximate frame count to token conversion
audio_frames = segment.audio.size(0) // 6094
total_tokens += audio_frames
# If we are within limits, the trimming is done.
if total_tokens <= 4096:
break
# Otherwise, remove the oldest dynamic segment and re-check in the next loop iteration.
dynamic_segments.pop(0)
else:
# Fallback if tokenizer is not available
logger.warning("Unable to access tokenizer - falling back to word-based estimation for context trimming")
def estimate_tokens(segment):
words = segment.text.split()
punctuation = sum(1 for char in segment.text if char in ".,!?;:\"'()[]{}")
text_tokens = len(words) + punctuation
audio_tokens = 0
if segment.audio is not None:
audio_frames = segment.audio.size(0) // 6094
audio_tokens = audio_frames
return text_tokens + audio_tokens
while dynamic_segments:
total_estimated_tokens = sum(estimate_tokens(s) for s in protected_segments) + \
sum(estimate_tokens(s) for s in dynamic_segments)
if total_estimated_tokens <= 2048:
break
dynamic_segments.pop(0)
# Finally, overwrite the global variable with the new, correctly-trimmed list.
# This is the single source of truth for the update.
reference_segments = protected_segments + dynamic_segments
# Log the final state for debugging
logger.info(f"Context updated. Segments: {len(reference_segments)} total " +
f"({len(protected_segments)} protected, {len(dynamic_segments)} dynamic).")
def preprocess_text_for_tts(text):
"""
Removes all punctuation except periods, commas, exclamation points, and question marks
from the input text to create cleaner speech output while preserving intonation.
Args:
text (str): Input text with potential punctuation
Returns:
str: Cleaned text with only allowed punctuation
"""
# Define a regex pattern that matches all punctuation except periods, commas, exclamation points, and question marks
# This includes: ; : " ' ~ @ # $ % ^ & * ( ) _ - + = [ ] { } \ | / < >
pattern = r'[^\w\s.,!?\']'
# Replace matched punctuation with empty string
cleaned_text = re.sub(pattern, '', text)
# normalize multiple spaces to single space
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
# ensure there's a space after punctuation for better speech pacing
cleaned_text = re.sub(r'([.,!?])(\S)', r'\1 \2', cleaned_text)
return cleaned_text.strip()
def audio_generation_thread(text, output_file):
global is_speaking, interrupt_flag, audio_queue, model_thread_running, current_generation_id, speaking_start_time
current_generation_id += 1
this_id = current_generation_id
interrupt_flag.clear()
# Log the start of generation
logger.info(f"Starting audio generation for ID: {this_id}")
# Try to acquire the lock, but don't block if it's busy
if not audio_gen_lock.acquire(blocking=False):
logger.warning(f"Audio generation {this_id} - lock acquisition failed, another generation is in progress")
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "error",
"message": "Audio generation busy, skipping synthesis",
"gen_id": this_id
}),
loop
)
return
try:
# Start the model thread if it's not already running
start_model_thread()
interrupt_flag.clear()
is_speaking = True
speaking_start_time = time.time()
# Create output directory
os.makedirs(os.path.dirname(output_file), exist_ok=True)
all_audio_chunks = []
# Prepare text
text_lower = text.lower()
text_lower = preprocess_text_for_tts(text_lower)
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "audio_status",
"status": "preparing_generation",
"gen_id": this_id
}),
loop
)
# Give client a moment to process
time.sleep(0.2)
logger.info(f"Sending generating status with ID {this_id}")
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "audio_status",
"status": "generating",
"gen_id": this_id # Include generation ID
}),
loop
)
# Small delay to ensure client gets the signal
time.sleep(0.2)
# Estimate audio length
words = text.split()
avg_wpm = 100
words_per_second = avg_wpm / 60
estimated_seconds = len(words) / words_per_second
max_audio_length_ms = int(estimated_seconds * 1000)
# Send request to model thread
logger.info(f"Audio generation {this_id} - sending request to model thread")
model_queue.put((
text_lower,
config.voice_speaker_id,
reference_segments,
max_audio_length_ms,
0.8, # temperature
50 # topk
))
# Start timing
generation_start = time.time()
chunk_counter = 0
# Process results as they come
while True:
try:
# Check for interruption FIRST before getting more results
if interrupt_flag.is_set():
logger.info(f"Audio generation {this_id} - interrupt detected, stopping")
# Signal model thread to exit and restart
model_thread_running.clear()
time.sleep(0.1)
model_thread_running.set()
start_model_thread()
# Clear any remaining items in the result queue
while not model_result_queue.empty():
try:
model_result_queue.get_nowait()
except queue.Empty:
pass
# Break out of the processing loop
break
# Get result with timeout to allow checking interrupt
result = model_result_queue.get(timeout=0.1)
# Check for end of generation or error
if result is None:
logger.info(f"Audio generation {this_id} - complete")
break
if isinstance(result, Exception):
logger.error(f"Audio generation {this_id} - error: {result}")
raise result
# Track timing for first chunk
if chunk_counter == 0:
first_chunk_time = time.time() - generation_start
logger.info(f"Audio generation {this_id} - first chunk latency: {first_chunk_time*1000:.1f}ms")
chunk_counter += 1
# One more interrupt check before processing chunk
if interrupt_flag.is_set():
logger.info(f"Audio generation {this_id} - interrupt flag set during chunk processing")
break
# Process this audio chunk
audio_chunk = result
all_audio_chunks.append(audio_chunk)
# Convert to numpy and send to audio queue
chunk_array = audio_chunk.cpu().numpy().astype(np.float32)
audio_queue.put(chunk_array)
if chunk_counter == 1:
logger.info(f"Sending first audio chunk with ID {this_id}")
# Notify client we're sending the first chunk
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "audio_status",
"status": "first_chunk",
"gen_id": this_id
}),
loop
)
# Small delay
time.sleep(0.1)
# Send chunk with generation ID
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "audio_chunk",
"audio": chunk_array.tolist(),
"sample_rate": generator.sample_rate,
"gen_id": this_id,
"chunk_num": chunk_counter # Include chunk number
}),
loop
)
except queue.Empty:
# No results yet, keep checking
continue
except Exception as e:
logger.error(f"Audio generation {this_id} - error processing result: {e}")
break
# Save complete audio if available
if all_audio_chunks and not interrupt_flag.is_set():
try:
complete_audio = torch.cat(all_audio_chunks)
save_audio_and_trim(output_file, "default", config.voice_speaker_id, complete_audio, generator.sample_rate)
add_segment(text.lower(), config.voice_speaker_id, complete_audio)
# Log statistics
total_time = time.time() - generation_start
total_audio_seconds = complete_audio.size(0) / generator.sample_rate
rtf = total_time / total_audio_seconds
logger.info(f"Audio generation {this_id} - completed in {total_time:.2f}s, RTF: {rtf:.2f}x")
except Exception as e:
logger.error(f"Audio generation {this_id} - error saving complete audio: {e}")
except Exception as e:
import traceback
logger.error(f"Audio generation {this_id} - unexpected error: {e}\n{traceback.format_exc()}")
finally:
is_speaking = False
# Signal end of audio
audio_queue.put(None)
try:
logger.info(f"Audio generation {this_id} - sending completion status")
asyncio.run_coroutine_threadsafe(
message_queue.put({
"type": "audio_status",
"status": "complete",
"gen_id": this_id
}),
loop
)
except Exception as e:
logger.error(f"Audio generation {this_id} - failed to send completion status: {e}")
# Process any pending inputs
with user_input_lock:
if pending_user_inputs:
# Process pending inputs
logger.info(f"Audio generation {this_id} - processing pending inputs")
process_pending_inputs()
# Release the lock
logger.info(f"Audio generation {this_id} - releasing lock")
audio_gen_lock.release()
def handle_interrupt(websocket):
global is_speaking, last_interrupt_time, interrupt_flag, model_thread_running, speaking_start_time
# Log the current state
logger.info(f"Interrupt requested. Current state: is_speaking={is_speaking}")
current_time = time.time()
time_since_speech_start = current_time - speaking_start_time if speaking_start_time > 0 else 999
time_since_last_interrupt = current_time - last_interrupt_time
# Only apply cooldown for established speech, not for new speech
if time_since_last_interrupt < interrupt_cooldown and time_since_speech_start > 3.0:
logger.info(f"Ignoring interrupt: too soon after previous interrupt ({time_since_last_interrupt:.1f}s < {interrupt_cooldown}s)")
# Let the client know we're not interrupting
asyncio.run_coroutine_threadsafe(
websocket.send_json({
"type": "audio_status",
"status": "interrupt_acknowledged",
"success": False,
"reason": "cooldown"
}),
loop
)
return False
# Update the last interrupt time
last_interrupt_time = current_time
# We should interrupt if we're speaking OR if model generation is in progress
if is_speaking or not model_result_queue.empty():
logger.info("Interruption processing: we are speaking or generating")
interrupt_flag.set()
# Notify clients
asyncio.run_coroutine_threadsafe(
message_queue.put({"type": "audio_status", "status": "interrupted"}),
loop
)
asyncio.run_coroutine_threadsafe(
websocket.send_json({
"type": "audio_status",
"status": "interrupt_acknowledged"
}),
loop
)
# Clear the audio queue to stop additional audio from being processed
try:
# Drain the existing queue
while not audio_queue.empty():
try:
audio_queue.get_nowait()
except queue.Empty:
break
# Add end signal
audio_queue.put(None)
logger.info("Audio queue cleared")
except Exception as e:
logger.error(f"Error clearing audio queue: {e}")
# Reset VAD to prepare for new input
if vad_processor:
try:
vad_processor.reset()
logger.info("VAD processor reset")
except Exception as e:
logger.error(f"Error resetting VAD: {e}")
# Stop current model worker if needed
if model_thread and model_thread.is_alive():
try:
# Clear the thread running flag to stop generation
model_thread_running.clear()
# Wait a brief moment for thread to notice and exit
time.sleep(0.1)
# Now restart the thread state flag
model_thread_running.set()
# And restart the thread
start_model_thread()
logger.info("Model thread restarted")
except Exception as e:
logger.error(f"Error restarting model thread: {e}")
return True
logger.info("No active speech to interrupt")
return False
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global is_speaking, audio_queue
await websocket.accept()
active_connections.append(websocket)
saved = config_manager.load_config()
if saved:
await websocket.send_json({"type": "saved_config", "config": saved})
try:
while True:
data = await websocket.receive_json()
if data["type"] == "config":
# Config handling
try:
config_data = data["config"]
logger.info(f"Received config data keys: {config_data.keys()}")
for key in ["reference_audio_path", "reference_audio_path2", "reference_audio_path3",
"reference_text", "reference_text2", "reference_text3"]:
if key in config_data:
logger.info(f"Config includes {key}: {config_data[key]}")
else:
logger.warning(f"Config missing {key}")
conf = CompanionConfig(**config_data)
saved = config_manager.save_config(config_data)
if saved:
initialize_models(conf)
await websocket.send_json({"type": "status", "message": "Models initialized and configuration saved"})
else:
await websocket.send_json({"type": "error", "message": "Failed to save configuration"})
except Exception as e:
logger.error(f"Error processing config: {str(e)}")
await websocket.send_json({"type": "error", "message": f"Configuration error: {str(e)}"})
elif data["type"] == "request_saved_config":
saved = config_manager.load_config()
await websocket.send_json({"type": "saved_config", "config": saved})
elif data["type"] == "text_message":
user_text = data["text"]
session_id = data.get("session_id", "default")
logger.info(f"TEXT-MSG from client: {user_text!r}")
# If the model is already talking, queue the request but
if is_speaking:
with user_input_lock:
if len(pending_user_inputs) >= 3:
pending_user_inputs = pending_user_inputs[-2:]
pending_user_inputs.append((user_text, session_id))
await websocket.send_json(
{"type":"status","message":"Queued – I’ll answer in a moment"})
continue
await message_queue.put({"type":"transcription","text":user_text})
threading.Thread(
target=lambda: process_user_input(user_text, session_id),
daemon=True).start()
elif data["type"] == "audio":
audio_data = np.asarray(data["audio"], dtype=np.float32)
sample_rate = data["sample_rate"]
if sample_rate != 16000:
audio_tensor = torch.tensor(audio_data).unsqueeze(0)
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=sample_rate, new_freq=16000
)
audio_data = audio_tensor.squeeze(0).numpy()
sample_rate = 16000
if config and config.vad_enabled:
vad_processor.process_audio(audio_data)
else:
text = transcribe_audio(audio_data, sample_rate)
await websocket.send_json({"type": "transcription", "text": text})
await message_queue.put({"type": "transcription", "text": text})
if is_speaking:
with user_input_lock:
pending_user_inputs.append((text, "default"))
else:
process_user_input(text)
elif data["type"] == "interrupt":
logger.info("Explicit interrupt request received")
# Always acknowledge receipt of interrupt request
await websocket.send_json({
"type": "audio_status",
"status": "interrupt_acknowledged"
})
# Then try to handle the actual interrupt
success = handle_interrupt(websocket)
# If successful, allow a brief delay for clearing everything
if success:
await asyncio.sleep(0.3) # Short delay to allow complete clearing
# Force process pending inputs after interrupt
with user_input_lock:
if pending_user_inputs:
user_text, session_id = pending_user_inputs.pop(0)
pending_user_inputs.clear() # Clear any backup to avoid multiple responses
# Process in a new thread to avoid blocking
threading.Thread(
target=lambda: process_user_input(user_text, session_id),
daemon=True
).start()
# Send final status update about the interrupt
await websocket.send_json({
"type": "audio_status",
"status": "interrupted",
"success": success
})
elif data["type"] == "mute":
await websocket.send_json({"type": "mute_status", "muted": data["muted"]})
if not data["muted"] and config and config.vad_enabled:
vad_processor.reset()
except WebSocketDisconnect:
if websocket in active_connections:
active_connections.remove(websocket)
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/setup", response_class=HTMLResponse)
async def setup_page(request: Request):
return templates.TemplateResponse("setup.html", {"request": request})
@app.get("/chat", response_class=HTMLResponse)
async def chat_page(request: Request):
return templates.TemplateResponse("chat.html", {"request": request})
@app.on_event("startup")
async def startup_event():
os.makedirs("static", exist_ok=True)
os.makedirs("audio/user", exist_ok=True)
os.makedirs("audio/ai", exist_ok=True)
os.makedirs("embeddings_cache", exist_ok=True)
os.makedirs("templates", exist_ok=True)
with open("templates/index.html", "w") as f:
f.write("""<meta http-equiv="refresh" content="0; url=/setup" />""")
try:
torch.hub.load('snakers4/silero-vad', model='silero_vad', force_reload=False)
except: pass
asyncio.create_task(process_message_queue())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Server shutting down...")
from flask import Flask, jsonify, request, send_file
@app.get("/api/conversations")
async def get_conversations(request: Request):
conn = sqlite3.connect("companion.db")
cur = conn.cursor()
cur.execute("SELECT id, user_message, ai_message FROM conversations ORDER BY id DESC")
data = [{"id": row[0], "user_message": row[1], "ai_message": row[2]} for row in cur.fetchall()]
conn.close()
return JSONResponse(content=data)
@app.route("/api/conversations/<int:conv_id>", methods=["PUT"])
def update_conversation(conv_id):
data = request.get_json()
conn = sqlite3.connect("companion.db")
cur = conn.cursor()
cur.execute("UPDATE conversations SET user_message=?, ai_message=? WHERE id=?",
(data["user_message"], data["ai_message"], conv_id))
conn.commit()
conn.close()
return "", 204
@app.delete("/api/conversations")
async def delete_all_conversations():
try:
conn = sqlite3.connect("companion.db")
cur = conn.cursor()
cur.execute("DELETE FROM conversations")
conn.commit()
conn.close()
return {"status": "all deleted"}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.delete("/api/conversations/{conv_id}")
async def delete_conversation(conv_id: int):
try:
conn = sqlite3.connect("companion.db")
cur = conn.cursor()
cur.execute("DELETE FROM conversations WHERE id = ?", (conv_id,))
conn.commit()
conn.close()
return JSONResponse(content={"status": "deleted", "id": conv_id})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/crud", response_class=HTMLResponse)
async def crud_ui(request: Request):
return templates.TemplateResponse("crud.html", {"request": request})
if __name__ == "__main__":
import uvicorn
threading.Thread(target=lambda: asyncio.run(loop.run_forever()), daemon=True).start()
uvicorn.run(app, host="0.0.0.0", port=8000)
================================================
FILE: models.py
================================================
import logging
from dataclasses import dataclass
import torch
import torch.nn as nn
import torchtune
from huggingface_hub import PyTorchModelHubMixin
from torchtune.models import llama3_2
logger = logging.getLogger(__name__)
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=4,
num_heads=8,
num_kv_heads=2,
embed_dim=1024,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
FLAVORS = {
"llama-1B": llama3_2_1B,
"llama-100M": llama3_2_100M,
}
def _prepare_transformer(model):
embed_dim = model.tok_embeddings.embedding_dim
model.tok_embeddings = nn.Identity()
model.output = nn.Identity()
return model, embed_dim
def _create_causal_mask(seq_len: int, device: torch.device):
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
"""
Args:
mask: (max_seq_len, max_seq_len)
input_pos: (batch_size, seq_len)
Returns:
(batch_size, seq_len, max_seq_len)
"""
r = mask[input_pos, :]
return r
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
logits = logits / temperature
filter_value: float = -float("Inf")
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
@dataclass
class ModelArgs:
backbone_flavor: str
decoder_flavor: str
text_vocab_size: int
audio_vocab_size: int
audio_num_codebooks: int
class Model(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/SesameAILabs/csm",
pipeline_tag="text-to-speech",
license="apache-2.0",
):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
def generate_frame(
self,
tokens: torch.Tensor,
tokens_mask: torch.Tensor,
input_pos: torch.Tensor,
temperature: float,
topk: int,
) -> torch.Tensor:
"""
Args:
tokens: (batch_size, seq_len, audio_num_codebooks+1)
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
input_pos: (batch_size, seq_len) positions for each token
mask: (batch_size, seq_len, max_seq_len
Returns:
(batch_size, audio_num_codebooks) sampled tokens
"""
dtype = next(self.parameters()).dtype
b, s, _ = tokens.size()
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
c0_sample = sample_topk(c0_logits, topk, temperature)
c0_embed = self._embed_audio(0, c0_sample)
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
dtype=dtype
)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, topk, temperature)
ci_embed = self._embed_audio(i, ci_sample)
curr_h = ci_embed
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
curr_pos = curr_pos[:, -1:] + 1
return curr_sample
def reset_caches(self):
self.backbone.reset_caches()
self.decoder.reset_caches()
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
)
return torch.cat([audio_embeds, text_embeds], dim=-2)
================================================
FILE: rag_system.py
================================================
import sqlite3
import numpy as np
import json
from pathlib import Path
import time
from typing import List, Dict, Any, Tuple, Optional
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
class RAGSystem:
def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = "./embeddings_cache"):
"""
Initialize the enhanced RAG system with embeddings.
Args:
db_path: Path to the SQLite database
model_name: Name of the sentence-transformer model to use
cache_dir: Directory to cache embeddings
"""
self.db_path = db_path
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
# Load embedding model
print(f"Loading embedding model: {model_name}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, device=self.device)
print(f"Embedding model loaded on {self.device}")
# Cache for embeddings
self.embedding_cache = self._load_embedding_cache()
# Initialize database tables if needed
self._initialize_db()
# Load existing conversations and cache embeddings
self._load_conversations()
def _initialize_db(self):
"""Create necessary tables if they don't exist."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create conversations table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY,
user_message TEXT,
ai_message TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Create embeddings table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY,
conversation_id INTEGER,
text TEXT,
embedding_file TEXT,
chunk_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
)
""")
conn.commit()
conn.close()
def _load_embedding_cache(self) -> Dict[str, np.ndarray]:
"""Load cached embeddings from disk."""
cache = {}
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r") as f:
cache_data = json.load(f)
for chunk_id, embedding_data in cache_data.items():
cache[chunk_id] = np.array(embedding_data)
except Exception as e:
print(f"Error loading cache file {cache_file}: {e}")
print(f"Loaded {len(cache)} cached embeddings")
return cache
def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):
"""Save an embedding to the cache."""
cache_file = self.cache_dir / f"{chunk_id[:2]}.json"
# Load existing cache file or create new one
if cache_file.exists():
try:
with open(cache_file, "r") as f:
cache_data = json.load(f)
except:
cache_data = {}
else:
cache_data = {}
# Add new embedding
cache_data[chunk_id] = embedding.tolist()
# Save cache file
with open(cache_file, "w") as f:
json.dump(cache_data, f)
def _load_conversations(self):
"""Load existing conversations from the database and cache their embeddings."""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# First check if the conversations table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'")
if not cursor.fetchone():
print("Conversations table does not exist yet")
conn.close()
return
# Get all conversations not yet in the embeddings table
cursor.execute("""
SELECT c.id, c.user_message, c.ai_message
FROM conversations c
LEFT JOIN embeddings e ON c.id = e.conversation_id
WHERE e.id IS NULL
""")
conversations = cursor.fetchall()
if not conversations:
conn.close()
return
print(f"Processing embeddings for {len(conversations)} new conversations")
for conv_id, user_message, ai_message in conversations:
# Create chunks for indexing
if user_message is not None and ai_message is not None: # Ensure neither is None
self._process_conversation(conv_id, user_message, ai_message, conn)
conn.close()
print("Finished processing conversation embeddings")
except Exception as e:
print(f"Error loading conversations: {e}")
def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection):
"""Process a conversation and store its embeddings."""
try:
cursor = conn.cursor()
# Combine user and AI messages
full_text = f"User: {user_message}\nAI: {ai_message}"
# For simplicity, we're using the entire message as a chunk
# In a more sophisticated system, you might split long messages into smaller chunks
chunk_id = f"conv_{conv_id}"
# Check if we already have this embedding cached
if chunk_id not in self.embedding_cache:
# Generate embedding
embedding = self.model.encode(full_text)
self.embedding_cache[chunk_id] = embedding
# Save to cache
self._save_embedding_to_cache(chunk_id, embedding)
else:
embedding = self.embedding_cache[chunk_id]
# Store reference in database
embedding_file = f"{chunk_id[:2]}.json"
cursor.execute(
"INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)",
(conv_id, full_text, embedding_file, chunk_id)
)
conn.commit()
except Exception as e:
print(f"Error processing conversation {conv_id}: {e}")
def add_conversation(self, user_message: str, ai_message: str) -> int:
"""
Add a new conversation to the RAG system.
Returns:
The id of the newly added conversation
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Insert the conversation first
cursor.execute(
"INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)",
(user_message, ai_message)
)
# Get the ID of the new conversation
conv_id = cursor.lastrowid
# Process the conversation for embeddings
self._process_conversation(conv_id, user_message, ai_message, conn)
conn.commit()
conn.close()
return conv_id
except Exception as e:
print(f"Error adding conversation: {e}")
return -1
def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]:
"""
Query the RAG system for relevant context.
Args:
query_text: The query text
top_k: Number of top results to return
Returns:
List of tuples with (text, similarity_score)
"""
if query_text is None or query_text.strip() == "":
print("Error: Empty query text")
return []
try:
# Generate query embedding
query_embedding = self.model.encode(query_text)
# Find most similar conversations
results = self._find_similar(query_embedding, top_k)
return results
except Exception as e:
print(f"Error during query: {e}")
return []
def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str:
"""
Get formatted context from the RAG system.
Args:
query_text: The query text
top_k: Number of top results to return
threshold: Minimum similarity score to include
Returns:
String with relevant context
"""
results = self.query(query_text, top_k)
if not results:
return ""
# Format results
context_parts = []
for text, score in results:
# Only include really relevant results
if score < threshold: # Threshold for relevance
continue
context_parts.append(f"Relevance: {score:.2f}\n{text}")
return "\n---\n".join(context_parts)
def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]:
"""Find the most similar conversations to the query."""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if the embeddings table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'")
if not cursor.fetchone():
print("Embeddings table does not exist yet")
conn.close()
return []
# Get all embeddings from the database
cursor.execute("SELECT id, text, embedding_file, chunk_id FROM embeddings")
results = cursor.fetchall()
if not results:
conn.close()
return []
# Calculate similarities
similarities = []
for db_id, text, embedding_file, chunk_id in results:
# Get embedding from cache
if chunk_id in self.embedding_cache:
embedding = self.embedding_cache[chunk_id]
else:
# This should not happen, but just in case
# We'll reload from the cache file
cache_file = self.cache_dir / embedding_file
if cache_file.exists():
with open(cache_file, "r") as f:
cache_data = json.load(f)
if chunk_id in cache_data:
embedding = np.array(cache_data[chunk_id])
self.embedding_cache[chunk_id] = embedding
else:
continue
else:
continue
# Calculate similarity
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
embedding.reshape(1, -1)
)[0][0]
similarities.append((text, similarity))
conn.close()
# Sort by similarity and return top_k
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
except Exception as e:
print(f"Error finding similar documents: {e}")
return []
def refresh(self):
"""Refresh embeddings from the database."""
self._load_conversations()
# Example usage
if __name__ == "__main__":
# Initialize the RAG system
rag = RAGSystem("conversations.db")
================================================
FILE: requirements.txt
================================================
--extra-index-url=https://download.pytorch.org/whl/cu128
vllm==0.8.0
torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
tokenizers==0.21.0
transformers==4.49.0
huggingface_hub==0.28.1
moshi==0.2.2
sounddevice
torchtune==0.4.0
torchao==0.9.0
bitsandbytes
peft
wandb
silero_vad
python-multipart>=0.0.6
aiofiles>=23.1.0
sentence-transformers>=2.2.2
ctransformers>=0.2.24
python-multipart>=0.0.6
sqlalchemy>=2.0.0
pydantic>=2.0.0
fastapi>=0.95.0
uvicorn>=0.22.0
websockets>=11.0.3
jinja2>=3.0.0
speechbrain>=0.5.15
matplotlib
whisper-openai
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
numpy==1.26.0
================================================
FILE: run_csm.py
================================================
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Default prompts are available at https://hf.co/sesame/csm-1b
prompt_filepath_conversational_a = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_a.wav"
)
prompt_filepath_conversational_b = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_b.wav"
)
SPEAKER_PROMPTS = {
"conversational_a": {
"text": (
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
"start really early I'd be like okay I'm gonna start revising now and then like "
"you're revising for ages and then I just like start losing steam I didn't do that "
"for the exam we had recently to be fair that was a more of a last minute scenario "
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
"sort of start the day with this not like a panic but like a"
),
"audio": prompt_filepath_conversational_a
},
"conversational_b": {
"text": (
"like a super Mario level. Like it's very like high detail. And like, once you get "
"into the park, it just like, everything looks like a computer game and they have all "
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
"will have like a question block. And if you like, you know, punch it, a coin will "
"come out. So like everyone, when they come into the park, they get like this little "
"bracelet and then you can go punching question blocks around."
),
"audio": prompt_filepath_conversational_b
}
}
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.squeeze(0)
# Resample is lazy so we can always call it
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
)
return audio_tensor
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
audio_tensor = load_prompt_audio(audio_path, sample_rate)
return Segment(text=text, speaker=speaker, audio=audio_tensor)
def main():
# Select the best available device, skipping MPS due to float64 limitations
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Load model
generator = load_csm_1b(device)
# Prepare prompts
prompt_a = prepare_prompt(
SPEAKER_PROMPTS["conversational_a"]["text"],
0,
SPEAKER_PROMPTS["conversational_a"]["audio"],
generator.sample_rate
)
prompt_b = prepare_prompt(
SPEAKER_PROMPTS["conversational_b"]["text"],
1,
SPEAKER_PROMPTS["conversational_b"]["audio"],
generator.sample_rate
)
# Generate conversation
conversation = [
{"text": "Hey how are you doing?", "speaker_id": 0},
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
]
# Generate each utterance
generated_segments = []
prompt_segments = [prompt_a, prompt_b]
for utterance in conversation:
print(f"Generating: {utterance['text']}")
audio_tensor = generator.generate(
text=utterance['text'],
speaker=utterance['speaker_id'],
context=prompt_segments + generated_segments,
max_audio_length_ms=10_000,
)
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
# Concatenate all generations
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
torchaudio.save(
"full_conversation.wav",
all_audio.unsqueeze(0).cpu(),
generator.sample_rate
)
print("Successfully generated full_conversation.wav")
if __name__ == "__main__":
main()
================================================
FILE: setup.py
================================================
import os
import sys
import subprocess
import logging
import urllib.request
import torch
import time
import shutil
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_requirements():
"""Check if all required Python packages are installed"""
logger.info("Checking requirements...")
requirements = [
"torch", "torchaudio", "fastapi", "uvicorn", "websockets", "numpy",
"scikit-learn", "sqlalchemy", "pydantic", "jinja2", "whisper",
"sounddevice", "soundfile", "sentence_transformers", "ctransformers"
]
missing = []
for req in requirements:
try:
__import__(req)
except ImportError:
missing.append(req)
if missing:
logger.warning(f"Missing required packages: {', '.join(missing)}")
logger.info("Installing missing requirements...")
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
logger.info("Requirements installed successfully")
else:
logger.info("All requirements are satisfied")
def download_vad_model():
"""Download the Silero VAD model using PyTorch Hub instead of direct URL"""
model_path = "silero_vad.jit"
if os.path.exists(model_path):
logger.info(f"Silero VAD model already exists at {model_path}")
return
logger.info("Downloading Silero VAD model using PyTorch Hub...")
try:
# Use torch.hub to download the model instead of direct URL
torch.hub.set_dir("./models")
model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=True,
onnx=False)
# Save the model
torch.jit.save(model, model_path)
logger.info(f"Model downloaded and saved to {model_path}")
except Exception as e:
logger.error(f"Failed to download Silero VAD model using PyTorch Hub: {e}")
logger.info("Falling back to energy-based VAD - the system will still work but with simpler voice detection")
def download_embedding_models():
"""Download the sentence transformer models for RAG"""
logger.info("Setting up sentence transformer models...")
try:
from sentence_transformers import SentenceTransformer
# Download lightweight model for embeddings
logger.info("Downloading embedding models (this may take a few minutes)...
gitextract_5x1lml4p/ ├── .github/ │ └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── finetuned_model/ │ └── config.json ├── generator.py ├── llm_interface.py ├── loadandmergecheckpoint.py ├── lora.py ├── main.py ├── models.py ├── rag_system.py ├── requirements.txt ├── run_csm.py ├── setup.py ├── static/ │ ├── app.js │ ├── chat.js │ └── crud.js ├── templates/ │ ├── chat.html │ ├── crud.html │ ├── index.html │ └── setup.html ├── test.py └── vad.py
SYMBOL INDEX (171 symbols across 15 files)
FILE: config.py
class ConfigManager (line 9) | class ConfigManager:
method __init__ (line 14) | def __init__(self, config_path: str = "config/app_config.json"):
method save_config (line 29) | def save_config(self, config_data: Dict[str, Any]) -> bool:
method load_config (line 68) | def load_config(self) -> Optional[Dict[str, Any]]:
function model_to_dict (line 102) | def model_to_dict(model: BaseModel) -> Dict[str, Any]:
FILE: generator.py
class Segment (line 24) | class Segment:
function load_llama3_tokenizer (line 31) | def load_llama3_tokenizer():
class Generator (line 48) | class Generator:
method __init__ (line 49) | def __init__(self, model: Model):
method _tokenize_text_segment (line 74) | def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[tor...
method _tokenize_audio (line 102) | def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, ...
method _tokenize_segment (line 128) | def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, t...
method _decode_frames (line 150) | def _decode_frames(self, frames):
method generate_stream (line 160) | def generate_stream(
method generate (line 346) | def generate(
class AudioStreamWriter (line 457) | class AudioStreamWriter:
method __init__ (line 461) | def __init__(self, filename, sample_rate):
method _writer_worker (line 473) | def _writer_worker(self):
method add_chunk (line 505) | def add_chunk(self, chunk):
method write_file (line 514) | def write_file(self):
function load_csm_1b_local (line 536) | def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_c...
function warmup_generator (line 597) | def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is ...
function load_csm_1b (line 733) | def load_csm_1b(device: str = "cuda") -> Generator:
function stream_audio_to_wav (line 781) | def stream_audio_to_wav(filename, sample_rate):
function generate_streaming_audio (line 820) | def generate_streaming_audio(
FILE: llm_interface.py
class LLMInterface (line 6) | class LLMInterface:
method __init__ (line 7) | def __init__(self, model_path: str, max_tokens: int = 8192, n_threads:...
method trim_to_last_sentence (line 33) | def trim_to_last_sentence(self, text: str) -> str:
method generate_response (line 53) | def generate_response(self, system_prompt: str, user_message: str, con...
method tokenize (line 91) | def tokenize(self, text: str) -> List[int]:
method get_token_count (line 105) | def get_token_count(self, text: str) -> int:
method batch_generate (line 116) | def batch_generate(self, prompts: List[Dict[str, str]],
FILE: loadandmergecheckpoint.py
function find_latest_checkpoint (line 19) | def find_latest_checkpoint(dir_path):
function load_checkpoint_and_merge (line 31) | def load_checkpoint_and_merge():
FILE: lora.py
class TrainingVisualizer (line 54) | class TrainingVisualizer:
method __init__ (line 55) | def __init__(self, output_dir):
method update (line 84) | def update(self, epoch, step, loss, lr, val_loss=None):
method finalize (line 171) | def finalize(self):
class LoRALinear (line 291) | class LoRALinear(nn.Module):
method __init__ (line 292) | def __init__(self, in_features, out_features, r=32, alpha=64, dropout=...
method forward (line 314) | def forward(self, x: torch.Tensor) -> torch.Tensor:
function replace_linear_with_lora (line 323) | def replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout...
function load_llama3_tokenizer (line 364) | def load_llama3_tokenizer():
class AudioTextPair (line 377) | class AudioTextPair:
method load_audio (line 383) | def load_audio(self, sample_rate=24000) -> torch.Tensor:
class CSMDataset (line 398) | class CSMDataset(Dataset):
method __init__ (line 399) | def __init__(self, data_items, text_tokenizer, audio_tokenizer, device):
method __len__ (line 406) | def __len__(self):
method tokenize_text_segment (line 409) | def tokenize_text_segment(self, text: str, speaker: int):
method tokenize_audio (line 417) | def tokenize_audio(self, audio: torch.Tensor):
method __getitem__ (line 438) | def __getitem__(self, idx: int):
function collate_fn (line 468) | def collate_fn(batch):
function transcribe_audio_files (line 505) | def transcribe_audio_files():
function prepare_csm_model_for_training (line 578) | def prepare_csm_model_for_training():
function setup_model_caches (line 647) | def setup_model_caches(model, batch_size):
class BridgingModule (line 657) | class BridgingModule(nn.Module):
method __init__ (line 659) | def __init__(self, in_dim=2048, out_dim=1024):
method forward (line 663) | def forward(self, x):
function compute_loss_for_codebooks_single_pass (line 666) | def compute_loss_for_codebooks_single_pass(
function single_pass_forward (line 715) | def single_pass_forward(model, bridging_module, target_tokens, target_ma...
function calculate_validation_loss (line 766) | def calculate_validation_loss(model, bridging_module, dataset, device, m...
function strip_bias_keys (line 804) | def strip_bias_keys(state_dict: dict) -> dict:
function remove_lora_modules (line 816) | def remove_lora_modules(module: nn.Module) -> nn.Module:
function merge_lora_layer (line 843) | def merge_lora_layer(lora_module: LoRALinear):
function merge_lora_weights (line 856) | def merge_lora_weights(model: nn.Module):
function finetune (line 862) | def finetune(model, dataset):
function forward_and_loss (line 1015) | def forward_and_loss(model, bridging_module, batch, device):
function main (line 1077) | def main():
FILE: main.py
class Conversation (line 63) | class Conversation(Base):
class CompanionConfig (line 75) | class CompanionConfig(BaseModel):
function process_message_queue (line 130) | async def process_message_queue():
function load_reference_segments (line 143) | def load_reference_segments(config_data: CompanionConfig):
function transcribe_audio (line 185) | def transcribe_audio(audio_data, sample_rate):
function initialize_models (line 201) | def initialize_models(config_data: CompanionConfig):
function on_speech_start (line 251) | def on_speech_start():
function on_speech_end (line 263) | def on_speech_end(audio_data, sample_rate):
function process_pending_inputs (line 293) | def process_pending_inputs():
function process_user_input (line 316) | def process_user_input(user_text, session_id="default"):
function model_worker (line 438) | def model_worker(cfg: CompanionConfig):
function start_model_thread (line 481) | def start_model_thread():
function run_audio_generation (line 495) | async def run_audio_generation(text, output_file):
function send_to_all_clients (line 499) | def send_to_all_clients(message: dict):
function save_audio_and_trim (line 522) | def save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate):
function add_segment (line 555) | def add_segment(text, speaker_id, audio_tensor):
function preprocess_text_for_tts (line 644) | def preprocess_text_for_tts(text):
function audio_generation_thread (line 664) | def audio_generation_thread(text, output_file):
function handle_interrupt (line 887) | def handle_interrupt(websocket):
function websocket_endpoint (line 982) | async def websocket_endpoint(websocket: WebSocket):
function index (line 1120) | async def index(request: Request):
function setup_page (line 1124) | async def setup_page(request: Request):
function chat_page (line 1128) | async def chat_page(request: Request):
function startup_event (line 1132) | async def startup_event():
function shutdown_event (line 1146) | async def shutdown_event():
function get_conversations (line 1152) | async def get_conversations(request: Request):
function update_conversation (line 1161) | def update_conversation(conv_id):
function delete_all_conversations (line 1172) | async def delete_all_conversations():
function delete_conversation (line 1184) | async def delete_conversation(conv_id: int):
function crud_ui (line 1199) | async def crud_ui(request: Request):
FILE: models.py
function llama3_2_1B (line 12) | def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
function llama3_2_100M (line 27) | def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
function _prepare_transformer (line 47) | def _prepare_transformer(model):
function _create_causal_mask (line 53) | def _create_causal_mask(seq_len: int, device: torch.device):
function _index_causal_mask (line 56) | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
function _multinomial_sample_one_no_sync (line 68) | def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling...
function sample_topk (line 72) | def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
class ModelArgs (line 85) | class ModelArgs:
class Model (line 93) | class Model(
method __init__ (line 100) | def __init__(self, config: ModelArgs):
method setup_caches (line 114) | def setup_caches(self, max_batch_size: int) -> torch.Tensor:
method generate_frame (line 126) | def generate_frame(
method reset_caches (line 180) | def reset_caches(self):
method _embed_audio (line 184) | def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.T...
method _embed_tokens (line 187) | def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
FILE: rag_system.py
class RAGSystem (line 11) | class RAGSystem:
method __init__ (line 12) | def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2",...
method _initialize_db (line 40) | def _initialize_db(self):
method _load_embedding_cache (line 70) | def _load_embedding_cache(self) -> Dict[str, np.ndarray]:
method _save_embedding_to_cache (line 86) | def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):
method _load_conversations (line 107) | def _load_conversations(self):
method _process_conversation (line 145) | def _process_conversation(self, conv_id: int, user_message: str, ai_me...
method add_conversation (line 179) | def add_conversation(self, user_message: str, ai_message: str) -> int:
method query (line 210) | def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, fl...
method get_context (line 237) | def get_context(self, query_text: str, top_k: int = 3, threshold: floa...
method _find_similar (line 264) | def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> Li...
method refresh (line 323) | def refresh(self):
FILE: run_csm.py
function load_prompt_audio (line 44) | def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch...
function prepare_prompt (line 53) | def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate...
function main (line 57) | def main():
FILE: setup.py
function check_requirements (line 16) | def check_requirements():
function download_vad_model (line 41) | def download_vad_model():
function download_embedding_models (line 66) | def download_embedding_models():
function setup_directories (line 90) | def setup_directories():
function setup_database (line 116) | def setup_database():
function check_cuda (line 144) | def check_cuda():
function main (line 154) | def main():
FILE: static/app.js
function showLoading (line 128) | function showLoading(show, message) {
function getSelectedMic (line 154) | function getSelectedMic() {
function getSelectedOutput (line 158) | function getSelectedOutput() {
function populateAudioDevices (line 162) | async function populateAudioDevices() {
function visualizeMic (line 193) | function visualizeMic(analyser, canvasId) {
FILE: static/chat.js
constant SESSION_ID (line 17) | const SESSION_ID = "default";
function createPermanentVoiceCircle (line 29) | function createPermanentVoiceCircle() {
function showVoiceCircle (line 56) | function showVoiceCircle() {
function hideVoiceCircle (line 61) | function hideVoiceCircle() {
function showNotification (line 66) | function showNotification(msg, type='info'){
function addMessageToConversation (line 77) | function addMessageToConversation(sender,text){
function connectWebSocket (line 102) | function connectWebSocket() {
function sendTextMessage (line 185) | function sendTextMessage(txt) {
function resetAudioState (line 251) | function resetAudioState() {
function clearAudioPlayback (line 271) | function clearAudioPlayback() {
function requestInterrupt (line 344) | function requestInterrupt() {
function handleWebSocketMessage (line 403) | function handleWebSocketMessage(d) {
function queueAudioForPlayback (line 557) | function queueAudioForPlayback(arr, sr, genId = 0) {
function queueAudioForPlayback (line 578) | function queueAudioForPlayback(arr, sr, genId = 0) {
function processAudioPlaybackQueue (line 615) | function processAudioPlaybackQueue() {
function playAudioChunk (line 684) | async function playAudioChunk(audioArr, sampleRate) {
function startRecording (line 763) | async function startRecording() {
function stopRecording (line 803) | function stopRecording() {
function setupChatUI (line 826) | async function setupChatUI() {
function initAudioLevelsChart (line 988) | function initAudioLevelsChart() {
FILE: static/crud.js
function loadConversations (line 22) | async function loadConversations() {
function renderConversations (line 28) | function renderConversations(list) {
FILE: test.py
function load_audio (line 31) | def load_audio(audio_path):
FILE: vad.py
class VoiceActivityDetector (line 5) | class VoiceActivityDetector:
method __init__ (line 6) | def __init__(
method reset (line 28) | def reset(self) -> None:
method process_audio_chunk (line 41) | def process_audio_chunk(self, audio_chunk: np.ndarray) -> bool:
class AudioStreamProcessor (line 99) | class AudioStreamProcessor:
method __init__ (line 100) | def __init__(
method process_audio (line 133) | def process_audio(self, audio_chunk: np.ndarray):
method reset (line 186) | def reset(self):
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (280K chars).
[
{
"path": ".github/FUNDING.yml",
"chars": 896,
"preview": "# These are supported funding model platforms\n\ngithub: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [u"
},
{
"path": ".gitignore",
"chars": 398,
"preview": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\np"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 9667,
"preview": "# CSM - Optimized Streaming/Finetuning Edition\n\n---\n\nCSM (Conversational Speech Model) is a speech generation model from"
},
{
"path": "config.py",
"chars": 3713,
"preview": "import os\nimport json\nimport logging\nfrom typing import Dict, Any, Optional\nfrom pydantic import BaseModel\n\nlogger = log"
},
{
"path": "finetuned_model/config.json",
"chars": 155,
"preview": "{\n \"audio_num_codebooks\": 32,\n \"audio_vocab_size\": 2051,\n \"backbone_flavor\": \"llama-1B\",\n \"decoder_flavor\": \"llama-1"
},
{
"path": "generator.py",
"chars": 44034,
"preview": "from dataclasses import dataclass\nimport math\nimport os\nfrom typing import List, Tuple, Generator as PyGenerator, Option"
},
{
"path": "llm_interface.py",
"chars": 6717,
"preview": "import re\nfrom typing import List, Dict, Any, Optional\nimport torch\nfrom vllm import LLM, SamplingParams\n\nclass LLMInter"
},
{
"path": "loadandmergecheckpoint.py",
"chars": 2028,
"preview": "import os\nimport re\nimport torch\nfrom models import Model\nfrom safetensors.torch import save_file, load_file \n\nfrom lora"
},
{
"path": "lora.py",
"chars": 44719,
"preview": "import json\nimport os\nimport glob\nimport torch\nimport torchaudio\nimport logging\nimport numpy as np\nfrom dataclasses impo"
},
{
"path": "main.py",
"chars": 47270,
"preview": "import asyncio\nimport os\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\" \nos.environ[\"CUDA_LAUN"
},
{
"path": "models.py",
"chars": 7230,
"preview": "import logging\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torchtune\nfrom huggingface_h"
},
{
"path": "rag_system.py",
"chars": 12234,
"preview": "import sqlite3\nimport numpy as np\nimport json\nfrom pathlib import Path\nimport time\nfrom typing import List, Dict, Any, T"
},
{
"path": "requirements.txt",
"chars": 621,
"preview": "--extra-index-url=https://download.pytorch.org/whl/cu128\nvllm==0.8.0\ntorch==2.6.0\ntorchaudio==2.6.0\ntorchvision==0.21.0\n"
},
{
"path": "run_csm.py",
"chars": 4371,
"preview": "import os\nimport torch\nimport torchaudio\nfrom huggingface_hub import hf_hub_download\nfrom generator import load_csm_1b, "
},
{
"path": "setup.py",
"chars": 6066,
"preview": "import os\nimport sys\nimport subprocess\nimport logging\nimport urllib.request\nimport torch\nimport time\nimport shutil\nfrom "
},
{
"path": "static/app.js",
"chars": 7899,
"preview": "let ws;\nlet micAnalyser, micContext, micSource, micStream;\nlet outputAnalyser, outputAudioCtx;\nlet lastConfig = null;\nle"
},
{
"path": "static/chat.js",
"chars": 31614,
"preview": "let ws;\nlet sessionStartTime = null;\nlet messageCount = 0;\nlet audioLevelsChart = null;\nlet isRecording = false;\nlet isA"
},
{
"path": "static/crud.js",
"chars": 2666,
"preview": "let allConversations = [];\n\ndocument.addEventListener('DOMContentLoaded', async () => {\n await loadConversations();\n\n "
},
{
"path": "templates/chat.html",
"chars": 9043,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width, in"
},
{
"path": "templates/crud.html",
"chars": 1014,
"preview": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n <meta charset=\"UTF-8\" />\n <title>Conversation Manager</title>\n "
},
{
"path": "templates/index.html",
"chars": 53,
"preview": "<meta http-equiv=\"refresh\" content=\"0; url=/setup\" />"
},
{
"path": "templates/setup.html",
"chars": 4444,
"preview": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n<meta charset=\"UTF-8\" />\n<meta name=\"viewport\" content=\"width=devic"
},
{
"path": "test.py",
"chars": 2735,
"preview": "import time\nfrom generator import Segment, load_csm_1b, generate_streaming_audio\nimport torchaudio\n\nprint(f\"Starting scr"
},
{
"path": "vad.py",
"chars": 7386,
"preview": "import numpy as np\nimport torch\nfrom typing import Callable, Dict, List\nfrom collections import deque\nclass VoiceActivit"
}
]
About this extraction
This page contains the full source code of the davidbrowne17/csm-streaming GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (262.0 KB), approximately 60.2k tokens, and a symbol index with 171 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.