Repository: davidbrowne17/csm-streaming Branch: main Commit: 121552fcd68e Files: 25 Total size: 262.0 KB Directory structure: gitextract_5x1lml4p/ ├── .github/ │ └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── finetuned_model/ │ └── config.json ├── generator.py ├── llm_interface.py ├── loadandmergecheckpoint.py ├── lora.py ├── main.py ├── models.py ├── rag_system.py ├── requirements.txt ├── run_csm.py ├── setup.py ├── static/ │ ├── app.js │ ├── chat.js │ └── crud.js ├── templates/ │ ├── chat.html │ ├── crud.html │ ├── index.html │ └── setup.html ├── test.py └── vad.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: davidbrowne17 tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry polar: # Replace with a single Polar username buy_me_a_coffee: # Replace with a single Buy Me a Coffee username thanks_dev: # Replace with a single thanks.dev username custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] ================================================ FILE: .gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # Virtual Environment .env .venv env/ venv/ ENV/ # IDE .idea/ .vscode/ *.swp *.swo # Project specific .python-version *.wav output_*/ basic_audio.wav full_conversation.wav context_audio.wav # Model files *.pt *.ckpt ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # CSM - Optimized Streaming/Finetuning Edition --- CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. Our fork adds **streaming audio generation**, **real-time playback**, and **performance optimizations** to the original implementation. ## Requirements * A CUDA-compatible GPU * The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions * Similarly, Python 3.10 is recommended, but newer versions may be fine * For some audio operations, `ffmpeg` may be required * For real-time audio playback: `pip install sounddevice` * Access to the following Hugging Face models: * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) * [CSM-1B](https://huggingface.co/sesame/csm-1b) ### Setup ```bash sudo apt-get update && sudo apt-get install -y libportaudio2 libportaudio-dev git clone git@github.com:davidbrowne17/csm-streaming.git cd csm-streaming python3.10 -m venv .venv source .venv/bin/activate pip install -r requirements.txt # Optional speedup pip install flash-attn # You will need access to CSM-1B and Llama-3.2-1B huggingface-cli login ``` ### Windows Setup The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`. The realtime demo uses VLLM for inference speed. This is currently not supported for windows but you can try with https://github.com/SystemPanic/vllm-windows until support is added. ## Quickstart Generate a sentence with streaming (chunks are processed and output as they're generated): ```python import time from huggingface_hub import hf_hub_download from generator import Generator, Segment, load_csm_1b, generate_streaming_audio import torchaudio # Load the model generator = load_csm_1b("cuda") # Generate audio with streaming and real-time playback generate_streaming_audio( generator=generator, text="Hello, this is streaming audio generation in action!", speaker=0, context=[], # No context needed for basic generation output_file="streaming_audio.wav", play_audio=True # Enable real-time playback ) ``` ## Finetuning To finetune CSM all you need are some wav audio files with the speaker voice you want to train, just the raw wavs. Place them in a folder called audio_data and run lora.py. You can configure the exact training params such as batch size, number of epochs and learning rate by modifying the values at the top of lora.py. You will need a CUDA gpu with at least 12gb of vram depending on your dataset size and training params. You can monitor the training metrics via the dynamic png created in /finetuned_model/ folder. This contains various graphs to help you track the training progress. If you want to try a checkpoint you can use the loadandmergecheckpoint.py (make sure to set the same R and Alpha values as you used in the training) ## Realtime chat demo To use the realtime demo run the setup.py to download the required models, and then run main.py. This will open up a setup page at http://localhost:8000 in which you can set the paths for your chosen LLM and setup the CSM paths and reference audio as well as select your headset and mic. When loaded you will be able to chat in realtime with the AI just like the CSM demo. Our demo includes a dynamic RAG system so the AI can remember previous conversations. The demo by default uses whisper-large-v3-turbo for STT and includes Automatic Voice Detection using Silero VAD. ## Usage Our optimized version offers several ways to use CSM with streaming capabilities: ### 1. Basic Streaming Generation Generate audio with streaming and save to a file: ```python from generator import load_csm_1b, generate_streaming_audio generator = load_csm_1b("cuda") # Generate with streaming (writes to file as it generates) generate_streaming_audio( generator=generator, text="This audio will be generated in chunks for faster response times.", speaker=0, context=[], output_file="streaming_output.wav" ) ``` ### 2. Real-time Audio Playback Generate and play audio in real-time as it's being generated: ```python from generator import load_csm_1b, generate_streaming_audio generator = load_csm_1b("cuda") # Generate with streaming and play in real-time generate_streaming_audio( generator=generator, text="You'll hear me speaking as I'm being generated!", speaker=0, context=[], output_file="streaming_output.wav", play_audio=True # Enable real-time playback ) ``` ### 3. Low-level Streaming API For more control, use the low-level streaming API: ```python from generator import load_csm_1b, Segment import torchaudio generator = load_csm_1b("cuda") # Process audio chunks as they're generated for audio_chunk in generator.generate_stream( text="This is generated chunk by chunk.", speaker=0, context=[] ): # Do something with each chunk as it's generated print(f"Received chunk of size: {audio_chunk.shape}") # You could process or play each chunk here # For example, write to a file incrementally # Or send over a network connection ``` ### 4. Generate with Context For best results, provide reference audio context: ```python from generator import load_csm_1b, Segment, generate_streaming_audio import torchaudio generator = load_csm_1b("cuda") # Load reference audio def load_audio(audio_path): audio_tensor, sample_rate = torchaudio.load(audio_path) audio_tensor = torchaudio.functional.resample( audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate ) return audio_tensor # Create context segments segments = [ Segment( text="I knew I could trust you.", speaker=0, audio=load_audio("reference.wav") ) ] # Generate with streaming using the context generate_streaming_audio( generator=generator, text="Me too, this is some cool stuff huh?", speaker=0, context=segments, output_file="contextual_streaming.wav", play_audio=True ) ``` ### 5. Regular Generation with Internal Streaming Use the original API with streaming enabled internally: ```python from generator import load_csm_1b, Segment import torchaudio generator = load_csm_1b("cuda") # Regular generation but with internal streaming optimization audio = generator.generate( text="This uses internal streaming for faster processing.", speaker=0, context=[], max_audio_length_ms=10_000, stream=True # Enable internal streaming optimization ) torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) ``` ## Performance Optimizations Our optimized version includes several performance enhancements: - **Streaming Generation**: Processes and outputs audio in chunks instead of waiting for the entire generation achieving a Real-time factor (RTF): 0.28x (target: <1.0) on a 4090 (10 seconds of audio takes 2.8 seconds to generate) - **Frame Batching**: Processes multiple frames at once for better GPU utilization - **Half-precision Inference**: Uses bfloat16/float16 for faster processing - **CUDA Optimizations**: Enables cuDNN benchmarking and Flash Attention where available - **Memory Management**: Clears GPU cache before generation to reduce memory pressure ## FAQ **How much faster is the streaming version?** The perceived response time is significantly faster since you get the first audio chunks in milliseconds instead of waiting for the entire generation to complete. The actual total generation time is also improved by 40-60% depending on your hardware. **Does this model come with any voices?** The model is a base generation model capable of producing a variety of voices but hasn't been fine-tuned on any specific voice. Provide reference audio for best results. **Can I converse with the model?** CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. Using a seperate LLM you can converse with the realtime demo via the web ui. **Does it support other languages?** The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well. ## Misuse and abuse ⚠️ This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following: - **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent. - **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls. - **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes. By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology. --- ## Original Authors Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team. ## Streaming, Realtime Demo and Finetuning Implementation David Browne ## Support me Support this project on Ko-fi: https://ko-fi.com/davidbrowne17 ## Transformers streaming If you want to use streaming with the Transformers implementation you can find it here: https://github.com/davidbrowne17/csm-streaming-tf ================================================ FILE: config.py ================================================ import os import json import logging from typing import Dict, Any, Optional from pydantic import BaseModel logger = logging.getLogger(__name__) class ConfigManager: """ Manages configuration persistence for the AI Companion app. Saves and loads configuration to avoid re-entering model paths. """ def __init__(self, config_path: str = "config/app_config.json"): """ Initialize the configuration manager. Args: config_path: Path to store the configuration file """ self.config_path = config_path self.config_dir = os.path.dirname(config_path) # Create config directory if it doesn't exist if not os.path.exists(self.config_dir): os.makedirs(self.config_dir, exist_ok=True) logger.info(f"Created configuration directory: {self.config_dir}") def save_config(self, config_data: Dict[str, Any]) -> bool: """ Save configuration data to the config file. Args: config_data: Configuration data to save Returns: bool: True if successful, False otherwise """ try: # Ensure directory exists os.makedirs(self.config_dir, exist_ok=True) print(config_data) # Verify all reference paths are included ref_paths = [ "reference_audio_path", "reference_audio_path2", "reference_audio_path3" ] # Log which references are being saved for path_key in ref_paths: if path_key in config_data and config_data[path_key]: logger.info(f"Saving reference path: {path_key}={config_data[path_key]}") else: logger.info(f"No {path_key} provided in configuration") # Save configuration with open(self.config_path, 'w') as f: json.dump(config_data, f, indent=2) logger.info(f"Configuration saved to {self.config_path}") return True except Exception as e: logger.error(f"Failed to save configuration: {e}") return False def load_config(self) -> Optional[Dict[str, Any]]: """ Load configuration data from the config file. Returns: Dict or None: Configuration data if successful, None otherwise """ if not os.path.exists(self.config_path): logger.info(f"Configuration file does not exist at {self.config_path}") return None try: with open(self.config_path, 'r') as f: config_data = json.load(f) # Log which references are being loaded ref_paths = [ "reference_audio_path", "reference_audio_path2", "reference_audio_path3" ] for path_key in ref_paths: if path_key in config_data and config_data[path_key]: logger.info(f"Loaded reference path: {path_key}={config_data[path_key]}") logger.info(f"Configuration loaded from {self.config_path}") return config_data except Exception as e: logger.error(f"Failed to load configuration: {e}") return None # Helper function to convert Pydantic model to dict def model_to_dict(model: BaseModel) -> Dict[str, Any]: """Convert a Pydantic model to a dictionary suitable for JSON serialization""" return json.loads(model.json()) ================================================ FILE: finetuned_model/config.json ================================================ { "audio_num_codebooks": 32, "audio_vocab_size": 2051, "backbone_flavor": "llama-1B", "decoder_flavor": "llama-100M", "text_vocab_size": 128256 } ================================================ FILE: generator.py ================================================ from dataclasses import dataclass import math import os from typing import List, Tuple, Generator as PyGenerator, Optional, Callable import time import queue import threading import platform from typing_extensions import OrderedDict import wave import numpy as np import torch import torchaudio from huggingface_hub import hf_hub_download from models import Model, ModelArgs from moshi.models import loaders from tokenizers.processors import TemplateProcessing from transformers import AutoTokenizer import logging logger = logging.getLogger(__name__) @dataclass class Segment: speaker: int text: str sample_rate = 24_000 audio: torch.Tensor def load_llama3_tokenizer(): """ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 """ tokenizer_name = "unsloth/Llama-3.2-1B" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) bos = tokenizer.bos_token eos = tokenizer.eos_token tokenizer._tokenizer.post_processor = TemplateProcessing( single=f"{bos}:0 $A:0 {eos}:0", pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], ) return tokenizer class Generator: def __init__(self, model: Model): self._model = model self._model.setup_caches(1) self._text_tokenizer = load_llama3_tokenizer() device = next(model.parameters()).device mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) mimi = loaders.get_mimi(mimi_weight, device=device) num_codebooks = model.config.audio_num_codebooks mimi.set_num_codebooks(num_codebooks) self._num_codebooks = num_codebooks self._audio_tokenizer = mimi self.sample_rate = mimi.sample_rate self.device = device self._stream_buffer_size = 20 self.max_seq_len = 2048 self._cache = OrderedDict() self._text_token_cache = {} torch.set_num_threads(16) torch.cuda.set_per_process_memory_fraction(0.95) def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Tokenize text segment with caching optimization for reduced latency. """ # Check cache first cache_key = f"{speaker}:{text}" if not hasattr(self, '_text_token_cache'): self._text_token_cache = {} if cache_key in self._text_token_cache: return self._text_token_cache[cache_key] text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") text_frame = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.long, device=self.device) text_frame_mask = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.bool, device=self.device) text_frame[:, -1] = torch.tensor(text_tokens, device=self.device) text_frame_mask[:, -1] = True frame_tokens = [text_frame] frame_masks = [text_frame_mask] result = (torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)) self._text_token_cache[cache_key] = result return result def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: frame_tokens = [] frame_masks = [] # (K, T) audio = audio.to(self.device) audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] # Limit to the number of codebooks set in MIMI audio_tokens = audio_tokens[:self._num_codebooks, :] # add EOS frame eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) audio_frame = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).long().to(self.device) audio_frame_mask = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).bool().to(self.device) audio_frame[:, :self._num_codebooks] = audio_tokens.transpose(0, 1) audio_frame_mask[:, :self._num_codebooks] = True frame_tokens.append(audio_frame) frame_masks.append(audio_frame_mask) return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) audio_tokens, audio_masks = self._tokenize_audio(segment.audio) total_len = text_tokens.size(0) + audio_tokens.size(0) if total_len > self.max_seq_len: overflow = total_len - self.max_seq_len if text_tokens.size(0) > overflow: text_tokens = text_tokens[overflow:] text_masks = text_masks[overflow:] else: audio_overflow = overflow - text_tokens.size(0) text_tokens = text_tokens[0:0] text_masks = text_masks[0:0] audio_tokens = audio_tokens[audio_overflow:] audio_masks = audio_masks[audio_overflow:] return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) @torch.inference_mode() def _decode_frames(self, frames): if not frames: return torch.tensor([]) # Only use first N codebooks for faster decoding frames_reduced = [frame[:, :self._num_codebooks//2] for frame in frames] audio = self._audio_tokenizer.decode(torch.stack(frames_reduced).permute(1, 2, 0)).squeeze(0).squeeze(0) return audio @torch.inference_mode() def generate_stream( self, text: str, speaker: int, context: List[Segment], max_audio_length_ms: float = 90_000, temperature: float = 0.7, topk: int = 30, on_chunk_generated: Optional[Callable[[torch.Tensor], None]] = None, ): """ Generate audio in a streaming fashion, optimized for lower latency to first chunk. """ if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.cuda.empty_cache() torch.cuda.synchronize() self._model.reset_caches() max_generation_len = int(max_audio_length_ms / 80) tokens, tokens_mask = [], [] initial_batch_size = 20 normal_batch_size = 20 initial_buffer_size = 20 normal_buffer_size = 20 batch_size = initial_batch_size buffer_size = initial_buffer_size first_chunk_delivered = False context_start = time.time() if context: for segment in context: segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) tokens.append(segment_tokens) tokens_mask.append(segment_tokens_mask) gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) tokens.append(gen_segment_tokens) tokens_mask.append(gen_segment_tokens_mask) prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) max_seq_len = 2048 if prompt_tokens.size(0) > max_seq_len: prompt_tokens = prompt_tokens[-max_seq_len:] prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:] curr_tokens = prompt_tokens.unsqueeze(0) curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) expected_frame_count = buffer_size frame_buffer = [] zeros_1_1 = torch.zeros(1, 1).long().to(self.device) zeros_mask_1_1 = torch.zeros(1, 1).bool().to(self.device) def update_tokens(sample): nonlocal curr_tokens, curr_tokens_mask, curr_pos ones = torch.ones_like(sample).bool() curr_tokens = torch.cat([sample, zeros_1_1], dim=1).unsqueeze(1) curr_tokens_mask = torch.cat([ones, zeros_mask_1_1], dim=1).unsqueeze(1) curr_pos = curr_pos[:, -1:] + 1 with self._audio_tokenizer.streaming(1): i = 0 generation_start = time.time() while i < max_generation_len: batch_end = min(i + batch_size, max_generation_len) batch_size_actual = batch_end - i batch_samples = [] for _ in range(batch_size_actual): with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) if torch.cuda.is_available() and hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available"): try: torch.cuda.synchronize() # Force sync before checking if sample.numel() == 0 or torch.isnan(sample).any(): print("Warning: Generated empty or NaN sample, stopping generation") break except: print("Error checking tensor, stopping generation") break if torch.all(sample == 0): break batch_samples.append(sample) update_tokens(sample) if not batch_samples: break frame_buffer.extend(batch_samples) i += len(batch_samples) if len(frame_buffer) >= buffer_size: frames_to_process = frame_buffer[:expected_frame_count] # If we don't have enough frames, pad with zeros to match expected shape if len(frames_to_process) < expected_frame_count: # Create padding frames (zeros) padding_frames = [ torch.zeros_like(frames_to_process[0]) for _ in range(expected_frame_count - len(frames_to_process)) ] # Combine actual frames with padding frames_to_process = frames_to_process + padding_frames frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0) audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0) # Keep remaining frames for next iteration frame_buffer = frame_buffer[expected_frame_count:] # Process and yield the chunk cpu_chunk = audio_chunk.cpu() if on_chunk_generated: on_chunk_generated(cpu_chunk) # After first chunk is delivered, switch to normal batch and buffer sizes if not first_chunk_delivered: batch_size = normal_batch_size buffer_size = normal_buffer_size expected_frame_count = buffer_size first_chunk_delivered = True yield cpu_chunk # Occasionally print progress and sync GPU if i >= 100 and (i % 100 == 0): if torch.cuda.is_available(): torch.cuda.synchronize() print(f"Generated {i} frames ({i * 0.08:.2f}s of audio)") # Process any remaining frames if frame_buffer: # Pad frame buffer if necessary if len(frame_buffer) < expected_frame_count: padding_frames = [ torch.zeros_like(frame_buffer[0]) for _ in range(expected_frame_count - len(frame_buffer)) ] frames_to_process = frame_buffer + padding_frames else: # Otherwise take as many frames as possible that are a multiple of expected_frame_count frames_multiple = (len(frame_buffer) // expected_frame_count) * expected_frame_count frames_to_process = frame_buffer[:frames_multiple] frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0) audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0) # Determine actual audio length (before padding) actual_frames_percentage = min(len(frame_buffer), expected_frame_count) / expected_frame_count actual_samples = int(audio_chunk.shape[0] * actual_frames_percentage) # Return only the non-padded portion of audio if we added padding if len(frame_buffer) < expected_frame_count: audio_chunk = audio_chunk[:actual_samples] cpu_chunk = audio_chunk.cpu() if on_chunk_generated: on_chunk_generated(cpu_chunk) yield cpu_chunk # Print final performance metrics if torch.cuda.is_available(): torch.cuda.synchronize() total_time = time.time() - generation_start frames_generated = i audio_seconds = frames_generated * 0.08 rtf = total_time / audio_seconds if audio_seconds > 0 else float('inf') print(f"Total time: {total_time:.2f}s") print(f"Generated {frames_generated} frames ({audio_seconds:.2f}s of audio)") print(f"Real-time factor: {rtf:.3f}x (target: <1.0)") @torch.inference_mode() def generate( self, text: str, speaker: int, context: List[Segment], max_audio_length_ms: float = 90_000, temperature: float = 0.8, topk: int = 40, stream: bool = False, output_file: Optional[str] = None, ): """ Generate audio with optional streaming and file output. Args: text: Text to generate audio for speaker: Speaker ID context: List of context segments max_audio_length_ms: Maximum audio length in milliseconds temperature: Sampling temperature topk: Top-k sampling parameter stream: Whether to use streaming generation output_file: If provided and stream=True, output will be saved to this file Returns: torch.Tensor: Generated audio tensor """ if stream: if output_file: # Setup streaming to file write_chunk, close_wav = stream_audio_to_wav(output_file, self.sample_rate) # Collect chunks while streaming to file audio_chunks = [] t1 = time.time() for i, chunk in enumerate(self.generate_stream( text, speaker, context, max_audio_length_ms, temperature, topk )): # Write to file write_chunk(chunk) # Store for return value audio_chunks.append(chunk) # Occasionally print progress if i % 5 == 0: print(f"Part {i+1} available after {time.time() - t1:.4f}s") t1 = time.time() # Close file close_wav() print(f"Streaming complete, WAV file saved to {output_file}") else: # Just collect chunks without file output audio_chunks = [] for chunk in self.generate_stream(text, speaker, context, max_audio_length_ms, temperature, topk): audio_chunks.append(chunk) if not audio_chunks: return torch.tensor([]) return torch.cat(audio_chunks) # Non-streaming generation remains unchanged if torch.cuda.is_available(): torch.cuda.empty_cache() self._model.reset_caches() max_generation_len = int(max_audio_length_ms / 80) tokens, tokens_mask = [], [] for segment in context: segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) tokens.append(segment_tokens) tokens_mask.append(segment_tokens_mask) gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) tokens.append(gen_segment_tokens) tokens_mask.append(gen_segment_tokens_mask) prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) max_seq_len = 2048 if prompt_tokens.size(0) > max_seq_len: prompt_tokens = prompt_tokens[-max_seq_len:] prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:] curr_tokens = prompt_tokens.unsqueeze(0) curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) samples = [] with self._audio_tokenizer.streaming(1): for _ in range(max_generation_len): sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) if torch.all(sample == 0): break samples.append(sample) curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) curr_tokens_mask = torch.cat( [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 ).unsqueeze(1) curr_pos = curr_pos[:, -1:] + 1 if not samples: return torch.tensor([]) return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) class AudioStreamWriter: """ Helper class for writing streaming audio to a file. """ def __init__(self, filename, sample_rate): self.filename = filename self.sample_rate = sample_rate self.audio_chunks = [] self.lock = threading.Lock() self.queue = queue.Queue() self.running = True # Start background writer thread self.writer_thread = threading.Thread(target=self._writer_worker, daemon=True) self.writer_thread.start() def _writer_worker(self): """Background thread that handles audio chunk processing""" buffer_chunks = [] last_flush_time = time.time() while self.running or not self.queue.empty(): try: # Get chunk with timeout to allow for regular checks chunk = self.queue.get(timeout=0.2) buffer_chunks.append(chunk) # Periodically flush the buffer to the main list current_time = time.time() if len(buffer_chunks) >= 10 or (current_time - last_flush_time > 2.0 and buffer_chunks): with self.lock: self.audio_chunks.extend(buffer_chunks) buffer_chunks = [] last_flush_time = current_time except queue.Empty: # If queue is empty but we have pending chunks, add them if buffer_chunks: with self.lock: self.audio_chunks.extend(buffer_chunks) buffer_chunks = [] last_flush_time = time.time() # Final flush of any remaining chunks if buffer_chunks: with self.lock: self.audio_chunks.extend(buffer_chunks) def add_chunk(self, chunk): """Add an audio chunk to the buffer queue without blocking""" try: self.queue.put(chunk, timeout=0.1) except queue.Full: # If queue is full, add directly to avoid losing data with self.lock: self.audio_chunks.append(chunk) def write_file(self): """Write all collected audio chunks to file and clean up""" # Signal the background thread to stop self.running = False # Wait for the thread to finish with a timeout self.writer_thread.join(timeout=3.0) with self.lock: if not self.audio_chunks: return # Concatenate all chunks audio = torch.cat(self.audio_chunks) # Save to file torchaudio.save(self.filename, audio.unsqueeze(0).cpu(), self.sample_rate) from safetensors.torch import load_file import os import torch from models import Model, ModelArgs from generator import Generator def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_codebooks: int = 32): """ Load the CSM-1B model from a local checkpoint with extreme optimizations and warmup. """ import torch import platform from functools import lru_cache from generator import Generator, Model, ModelArgs # Enable all CUDA optimizations torch.backends.cuda.matmul.allow_tf32 = True if hasattr(torch.backends.cuda, 'enable_flash_sdp'): torch.backends.cuda.enable_flash_sdp(True) torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True print(f"Loading CSM-1B model from local checkpoint '{model_path}' with extreme optimizations...") if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() config = ModelArgs( backbone_flavor="llama-1B", decoder_flavor="llama-100M", text_vocab_size=128256, audio_vocab_size=2051, audio_num_codebooks=audio_num_codebooks, ) model = Model.from_pretrained(model_path) model.eval() dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor') model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor') model.to(device=device, dtype=dtype) print("Model compilation complete. Creating generator...") generator = Generator(model) generator._stream_buffer_size = 20 # Setup tokenization caching generator._tokenization_cache = {} original_tokenize_text = generator._tokenize_text_segment @lru_cache(maxsize=2048) def cached_tokenize_text_segment(text_str, speaker_int): return original_tokenize_text(text_str, speaker_int) generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker) # Perform warmup warmup_generator(generator) return generator def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is a comprehensive warmup text that will exercise the model's generation capabilities.", speaker_id: int = 0): """ Perform an extremely aggressive warmup to drastically reduce first-generation latency. """ print("Starting maximum-intensity warmup sequence...") # Directly access and optimize the model's internal state if hasattr(gen._model, 'backbone') and hasattr(gen._model.backbone, 'positional_embedding'): # Force calculation of position embeddings to ensure they're cached with torch.inference_mode(): positions = torch.arange(0, 2048).to(gen.device) _ = gen._model.backbone.positional_embedding(positions) # Pre-allocate CUDA memory to prevent fragmentation during generation if torch.cuda.is_available(): print("Optimizing GPU memory allocation...") # Try to reserve a large chunk of memory try: import math reserved_memory = [] # Reserve multiple blocks of different sizes for size_mb in [128, 256, 512, 256, 128, 64]: size = int(size_mb * 1024 * 1024 / 4) # Convert MB to float32 elements tensor_size = int(math.sqrt(size)) tensor = torch.ones((tensor_size, tensor_size), device=gen.device, dtype=torch.float32) tensor = tensor * 1.0 # Force allocation reserved_memory.append(tensor) torch.cuda.synchronize() # Now free the memory for tensor in reserved_memory: del tensor reserved_memory = [] torch.cuda.empty_cache() torch.cuda.synchronize() except Exception as e: print(f"Memory pre-allocation: {e}") # Create multiple dummy audio segments with varying characteristics print("Creating diverse audio contexts...") audio_segments = [] # Create 3 different audio patterns for i in range(3): length = 24000 * (i + 1) # 1s, 2s, 3s audio = torch.zeros(length).to(gen.device) # Add different patterns to each segment if i == 0: # Sine wave pattern import math t = torch.linspace(0, 8 * math.pi, length).to(gen.device) audio = torch.sin(t) * 0.1 elif i == 1: # Random noise pattern audio = torch.randn(length).to(gen.device) * 0.05 else: # Pulse pattern audio[::800] = 0.2 audio[::801] = -0.2 segment = Segment( speaker=speaker_id, text=f"Warmup segment {i+1} with {length/24000:.1f}s of audio.", audio=audio ) audio_segments.append(segment) # Force compilation of critical model components print("Forcing compilation of critical components...") # Directly exercise the audio tokenizer with real data with torch.inference_mode(): for segment in audio_segments: # Force tokenization of both text and audio gen._tokenize_segment(segment) # Exercise the model's generation capabilities directly with torch.inference_mode(): # Generate some sample frames to ensure model is compiled dummy_tokens = torch.ones(1, 10, gen._num_codebooks+1).long().to(gen.device) dummy_mask = torch.ones(1, 10, gen._num_codebooks+1).bool().to(gen.device) dummy_pos = torch.arange(0, 10).unsqueeze(0).to(gen.device) # Generate multiple frames with different parameters for temp in [0.6, 0.7, 0.8]: for topk in [20, 30, 40]: _ = gen._model.generate_frame(dummy_tokens, dummy_mask, dummy_pos, temp, topk) gen._text_token_cache.clear() print("Running final generation with exact same setup as a real request...") final_text = "This is the final warmup that exactly matches a real generation request." # First tokenize the text - to fill the cache gen._tokenize_text_segment(final_text, speaker_id) try: # Now run a complete generation with a single context segment generate_streaming_audio( generator=gen, text=final_text, speaker=speaker_id, context=[audio_segments[0]], # Just one context segment output_file="warmup_final.wav", max_audio_length_ms=6000, temperature=0.7, topk=30, play_audio=False ) except Exception as e: print(f"Final warmup run exception (ignorable): {e}") # Force final synchronization and memory optimization if torch.cuda.is_available(): print("Final GPU optimization...") torch.cuda.synchronize() torch.cuda.empty_cache() try: # Allocate a large tensor to force compaction large_tensor = torch.empty(int(1e9//4), dtype=torch.float, device=gen.device) # Immediately delete it del large_tensor except RuntimeError: # Expected if there's not enough memory pass # Final cleanup torch.cuda.empty_cache() torch.cuda.synchronize() print("Maximum-intensity warmup complete. First generation should now be MUCH faster.") def load_csm_1b(device: str = "cuda") -> Generator: """ Load the CSM-1B model with extreme optimizations for real-time performance. """ # Enable all CUDA optimizations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.enable_flash_sdp(True) torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True print("Loading CSM-1B model with extreme optimizations for real-time performance...") if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() model = Model.from_pretrained("sesame/csm-1b") dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor') model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor') model.to(device=device, dtype=dtype) print("Model compilation complete. Creating generator...") generator = Generator(model) generator._stream_buffer_size = 20 generator._tokenization_cache = {} from functools import lru_cache # Patch the tokenize method with caching original_tokenize_text = generator._tokenize_text_segment @lru_cache(maxsize=2048) def cached_tokenize_text_segment(text_str, speaker_int): return original_tokenize_text(text_str, speaker_int) generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker) warmup_generator(generator) return generator def stream_audio_to_wav(filename, sample_rate): """ Initialize a WAV writer for streaming audio chunks. Args: filename: Output WAV file path sample_rate: Audio sample rate in Hz Returns: tuple: (write_chunk, close) functions for writing audio data and closing the file """ # Create a WAV file with the proper header wav_file = wave.open(filename, 'wb') wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(sample_rate) def write_chunk(audio_chunk): # Convert tensor to numpy and then to int16 PCM format if isinstance(audio_chunk, torch.Tensor): # Ensure it's on CPU and detached before converting to numpy audio_np = audio_chunk.detach().cpu().numpy() else: audio_np = audio_chunk # Normalize if needed (assuming audio is in [-1, 1] range) if audio_np.max() <= 1.0 and audio_np.min() >= -1.0: audio_int = (audio_np * 32767).astype(np.int16) else: audio_int = audio_np.astype(np.int16) # Write to WAV file wav_file.writeframes(audio_int.tobytes()) def close(): wav_file.close() return write_chunk, close def generate_streaming_audio( generator: Generator, text: str, speaker: int, context: List[Segment], output_file: str, max_audio_length_ms: float = 90_000, temperature: float = 1.0, topk: int = 50, play_audio: bool = False, ): """ Generate audio with streaming output and comprehensive timing metrics. Optimized for reduced first-chunk latency. """ # Initialize the streaming WAV writer write_chunk, close_wav = stream_audio_to_wav(output_file, generator.sample_rate) # Set up audio playback if requested audio_queue = queue.Queue(maxsize=100) if play_audio else None stop_event = threading.Event() if play_audio: try: import sounddevice as sd # Get available sample rates for default output device to check compatibility device_info = sd.query_devices(kind='output') supported_rate = device_info.get('default_samplerate', 44100) need_resampling = abs(supported_rate - generator.sample_rate) > 100 if need_resampling: try: # Use resampling if sample rate doesn't match import librosa print(f"Resampling from {generator.sample_rate}Hz to {int(supported_rate)}Hz for playback") def audio_playback_worker(): while not stop_event.is_set() or not audio_queue.empty(): try: chunk = audio_queue.get(timeout=0.5) if isinstance(chunk, torch.Tensor) and chunk.numel() == 0: audio_queue.task_done() continue audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk # Skip very short chunks (likely noise) if len(audio_np) < 100: audio_queue.task_done() continue # Resample to device's supported rate resampled = librosa.resample( audio_np, orig_sr=generator.sample_rate, target_sr=int(supported_rate) ) sd.play(resampled, supported_rate, blocking=True) # Add a small delay to ensure audio finishes playing time.sleep(0.05) audio_queue.task_done() except queue.Empty: # If queue empty but not stopping, keep trying if not stop_event.is_set(): continue else: break except Exception as e: print(f"Playback error: {e}") audio_queue.task_done() except ImportError: print("Librosa not found. Using direct playback which may cause sample rate warnings.") need_resampling = False if not need_resampling: def audio_playback_worker(): while not stop_event.is_set() or not audio_queue.empty(): try: chunk = audio_queue.get(timeout=0.5) if isinstance(chunk, torch.Tensor) and chunk.numel() == 0: audio_queue.task_done() continue audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk # Skip very short chunks (likely noise) if len(audio_np) < 100: audio_queue.task_done() continue sd.play(audio_np, generator.sample_rate, blocking=True) # Add a small delay to ensure audio finishes playing time.sleep(0.05) audio_queue.task_done() except queue.Empty: # If queue empty but not stopping, keep trying if not stop_event.is_set(): continue else: break except Exception as e: print(f"Playback error: {e}") audio_queue.task_done() # Start playback thread playback_thread = threading.Thread(target=audio_playback_worker, daemon=False) playback_thread.start() except ImportError: print("sounddevice library not found. Install with 'pip install sounddevice' for real-time playback.") play_audio = False # Timing metrics chunk_times = [] latency_to_first_chunk = None total_audio_duration = 0 chunk_count = 0 # Function to handle each generated chunk def on_chunk_generated(chunk): nonlocal chunk_count, latency_to_first_chunk, total_audio_duration current_time = time.time() if chunk_count == 0: latency_to_first_chunk = current_time - start_time print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms") # Save chunk to WAV file write_chunk(chunk) # Update metrics chunk_count += 1 chunk_duration = len(chunk) / generator.sample_rate total_audio_duration += chunk_duration chunk_times.append(current_time) # Send to audio player if enabled if play_audio and audio_queue is not None: try: audio_queue.put(chunk, timeout=1.0) except queue.Full: pass # Skip if queue is full to avoid blocking if torch.cuda.is_available(): print("Preparing GPU for low-latency generation...") torch.cuda.empty_cache() torch.cuda.synchronize() # Pre-allocate some GPU memory to avoid allocation during generation dummy_tensors = [] for i in range(5): dummy = torch.ones((100, 100), device=generator.device) dummy = dummy + 1.0 # Force computation dummy_tensors.append(dummy) # Keep reference to prevent deallocation torch.cuda.synchronize() # Set process priority to improve performance - use higher priority try: import psutil process = psutil.Process() if platform.system() == 'Windows': process.nice(psutil.HIGH_PRIORITY_CLASS) else: process.nice(-1) except (ImportError, PermissionError, psutil.AccessDenied): pass print(f"Starting audio generation for: '{text[:50]}{'...' if len(text) > 50 else ''}'") start_time = time.time() # Generate audio in chunks, catching possible errors frame_count = 0 audio_chunks = [] # Store all chunks for possible use at the end try: for audio_chunk in generator.generate_stream( text=text, speaker=speaker, context=context, max_audio_length_ms=max_audio_length_ms, temperature=temperature, topk=topk, on_chunk_generated=on_chunk_generated ): frame_count += 1 audio_chunks.append(audio_chunk) # Store the chunk # Print timing info less frequently to reduce overhead if frame_count % 10 == 0: current_time = time.time() elapsed = current_time - start_time if total_audio_duration > 0: rtf = elapsed / total_audio_duration remaining_time = (max_audio_length_ms/1000 - total_audio_duration) * rtf print(f"Chunk {chunk_count}: {total_audio_duration:.1f}s audio in {elapsed:.1f}s " f"(RTF: {rtf:.2f}x, Est. remaining: {remaining_time:.1f}s)") except Exception as e: print(f"Error during audio generation: {e}") import traceback traceback.print_exc() # Release dummy tensors to free memory if 'dummy_tensors' in locals(): del dummy_tensors # Ensure all chunks are properly processed if play_audio and audio_queue is not None: print("Waiting for playback queue to finish...") try: timeout_start = time.time() while not audio_queue.empty() and time.time() - timeout_start < 5.0: time.sleep(0.1) except: pass # Add a small delay to ensure everything is processed time.sleep(0.5) # Signal audio worker that generation is complete stop_event.set() # Close WAV file close_wav() # Wait for audio playback to complete if enabled if play_audio and 'playback_thread' in locals(): print("Waiting for audio playback to complete...") # First, ensure the queue is empty try: timeout_start = time.time() while not audio_queue.empty() and time.time() - timeout_start < 5.0: time.sleep(0.1) except: pass # Set a flag to indicate complete audio playback is needed if hasattr(sd, 'wait'): try: sd.wait() except: pass # Join the playback thread with timeout playback_thread.join(timeout=5.0) # Force sounddevice to stop if it's still playing try: sd.stop() except: pass # Calculate and print detailed performance metrics end_time = time.time() total_elapsed = end_time - start_time # Calculate inter-chunk latency if len(chunk_times) > 1: inter_chunk_latencies = [chunk_times[i] - chunk_times[i-1] for i in range(1, len(chunk_times))] avg_inter_chunk_latency = sum(inter_chunk_latencies) / len(inter_chunk_latencies) max_inter_chunk_latency = max(inter_chunk_latencies) if inter_chunk_latencies else 0 min_inter_chunk_latency = min(inter_chunk_latencies) if inter_chunk_latencies else 0 else: avg_inter_chunk_latency = max_inter_chunk_latency = min_inter_chunk_latency = 0 rtf = total_elapsed / total_audio_duration if total_audio_duration > 0 else float('inf') print("\n" + "="*50) print("AUDIO GENERATION PERFORMANCE METRICS") print("="*50) print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms") print(f"Total generation time: {total_elapsed:.2f}s") print(f"Audio duration: {total_audio_duration:.2f}s") print(f"Real-time factor (RTF): {rtf:.3f}x (target: <1.0)") print(f"Number of chunks: {chunk_count}") print(f"Average chunk size: {(total_audio_duration/chunk_count)*1000:.1f}ms") if chunk_count > 0 else None print(f"Average inter-chunk latency: {avg_inter_chunk_latency*1000:.1f}ms") print(f"Min/Max inter-chunk latency: {min_inter_chunk_latency*1000:.1f}ms / {max_inter_chunk_latency*1000:.1f}ms") print(f"Chunks per second: {chunk_count/total_elapsed:.2f}") print(f"Output file: {output_file}") print("="*50) ================================================ FILE: llm_interface.py ================================================ import re from typing import List, Dict, Any, Optional import torch from vllm import LLM, SamplingParams class LLMInterface: def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1): """Initialize the LLM interface using VLLM with a given model. Args: model_path (str): Path to the model or HuggingFace model name max_tokens (int, optional): Maximum context length. Defaults to 8192. n_threads (int, optional): Number of CPU threads. Defaults to 8. gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility. """ # VLLM configuration self.llm = LLM( model=model_path, tensor_parallel_size=1, # Adjust based on number of GPUs available gpu_memory_utilization=0.6, max_model_len=max_tokens, swap_space=0, trust_remote_code=True, dtype=torch.float16, ) # Store configuration for reference self.config = { "model_path": model_path, "max_tokens": max_tokens, } def trim_to_last_sentence(self, text: str) -> str: """ Return *text* truncated at the final full sentence boundary. A boundary is considered to be any '.', '!' or '?' followed by optional quotes/brackets, optional whitespace, and then end-of-string. If no sentence terminator exists, the original text is returned. """ # Regex explanation: # (.*?[.!?]["')\]]?) any text lazily until a terminator # \s*$ followed only by whitespace till end-of-string m = re.match(r"^(.*?[.!?][\"')\]]?)\s*$", text, re.DOTALL) if m: return m.group(1).strip() # Fall back to manual search (handles cases with additional text) for i in range(len(text) - 1, -1, -1): if text[i] in ".!?": return text[: i + 1].strip() return text.strip() def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = "") -> str: """Generate a response from the LLM using chat-style prompt formatting. Args: system_prompt (str): The system prompt/instructions user_message (str): The user's input message conversation_history (str, optional): Any prior conversation context. Defaults to "". Returns: str: The generated response """ # Format prompt following chat template structure prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|> {conversation_history} <|start_header_id|>user<|end_header_id|>\n{user_message}<|eot_id|> <|start_header_id|>assistant<|end_header_id|>\n""" # Define sampling parameters (equivalent to the previous implementation) sampling_params = SamplingParams( temperature=1.0, top_p=0.95, max_tokens=100, repetition_penalty=1.2, top_k=200, stop=["", "<|endoftext|>", "<>", "<>", "<>", "<>", "<>", "<|end_header_id|>", "<>", "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"] ) # Generate response using VLLM outputs = self.llm.generate(prompt, sampling_params) # Extract and return the generated text if outputs and len(outputs) > 0: text = outputs[0].outputs[0].text return self.trim_to_last_sentence(text) return "" def tokenize(self, text: str) -> List[int]: """Tokenize text using VLLM's tokenizer. Args: text (str): Text to tokenize Returns: List[int]: List of token IDs """ # VLLM doesn't expose tokenizer directly in the same way # We can access the tokenizer through the LLM instance tokenizer = self.llm.get_tokenizer() return tokenizer.encode(text) def get_token_count(self, text: str) -> int: """Return token count of the input text. Args: text (str): Text to count tokens for Returns: int: Number of tokens """ return len(self.tokenize(text)) def batch_generate(self, prompts: List[Dict[str, str]], max_tokens: int = 512, temperature: float = 0.7) -> List[str]: """Generate responses for multiple prompts in a batch. Args: prompts (List[Dict[str, str]]): List of prompt dictionaries, each with 'system', 'user' and optional 'history' keys max_tokens (int, optional): Maximum tokens to generate per response temperature (float, optional): Temperature for sampling Returns: List[str]: Generated responses """ formatted_prompts = [] # Format each prompt according to the chat template for p in prompts: system = p.get("system", "") user = p.get("user", "") history = p.get("history", "") formatted_prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system}<|eot_id|> {history} <|start_header_id|>user<|end_header_id|>\n{user}<|eot_id|> <|start_header_id|>assistant<|end_header_id|>\n""" formatted_prompts.append(formatted_prompt) # Set up batch sampling parameters sampling_params = SamplingParams( temperature=temperature, top_p=0.95, max_tokens=max_tokens, repetition_penalty=1.2, top_k=400, stop=["", "<|endoftext|>", "<>", "<>", "<>", "<>", "<>", "<|end_header_id|>", "<>", "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"] ) # Generate responses for all prompts in a batch outputs = self.llm.generate(formatted_prompts, sampling_params) # Extract and return the generated texts results = [] for output in outputs: if output.outputs: results.append(output.outputs[0].text.strip()) else: results.append("") return results ================================================ FILE: loadandmergecheckpoint.py ================================================ import os import re import torch from models import Model from safetensors.torch import save_file, load_file from lora import ( remove_lora_modules, merge_lora_weights, strip_bias_keys, DEVICE, OUTPUT_DIR, replace_linear_with_lora, ) MODEL_NAME = "sesame/csm-1b" R=32 APLHA=32 def find_latest_checkpoint(dir_path): checkpoints = [ (int(re.search(r"checkpoint-epoch-(\d+)", d).group(1)), os.path.join(dir_path, d)) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d)) and "checkpoint-epoch" in d ] if not checkpoints: raise FileNotFoundError("No checkpoints found.") latest_epoch, latest_path = max(checkpoints, key=lambda x: x[0]) print(f"Latest checkpoint: epoch {latest_epoch} -> {latest_path}") return latest_path def load_checkpoint_and_merge(): print("Loading base model...") model = Model.from_pretrained(MODEL_NAME).to(DEVICE) print("Applying LoRA structure to the model...") target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj'] model = replace_linear_with_lora(model, r=R, alpha=APLHA, dropout=0.0, target_linear_names = target_layers) checkpoint_path = find_latest_checkpoint(OUTPUT_DIR) print(f"Loading state dictionary from safetensors file...") state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"), device=DEVICE) print("Loading weights into the model...") model.load_state_dict(state_dict, strict=False) print("Merging LoRA weights into base model...") merge_lora_weights(model) print("Replacing LoRALinear modules with standard nn.Linear...") model = remove_lora_modules(model) print("Stripping bias keys for final clean model...") merged_state = strip_bias_keys(model.state_dict()) final_path = os.path.join(OUTPUT_DIR, "model.safetensors") save_file(merged_state, final_path) print(f"Merged and cleaned model saved to: {final_path}") if __name__ == "__main__": load_checkpoint_and_merge() ================================================ FILE: lora.py ================================================ import json import os import glob import torch import torchaudio import logging import numpy as np from dataclasses import dataclass from typing import List, Dict, Optional, Tuple from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer, get_scheduler import torch.nn.functional as F from tqdm import tqdm import wandb from safetensors.torch import save_file import csv from models import Model from moshi.models import loaders from huggingface_hub import hf_hub_download from tokenizers.processors import TemplateProcessing import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') import torch.nn as nn # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(), logging.FileHandler("finetune.log")] ) logger = logging.getLogger(__name__) AUDIO_DIR = "audio_data" OUTPUT_DIR = "finetuned_model" NUM_EPOCHS = 5 BATCH_SIZE = 1 GRADIENT_ACCUMULATION_STEPS = 8 LEARNING_RATE = 1e-6 MAX_GRAD_NORM = 0.1 NUM_CYCLES = 1.0 USE_WANDB = False SEED = 42 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MIXED_PRECISION = True WARMUP_STEPS = 50 SPEAKER_ID = 0 MODEL_NAME = "sesame/csm-1b" TRANSCRIPTION_MODEL = "openai/whisper-large-v3-turbo" MAX_AUDIO_FILES = 0 R=32 APLHA=32 class TrainingVisualizer: def __init__(self, output_dir): self.output_dir = output_dir self.epochs = [] self.losses = [] self.val_losses = [] # Added validation losses self.learning_rates = [] self.steps = [] self.fig, self.axes = plt.subplots(3, 1, figsize=(10, 15)) self.fig.suptitle('CSM Finetuning Progress', fontsize=16) # Setup training loss plot self.axes[0].set_title('Training Loss') self.axes[0].set_xlabel('Epoch') self.axes[0].set_ylabel('Loss') self.axes[0].grid(True, linestyle='--', alpha=0.7) # Setup validation loss plot self.axes[1].set_title('Training vs Validation Loss') self.axes[1].set_xlabel('Epoch') self.axes[1].set_ylabel('Loss') self.axes[1].grid(True, linestyle='--', alpha=0.7) # Setup learning rate plot self.axes[2].set_title('Learning Rate') self.axes[2].set_xlabel('Epoch') self.axes[2].set_ylabel('Learning Rate') self.axes[2].grid(True, linestyle='--', alpha=0.7) def update(self, epoch, step, loss, lr, val_loss=None): """Update the metrics and redraw the plot""" self.epochs.append(epoch) self.steps.append(step) self.losses.append(loss) self.learning_rates.append(lr) # Add validation loss if provided, otherwise use None if val_loss is not None: self.val_losses.append(val_loss) elif len(self.val_losses) > 0: # If we have validation losses but none provided this time, use the last one self.val_losses.append(self.val_losses[-1]) else: # If we've never had validation losses, use None self.val_losses.append(None) # Update training loss plot self.axes[0].clear() self.axes[0].plot(self.epochs, self.losses, 'b-') self.axes[0].set_title('Training Loss') self.axes[0].set_xlabel('Epoch') self.axes[0].set_ylabel('Loss') self.axes[0].grid(True, linestyle='--', alpha=0.7) # Update validation loss plot self.axes[1].clear() self.axes[1].plot(self.epochs, self.losses, 'b-', label='Training') # If we have validation losses, plot them if any(x is not None for x in self.val_losses): # Filter out None values val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None] val_loss_values = [v for v in self.val_losses if v is not None] if val_epochs: self.axes[1].plot(val_epochs, val_loss_values, 'r-', label='Validation') self.axes[1].legend() self.axes[1].set_title('Training vs Validation Loss') self.axes[1].set_xlabel('Epoch') self.axes[1].set_ylabel('Loss') self.axes[1].grid(True, linestyle='--', alpha=0.7) # Update learning rate plot self.axes[2].clear() self.axes[2].plot(self.epochs, self.learning_rates, 'g-') self.axes[2].set_title('Learning Rate') self.axes[2].set_xlabel('Epoch') self.axes[2].set_ylabel('Learning Rate') self.axes[2].grid(True, linestyle='--', alpha=0.7) # Calculate convergence metrics min_loss = min(self.losses) min_loss_epoch = self.epochs[self.losses.index(min_loss)] # Check for potential convergence stall recent_window = 10 # Look at last 10 steps if len(self.losses) > recent_window: recent_losses = self.losses[-recent_window:] loss_std = np.std(recent_losses) loss_change = (recent_losses[0] - recent_losses[-1]) / recent_losses[0] if recent_losses[0] != 0 else 0 convergence_status = "" if loss_std < 0.001 and loss_change < 0.01: convergence_status = "STALLED: Loss not improving significantly" elif min_loss == self.losses[-1]: convergence_status = "IMPROVING: New best loss!" elif self.losses[-1] < self.losses[-2]: convergence_status = "IMPROVING: Loss decreasing" else: convergence_status = "FLUCTUATING: Loss increased" # Add convergence status to title self.fig.suptitle(f'CSM Finetuning Progress - {convergence_status}\n' + f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' + f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12) else: self.fig.suptitle(f'CSM Finetuning Progress\n' + f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' + f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12) plt.tight_layout(rect=[0, 0.03, 1, 0.92]) # Adjust for the larger title # Save the figure plot_path = os.path.join(self.output_dir, 'training_progress.png') self.fig.savefig(plot_path) def finalize(self): """Create a final, more detailed visualization when training completes""" # Create a new figure for the final plot final_fig = plt.figure(figsize=(12, 16)) gs = plt.GridSpec(4, 2, figure=final_fig) # Plot 1: Loss vs Steps ax1 = final_fig.add_subplot(gs[0, :]) ax1.plot(self.steps, self.losses, 'b-', linewidth=2) ax1.set_title('Training Loss vs Steps', fontsize=14) ax1.set_xlabel('Steps') ax1.set_ylabel('Loss') ax1.grid(True, linestyle='--', alpha=0.7) # Plot 2: Loss vs Epochs ax2 = final_fig.add_subplot(gs[1, 0]) ax2.plot(self.epochs, self.losses, 'r-', linewidth=2) ax2.set_title('Training Loss vs Epochs', fontsize=14) ax2.set_xlabel('Epochs') ax2.set_ylabel('Loss') ax2.grid(True, linestyle='--', alpha=0.7) # Plot 3: Learning Rate vs Steps ax3 = final_fig.add_subplot(gs[1, 1]) ax3.plot(self.steps, self.learning_rates, 'g-', linewidth=2) ax3.set_title('Learning Rate Schedule', fontsize=14) ax3.set_xlabel('Steps') ax3.set_ylabel('Learning Rate') ax3.grid(True, linestyle='--', alpha=0.7) # Plot 4: Training vs Validation Loss ax4 = final_fig.add_subplot(gs[2, :]) ax4.plot(self.epochs, self.losses, 'b-', linewidth=2, label='Training') if any(x is not None for x in self.val_losses): # Filter out None values val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None] val_loss_values = [v for v in self.val_losses if v is not None] if val_epochs: ax4.plot(val_epochs, val_loss_values, 'r-', linewidth=2, label='Validation') ax4.legend() ax4.set_title('Training vs Validation Loss', fontsize=14) ax4.set_xlabel('Epochs') ax4.set_ylabel('Loss') ax4.grid(True, linestyle='--', alpha=0.7) # Plot 5: Combined plot with two y-axes ax5 = final_fig.add_subplot(gs[3, :]) color1, color2 = 'blue', 'green' # Plot loss on left axis line1 = ax5.plot(self.epochs, self.losses, color=color1, linewidth=2.5, label='Loss') ax5.set_xlabel('Epochs') ax5.set_ylabel('Loss', color=color1) ax5.tick_params(axis='y', labelcolor=color1) # Plot learning rate on right axis ax6 = ax5.twinx() line2 = ax6.plot(self.epochs, self.learning_rates, color=color2, linewidth=2.5, label='Learning Rate') ax6.set_ylabel('Learning Rate', color=color2) ax6.tick_params(axis='y', labelcolor=color2) # Combine legends lines = line1 + line2 labels = [l.get_label() for l in lines] ax5.legend(lines, labels, loc='upper right') ax5.set_title('Loss and Learning Rate vs Epochs', fontsize=14) ax5.grid(True, linestyle='--', alpha=0.7) # Add training summary if self.epochs: epoch_count = max(self.epochs) step_count = max(self.steps) min_loss = min(self.losses) min_loss_epoch = self.epochs[self.losses.index(min_loss)] min_loss_step = self.steps[self.losses.index(min_loss)] # Calculate convergence indicators recent_epochs = min(10, len(self.losses)) recent_losses = self.losses[-recent_epochs:] loss_change_pct = ((recent_losses[0] - recent_losses[-1]) / recent_losses[0]) * 100 if recent_losses[0] != 0 else 0 summary_text = ( f"Training Summary\n" f"Total Epochs: {epoch_count:.2f}\n" f"Total Steps: {step_count}\n" f"Min Loss: {min_loss:.6f} (Epoch {min_loss_epoch:.2f}, Step {min_loss_step})\n" f"Recent {recent_epochs} epochs loss change: {loss_change_pct:.2f}%\n" ) if len(self.losses) > 20: # Add convergence assessment last_20_losses = self.losses[-20:] std_last_20 = np.std(last_20_losses) converged = std_last_20 < 0.01 and loss_change_pct < 1.0 summary_text += f"Convergence status: {'CONVERGED' if converged else 'NOT CONVERGED'}\n" if converged: summary_text += f"Loss stabilized with std dev {std_last_20:.6f}" else: summary_text += f"Loss still changing significantly (std dev: {std_last_20:.6f})" plt.figtext(0.02, 0.02, summary_text, fontsize=10, bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) plt.tight_layout(rect=[0, 0.05, 1, 0.97]) final_fig.suptitle('CSM Model Finetuning Metrics', fontsize=16, fontweight='bold') plt.subplots_adjust(top=0.93) # Save the final detailed plot final_plot_path = os.path.join(self.output_dir, 'training_metrics_final.png') final_fig.savefig(final_plot_path, dpi=300, bbox_inches='tight') plt.close(final_fig) plt.close(self.fig) logger.info(f"Final training visualization saved to {final_plot_path}") return final_plot_path class LoRALinear(nn.Module): def __init__(self, in_features, out_features, r=32, alpha=64, dropout=0.0, bias=True): super().__init__() self.in_features = in_features self.out_features = out_features self.r = r self.alpha = alpha self.scaling = alpha / r self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() # The base linear (frozen). self.weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False) nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=bias) # LoRA trainable matrices self.lora_A = nn.Parameter(torch.zeros(r, in_features)) self.lora_B = nn.Parameter(torch.zeros(out_features, r)) nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5)) nn.init.zeros_(self.lora_B) def forward(self, x: torch.Tensor) -> torch.Tensor: # normal forward with frozen weight result = F.linear(x, self.weight, self.bias) # LoRA forward with trainable A and B lora_out = F.linear(self.dropout(x), self.lora_A) # [*, r] lora_out = F.linear(lora_out, self.lora_B) # [*, out_features] return result + self.scaling * lora_out def replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout=0.0, target_linear_names: List[str] = None): """ Replaces specified nn.Linear layers with LoRALinear layers within a model, ensuring device consistency. """ if target_linear_names is None: logger.warning("No target layer names specified for LoRA replacement. No layers will be replaced.") return model for name, module in list(model.named_modules()): if isinstance(module, nn.Linear) and any(target_name in name for target_name in target_linear_names): parent_name, child_name = name.rsplit('.', 1) parent_module = model for part in parent_name.split('.'): parent_module = getattr(parent_module, part) original_device = module.weight.device original_dtype = module.weight.dtype # Create the new LoRA layer lora_linear = LoRALinear( in_features=module.in_features, out_features=module.out_features, r=r, alpha=alpha, dropout=dropout, bias=(module.bias is not None) ) # Copy the original weights and bias with torch.no_grad(): lora_linear.weight.copy_(module.weight.data) if module.bias is not None: lora_linear.bias.copy_(module.bias.data) lora_linear.to(device=original_device, dtype=original_dtype) setattr(parent_module, child_name, lora_linear) logger.info(f"Replaced layer: {name} with LoRALinear on device {original_device}") return model def load_llama3_tokenizer(): tokenizer_name = "unsloth/Llama-3.2-1B" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) bos = tokenizer.bos_token eos = tokenizer.eos_token tokenizer._tokenizer.post_processor = TemplateProcessing( single=f"{bos}:0 $A:0 {eos}:0", pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", special_tokens=[(bos, tokenizer.bos_token_id), (eos, tokenizer.eos_token_id)], ) return tokenizer @dataclass class AudioTextPair: audio_path: str text: str speaker_id: int processed_audio: Optional[torch.Tensor] = None def load_audio(self, sample_rate=24000) -> torch.Tensor: if self.processed_audio is not None: return self.processed_audio waveform, sr = torchaudio.load(self.audio_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sr != sample_rate: resampler = torchaudio.transforms.Resample(sr, sample_rate) waveform = resampler(waveform) waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) self.processed_audio = waveform.squeeze(0) return self.processed_audio class CSMDataset(Dataset): def __init__(self, data_items, text_tokenizer, audio_tokenizer, device): self.data_items = data_items self.text_tokenizer = text_tokenizer self.audio_tokenizer = audio_tokenizer self.device = device self.sample_rate = audio_tokenizer.sample_rate def __len__(self): return len(self.data_items) def tokenize_text_segment(self, text: str, speaker: int): text_tokens = self.text_tokenizer.encode(f"[{speaker}]{text}") text_frame = torch.zeros(len(text_tokens), 33).long() text_frame_mask = torch.zeros(len(text_tokens), 33).bool() text_frame[:, -1] = torch.tensor(text_tokens) text_frame_mask[:, -1] = True return text_frame, text_frame_mask def tokenize_audio(self, audio: torch.Tensor): assert audio.ndim == 1, "Audio must be single channel" audio_device = next(self.audio_tokenizer.parameters()).device audio = audio.to(audio_device) try: audio_tokens = self.audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] eos_frame = torch.zeros(audio_tokens.size(0), 1, device=audio_device) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) audio_frame = torch.zeros(audio_tokens.size(1), 33, device=audio_device).long() audio_frame_mask = torch.zeros(audio_tokens.size(1), 33, device=audio_device).bool() audio_frame[:, :-1] = audio_tokens.transpose(0, 1) audio_frame_mask[:, :-1] = True except RuntimeError as e: logger.warning(f"Error encoding audio: {e}, using empty frames") audio_frame = torch.zeros(1, 33, device=audio_device).long() audio_frame_mask = torch.zeros(1, 33, device=audio_device).bool() return audio_frame, audio_frame_mask def __getitem__(self, idx: int): item = self.data_items[idx] audio = item.load_audio(self.sample_rate) text_tokens, text_masks = self.tokenize_text_segment(item.text, item.speaker_id) audio_tokens, audio_masks = self.tokenize_audio(audio) device = audio_tokens.device text_tokens = text_tokens.to(device) text_masks = text_masks.to(device) input_tokens = text_tokens input_masks = text_masks target_tokens = torch.cat([text_tokens, audio_tokens], dim=0) target_masks = torch.cat([text_masks, audio_masks], dim=0) if device != self.device: input_tokens = input_tokens.to(self.device) input_masks = input_masks.to(self.device) target_tokens = target_tokens.to(self.device) target_masks = target_masks.to(self.device) return { "input_tokens": input_tokens, "input_masks": input_masks, "target_tokens": target_tokens, "target_masks": target_masks, } def collate_fn(batch): max_seq_len = 1024 device = batch[0]["input_tokens"].device max_input_len = min(max(item["input_tokens"].size(0) for item in batch), max_seq_len) max_target_len = min(max(item["target_tokens"].size(0) for item in batch), max_seq_len) batch_input_tokens = [] batch_input_masks = [] batch_target_tokens = [] batch_target_masks = [] for item in batch: input_tokens = item["input_tokens"][:max_input_len] input_masks = item["input_masks"][:max_input_len] target_tokens = item["target_tokens"][:max_target_len] target_masks = item["target_masks"][:max_target_len] input_tokens = F.pad(input_tokens, (0,0,0, max_input_len - input_tokens.size(0)), "constant", 0) input_masks = F.pad(input_masks, (0,0,0, max_input_len - input_masks.size(0)), "constant", False) target_tokens = F.pad(target_tokens, (0,0,0, max_target_len - target_tokens.size(0)), "constant", 0) target_masks = F.pad(target_masks, (0,0,0, max_target_len - target_masks.size(0)), "constant", False) batch_input_tokens.append(input_tokens) batch_input_masks.append(input_masks) batch_target_tokens.append(target_tokens) batch_target_masks.append(target_masks) return { "input_tokens": torch.stack(batch_input_tokens), "input_masks": torch.stack(batch_input_masks), "target_tokens": torch.stack(batch_target_tokens), "target_masks": torch.stack(batch_target_masks), "positions": torch.arange(0, max_target_len).unsqueeze(0).repeat(len(batch), 1).to(device) } def transcribe_audio_files(): from transformers import pipeline # Cache file path cache_file = os.path.join(AUDIO_DIR, "transcription_cache.json") # Load existing cache cache = {} if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cache = json.load(f) logger.info(f"Loaded transcription cache with {len(cache)} entries") except Exception as e: logger.warning(f"Could not load cache file: {e}") cache = {} logger.info(f"Transcribing audio files in: {AUDIO_DIR}") transcriber = pipeline("automatic-speech-recognition", model=TRANSCRIPTION_MODEL) audio_text_pairs = [] audio_files = glob.glob(os.path.join(AUDIO_DIR, "*.wav")) \ + glob.glob(os.path.join(AUDIO_DIR, "*.mp3")) \ + glob.glob(os.path.join(AUDIO_DIR, "*.flac")) if MAX_AUDIO_FILES > 0 and len(audio_files) > MAX_AUDIO_FILES: logger.info(f"Found {len(audio_files)} files, limiting to {MAX_AUDIO_FILES}") audio_files = audio_files[:MAX_AUDIO_FILES] cache_hits = 0 cache_misses = 0 for audio_file in tqdm(audio_files, desc="Processing audio files"): try: # Create cache key using file path and modification time file_stat = os.stat(audio_file) cache_key = f"{audio_file}_{file_stat.st_mtime}_{file_stat.st_size}" # Check if transcription exists in cache if cache_key in cache: transcription = cache[cache_key] cache_hits += 1 logger.debug(f"Cache hit: {os.path.basename(audio_file)}") else: # Transcribe the file result = transcriber(audio_file, return_timestamps=True, chunk_length_s=30, stride_length_s=[6, 0], batch_size=32, generate_kwargs={"language": "<|en|>", "task": "transcribe"}) transcription = result["text"].strip() # Save to cache cache[cache_key] = transcription cache_misses += 1 logger.info(f"Transcribed: {os.path.basename(audio_file)} -> {transcription}") audio_text_pairs.append( AudioTextPair(audio_path=audio_file, text=transcription, speaker_id=0) ) except Exception as e: logger.error(f"Error processing {audio_file}: {e}") # Save updated cache try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(cache, f, ensure_ascii=False, indent=2) logger.info(f"Saved transcription cache with {len(cache)} entries") except Exception as e: logger.error(f"Could not save cache file: {e}") logger.info(f"Processed {len(audio_text_pairs)} audio files (Cache hits: {cache_hits}, Cache misses: {cache_misses})") return audio_text_pairs def prepare_csm_model_for_training(): logger.info(f"Loading CSM model: {MODEL_NAME}") model = Model.from_pretrained(MODEL_NAME).to(DEVICE) text_tokenizer = load_llama3_tokenizer() mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) mimi = loaders.get_mimi(mimi_weight, device=DEVICE) mimi.set_num_codebooks(32) audio_tokenizer = mimi try: codebook_0_centroids = mimi.quantizer.rvq_first.layers[0].codebook.weight.data num_codebook_0_tokens, embedding_dim = codebook_0_centroids.shape model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE) model.codebook_embedding.weight.data.copy_(codebook_0_centroids) logger.info(f"Successfully initialized codebook_embedding with shape: {codebook_0_centroids.shape}") except AttributeError: num_codebook_0_tokens, embedding_dim = 1024, 1024 model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE) nn.init.xavier_uniform_(model.codebook_embedding.weight) except Exception as e: num_codebook_0_tokens, embedding_dim = 1024, 1024 model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE) nn.init.xavier_uniform_(model.codebook_embedding.weight) # Some fallback logic for config if not hasattr(model.config, 'get'): def get_method(self, key, default=None): if hasattr(self, key): return getattr(self, key) return default model.config.__class__.get = get_method if not hasattr(model.config, 'tie_word_embeddings'): model.config.tie_word_embeddings = False target_layers = [ "q_proj", "k_proj", "v_proj", "output_proj", "w1", "w2", "w3" ] logger.info("Applying LoRA to model...") model = replace_linear_with_lora( model, r=R, alpha=APLHA, dropout=0.01, target_linear_names=target_layers ) model.cuda() # First, freeze all parameters of the base model for param in model.parameters(): param.requires_grad = False # Then, unfreeze only the newly added LoRA parameters. # It is also common practice to train the bias parameters. for name, param in model.named_parameters(): if "lora_A" in name or "lora_B" in name or "bias" in name: param.requires_grad = True return model, text_tokenizer, audio_tokenizer def setup_model_caches(model, batch_size): try: with torch.no_grad(): model.reset_caches() model.backbone.reset_caches() model.decoder.reset_caches() except Exception as e: logger.debug(f"No caches to reset or error: {e}") return True class BridgingModule(nn.Module): """For a 2048->1024 bridging if needed.""" def __init__(self, in_dim=2048, out_dim=1024): super().__init__() self.bridge = nn.Linear(in_dim, out_dim, bias=False) nn.init.xavier_uniform_(self.bridge.weight) def forward(self, x): return self.bridge(x) def compute_loss_for_codebooks_single_pass( backbone_out, # [b, seq_len, 2048] decoder_out, # [b, seq_len, 1024] model, target_tokens, # [b, seq_len, codebooks] target_masks, # [b, seq_len, codebooks bool] device ): bsz, seq_len = target_tokens.size()[:2] num_codebooks = model.config.audio_num_codebooks c0_logits = model.codebook0_head(backbone_out) audio_positions = target_masks[..., :-1].any(dim=-1) # [b, seq_len] for audio total_loss = torch.tensor(0.0, device=device) count = 0 # codebook0 for b in range(bsz): for s in range(seq_len): if audio_positions[b, s]: token_logits = c0_logits[b, s] target_token = target_tokens[b, s, 0] if target_token > 0: ce = F.cross_entropy(token_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum') total_loss += ce count += 1 # codebooks [1..N-1] from decoder_out for i in range(1, num_codebooks): weight_i = model.audio_head[i-1] flat_dec = decoder_out.view(bsz * seq_len, -1) token_logits_all = flat_dec.mm(weight_i) for b in range(bsz): for s in range(seq_len): if audio_positions[b, s]: target_token = target_tokens[b, s, i] if target_token > 0: row_idx = b*seq_len + s row_logits = token_logits_all[row_idx] ce = F.cross_entropy(row_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum') total_loss += ce count += 1 if count > 0: total_loss = total_loss / count return total_loss def single_pass_forward(model, bridging_module, target_tokens, target_masks, positions): device = next(model.parameters()).device dtype = next(model.parameters()).dtype embed = model._embed_tokens(target_tokens) masked_embed = embed * target_masks.unsqueeze(-1) h = masked_embed.sum(dim=2) backbone_out = model.backbone(h, input_pos=positions, mask=None).to(dtype) bridging_out = bridging_module(backbone_out) codebook0_logits = model.codebook0_head(backbone_out) codebook0_tokens = torch.argmax(codebook0_logits, dim=-1).clamp(0, model.codebook_embedding.num_embeddings - 1) c0_embed = model.codebook_embedding(codebook0_tokens) # Get the last hidden state from bridging module last_h = bridging_out[:, -1, :].unsqueeze(1) # Concatenate the last hidden state with the codebook embeddings decoder_input = torch.cat([last_h, c0_embed], dim=1) # Process decoder inputs in parallel B, S, D = decoder_input.shape # Batch, Sequence length, Dimension # Reshape to (B*S, D) to process all tokens in parallel decoder_input_flat = decoder_input.view(-1, D).unsqueeze(1) # [B*S, 1, D] # Run decoder on all inputs in parallel decoder_out_flat = model.decoder(decoder_input_flat).to(dtype) # [B*S, 1, output_dim] # Reshape back to original batch and sequence dimensions decoder_out = decoder_out_flat.view(B, S, -1) # [B, S, output_dim] # Remove the first token (corresponding to last_h) as in original code decoder_out = decoder_out[:, 1:, :] # [B, T, 1024] # Safety check: handle empty sequences if decoder_out.size(1) == 0: return torch.tensor(0.0, requires_grad=True, device=device) loss = compute_loss_for_codebooks_single_pass( backbone_out=backbone_out, decoder_out=decoder_out, model=model, target_tokens=target_tokens[..., 1:], # Drop codebook 0 target_masks=target_masks[..., 1:], device=device ) return loss def calculate_validation_loss(model, bridging_module, dataset, device, max_samples=50): """ Calculate validation loss on a subset of the dataset """ # Create a small validation dataloader with a subset of data val_indices = torch.randperm(len(dataset))[:max_samples].tolist() val_samples = [dataset[i] for i in val_indices] val_loader = DataLoader( val_samples, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=False ) model.eval() bridging_module.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in val_loader: setup_model_caches(model, batch["target_tokens"].size(0)) loss = forward_and_loss(model, bridging_module, batch, device) total_loss += loss.item() num_batches += 1 model.train() bridging_module.train() # Return average loss return total_loss / num_batches if num_batches > 0 else 0.0 def strip_bias_keys(state_dict: dict) -> dict: new_sd = {} for k, v in state_dict.items(): if k == "codebook_embedding.weight": print(f"Stripping {k} from checkpoint (training-only layer)") continue if not k.endswith(".bias"): new_sd[k] = v else: print(f"Stripping {k} from checkpoint") return new_sd def remove_lora_modules(module: nn.Module) -> nn.Module: for name, child in list(module.named_children()): new_child = remove_lora_modules(child) setattr(module, name, new_child) if isinstance(module, LoRALinear): out_features, in_features = module.out_features, module.in_features # Determine if we actually need a bias has_bias = (module.bias is not None) new_linear = nn.Linear( in_features=in_features, out_features=out_features, bias=has_bias ) # Copy over the merged weight new_linear.weight.data.copy_(module.weight.data) # If we had a bias in LoRALinear, copy it too if has_bias: new_linear.bias.data.copy_(module.bias.data) return new_linear return module def merge_lora_layer(lora_module: LoRALinear): """ Merge the LoRA params (lora_A, lora_B) into the base weight in-place. This transforms the LoRALinear into a standard Linear equivalent. """ # W = W + (alpha/r) * (lora_B @ lora_A) merged_delta = lora_module.scaling * (lora_module.lora_B @ lora_module.lora_A) lora_module.weight.data += merged_delta # Optionally zero out LoRA parameters so they no longer affect anything lora_module.lora_A.data.zero_() lora_module.lora_B.data.zero_() def merge_lora_weights(model: nn.Module): for module in model.modules(): if isinstance(module, LoRALinear): merge_lora_layer(module) return model def finetune(model, dataset): logger.info("Starting finetuning process") csv_file = os.path.join(OUTPUT_DIR, "training_metrics.csv") with open(csv_file, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["epoch", "step", "global_step", "loss", "learning_rate", "val_loss"]) def log_metrics(epoch, step, global_step, loss, learning_rate, val_loss=None): with open(csv_file, "a", newline="") as f: writer = csv.writer(f) writer.writerow([epoch, step, global_step, loss, learning_rate, val_loss if val_loss is not None else ""]) visualizer.update(epoch, global_step, loss, learning_rate, val_loss) visualizer = TrainingVisualizer(OUTPUT_DIR) bridging_module = BridgingModule(in_dim=2048, out_dim=1024).to(DEVICE) for param in bridging_module.parameters(): param.requires_grad = True dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=False ) trainable_params = [p for p in model.parameters() if p.requires_grad] + list(bridging_module.parameters()) optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE) num_training_steps = len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS lr_scheduler = get_scheduler( "cosine", optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps ) if USE_WANDB: wandb.init(project="csm-finetuning", name="csm-lora-finetune-fixed") scaler = torch.amp.GradScaler() if MIXED_PRECISION else None global_step = 0 validation_frequency = max(1, len(dataloader) // (2 * GRADIENT_ACCUMULATION_STEPS)) model.train() bridging_module.train() logger.info("Calculating initial validation loss...") initial_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE) logger.info(f"Initial validation loss: {initial_val_loss:.6f}") current_loss = 0.0 current_lr = LEARNING_RATE for epoch in range(NUM_EPOCHS): logger.info(f"Starting epoch {epoch+1}/{NUM_EPOCHS}") progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}") for step, batch in enumerate(dataloader): try: setup_model_caches(model, batch["target_tokens"].size(0)) with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16, enabled=MIXED_PRECISION): loss = forward_and_loss(model, bridging_module, batch, DEVICE) if GRADIENT_ACCUMULATION_STEPS > 1: loss = loss / GRADIENT_ACCUMULATION_STEPS if torch.isnan(loss) or torch.isinf(loss): logger.warning(f"NaN or Inf loss detected at step {step}. Skipping batch.") optimizer.zero_grad() progress_bar.update(1) continue if MIXED_PRECISION: scaler.scale(loss).backward() else: loss.backward() if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(dataloader): if MIXED_PRECISION: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(trainable_params, MAX_GRAD_NORM) if MIXED_PRECISION: scaler.step(optimizer) scaler.update() else: optimizer.step() lr_scheduler.step() optimizer.zero_grad() current_lr = optimizer.param_groups[0]["lr"] current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS if GRADIENT_ACCUMULATION_STEPS > 1 else loss.item() current_epoch = epoch + (step + 1) / len(dataloader) current_val_loss = None if global_step > 0 and global_step % validation_frequency == 0: logger.info(f"Calculating validation loss at global step {global_step}...") current_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE) logger.info(f"Validation loss: {current_val_loss:.6f}") log_metrics(current_epoch, step, global_step, current_loss, current_lr, current_val_loss) global_step += 1 if USE_WANDB: wandb.log({"loss": current_loss, "learning_rate": current_lr, "epoch": current_epoch, "global_step": global_step, "val_loss": current_val_loss}) progress_bar.set_postfix({"loss": f"{current_loss:.4f}", "lr": f"{current_lr:.2e}"}) progress_bar.update(1) except Exception as e: logger.error(f"Error in batch {step}: {e}") import traceback logger.error(traceback.format_exc()) try: optimizer.zero_grad() torch.cuda.empty_cache() except: pass progress_bar.update(1) continue logger.info(f"Calculating validation loss at end of epoch {epoch+1}...") epoch_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE) logger.info(f"Epoch {epoch+1} validation loss: {epoch_val_loss:.6f}") log_metrics(epoch + 1.0, len(dataloader), global_step, current_loss, current_lr, epoch_val_loss) checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-epoch-{epoch+1}") os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_tensors = { **model.state_dict(), **bridging_module.state_dict() } save_file(checkpoint_tensors, os.path.join(checkpoint_dir, "model.safetensors")) logger.info(f"Saved checkpoint to {checkpoint_dir}") final_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE, max_samples=100) logger.info(f"Final validation loss: {final_val_loss:.6f}") logger.info("Merging LoRA weights into the base model...") merge_lora_weights(model) model = remove_lora_modules(model) merged_state = strip_bias_keys(model.state_dict()) final_merged_path = os.path.join(OUTPUT_DIR, "model.safetensors") save_file(merged_state, final_merged_path) logger.info(f"LoRA-merged & replaced model saved to {final_merged_path}") visualizer.finalize() if USE_WANDB: wandb.finish() return model def forward_and_loss(model, bridging_module, batch, device): target_tokens = batch["target_tokens"].to(device) target_masks = batch["target_masks"].to(device) positions = batch["positions"].to(device) input_tokens = target_tokens[:, :-1] input_masks = target_masks[:, :-1] input_positions = positions[:, :-1] labels = target_tokens[:, 1:] label_masks = target_masks[:, 1:] if input_tokens.size(1) == 0: return torch.tensor(0.0, requires_grad=True, device=device) # 1. Embed tokens and apply mask embed = model._embed_tokens(input_tokens) masked_embed = embed * input_masks.unsqueeze(-1) h = masked_embed.sum(dim=2) # 2. Pass through the backbone backbone_out = model.backbone(h, input_pos=input_positions, mask=None) # 3. Calculate loss for all codebooks loss_fct = nn.CrossEntropyLoss(ignore_index=0) total_loss = 0.0 num_codebooks_with_loss = 0 c0_logits = model.codebook0_head(backbone_out) c0_labels = labels[..., 0] active_mask = label_masks[..., 0].view(-1) if active_mask.sum() > 0: active_logits = c0_logits.view(-1, c0_logits.size(-1))[active_mask] active_labels = c0_labels.view(-1)[active_mask] c0_loss = loss_fct(active_logits, active_labels) total_loss += c0_loss num_codebooks_with_loss += 1 decoder_states = bridging_module(backbone_out) num_codebooks = model.config.audio_num_codebooks for i in range(1, num_codebooks): if hasattr(model, 'audio_head') and len(model.audio_head) >= i: weight_i = model.audio_head[i-1] logits_i = decoder_states @ weight_i labels_i = labels[..., i] active_mask_i = label_masks[..., i].view(-1) if active_mask_i.sum() > 0: active_logits_i = logits_i.view(-1, logits_i.size(-1))[active_mask_i] active_labels_i = labels_i.view(-1)[active_mask_i] loss_i = loss_fct(active_logits_i, active_labels_i) total_loss += loss_i num_codebooks_with_loss += 1 if num_codebooks_with_loss > 0: return total_loss / num_codebooks_with_loss else: return torch.tensor(0.0, requires_grad=True, device=device) def main(): torch.manual_seed(SEED) np.random.seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) os.makedirs(OUTPUT_DIR, exist_ok=True) torch.backends.cuda.enable_flash_sdp(True) if DEVICE == "cuda": torch.backends.cudnn.benchmark = True model, text_tokenizer, audio_tokenizer = prepare_csm_model_for_training() audio_text_pairs = transcribe_audio_files() if not audio_text_pairs: logger.error(f"No audio files found or transcribed in {AUDIO_DIR}") return dataset = CSMDataset( audio_text_pairs, text_tokenizer=text_tokenizer, audio_tokenizer=audio_tokenizer, device=DEVICE ) logger.info(f"Dataset created with {len(dataset)} samples") try: finetune(model, dataset) logger.info("Finetuning completed successfully!") except Exception as e: logger.error(f"Error during finetuning: {e}") import traceback logger.error(traceback.format_exc()) try: # If there's an error, at least save a partial state partial_path = os.path.join(OUTPUT_DIR, "model_partial.safetensors") torch.save(model.state_dict(), partial_path) logger.info(f"Saved partial model to {partial_path} despite errors") except Exception as save_error: logger.error(f"Could not save partial model: {save_error}") if __name__ == "__main__": main() ================================================ FILE: main.py ================================================ import asyncio import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["PYTORCH_DISABLE_CUDA_GRAPHS"] = "1" import platform import sqlite3 import time import threading import json import queue from fastapi.websockets import WebSocketState import torch import torchaudio import sounddevice as sd import numpy as np import whisper from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates from sqlalchemy import create_engine, Column, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from typing import Optional from generator import Segment, load_csm_1b_local from llm_interface import LLMInterface from rag_system import RAGSystem from vad import AudioStreamProcessor from pydantic import BaseModel import logging from config import ConfigManager from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline import re speaking_start_time = 0.0 # set every time the AI begins a new turn MIN_BARGE_LATENCY = 0.9 speaker_counters = { 0: 0, # AI 1: 0 # User } current_generation_id = 1 pending_user_inputs = [] user_input_lock = threading.Lock() audio_fade_duration = 0.3 # seconds for fade-out last_interrupt_time = 0 interrupt_cooldown = 6.0 # seconds between allowed interrupts audio_chunk_buffer = [] # Buffer to store the most recent audio chunks for fade-out # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) model_thread = None model_queue = queue.Queue() model_result_queue = queue.Queue() model_thread_running = threading.Event() llm_lock = threading.Lock() audio_gen_lock = threading.Lock() # Database Base = declarative_base() engine = create_engine("sqlite:///companion.db") SessionLocal = sessionmaker(bind=engine) class Conversation(Base): __tablename__ = "conversations" id = Column(Integer, primary_key=True, index=True) session_id = Column(String, index=True) timestamp = Column(String) user_message = Column(Text) ai_message = Column(Text) audio_path = Column(String) Base.metadata.create_all(bind=engine) # Pydantic config schema class CompanionConfig(BaseModel): system_prompt: str reference_audio_path: str reference_text: str reference_audio_path2: Optional[str] = None # optional field reference_text2: Optional[str] = None # optional field reference_audio_path3: Optional[str] = None # optional field reference_text3: Optional[str] = None # optional field model_path: str llm_path: str max_tokens: int = 8192 voice_speaker_id: int = 0 vad_enabled: bool = True vad_threshold: float = 0.5 embedding_model: str = "all-MiniLM-L6-v2" # Global state conversation_history = [] config = None audio_queue = queue.Queue() is_speaking = False interrupt_flag = threading.Event() generator = None llm = None rag = None vad_processor = None reference_segments = [] active_connections = [] message_queue = asyncio.Queue() # Async event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # FastAPI app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") config_manager = ConfigManager() model_id = "openai/whisper-large-v3-turbo" # Whisper whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True ) whisper_model.to("cuda") processor = AutoProcessor.from_pretrained(model_id) whisper_pipe = pipeline( "automatic-speech-recognition", model=whisper_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch.float16, device='cuda', ) # Background queue async def process_message_queue(): while True: message = await message_queue.get() for client in active_connections[:]: try: if client.client_state == WebSocketState.CONNECTED: await client.send_json(message) except Exception as e: logger.error(f"Error in message queue for client: {e}") if client in active_connections: active_connections.remove(client) message_queue.task_done() def load_reference_segments(config_data: CompanionConfig): """Load multiple reference clips for voice‑cloning.""" global reference_segments reference_segments = [] # Load primary reference (required) if os.path.isfile(config_data.reference_audio_path): logger.info(f"Loading primary reference audio: {config_data.reference_audio_path}") wav, sr = torchaudio.load(config_data.reference_audio_path) wav = torchaudio.functional.resample(wav.squeeze(0), orig_freq=sr, new_freq=24_000) reference_segments.append(Segment(text=config_data.reference_text, speaker=config_data.voice_speaker_id, audio=wav)) else: logger.warning(f"Primary reference audio '{config_data.reference_audio_path}' not found.") # Load second reference (optional) if config_data.reference_audio_path2 and os.path.isfile(config_data.reference_audio_path2): logger.info(f"Loading second reference audio: {config_data.reference_audio_path2}") wav, sr = torchaudio.load(config_data.reference_audio_path2) wav = torchaudio.functional.resample(wav.squeeze(0), orig_freq=sr, new_freq=24_000) reference_segments.append(Segment(text=config_data.reference_text2, speaker=config_data.voice_speaker_id, audio=wav)) # Load third reference (optional) if config_data.reference_audio_path3 and os.path.isfile(config_data.reference_audio_path3): logger.info(f"Loading third reference audio: {config_data.reference_audio_path3}") wav, sr = torchaudio.load(config_data.reference_audio_path3) wav = torchaudio.functional.resample(wav.squeeze(0), orig_freq=sr, new_freq=24_000) reference_segments.append(Segment(text=config_data.reference_text3, speaker=config_data.voice_speaker_id, audio=wav)) logger.info(f"Loaded {len(reference_segments)} reference audio segments.") def transcribe_audio(audio_data, sample_rate): global whisper_model audio_np = np.array(audio_data).astype(np.float32) if sample_rate != 16000: try: audio_tensor = torch.tensor(audio_np).unsqueeze(0) audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000) audio_np = audio_tensor.squeeze(0).numpy() except: pass try: with torch.jit.optimized_execution(False): result = whisper_pipe(audio_np, generate_kwargs={"language": "english"}) return result["text"] except: return "[Transcription error]" def initialize_models(config_data: CompanionConfig): global generator, llm, rag, vad_processor, config config = config_data logger.info("Loading LLM …") llm = LLMInterface(config_data.llm_path, config_data.max_tokens) logger.info("Loading RAG …") rag = RAGSystem("companion.db", model_name=config_data.embedding_model) vad_model, vad_utils = torch.hub.load('snakers4/silero-vad', model='silero_vad', force_reload=False) vad_processor = AudioStreamProcessor( model=vad_model, utils=vad_utils, sample_rate=16_000, vad_threshold=config_data.vad_threshold, callbacks={"on_speech_start": on_speech_start, "on_speech_end": on_speech_end}, ) load_reference_segments(config_data) start_model_thread() logger.info("Compiling / warming‑up voice model …") t0 = time.time() # send a dummy request; max 0.5 s of audio, result discarded model_queue.put(( "warm‑up.", # text config_data.voice_speaker_id, # speaker [], # no context 500, # max_ms 0.7, # temperature 40, # top‑k )) # block until worker signals EOS (None marker) while True: r = model_result_queue.get() if r is None: break logger.info(f"Voice model ready in {time.time() - t0:.1f}s") def on_speech_start(): asyncio.run_coroutine_threadsafe( message_queue.put( { "type": "vad_status", "status": "speech_started", "should_interrupt": False, # always False – UI never barges-in here } ), loop, ) def on_speech_end(audio_data, sample_rate): try: logger.info("Transcription starting") user_text = transcribe_audio(audio_data, sample_rate) logger.info(f"Transcription completed: '{user_text}'") session_id = "default" speaker_id = 1 index = speaker_counters[speaker_id] user_audio_path = f"audio/user/{session_id}_user_{index}.wav" os.makedirs(os.path.dirname(user_audio_path), exist_ok=True) audio_tensor = torch.tensor(audio_data).unsqueeze(0) save_audio_and_trim(user_audio_path, session_id, speaker_id, audio_tensor.squeeze(0), sample_rate) add_segment(user_text, speaker_id, audio_tensor.squeeze(0)) logger.info(f"User audio saved and segment appended: {user_audio_path}") speaker_counters[speaker_id] += 1 # Send transcription to clients asyncio.run_coroutine_threadsafe( message_queue.put({"type": "transcription", "text": user_text}), loop ) threading.Thread(target=lambda: process_user_input(user_text, session_id), daemon=True).start() except Exception as e: logger.error(f"VAD callback failed: {e}") def process_pending_inputs(): """Process only the latest user input after an interruption""" global pending_user_inputs, is_speaking, interrupt_flag time.sleep(0.2) is_speaking = False interrupt_flag.clear() with user_input_lock: if not pending_user_inputs: logger.info("No pending user inputs to process") return # Only take the most recent input and ignore others latest_input = pending_user_inputs[-1] logger.info(f"Processing only latest input: '{latest_input[0]}'") # Clear all pending inputs pending_user_inputs = [] # Process only the latest input user_text, session_id = latest_input process_user_input(user_text, session_id) def process_user_input(user_text, session_id="default"): global config, is_speaking, pending_user_inputs, interrupt_flag # Skip empty messages if not user_text or user_text.strip() == "": logger.warning("Empty user input received, ignoring") return interrupt_flag.clear() is_speaking = False # Check if we're currently supposed to be speaking if is_speaking: logger.info(f"AI is currently speaking, adding input to pending queue: '{user_text}'") with user_input_lock: # Only keep the most recent input, replacing any existing ones pending_user_inputs = [(user_text, session_id)] logger.info(f"Added user input as the only pending input: '{user_text}'") # Request interruption if not already interrupted if not interrupt_flag.is_set(): logger.info("Automatically interrupting current speech for new input") interrupt_flag.set() # Notify clients of interruption asyncio.run_coroutine_threadsafe( message_queue.put({"type": "audio_status", "status": "interrupted"}), loop ) # Allow a short delay before processing the new input time.sleep(0.3) # Process the pending input after interruption process_pending_inputs() return interrupt_flag.clear() # Normal processing continues... logger.info(f"Processing user input: '{user_text}'") context = "\n".join([f"User: {msg['user']}\nAI: {msg['ai']}" for msg in conversation_history[-5:]]) rag_context = rag.query(user_text) system_prompt = config.system_prompt if rag_context: system_prompt += f"\n\nRelevant context:\n{rag_context}" # Notify clients that we're thinking asyncio.run_coroutine_threadsafe( message_queue.put({"type": "status", "message": "Thinking..."}), loop ) try: with llm_lock: ai_response = llm.generate_response(system_prompt, user_text, context) timestamp = time.strftime("%Y-%m-%d %H:%M:%S") conversation_history.append({ "timestamp": timestamp, "user": user_text, "ai": ai_response }) try: db = SessionLocal() conv = Conversation( session_id=session_id, timestamp=timestamp, user_message=user_text, ai_message=ai_response, audio_path="" ) db.add(conv) db.commit() index = speaker_counters[0] output_file = f"audio/ai/{session_id}_response_{index}.wav" speaker_counters[0] += 1 conv.audio_path = output_file db.commit() db.close() except Exception as e: logger.error(f"Database error: {e}") threading.Thread(target=lambda: rag.add_conversation(user_text, ai_response), daemon=True).start() asyncio.run_coroutine_threadsafe( message_queue.put({"type": "audio_status", "status": "preparing"}), loop ) # Small delay to ensure client is ready time.sleep(0.2) # Send the response to clients asyncio.run_coroutine_threadsafe( message_queue.put({"type": "response", "text": ai_response}), loop ) time.sleep(0.5) if is_speaking: logger.warning("Still speaking when trying to start new audio - forcing interrupt") interrupt_flag.set() is_speaking = False time.sleep(0.5) # Give time for cleanup interrupt_flag.clear() # Make absolutely sure is_speaking = False # Reset for audio thread to take over # Start audio generation in a new thread threading.Thread(target=audio_generation_thread, args=(ai_response, output_file), daemon=True).start() except Exception as e: logger.error(f"Error generating response: {e}") asyncio.run_coroutine_threadsafe( message_queue.put({"type": "error", "message": "Failed to generate response"}), loop ) def model_worker(cfg: CompanionConfig): global generator, model_thread_running logger.info("Model worker thread started") if generator is None: torch._inductor.config.triton.cudagraphs = False # Disable cudagraphs torch._inductor.config.fx_graph_cache = False # Disable graph caching logger.info("Loading voice model inside worker thread …") generator = load_csm_1b_local(cfg.model_path, "cuda") logger.info("Voice model ready (compiled with cudagraphs)") while model_thread_running.is_set(): try: request = model_queue.get(timeout=0.1) if request is None: break text, speaker_id, context, max_ms, temperature, topk = request for chunk in generator.generate_stream( text=text, speaker=speaker_id, context=context, max_audio_length_ms=max_ms, temperature=temperature, topk=topk): model_result_queue.put(chunk) if not model_thread_running.is_set(): break model_result_queue.put(None) # EOS marker except queue.Empty: continue except Exception as e: import traceback logger.error(f"Error in model worker: {e}\n{traceback.format_exc()}") model_result_queue.put(Exception(f"Generation error: {e}")) logger.info("Model worker thread exiting") def start_model_thread(): global model_thread, model_thread_running if model_thread is not None and model_thread.is_alive(): return model_thread_running.set() model_thread = threading.Thread(target=model_worker, args=(config,), daemon=True, name="model_worker") model_thread.start() logger.info("Started dedicated model worker thread") async def run_audio_generation(text, output_file): """Async wrapper for audio generation that runs in the event loop thread""" audio_generation_thread(text, output_file) def send_to_all_clients(message: dict): """Send a message to all connected WebSocket clients""" for client in active_connections[:]: try: if client.client_state == WebSocketState.CONNECTED: asyncio.run_coroutine_threadsafe(client.send_json(message), loop) logger.info(f"Sent message to client: {message}") else: logger.warning("Detected non-connected client; removing from active_connections") active_connections.remove(client) except Exception as e: logger.error(f"Error sending message to client: {e}") if client in active_connections: active_connections.remove(client) saved_audio_paths = { "default": { 0: [], # AI 1: [] # User } } MAX_AUDIO_FILES = 8 def save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate): """ Save audio file and trim old audio files for both AI and user to maintain storage limits. Args: path: Path to save the audio file session_id: Conversation session ID speaker_id: 0 for AI, 1 for user tensor: Audio tensor to save sample_rate: Audio sample rate """ torchaudio.save(path, tensor.unsqueeze(0), sample_rate) saved_audio_paths.setdefault(session_id, {}).setdefault(speaker_id, []).append(path) paths = saved_audio_paths[session_id][speaker_id] while len(paths) > MAX_AUDIO_FILES: old_path = paths.pop(0) if os.path.exists(old_path): os.remove(old_path) logger.info(f"Removed old audio file: {old_path}") other_speaker_id = 1 if speaker_id == 0 else 0 if other_speaker_id in saved_audio_paths[session_id]: other_paths = saved_audio_paths[session_id][other_speaker_id] while len(other_paths) > MAX_AUDIO_FILES: old_path = other_paths.pop(0) if os.path.exists(old_path): os.remove(old_path) logger.info(f"Removed old audio file from other speaker: {old_path}") MAX_SEGMENTS = 8 def add_segment(text, speaker_id, audio_tensor): """ Add a new segment and ensure the total context stays within token limits. This version correctly separates protected and dynamic segments, performs trimming on the dynamic list, and rebuilds the global context list at the end. Args: text: Text content of the segment speaker_id: ID of the speaker (0 for AI, 1 for user) audio_tensor: Audio data as a tensor """ global reference_segments, generator, config # Determine the number of protected, initial reference segments based on what was actually loaded. num_protected_segments = 0 if config.reference_audio_path and os.path.exists(config.reference_audio_path): num_protected_segments += 1 if config.reference_audio_path2 and os.path.exists(config.reference_audio_path2): num_protected_segments += 1 if config.reference_audio_path3 and os.path.exists(config.reference_audio_path3): num_protected_segments += 1 # Separate protected from dynamic segments from the current global state protected_segments = reference_segments[:num_protected_segments] dynamic_segments = reference_segments[num_protected_segments:] # Add the new segment to the dynamic list new_segment = Segment(text=text, speaker=speaker_id, audio=audio_tensor) dynamic_segments.append(new_segment) # First, trim by MAX_SEGMENTS count. The oldest dynamic segments are removed. max_dynamic_allowed = MAX_SEGMENTS - len(protected_segments) if len(dynamic_segments) > max_dynamic_allowed: # Keep only the most recent dynamic segments dynamic_segments = dynamic_segments[-max_dynamic_allowed:] # Then, check and trim by token count if necessary. # This loop will trim the oldest dynamic segments until the token count is acceptable. if hasattr(generator, '_text_tokenizer'): while dynamic_segments: # Tentatively combine for token calculation temp_full_list = protected_segments + dynamic_segments total_tokens = 0 # Calculate total tokens for the current combination for segment in temp_full_list: tokens = generator._text_tokenizer.encode(f"[{segment.speaker}]{segment.text}") total_tokens += len(tokens) if segment.audio is not None: # Approximate frame count to token conversion audio_frames = segment.audio.size(0) // 6094 total_tokens += audio_frames # If we are within limits, the trimming is done. if total_tokens <= 4096: break # Otherwise, remove the oldest dynamic segment and re-check in the next loop iteration. dynamic_segments.pop(0) else: # Fallback if tokenizer is not available logger.warning("Unable to access tokenizer - falling back to word-based estimation for context trimming") def estimate_tokens(segment): words = segment.text.split() punctuation = sum(1 for char in segment.text if char in ".,!?;:\"'()[]{}") text_tokens = len(words) + punctuation audio_tokens = 0 if segment.audio is not None: audio_frames = segment.audio.size(0) // 6094 audio_tokens = audio_frames return text_tokens + audio_tokens while dynamic_segments: total_estimated_tokens = sum(estimate_tokens(s) for s in protected_segments) + \ sum(estimate_tokens(s) for s in dynamic_segments) if total_estimated_tokens <= 2048: break dynamic_segments.pop(0) # Finally, overwrite the global variable with the new, correctly-trimmed list. # This is the single source of truth for the update. reference_segments = protected_segments + dynamic_segments # Log the final state for debugging logger.info(f"Context updated. Segments: {len(reference_segments)} total " + f"({len(protected_segments)} protected, {len(dynamic_segments)} dynamic).") def preprocess_text_for_tts(text): """ Removes all punctuation except periods, commas, exclamation points, and question marks from the input text to create cleaner speech output while preserving intonation. Args: text (str): Input text with potential punctuation Returns: str: Cleaned text with only allowed punctuation """ # Define a regex pattern that matches all punctuation except periods, commas, exclamation points, and question marks # This includes: ; : " ' ~ @ # $ % ^ & * ( ) _ - + = [ ] { } \ | / < > pattern = r'[^\w\s.,!?\']' # Replace matched punctuation with empty string cleaned_text = re.sub(pattern, '', text) # normalize multiple spaces to single space cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # ensure there's a space after punctuation for better speech pacing cleaned_text = re.sub(r'([.,!?])(\S)', r'\1 \2', cleaned_text) return cleaned_text.strip() def audio_generation_thread(text, output_file): global is_speaking, interrupt_flag, audio_queue, model_thread_running, current_generation_id, speaking_start_time current_generation_id += 1 this_id = current_generation_id interrupt_flag.clear() # Log the start of generation logger.info(f"Starting audio generation for ID: {this_id}") # Try to acquire the lock, but don't block if it's busy if not audio_gen_lock.acquire(blocking=False): logger.warning(f"Audio generation {this_id} - lock acquisition failed, another generation is in progress") asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "error", "message": "Audio generation busy, skipping synthesis", "gen_id": this_id }), loop ) return try: # Start the model thread if it's not already running start_model_thread() interrupt_flag.clear() is_speaking = True speaking_start_time = time.time() # Create output directory os.makedirs(os.path.dirname(output_file), exist_ok=True) all_audio_chunks = [] # Prepare text text_lower = text.lower() text_lower = preprocess_text_for_tts(text_lower) asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "audio_status", "status": "preparing_generation", "gen_id": this_id }), loop ) # Give client a moment to process time.sleep(0.2) logger.info(f"Sending generating status with ID {this_id}") asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "audio_status", "status": "generating", "gen_id": this_id # Include generation ID }), loop ) # Small delay to ensure client gets the signal time.sleep(0.2) # Estimate audio length words = text.split() avg_wpm = 100 words_per_second = avg_wpm / 60 estimated_seconds = len(words) / words_per_second max_audio_length_ms = int(estimated_seconds * 1000) # Send request to model thread logger.info(f"Audio generation {this_id} - sending request to model thread") model_queue.put(( text_lower, config.voice_speaker_id, reference_segments, max_audio_length_ms, 0.8, # temperature 50 # topk )) # Start timing generation_start = time.time() chunk_counter = 0 # Process results as they come while True: try: # Check for interruption FIRST before getting more results if interrupt_flag.is_set(): logger.info(f"Audio generation {this_id} - interrupt detected, stopping") # Signal model thread to exit and restart model_thread_running.clear() time.sleep(0.1) model_thread_running.set() start_model_thread() # Clear any remaining items in the result queue while not model_result_queue.empty(): try: model_result_queue.get_nowait() except queue.Empty: pass # Break out of the processing loop break # Get result with timeout to allow checking interrupt result = model_result_queue.get(timeout=0.1) # Check for end of generation or error if result is None: logger.info(f"Audio generation {this_id} - complete") break if isinstance(result, Exception): logger.error(f"Audio generation {this_id} - error: {result}") raise result # Track timing for first chunk if chunk_counter == 0: first_chunk_time = time.time() - generation_start logger.info(f"Audio generation {this_id} - first chunk latency: {first_chunk_time*1000:.1f}ms") chunk_counter += 1 # One more interrupt check before processing chunk if interrupt_flag.is_set(): logger.info(f"Audio generation {this_id} - interrupt flag set during chunk processing") break # Process this audio chunk audio_chunk = result all_audio_chunks.append(audio_chunk) # Convert to numpy and send to audio queue chunk_array = audio_chunk.cpu().numpy().astype(np.float32) audio_queue.put(chunk_array) if chunk_counter == 1: logger.info(f"Sending first audio chunk with ID {this_id}") # Notify client we're sending the first chunk asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "audio_status", "status": "first_chunk", "gen_id": this_id }), loop ) # Small delay time.sleep(0.1) # Send chunk with generation ID asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "audio_chunk", "audio": chunk_array.tolist(), "sample_rate": generator.sample_rate, "gen_id": this_id, "chunk_num": chunk_counter # Include chunk number }), loop ) except queue.Empty: # No results yet, keep checking continue except Exception as e: logger.error(f"Audio generation {this_id} - error processing result: {e}") break # Save complete audio if available if all_audio_chunks and not interrupt_flag.is_set(): try: complete_audio = torch.cat(all_audio_chunks) save_audio_and_trim(output_file, "default", config.voice_speaker_id, complete_audio, generator.sample_rate) add_segment(text.lower(), config.voice_speaker_id, complete_audio) # Log statistics total_time = time.time() - generation_start total_audio_seconds = complete_audio.size(0) / generator.sample_rate rtf = total_time / total_audio_seconds logger.info(f"Audio generation {this_id} - completed in {total_time:.2f}s, RTF: {rtf:.2f}x") except Exception as e: logger.error(f"Audio generation {this_id} - error saving complete audio: {e}") except Exception as e: import traceback logger.error(f"Audio generation {this_id} - unexpected error: {e}\n{traceback.format_exc()}") finally: is_speaking = False # Signal end of audio audio_queue.put(None) try: logger.info(f"Audio generation {this_id} - sending completion status") asyncio.run_coroutine_threadsafe( message_queue.put({ "type": "audio_status", "status": "complete", "gen_id": this_id }), loop ) except Exception as e: logger.error(f"Audio generation {this_id} - failed to send completion status: {e}") # Process any pending inputs with user_input_lock: if pending_user_inputs: # Process pending inputs logger.info(f"Audio generation {this_id} - processing pending inputs") process_pending_inputs() # Release the lock logger.info(f"Audio generation {this_id} - releasing lock") audio_gen_lock.release() def handle_interrupt(websocket): global is_speaking, last_interrupt_time, interrupt_flag, model_thread_running, speaking_start_time # Log the current state logger.info(f"Interrupt requested. Current state: is_speaking={is_speaking}") current_time = time.time() time_since_speech_start = current_time - speaking_start_time if speaking_start_time > 0 else 999 time_since_last_interrupt = current_time - last_interrupt_time # Only apply cooldown for established speech, not for new speech if time_since_last_interrupt < interrupt_cooldown and time_since_speech_start > 3.0: logger.info(f"Ignoring interrupt: too soon after previous interrupt ({time_since_last_interrupt:.1f}s < {interrupt_cooldown}s)") # Let the client know we're not interrupting asyncio.run_coroutine_threadsafe( websocket.send_json({ "type": "audio_status", "status": "interrupt_acknowledged", "success": False, "reason": "cooldown" }), loop ) return False # Update the last interrupt time last_interrupt_time = current_time # We should interrupt if we're speaking OR if model generation is in progress if is_speaking or not model_result_queue.empty(): logger.info("Interruption processing: we are speaking or generating") interrupt_flag.set() # Notify clients asyncio.run_coroutine_threadsafe( message_queue.put({"type": "audio_status", "status": "interrupted"}), loop ) asyncio.run_coroutine_threadsafe( websocket.send_json({ "type": "audio_status", "status": "interrupt_acknowledged" }), loop ) # Clear the audio queue to stop additional audio from being processed try: # Drain the existing queue while not audio_queue.empty(): try: audio_queue.get_nowait() except queue.Empty: break # Add end signal audio_queue.put(None) logger.info("Audio queue cleared") except Exception as e: logger.error(f"Error clearing audio queue: {e}") # Reset VAD to prepare for new input if vad_processor: try: vad_processor.reset() logger.info("VAD processor reset") except Exception as e: logger.error(f"Error resetting VAD: {e}") # Stop current model worker if needed if model_thread and model_thread.is_alive(): try: # Clear the thread running flag to stop generation model_thread_running.clear() # Wait a brief moment for thread to notice and exit time.sleep(0.1) # Now restart the thread state flag model_thread_running.set() # And restart the thread start_model_thread() logger.info("Model thread restarted") except Exception as e: logger.error(f"Error restarting model thread: {e}") return True logger.info("No active speech to interrupt") return False @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): global is_speaking, audio_queue await websocket.accept() active_connections.append(websocket) saved = config_manager.load_config() if saved: await websocket.send_json({"type": "saved_config", "config": saved}) try: while True: data = await websocket.receive_json() if data["type"] == "config": # Config handling try: config_data = data["config"] logger.info(f"Received config data keys: {config_data.keys()}") for key in ["reference_audio_path", "reference_audio_path2", "reference_audio_path3", "reference_text", "reference_text2", "reference_text3"]: if key in config_data: logger.info(f"Config includes {key}: {config_data[key]}") else: logger.warning(f"Config missing {key}") conf = CompanionConfig(**config_data) saved = config_manager.save_config(config_data) if saved: initialize_models(conf) await websocket.send_json({"type": "status", "message": "Models initialized and configuration saved"}) else: await websocket.send_json({"type": "error", "message": "Failed to save configuration"}) except Exception as e: logger.error(f"Error processing config: {str(e)}") await websocket.send_json({"type": "error", "message": f"Configuration error: {str(e)}"}) elif data["type"] == "request_saved_config": saved = config_manager.load_config() await websocket.send_json({"type": "saved_config", "config": saved}) elif data["type"] == "text_message": user_text = data["text"] session_id = data.get("session_id", "default") logger.info(f"TEXT-MSG from client: {user_text!r}") # If the model is already talking, queue the request but if is_speaking: with user_input_lock: if len(pending_user_inputs) >= 3: pending_user_inputs = pending_user_inputs[-2:] pending_user_inputs.append((user_text, session_id)) await websocket.send_json( {"type":"status","message":"Queued – I’ll answer in a moment"}) continue await message_queue.put({"type":"transcription","text":user_text}) threading.Thread( target=lambda: process_user_input(user_text, session_id), daemon=True).start() elif data["type"] == "audio": audio_data = np.asarray(data["audio"], dtype=np.float32) sample_rate = data["sample_rate"] if sample_rate != 16000: audio_tensor = torch.tensor(audio_data).unsqueeze(0) audio_tensor = torchaudio.functional.resample( audio_tensor, orig_freq=sample_rate, new_freq=16000 ) audio_data = audio_tensor.squeeze(0).numpy() sample_rate = 16000 if config and config.vad_enabled: vad_processor.process_audio(audio_data) else: text = transcribe_audio(audio_data, sample_rate) await websocket.send_json({"type": "transcription", "text": text}) await message_queue.put({"type": "transcription", "text": text}) if is_speaking: with user_input_lock: pending_user_inputs.append((text, "default")) else: process_user_input(text) elif data["type"] == "interrupt": logger.info("Explicit interrupt request received") # Always acknowledge receipt of interrupt request await websocket.send_json({ "type": "audio_status", "status": "interrupt_acknowledged" }) # Then try to handle the actual interrupt success = handle_interrupt(websocket) # If successful, allow a brief delay for clearing everything if success: await asyncio.sleep(0.3) # Short delay to allow complete clearing # Force process pending inputs after interrupt with user_input_lock: if pending_user_inputs: user_text, session_id = pending_user_inputs.pop(0) pending_user_inputs.clear() # Clear any backup to avoid multiple responses # Process in a new thread to avoid blocking threading.Thread( target=lambda: process_user_input(user_text, session_id), daemon=True ).start() # Send final status update about the interrupt await websocket.send_json({ "type": "audio_status", "status": "interrupted", "success": success }) elif data["type"] == "mute": await websocket.send_json({"type": "mute_status", "muted": data["muted"]}) if not data["muted"] and config and config.vad_enabled: vad_processor.reset() except WebSocketDisconnect: if websocket in active_connections: active_connections.remove(websocket) @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.get("/setup", response_class=HTMLResponse) async def setup_page(request: Request): return templates.TemplateResponse("setup.html", {"request": request}) @app.get("/chat", response_class=HTMLResponse) async def chat_page(request: Request): return templates.TemplateResponse("chat.html", {"request": request}) @app.on_event("startup") async def startup_event(): os.makedirs("static", exist_ok=True) os.makedirs("audio/user", exist_ok=True) os.makedirs("audio/ai", exist_ok=True) os.makedirs("embeddings_cache", exist_ok=True) os.makedirs("templates", exist_ok=True) with open("templates/index.html", "w") as f: f.write("""""") try: torch.hub.load('snakers4/silero-vad', model='silero_vad', force_reload=False) except: pass asyncio.create_task(process_message_queue()) @app.on_event("shutdown") async def shutdown_event(): logger.info("Server shutting down...") from flask import Flask, jsonify, request, send_file @app.get("/api/conversations") async def get_conversations(request: Request): conn = sqlite3.connect("companion.db") cur = conn.cursor() cur.execute("SELECT id, user_message, ai_message FROM conversations ORDER BY id DESC") data = [{"id": row[0], "user_message": row[1], "ai_message": row[2]} for row in cur.fetchall()] conn.close() return JSONResponse(content=data) @app.route("/api/conversations/", methods=["PUT"]) def update_conversation(conv_id): data = request.get_json() conn = sqlite3.connect("companion.db") cur = conn.cursor() cur.execute("UPDATE conversations SET user_message=?, ai_message=? WHERE id=?", (data["user_message"], data["ai_message"], conv_id)) conn.commit() conn.close() return "", 204 @app.delete("/api/conversations") async def delete_all_conversations(): try: conn = sqlite3.connect("companion.db") cur = conn.cursor() cur.execute("DELETE FROM conversations") conn.commit() conn.close() return {"status": "all deleted"} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.delete("/api/conversations/{conv_id}") async def delete_conversation(conv_id: int): try: conn = sqlite3.connect("companion.db") cur = conn.cursor() cur.execute("DELETE FROM conversations WHERE id = ?", (conv_id,)) conn.commit() conn.close() return JSONResponse(content={"status": "deleted", "id": conv_id}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.get("/crud", response_class=HTMLResponse) async def crud_ui(request: Request): return templates.TemplateResponse("crud.html", {"request": request}) if __name__ == "__main__": import uvicorn threading.Thread(target=lambda: asyncio.run(loop.run_forever()), daemon=True).start() uvicorn.run(app, host="0.0.0.0", port=8000) ================================================ FILE: models.py ================================================ import logging from dataclasses import dataclass import torch import torch.nn as nn import torchtune from huggingface_hub import PyTorchModelHubMixin from torchtune.models import llama3_2 logger = logging.getLogger(__name__) def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: return llama3_2.llama3_2( vocab_size=128_256, num_layers=16, num_heads=32, num_kv_heads=8, embed_dim=2048, max_seq_len=2048, intermediate_dim=8192, attn_dropout=0.0, norm_eps=1e-5, rope_base=500_000, scale_factor=32, ) def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: return llama3_2.llama3_2( vocab_size=128_256, num_layers=4, num_heads=8, num_kv_heads=2, embed_dim=1024, max_seq_len=2048, intermediate_dim=8192, attn_dropout=0.0, norm_eps=1e-5, rope_base=500_000, scale_factor=32, ) FLAVORS = { "llama-1B": llama3_2_1B, "llama-100M": llama3_2_100M, } def _prepare_transformer(model): embed_dim = model.tok_embeddings.embedding_dim model.tok_embeddings = nn.Identity() model.output = nn.Identity() return model, embed_dim def _create_causal_mask(seq_len: int, device: torch.device): return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): """ Args: mask: (max_seq_len, max_seq_len) input_pos: (batch_size, seq_len) Returns: (batch_size, seq_len, max_seq_len) """ r = mask[input_pos, :] return r def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs).exponential_(1) return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) def sample_topk(logits: torch.Tensor, topk: int, temperature: float): logits = logits / temperature filter_value: float = -float("Inf") indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] scores_processed = logits.masked_fill(indices_to_remove, filter_value) scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) probs = torch.nn.functional.softmax(scores_processed, dim=-1) sample_token = _multinomial_sample_one_no_sync(probs) return sample_token @dataclass class ModelArgs: backbone_flavor: str decoder_flavor: str text_vocab_size: int audio_vocab_size: int audio_num_codebooks: int class Model( nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/SesameAILabs/csm", pipeline_tag="text-to-speech", license="apache-2.0", ): def __init__(self, config: ModelArgs): super().__init__() self.config = config self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) def setup_caches(self, max_batch_size: int) -> torch.Tensor: """Setup KV caches and return a causal mask.""" dtype = next(self.parameters()).dtype device = next(self.parameters()).device with device: self.backbone.setup_caches(max_batch_size, dtype) self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) def generate_frame( self, tokens: torch.Tensor, tokens_mask: torch.Tensor, input_pos: torch.Tensor, temperature: float, topk: int, ) -> torch.Tensor: """ Args: tokens: (batch_size, seq_len, audio_num_codebooks+1) tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) input_pos: (batch_size, seq_len) positions for each token mask: (batch_size, seq_len, max_seq_len Returns: (batch_size, audio_num_codebooks) sampled tokens """ dtype = next(self.parameters()).dtype b, s, _ = tokens.size() assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) embeds = self._embed_tokens(tokens) masked_embeds = embeds * tokens_mask.unsqueeze(-1) h = masked_embeds.sum(dim=2) h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) last_h = h[:, -1, :] c0_logits = self.codebook0_head(last_h) c0_sample = sample_topk(c0_logits, topk, temperature) c0_embed = self._embed_audio(0, c0_sample) curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) curr_sample = c0_sample.clone() curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) # Decoder caches must be reset every frame. self.decoder.reset_caches() for i in range(1, self.config.audio_num_codebooks): curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( dtype=dtype ) ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) ci_sample = sample_topk(ci_logits, topk, temperature) ci_embed = self._embed_audio(i, ci_sample) curr_h = ci_embed curr_sample = torch.cat([curr_sample, ci_sample], dim=1) curr_pos = curr_pos[:, -1:] + 1 return curr_sample def reset_caches(self): self.backbone.reset_caches() self.decoder.reset_caches() def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) audio_tokens = tokens[:, :, :-1] + ( self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) ) audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 ) return torch.cat([audio_embeds, text_embeds], dim=-2) ================================================ FILE: rag_system.py ================================================ import sqlite3 import numpy as np import json from pathlib import Path import time from typing import List, Dict, Any, Tuple, Optional from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import torch class RAGSystem: def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = "./embeddings_cache"): """ Initialize the enhanced RAG system with embeddings. Args: db_path: Path to the SQLite database model_name: Name of the sentence-transformer model to use cache_dir: Directory to cache embeddings """ self.db_path = db_path self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True) # Load embedding model print(f"Loading embedding model: {model_name}") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SentenceTransformer(model_name, device=self.device) print(f"Embedding model loaded on {self.device}") # Cache for embeddings self.embedding_cache = self._load_embedding_cache() # Initialize database tables if needed self._initialize_db() # Load existing conversations and cache embeddings self._load_conversations() def _initialize_db(self): """Create necessary tables if they don't exist.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create conversations table if it doesn't exist cursor.execute(""" CREATE TABLE IF NOT EXISTS conversations ( id INTEGER PRIMARY KEY, user_message TEXT, ai_message TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) # Create embeddings table if it doesn't exist cursor.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( id INTEGER PRIMARY KEY, conversation_id INTEGER, text TEXT, embedding_file TEXT, chunk_id TEXT, FOREIGN KEY (conversation_id) REFERENCES conversations(id) ) """) conn.commit() conn.close() def _load_embedding_cache(self) -> Dict[str, np.ndarray]: """Load cached embeddings from disk.""" cache = {} for cache_file in self.cache_dir.glob("*.json"): try: with open(cache_file, "r") as f: cache_data = json.load(f) for chunk_id, embedding_data in cache_data.items(): cache[chunk_id] = np.array(embedding_data) except Exception as e: print(f"Error loading cache file {cache_file}: {e}") print(f"Loaded {len(cache)} cached embeddings") return cache def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray): """Save an embedding to the cache.""" cache_file = self.cache_dir / f"{chunk_id[:2]}.json" # Load existing cache file or create new one if cache_file.exists(): try: with open(cache_file, "r") as f: cache_data = json.load(f) except: cache_data = {} else: cache_data = {} # Add new embedding cache_data[chunk_id] = embedding.tolist() # Save cache file with open(cache_file, "w") as f: json.dump(cache_data, f) def _load_conversations(self): """Load existing conversations from the database and cache their embeddings.""" try: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # First check if the conversations table exists cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'") if not cursor.fetchone(): print("Conversations table does not exist yet") conn.close() return # Get all conversations not yet in the embeddings table cursor.execute(""" SELECT c.id, c.user_message, c.ai_message FROM conversations c LEFT JOIN embeddings e ON c.id = e.conversation_id WHERE e.id IS NULL """) conversations = cursor.fetchall() if not conversations: conn.close() return print(f"Processing embeddings for {len(conversations)} new conversations") for conv_id, user_message, ai_message in conversations: # Create chunks for indexing if user_message is not None and ai_message is not None: # Ensure neither is None self._process_conversation(conv_id, user_message, ai_message, conn) conn.close() print("Finished processing conversation embeddings") except Exception as e: print(f"Error loading conversations: {e}") def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection): """Process a conversation and store its embeddings.""" try: cursor = conn.cursor() # Combine user and AI messages full_text = f"User: {user_message}\nAI: {ai_message}" # For simplicity, we're using the entire message as a chunk # In a more sophisticated system, you might split long messages into smaller chunks chunk_id = f"conv_{conv_id}" # Check if we already have this embedding cached if chunk_id not in self.embedding_cache: # Generate embedding embedding = self.model.encode(full_text) self.embedding_cache[chunk_id] = embedding # Save to cache self._save_embedding_to_cache(chunk_id, embedding) else: embedding = self.embedding_cache[chunk_id] # Store reference in database embedding_file = f"{chunk_id[:2]}.json" cursor.execute( "INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)", (conv_id, full_text, embedding_file, chunk_id) ) conn.commit() except Exception as e: print(f"Error processing conversation {conv_id}: {e}") def add_conversation(self, user_message: str, ai_message: str) -> int: """ Add a new conversation to the RAG system. Returns: The id of the newly added conversation """ try: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Insert the conversation first cursor.execute( "INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)", (user_message, ai_message) ) # Get the ID of the new conversation conv_id = cursor.lastrowid # Process the conversation for embeddings self._process_conversation(conv_id, user_message, ai_message, conn) conn.commit() conn.close() return conv_id except Exception as e: print(f"Error adding conversation: {e}") return -1 def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]: """ Query the RAG system for relevant context. Args: query_text: The query text top_k: Number of top results to return Returns: List of tuples with (text, similarity_score) """ if query_text is None or query_text.strip() == "": print("Error: Empty query text") return [] try: # Generate query embedding query_embedding = self.model.encode(query_text) # Find most similar conversations results = self._find_similar(query_embedding, top_k) return results except Exception as e: print(f"Error during query: {e}") return [] def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str: """ Get formatted context from the RAG system. Args: query_text: The query text top_k: Number of top results to return threshold: Minimum similarity score to include Returns: String with relevant context """ results = self.query(query_text, top_k) if not results: return "" # Format results context_parts = [] for text, score in results: # Only include really relevant results if score < threshold: # Threshold for relevance continue context_parts.append(f"Relevance: {score:.2f}\n{text}") return "\n---\n".join(context_parts) def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]: """Find the most similar conversations to the query.""" try: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Check if the embeddings table exists cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'") if not cursor.fetchone(): print("Embeddings table does not exist yet") conn.close() return [] # Get all embeddings from the database cursor.execute("SELECT id, text, embedding_file, chunk_id FROM embeddings") results = cursor.fetchall() if not results: conn.close() return [] # Calculate similarities similarities = [] for db_id, text, embedding_file, chunk_id in results: # Get embedding from cache if chunk_id in self.embedding_cache: embedding = self.embedding_cache[chunk_id] else: # This should not happen, but just in case # We'll reload from the cache file cache_file = self.cache_dir / embedding_file if cache_file.exists(): with open(cache_file, "r") as f: cache_data = json.load(f) if chunk_id in cache_data: embedding = np.array(cache_data[chunk_id]) self.embedding_cache[chunk_id] = embedding else: continue else: continue # Calculate similarity similarity = cosine_similarity( query_embedding.reshape(1, -1), embedding.reshape(1, -1) )[0][0] similarities.append((text, similarity)) conn.close() # Sort by similarity and return top_k similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] except Exception as e: print(f"Error finding similar documents: {e}") return [] def refresh(self): """Refresh embeddings from the database.""" self._load_conversations() # Example usage if __name__ == "__main__": # Initialize the RAG system rag = RAGSystem("conversations.db") ================================================ FILE: requirements.txt ================================================ --extra-index-url=https://download.pytorch.org/whl/cu128 vllm==0.8.0 torch==2.6.0 torchaudio==2.6.0 torchvision==0.21.0 tokenizers==0.21.0 transformers==4.49.0 huggingface_hub==0.28.1 moshi==0.2.2 sounddevice torchtune==0.4.0 torchao==0.9.0 bitsandbytes peft wandb silero_vad python-multipart>=0.0.6 aiofiles>=23.1.0 sentence-transformers>=2.2.2 ctransformers>=0.2.24 python-multipart>=0.0.6 sqlalchemy>=2.0.0 pydantic>=2.0.0 fastapi>=0.95.0 uvicorn>=0.22.0 websockets>=11.0.3 jinja2>=3.0.0 speechbrain>=0.5.15 matplotlib whisper-openai silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master numpy==1.26.0 ================================================ FILE: run_csm.py ================================================ import os import torch import torchaudio from huggingface_hub import hf_hub_download from generator import load_csm_1b, Segment from dataclasses import dataclass # Default prompts are available at https://hf.co/sesame/csm-1b prompt_filepath_conversational_a = hf_hub_download( repo_id="sesame/csm-1b", filename="prompts/conversational_a.wav" ) prompt_filepath_conversational_b = hf_hub_download( repo_id="sesame/csm-1b", filename="prompts/conversational_b.wav" ) SPEAKER_PROMPTS = { "conversational_a": { "text": ( "like revising for an exam I'd have to try and like keep up the momentum because I'd " "start really early I'd be like okay I'm gonna start revising now and then like " "you're revising for ages and then I just like start losing steam I didn't do that " "for the exam we had recently to be fair that was a more of a last minute scenario " "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " "sort of start the day with this not like a panic but like a" ), "audio": prompt_filepath_conversational_a }, "conversational_b": { "text": ( "like a super Mario level. Like it's very like high detail. And like, once you get " "into the park, it just like, everything looks like a computer game and they have all " "these, like, you know, if, if there's like a, you know, like in a Mario game, they " "will have like a question block. And if you like, you know, punch it, a coin will " "come out. So like everyone, when they come into the park, they get like this little " "bracelet and then you can go punching question blocks around." ), "audio": prompt_filepath_conversational_b } } def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: audio_tensor, sample_rate = torchaudio.load(audio_path) audio_tensor = audio_tensor.squeeze(0) # Resample is lazy so we can always call it audio_tensor = torchaudio.functional.resample( audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate ) return audio_tensor def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: audio_tensor = load_prompt_audio(audio_path, sample_rate) return Segment(text=text, speaker=speaker, audio=audio_tensor) def main(): # Select the best available device, skipping MPS due to float64 limitations if torch.cuda.is_available(): device = "cuda" else: device = "cpu" print(f"Using device: {device}") # Load model generator = load_csm_1b(device) # Prepare prompts prompt_a = prepare_prompt( SPEAKER_PROMPTS["conversational_a"]["text"], 0, SPEAKER_PROMPTS["conversational_a"]["audio"], generator.sample_rate ) prompt_b = prepare_prompt( SPEAKER_PROMPTS["conversational_b"]["text"], 1, SPEAKER_PROMPTS["conversational_b"]["audio"], generator.sample_rate ) # Generate conversation conversation = [ {"text": "Hey how are you doing?", "speaker_id": 0}, {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} ] # Generate each utterance generated_segments = [] prompt_segments = [prompt_a, prompt_b] for utterance in conversation: print(f"Generating: {utterance['text']}") audio_tensor = generator.generate( text=utterance['text'], speaker=utterance['speaker_id'], context=prompt_segments + generated_segments, max_audio_length_ms=10_000, ) generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) # Concatenate all generations all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) torchaudio.save( "full_conversation.wav", all_audio.unsqueeze(0).cpu(), generator.sample_rate ) print("Successfully generated full_conversation.wav") if __name__ == "__main__": main() ================================================ FILE: setup.py ================================================ import os import sys import subprocess import logging import urllib.request import torch import time import shutil from pathlib import Path # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def check_requirements(): """Check if all required Python packages are installed""" logger.info("Checking requirements...") requirements = [ "torch", "torchaudio", "fastapi", "uvicorn", "websockets", "numpy", "scikit-learn", "sqlalchemy", "pydantic", "jinja2", "whisper", "sounddevice", "soundfile", "sentence_transformers", "ctransformers" ] missing = [] for req in requirements: try: __import__(req) except ImportError: missing.append(req) if missing: logger.warning(f"Missing required packages: {', '.join(missing)}") logger.info("Installing missing requirements...") subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) logger.info("Requirements installed successfully") else: logger.info("All requirements are satisfied") def download_vad_model(): """Download the Silero VAD model using PyTorch Hub instead of direct URL""" model_path = "silero_vad.jit" if os.path.exists(model_path): logger.info(f"Silero VAD model already exists at {model_path}") return logger.info("Downloading Silero VAD model using PyTorch Hub...") try: # Use torch.hub to download the model instead of direct URL torch.hub.set_dir("./models") model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=False) # Save the model torch.jit.save(model, model_path) logger.info(f"Model downloaded and saved to {model_path}") except Exception as e: logger.error(f"Failed to download Silero VAD model using PyTorch Hub: {e}") logger.info("Falling back to energy-based VAD - the system will still work but with simpler voice detection") def download_embedding_models(): """Download the sentence transformer models for RAG""" logger.info("Setting up sentence transformer models...") try: from sentence_transformers import SentenceTransformer # Download lightweight model for embeddings logger.info("Downloading embedding models (this may take a few minutes)...") models = [ "all-MiniLM-L6-v2", # Fast "all-mpnet-base-v2", # Balanced "multi-qa-mpnet-base-dot-v1" # Best for Q&A ] for model_name in models: logger.info(f"Setting up model: {model_name}") _ = SentenceTransformer(model_name) logger.info(f"Model {model_name} is ready") except Exception as e: logger.error(f"Failed to download embedding models: {e}") logger.error("Please try running the script again or download models manually") def setup_directories(): """Create necessary directories for the application""" directories = ["static", "responses", "embeddings_cache", "templates"] for directory in directories: os.makedirs(directory, exist_ok=True) logger.info(f"Directory {directory} is ready") # Create template redirect file template_dir = Path("templates") index_html = template_dir / "index.html" with open(index_html, "w") as f: f.write("""

Redirecting to AI Companion...

""") logger.info("Created index template for redirection") def setup_database(): """Initialize the SQLite database""" logger.info("Setting up database...") try: from sqlalchemy import create_engine, Column, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker Base = declarative_base() engine = create_engine("sqlite:///companion.db") class Conversation(Base): __tablename__ = "conversations" id = Column(Integer, primary_key=True, index=True) session_id = Column(String, index=True) timestamp = Column(String) user_message = Column(Text) ai_message = Column(Text) audio_path = Column(String) # Create tables Base.metadata.create_all(bind=engine) logger.info("Database initialized successfully") except Exception as e: logger.error(f"Failed to set up database: {e}") def check_cuda(): """Check if CUDA is available for PyTorch""" if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0) logger.info(f"CUDA is available: {device_name}") logger.info(f"CUDA version: {torch.version.cuda}") else: logger.warning("CUDA is not available. The application will run on CPU, which may be very slow") logger.warning("For optimal performance, a CUDA-capable GPU is recommended") def main(): """Main setup function""" logger.info("Starting AI Companion setup...") # Check for CUDA availability check_cuda() # Check and install requirements #check_requirements() # Create directories setup_directories() # Set up database setup_database() # Download models download_vad_model() download_embedding_models() logger.info("Setup completed successfully!") logger.info("You can now start the application with:") logger.info(" python main.py") if __name__ == "__main__": main() ================================================ FILE: static/app.js ================================================ let ws; let micAnalyser, micContext, micSource, micStream; let outputAnalyser, outputAudioCtx; let lastConfig = null; let isLoading = false; document.addEventListener('DOMContentLoaded', async () => { await populateAudioDevices(); ws = new WebSocket(`ws://${window.location.host}/ws`); ws.onopen = () => { console.log("WebSocket connected, requesting saved config..."); ws.send(JSON.stringify({ type: "request_saved_config" })); }; ws.onmessage = async (event) => { const data = JSON.parse(event.data); if (data.type === "saved_config" && data.config) { document.getElementById('systemPrompt').value = data.config.system_prompt || ""; document.getElementById('modelPath').value = data.config.model_path || ""; document.getElementById('llmPath').value = data.config.llm_path || ""; document.getElementById('referenceAudio').value = data.config.reference_audio_path || ""; document.getElementById('referenceText').value = data.config.reference_text || ""; document.getElementById('referenceAudio2').value = data.config.reference_audio_path2 || ""; document.getElementById('referenceText2').value = data.config.reference_text2 || ""; document.getElementById('referenceAudio3').value = data.config.reference_audio_path3 || ""; document.getElementById('referenceText3').value = data.config.reference_text3 || ""; setTimeout(() => { if (data.config.mic_id) document.getElementById('micSelect').value = data.config.mic_id; if (data.config.output_id) document.getElementById('outputSelect').value = data.config.output_id; }, 500); } if (data.type === "status") { if (data.message.includes("Models initialized")) { console.log("Model initialization confirmed. Redirecting..."); // Save config again just to be safe localStorage.setItem('ai_config', JSON.stringify(lastConfig)); // Close WebSocket before navigating if (ws && ws.readyState === WebSocket.OPEN) { ws.close(); } // Wait briefly to let server clean up, then redirect setTimeout(() => { window.location.href = "/chat"; }, 100); } else if (data.message.includes("Initializing") || data.message.includes("Loading")) { // Show that models are being loaded showLoading(true, data.message); } } }; document.getElementById('testMicBtn').addEventListener('click', async () => { const micId = getSelectedMic(); micStream = await navigator.mediaDevices.getUserMedia({ audio: { deviceId: micId } }); micContext = new AudioContext(); micSource = micContext.createMediaStreamSource(micStream); micAnalyser = micContext.createAnalyser(); micSource.connect(micAnalyser); visualizeMic(micAnalyser, 'micCanvas'); const recorder = new MediaRecorder(micStream); const chunks = []; recorder.ondataavailable = e => { if (e.data.size > 0) chunks.push(e.data); }; recorder.onstop = () => { const blob = new Blob(chunks, { type: 'audio/webm' }); const url = URL.createObjectURL(blob); const audio = new Audio(url); audio.play(); micStream.getTracks().forEach(track => track.stop()); micContext.close(); }; recorder.start(); setTimeout(() => recorder.stop(), 3000); }); document.getElementById('testOutputBtn').addEventListener('click', () => { const audio = new Audio('/static/test.mp3'); audio.setSinkId(getSelectedOutput()).then(() => { outputAudioCtx = new AudioContext(); const outputSource = outputAudioCtx.createMediaElementSource(audio); outputAnalyser = outputAudioCtx.createAnalyser(); outputSource.connect(outputAnalyser); outputAnalyser.connect(outputAudioCtx.destination); visualizeMic(outputAnalyser, 'outputCanvas'); audio.play(); }).catch(err => { console.warn("Failed to route output:", err); }); }); document.getElementById('saveAndStart').addEventListener('click', () => { lastConfig = { system_prompt: document.getElementById('systemPrompt').value, model_path: document.getElementById('modelPath').value, llm_path: document.getElementById('llmPath').value, reference_audio_path: document.getElementById('referenceAudio').value, reference_text: document.getElementById('referenceText').value, reference_audio_path2: document.getElementById('referenceAudio2').value, reference_text2: document.getElementById('referenceText2').value, reference_audio_path3: document.getElementById('referenceAudio3').value, reference_text3: document.getElementById('referenceText3').value, mic_id: getSelectedMic(), output_id: getSelectedOutput(), }; console.log("Sending config to backend..."); console.log(lastConfig) showLoading(true, "Initializing models, please wait..."); ws.send(JSON.stringify({ type: "config", config: lastConfig })); // we wait for the backend to reply with model status before navigating }); }); function showLoading(show, message) { const saveButton = document.getElementById('saveAndStart'); const loadingContainer = document.getElementById('loadingContainer'); const loadingSpinner = document.getElementById('loadingSpinner'); const loadingText = document.getElementById('loadingText'); isLoading = show; if (show) { saveButton.disabled = true; saveButton.classList.add('opacity-50', 'cursor-not-allowed'); saveButton.classList.remove('hover:bg-green-500'); loadingContainer.classList.remove('hidden'); loadingSpinner.style.display = 'block'; if (message) { loadingText.textContent = message; } } else { saveButton.disabled = false; saveButton.classList.remove('opacity-50', 'cursor-not-allowed'); saveButton.classList.add('hover:bg-green-500'); loadingContainer.classList.add('hidden'); loadingSpinner.style.display = 'none'; } } function getSelectedMic() { return document.getElementById('micSelect').value; } function getSelectedOutput() { return document.getElementById('outputSelect').value; } async function populateAudioDevices() { try { await navigator.mediaDevices.getUserMedia({ audio: true }); } catch (err) { console.warn("Microphone permission denied or not granted."); return; } const devices = await navigator.mediaDevices.enumerateDevices(); const micSelect = document.getElementById('micSelect'); const outputSelect = document.getElementById('outputSelect'); micSelect.innerHTML = ''; outputSelect.innerHTML = ''; devices.forEach(device => { const option = new Option(device.label || `${device.kind}`, device.deviceId); if (device.kind === 'audioinput') micSelect.add(option.cloneNode(true)); if (device.kind === 'audiooutput') { outputSelect.add(option.cloneNode(true)); } }); if (micSelect.options.length === 0) { micSelect.add(new Option("No mic devices found", "")); } if (outputSelect.options.length === 0) { outputSelect.add(new Option("Default Output", "default")); } } function visualizeMic(analyser, canvasId) { const canvas = document.getElementById(canvasId); const ctx = canvas.getContext("2d"); analyser.fftSize = 256; const bufferLength = analyser.frequencyBinCount; const dataArray = new Uint8Array(bufferLength); function draw() { requestAnimationFrame(draw); analyser.getByteFrequencyData(dataArray); ctx.fillStyle = "#1f2937"; ctx.fillRect(0, 0, canvas.width, canvas.height); const barWidth = canvas.width / bufferLength; for (let i = 0; i < bufferLength; i++) { const barHeight = dataArray[i]; ctx.fillStyle = "#4ade80"; ctx.fillRect(i * barWidth, canvas.height - barHeight / 2, barWidth - 1, barHeight / 2); } } draw(); } ================================================ FILE: static/chat.js ================================================ let ws; let sessionStartTime = null; let messageCount = 0; let audioLevelsChart = null; let isRecording = false; let isAudioCurrentlyPlaying = false; let configSaved = false; let currentAudioSource = null; let interruptRequested = false; let interruptInProgress = false; let audioContext = null; let lastSeenGenId = 0; let reconnecting = false; let reconnectAttempts = 0; let maxReconnectAttempts = 10; const SESSION_ID = "default"; console.log("chat.js loaded"); let micStream; let selectedMicId = null; let selectedOutputId = null; let audioPlaybackQueue = []; let audioDataHistory = []; let micAnalyser, micContext; let activeGenId = 0; function createPermanentVoiceCircle() { if (document.getElementById('voice-circle')) return; const style = document.createElement('style'); style.textContent = ` #voice-circle{ position:fixed;top:50%;left:50%; width:180px;height:180px;border-radius:50%; background:rgba(99,102,241,.20); transform:translate(-50%,-50%) scale(var(--dynamic-scale,1)); pointer-events:none;z-index:50; transition:background-color .35s ease; } #voice-circle.active{ animation:pulse-circle 2s infinite alternate ease-in-out; } @keyframes pulse-circle{ 0%{background:rgba(99,102,241,.55)} 100%{background:rgba(99,102,241,.20)} }`; document.head.appendChild(style); const c = document.createElement('div'); c.id='voice-circle'; document.body.appendChild(c); console.log("Created permanent voice circle"); } function showVoiceCircle() { const c=document.getElementById('voice-circle')||createPermanentVoiceCircle(); c.classList.add('active'); } function hideVoiceCircle() { const c=document.getElementById('voice-circle'); if (c){ c.classList.remove('active'); c.style.setProperty('--dynamic-scale',1); } } function showNotification(msg, type='info'){ const n=document.createElement('div'); n.className=`fixed bottom-4 right-4 px-4 py-3 rounded-lg shadow-lg z-50 ${type==='success'?'bg-green-600': type==='error' ?'bg-red-600':'bg-indigo-600'}`; n.textContent=msg; document.body.appendChild(n); setTimeout(()=>{n.classList.add('opacity-0'); setTimeout(()=>n.remove(),500)},3000); } function addMessageToConversation(sender,text){ const pane=document.getElementById('conversationHistory'); if(!pane) return; const box=document.createElement('div'); box.className=`p-3 mb-3 rounded-lg text-sm ${ sender==='user'?'bg-gray-800 ml-2':'bg-indigo-900 mr-2'}`; box.innerHTML=`
${sender==='user'?'U':'AI'}
${new Date().toLocaleTimeString()}
${text .replace(/&/g,'&').replace(/$1') .replace(/\*(.*?)\*/g,'$1') .replace(/```([^`]+)```/g,'
$1
') .replace(/`([^`]+)`/g,'$1') .replace(/\n/g,'
')}
`; pane.appendChild(box); pane.scrollTop=pane.scrollHeight; } function connectWebSocket() { if (reconnecting && reconnectAttempts >= maxReconnectAttempts) { console.error("Maximum reconnect attempts reached. Please refresh the page."); showNotification("Connection lost. Please refresh the page.", "error"); return; } if (ws && ws.readyState !== WebSocket.CLOSED && ws.readyState !== WebSocket.CLOSING) { try { ws.close(); } catch (e) { console.warn("Error closing existing WebSocket:", e); } } const proto = location.protocol === 'https:' ? 'wss:' : 'ws:'; ws = new WebSocket(`${proto}//${location.host}/ws`); window.ws = ws; const connLbl = document.getElementById('connectionStatus'); if (connLbl) { connLbl.textContent = reconnecting ? 'Reconnecting…' : 'Connecting…'; connLbl.className = 'text-yellow-500'; } ws.onopen = () => { if (connLbl) { connLbl.textContent = 'Connected'; connLbl.className = 'text-green-500'; } reconnecting = false; reconnectAttempts = 0; ws.send(JSON.stringify({type: 'request_saved_config'})); if (!reconnecting) { addMessageToConversation('ai', 'WebSocket connected. Ready for voice or text.'); } else { showNotification("Reconnected successfully", "success"); } }; ws.onclose = (event) => { console.log("WebSocket closed with code:", event.code); if (connLbl) { connLbl.textContent = 'Disconnected'; connLbl.className = 'text-red-500'; } // Clear audio state on disconnection clearAudioPlayback(); // Don't auto-reconnect if this was a normal closure if (event.code !== 1000 && event.code !== 1001) { reconnecting = true; reconnectAttempts++; const delay = Math.min(1000 * Math.pow(1.5, reconnectAttempts), 1000); console.log(`Reconnecting in ${delay}ms (attempt ${reconnectAttempts})`); setTimeout(connectWebSocket, delay); } }; ws.onerror = (error) => { console.error("WebSocket error:", error); if (connLbl) { connLbl.textContent = 'Error'; connLbl.className = 'text-red-500'; } }; ws.onmessage = (e) => { try { const data = JSON.parse(e.data); handleWebSocketMessage(data); } catch (err) { console.error("Error handling WebSocket message:", err); } }; } function sendTextMessage(txt) { if (!txt.trim()) return; if (!ws || ws.readyState !== WebSocket.OPEN) { showNotification("Not connected", "error"); return; } console.log("Force clearing all audio state before sending text message"); // Stop any playing audio if (isAudioCurrentlyPlaying) { if (currentAudioSource) { try { if (currentAudioSource.disconnect) currentAudioSource.disconnect(); if (currentAudioSource.stop) currentAudioSource.stop(0); } catch (e) { console.warn("Error stopping audio:", e); } currentAudioSource = null; } isAudioCurrentlyPlaying = false; } // Clear all flags and queues interruptRequested = false; interruptInProgress = false; activeGenId = 0; audioPlaybackQueue = []; // Always force interruption to be absolutely sure if (ws && ws.readyState === WebSocket.OPEN) { try { ws.send(JSON.stringify({type: 'interrupt', immediate: true})); } catch (e) { console.warn("Error sending interrupt:", e); } } // Wait a bit before sending the actual message setTimeout(() => { try { // Show visual feedback showVoiceCircle(); // Send the message ws.send(JSON.stringify({ type: 'text_message', text: txt, session_id: SESSION_ID })); const cnt = document.getElementById('messageCount'); if (cnt) cnt.textContent = ++messageCount; document.getElementById('textInput').value = ''; console.log("Text message sent successfully"); } catch (error) { console.error("Error sending message:", error); showNotification("Error sending message", "error"); } }, 300); } // Reset all audio state to ensure clean state for new interactions function resetAudioState() { console.log("Resetting audio state"); // Clear any stale generation information activeGenId = 0; lastSeenGenId = 0; // Clear any remaining flags interruptRequested = false; interruptInProgress = false; // Make sure we don't have any playing audio if (isAudioCurrentlyPlaying) { clearAudioPlayback(); } // Clear any queued audio audioPlaybackQueue = []; } function clearAudioPlayback() { console.log("FORCEFULLY CLEARING AUDIO PLAYBACK"); interruptRequested = true; interruptInProgress = true; try { // Empty the queue first - do this before stopping current source console.log(`Clearing queue with ${audioPlaybackQueue.length} items`); audioPlaybackQueue = []; activeGenId = 0; // Stop any currently playing audio if (currentAudioSource) { console.log("Stopping active audio source"); try { if (currentAudioSource.disconnect) { currentAudioSource.disconnect(); } } catch (e) { console.warn("Error disconnecting audio source:", e); } try { if (currentAudioSource.stop) { currentAudioSource.stop(0); } } catch (e) { console.warn("Error stopping audio source:", e); } currentAudioSource = null; } try { if (audioContext) { const oldContext = audioContext; audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; try { oldContext.close(); } catch (closeError) { console.warn("Error closing old audio context:", closeError); } } else { audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; } } catch (contextError) { console.error("Error recreating audio context:", contextError); } } catch (err) { console.error("Error clearing audio:", err); } // Reset state isAudioCurrentlyPlaying = false; hideVoiceCircle(); console.log("Audio playback cleared successfully"); // After a short delay, reset the interrupt flags to accept new audio setTimeout(() => { interruptInProgress = false; // Keep interruptRequested true until we get a new generation }, 300); } // Handle interruption request from user function requestInterrupt() { console.log("User requested interruption"); if (interruptInProgress) { console.log("Interrupt already in progress - force clearing again"); clearAudioPlayback(); return false; } // Set the flags immediately interruptRequested = true; interruptInProgress = true; // Show visual feedback showNotification("Interrupting...", "info"); // Force clear all audio immediately on client side clearAudioPlayback(); // Show visual feedback for the button const interruptBtn = document.getElementById('interruptBtn'); if (interruptBtn) { interruptBtn.classList.add('bg-red-800'); setTimeout(() => { interruptBtn.classList.remove('bg-red-800'); }, 300); } // Then notify the server if (ws && ws.readyState === WebSocket.OPEN) { console.log("Sending interrupt request to server"); try { ws.send(JSON.stringify({ type: 'interrupt', immediate: true })); } catch (error) { console.error("Error sending interrupt request:", error); } // Set a timeout to reset interrupt flags if we don't get server confirmation setTimeout(() => { if (interruptInProgress) { console.log("No interrupt confirmation received from server, resetting state"); interruptInProgress = false; } }, 2000); return true; } else { console.warn("WebSocket not available for interrupt request"); // Reset flag after brief delay if we couldn't send to server setTimeout(() => { interruptInProgress = false; }, 500); return false; } } function handleWebSocketMessage(d) { console.log("Received message:", d.type, d); switch(d.type) { case 'transcription': addMessageToConversation('user', d.text); showVoiceCircle(); break; case 'response': addMessageToConversation('ai', d.text); showVoiceCircle(); console.log("NEW RESPONSE RECEIVED - FORCE RESETTING ALL AUDIO STATE"); if (isAudioCurrentlyPlaying) { if (currentAudioSource) { try { if (currentAudioSource.disconnect) currentAudioSource.disconnect(); if (currentAudioSource.stop) currentAudioSource.stop(0); } catch (e) { console.warn("Error stopping current audio:", e); } currentAudioSource = null; } isAudioCurrentlyPlaying = false; } interruptRequested = false; interruptInProgress = false; activeGenId = 0; audioPlaybackQueue = []; try { if (audioContext) { if (audioContext.state === 'suspended') { audioContext.resume().catch(e => console.warn("Error resuming audio context:", e)); } } else { audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; } } catch (e) { console.warn("Error with audio context:", e); audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; } console.log("Audio state fully reset and ready for new audio"); break; case 'audio_chunk': console.log("Audio chunk received, flags:", "interruptRequested:", interruptRequested, "interruptInProgress:", interruptInProgress, "genId:", d.gen_id, "activeGenId:", activeGenId); if (!isAudioCurrentlyPlaying && activeGenId === 0) { console.log("FIRST AUDIO CHUNK - FORCING FLAGS RESET"); interruptRequested = false; interruptInProgress = false; } // Don't queue new audio if an interrupt was requested if (interruptRequested || interruptInProgress) { console.log("Interrupt active - ignoring new audio chunk"); return; } // Set active generation ID on first chunk if (activeGenId === 0) { activeGenId = d.gen_id || 1; console.log("!!! Setting activeGenId to:", activeGenId); } // Only accept chunks that match our active generation if ((d.gen_id === activeGenId) || (activeGenId === 0)) { queueAudioForPlayback(d.audio, d.sample_rate, d.gen_id || 0); showVoiceCircle(); } else { console.log(`Ignored stale chunk - current gen: ${activeGenId}, received: ${d.gen_id}`); } break; case 'audio_status': console.log("Audio status update:", d.status); if (d.status === 'generating') { console.log("GOT GENERATING STATUS - IMMEDIATELY CLEARING ALL INTERRUPT FLAGS"); interruptRequested = false; interruptInProgress = false; // Capture the generation ID for new generations if (d.gen_id) { console.log(`New generation starting with ID: ${d.gen_id}`); activeGenId = d.gen_id; } showVoiceCircle(); } else if (d.status === 'complete') { console.log("Audio generation complete"); if (!d.gen_id || d.gen_id === activeGenId) { activeGenId = 0; // Reset for next generation } if (!isAudioCurrentlyPlaying) { hideVoiceCircle(); } } else if (d.status === 'interrupted' || d.status === 'interrupt_acknowledged') { console.log("Server confirmed interrupt - clearing audio"); clearAudioPlayback(); setTimeout(() => { console.log("Resetting interrupt flags after server confirmation"); interruptRequested = false; interruptInProgress = false; }, 300); } break; case 'status': if (d.message === 'Thinking...') { showVoiceCircle(); interruptRequested = false; interruptInProgress = false; activeGenId = 0; } break; case 'error': showNotification(d.message, 'error'); hideVoiceCircle(); break; case 'vad_status': if (d.status === 'speech_started') { console.log(`[VAD] speech_started | should_interrupt=${d.should_interrupt}`); if (d.should_interrupt && isAudioCurrentlyPlaying) { console.log('[VAD] confirmed – sending interrupt'); requestInterrupt(); } else { console.log('[VAD] ignored (echo / early AI audio)'); } } break; } } function queueAudioForPlayback(arr, sr, genId = 0) { if (activeGenId !== 0 && genId !== activeGenId) { console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`); return; } // Don't queue if interrupting if (interruptRequested || interruptInProgress) { console.log("Interrupt active - skipping audio chunk"); return; } console.log("Queueing audio chunk for playback"); audioPlaybackQueue.push({arr, sr, genId}); if (!isAudioCurrentlyPlaying) { console.log("▶Starting audio playback"); processAudioPlaybackQueue(); } } function queueAudioForPlayback(arr, sr, genId = 0) { // Extra logging for the first audio chunk if (!isAudioCurrentlyPlaying) { console.log("Queueing first audio chunk", "interruptRequested:", interruptRequested, "interruptInProgress:", interruptInProgress); } if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length === 0) { console.log("First audio chunk - forcing clear of interrupt flags"); interruptRequested = false; interruptInProgress = false; } // Don't queue audio from a different generation than our active one if (activeGenId !== 0 && genId !== activeGenId) { console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`); return; } // Don't queue if interrupting - BUT CHECK AGAIN THAT FLAGS ARE VALID if (interruptRequested || interruptInProgress) { console.log("Interrupt active - skipping audio chunk"); return; } console.log("Queueing audio chunk for playback"); audioPlaybackQueue.push({arr, sr, genId}); if (!isAudioCurrentlyPlaying) { console.log("STARTING AUDIO PLAYBACK - FIRST CHUNK"); processAudioPlaybackQueue(); } } // Modified to ensure first audio actually plays function processAudioPlaybackQueue() { if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length > 0) { console.log("Starting first audio chunk - force clearing interrupt flags"); interruptRequested = false; interruptInProgress = false; } // Double-check interrupt flags AFTER clearling them if (interruptRequested || interruptInProgress) { console.log("Interrupt active - not processing audio queue"); isAudioCurrentlyPlaying = false; hideVoiceCircle(); return; } // Check if queue is empty if (!audioPlaybackQueue.length) { console.log("📭 Audio queue empty, stopping playback"); isAudioCurrentlyPlaying = false; hideVoiceCircle(); currentAudioSource = null; return; } // Enable the interrupt button when audio is playing const interruptBtn = document.getElementById('interruptBtn'); if (interruptBtn) { interruptBtn.disabled = false; interruptBtn.classList.remove('opacity-50'); } console.log("Processing next audio chunk"); isAudioCurrentlyPlaying = true; // Get the genId from the chunk const {arr, sr, genId} = audioPlaybackQueue.shift(); // Skip if it's a stale chunk if (activeGenId !== 0 && genId !== activeGenId) { console.log(`Skipping stale chunk playback (gen ${genId} vs active ${activeGenId})`); processAudioPlaybackQueue(); // Continue with next chunk return; } playAudioChunk(arr, sr) .then(() => { // Check interrupt status again after playback if (!interruptRequested && !interruptInProgress) { processAudioPlaybackQueue(); } else { console.log("interrupt active - stopping queue processing"); isAudioCurrentlyPlaying = false; hideVoiceCircle(); } }) .catch(err => { console.error("Error in audio playback:", err); isAudioCurrentlyPlaying = false; hideVoiceCircle(); // Try to continue with next chunk despite errors setTimeout(() => { if (audioPlaybackQueue.length > 0 && !interruptRequested) { processAudioPlaybackQueue(); } }, 200); }); } async function playAudioChunk(audioArr, sampleRate) { // Skip playback if interrupt was requested if (interruptRequested || interruptInProgress) { console.log("Interrupt active - not playing audio chunk"); return Promise.resolve(); } try { // Ensure we have a valid audio context if (!audioContext) { audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; } // Make sure context is resumed if (audioContext.state === 'suspended') { await audioContext.resume(); } const buf = audioContext.createBuffer(1, audioArr.length, sampleRate); buf.copyToChannel(new Float32Array(audioArr), 0); const src = audioContext.createBufferSource(); src.buffer = buf; // Store reference to current source for potential interruption currentAudioSource = src; const an = audioContext.createAnalyser(); an.fftSize = 256; src.connect(an); an.connect(audioContext.destination); src.start(); console.log("🎵 Started playing audio chunk"); const arr = new Uint8Array(an.frequencyBinCount); const circle = document.getElementById('voice-circle'); // Animation function that respects interruption function pump() { // Stop animation if source is no longer current or interrupt requested if (src !== currentAudioSource || interruptRequested || interruptInProgress) { return; } try { an.getByteFrequencyData(arr); const avg = arr.reduce((a,b) => a+b, 0) / arr.length; if (circle) { circle.style.setProperty('--dynamic-scale', (1+avg/255*1.5).toFixed(3)); } } catch (e) { console.warn("Error in animation pump:", e); return; } if (src.playbackState !== src.FINISHED_STATE) { requestAnimationFrame(pump); } } pump(); return new Promise(resolve => { src.onended = () => { // Only resolve if this is still the current source and no interrupt if (src === currentAudioSource && !interruptRequested && !interruptInProgress) { resolve(); } else { resolve(); // Still resolve to maintain chain } }; }); } catch (error) { console.error("Error playing audio chunk:", error); return Promise.resolve(); // Resolve anyway to keep chain going } } async function startRecording() { if (isRecording) return; try { const constraints = { audio: selectedMicId ? {deviceId:{exact:selectedMicId}} : true }; micStream = await navigator.mediaDevices.getUserMedia(constraints); if (!audioContext) audioContext = new (AudioContext||webkitAudioContext)(); const src = audioContext.createMediaStreamSource(micStream); const proc = audioContext.createScriptProcessor(4096,1,1); src.connect(proc); proc.connect(audioContext.destination); proc.onaudioprocess = e => { const samples = Array.from(e.inputBuffer.getChannelData(0)); if (ws && ws.readyState === WebSocket.OPEN) { try { ws.send(JSON.stringify({ type:'audio', audio:samples, sample_rate:audioContext.sampleRate, session_id:SESSION_ID })); } catch (error) { console.error("Error sending audio data:", error); stopRecording(); } } }; window._micProcessor = proc; isRecording = true; document.getElementById('micStatus').textContent = 'Listening…'; showVoiceCircle(); } catch (err) { console.error("Microphone access error:", err); showNotification('Microphone access denied','error'); } } function stopRecording() { if (!isRecording) return; try { if (window._micProcessor) { window._micProcessor.disconnect(); window._micProcessor = null; } if (micStream) { micStream.getTracks().forEach(t => t.stop()); micStream = null; } } catch(e) { console.warn("Error stopping recording:", e); } isRecording = false; const micStatus = document.getElementById('micStatus'); if (micStatus) { micStatus.textContent = 'Click to speak'; } hideVoiceCircle(); } async function setupChatUI() { document.documentElement.classList.add('bg-gray-950'); document.documentElement.style.backgroundColor = '#030712'; createPermanentVoiceCircle(); connectWebSocket(); initAudioLevelsChart(); const txt = document.getElementById('textInput'); const btn = document.getElementById('sendTextBtn'); // Setup enhanced interrupt button const interruptBtn = document.createElement('button'); interruptBtn.id = 'interruptBtn'; interruptBtn.className = 'px-3 py-2 ml-2 bg-red-600 text-white rounded hover:bg-red-700 flex items-center transition duration-150'; interruptBtn.innerHTML = ' Stop'; interruptBtn.onclick = (e) => { e.preventDefault(); try { requestInterrupt(); interruptBtn.classList.add('bg-red-800', 'scale-95'); setTimeout(() => interruptBtn.classList.remove('bg-red-800', 'scale-95'), 150); } catch (error) { console.error("Error in interrupt button handler:", error); } }; interruptBtn.title = "Stop AI speech (Space or Esc)"; interruptBtn.disabled = true; // Disabled by default interruptBtn.classList.add('opacity-50', 'cursor-not-allowed'); if (btn && btn.parentElement) { btn.parentElement.appendChild(interruptBtn); } // Add debug button for easier debugging of interrupt issues const debugBtn = document.createElement('button'); debugBtn.innerText = "Debug Audio"; debugBtn.className = "px-3 py-2 ml-2 bg-blue-600 text-white rounded text-xs"; debugBtn.onclick = () => { console.log("- Debug info:"); console.log("- Audio playing:", isAudioCurrentlyPlaying); console.log("- Interrupt requested:", interruptRequested); console.log("- Interrupt in progress:", interruptInProgress); console.log("- Current source:", currentAudioSource); console.log("- Queue length:", audioPlaybackQueue.length); console.log("- Audio context state:", audioContext?.state); console.log("- Active generation ID:", activeGenId); console.log("- Last seen generation ID:", lastSeenGenId); console.log("- WebSocket state:", ws ? ws.readyState : "no websocket"); showNotification("Debug info in console", "info"); }; if (btn && btn.parentElement) { btn.parentElement.appendChild(debugBtn); } // Run the update function periodically setInterval(() => { const interruptBtn = document.getElementById('interruptBtn'); if (interruptBtn) { if (isAudioCurrentlyPlaying && !interruptRequested && !interruptInProgress) { interruptBtn.disabled = false; interruptBtn.classList.remove('opacity-50', 'cursor-not-allowed'); } else { interruptBtn.disabled = true; interruptBtn.classList.add('opacity-50', 'cursor-not-allowed'); } } }, 300); if (btn) { btn.onclick = () => { try { sendTextMessage(txt.value); } catch (error) { console.error("Error in send button handler:", error); } }; } if (txt) { txt.addEventListener('keydown', e => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); try { sendTextMessage(txt.value); } catch (error) { console.error("Error in text input handler:", error); } } }); } const micBtn = document.getElementById('micToggleBtn'); if (micBtn) { micBtn.addEventListener('click', () => { try { if (isRecording) stopRecording(); else startRecording(); } catch (error) { console.error("Error in mic button handler:", error); } }); } // Add event listeners to detect keyboard interruptions document.addEventListener('keydown', e => { // Allow space or escape to interrupt if ((e.code === 'Space' || e.code === 'Escape') && isAudioCurrentlyPlaying) { e.preventDefault(); try { requestInterrupt(); // Add visual feedback const interruptBtn = document.getElementById('interruptBtn'); if (interruptBtn) { interruptBtn.classList.add('bg-red-800'); setTimeout(() => { interruptBtn.classList.remove('bg-red-800'); }, 200); } } catch (error) { console.error("Error in keyboard interrupt handler:", error); } } }); // Initialize audio context if (!audioContext) { try { audioContext = new (window.AudioContext || window.webkitAudioContext)(); window.audioContext = audioContext; } catch (error) { console.error("Error creating audio context:", error); showNotification("Audio initialization failed. Please refresh the page.", "error"); } } // Try to unlock audio context on user interaction ['click', 'touchstart', 'keydown'].forEach(ev => document.addEventListener(ev, function unlock() { if (audioContext && audioContext.state === 'suspended') { try { audioContext.resume(); } catch (error) { console.warn("Error resuming audio context:", error); } } document.removeEventListener(ev, unlock); }) ); console.log("Chat UI ready with enhanced interruption support"); } if (document.readyState === 'loading') { document.addEventListener('DOMContentLoaded', setupChatUI); } else { setupChatUI(); } function initAudioLevelsChart() { const ctx = document.getElementById('audioLevels'); if (!ctx) return; try { if (audioLevelsChart) audioLevelsChart.destroy(); const grad = ctx.getContext('2d').createLinearGradient(0, 0, 0, 100); grad.addColorStop(0, 'rgba(79,70,229,.6)'); grad.addColorStop(1, 'rgba(79,70,229,.1)'); audioLevelsChart = new Chart(ctx, { type: 'line', data: { labels: Array(30).fill(''), datasets: [{ data: Array(30).fill(0), backgroundColor: grad, borderColor: 'rgba(99,102,241,1)', borderWidth: 2, tension: .4, fill: true, pointRadius: 0 }] }, options: { animation: false, responsive: true, scales: { y: { beginAtZero: true, max: 100, ticks: {display: false}, grid: {color: 'rgba(255,255,255,.1)'} }, x: {display: false, grid: {display: false}} }, plugins: { legend: {display: false}, tooltip: {enabled: false} }, elements: {point: {radius: 0}} } }); } catch (error) { console.error("Error initializing audio chart:", error); } } ================================================ FILE: static/crud.js ================================================ let allConversations = []; document.addEventListener('DOMContentLoaded', async () => { await loadConversations(); document.getElementById('searchInput').addEventListener('input', () => { const query = document.getElementById('searchInput').value.toLowerCase(); const filtered = allConversations.filter(c => c.user_message.toLowerCase().includes(query) || c.ai_message.toLowerCase().includes(query) ); renderConversations(filtered); }); document.getElementById('deleteAllBtn').addEventListener('click', async () => { if (!confirm("Are you sure you want to delete ALL conversations?")) return; await fetch('/api/conversations', { method: 'DELETE' }); await loadConversations(); }); }); async function loadConversations() { const res = await fetch('/api/conversations'); allConversations = await res.json(); renderConversations(allConversations); } function renderConversations(list) { const container = document.getElementById('conversationList'); container.innerHTML = ''; if (list.length === 0) { container.innerHTML = '

No conversations found.

'; return; } list.forEach(conv => { const div = document.createElement('div'); div.className = "bg-gray-800 p-4 rounded shadow"; div.innerHTML = `
User:
AI:
`; container.appendChild(div); div.querySelector('.saveBtn').addEventListener('click', async () => { const id = conv.id; const user = div.querySelector('textarea[data-field="user"]').value; const ai = div.querySelector('textarea[data-field="ai"]').value; await fetch(`/api/conversations/${id}`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ user_message: user, ai_message: ai }) }); alert("Saved."); }); div.querySelector('.deleteBtn').addEventListener('click', async () => { const id = conv.id; if (!confirm("Delete this conversation?")) return; await fetch(`/api/conversations/${id}`, { method: 'DELETE' }); await loadConversations(); }); }); } ================================================ FILE: templates/chat.html ================================================ AI Companion - Chat

AI Companion

================================================ FILE: templates/crud.html ================================================ Conversation Manager
Return to Setup

Memory Manager

================================================ FILE: templates/index.html ================================================ ================================================ FILE: templates/setup.html ================================================ AI Companion Setup

AI Companion Setup

Open Conversation DB

Primary Reference Audio (Required)

Secondary Reference (Optional)

For better voice quality

Tertiary Reference (Optional)

For even better voice quality
================================================ FILE: test.py ================================================ import time from generator import Segment, load_csm_1b, generate_streaming_audio import torchaudio print(f"Starting script at: {time.strftime('%H:%M:%S')}") start_time = time.time() print("Downloading model...") model_start = time.time() print(f"Model download completed in {time.time() - model_start:.2f} seconds") print("Loading model to CUDA...") load_start = time.time() generator = load_csm_1b("cuda") print(f"Model loaded in {time.time() - load_start:.2f} seconds") speakers = [0, 1, 0, 0] transcripts = [ "Hey how are you doing.", "Pretty good, pretty good.", "I'm great.", "So happy to be speaking to you.", ] audio_paths = [ "utterance_0.wav", "utterance_1.wav", "utterance_2.wav", "utterance_3.wav", ] def load_audio(audio_path): print(f"Loading reference audio: {audio_path}") audio_load_start = time.time() audio_tensor, sample_rate = torchaudio.load(audio_path) audio_tensor = torchaudio.functional.resample( audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate ) print(f"Audio loaded and resampled in {time.time() - audio_load_start:.2f} seconds") return audio_tensor print("Creating segments with reference audio...") segments_start = time.time() segments = [ Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) ] print(f"Segments created in {time.time() - segments_start:.2f} seconds") # Option 1: Regular generation with streaming internally enabled print("Generating audio (with internal streaming)...") gen_start = time.time() audio = generator.generate( text="Me too, this is some cool stuff huh?", speaker=0, context=segments, max_audio_length_ms=10_000, stream=True # Enable internal streaming ) print(f"Audio generation completed in {time.time() - gen_start:.2f} seconds") print("Saving audio file...") save_start = time.time() torchaudio.save("audio_regular.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) print(f"Audio saved in {time.time() - save_start:.2f} seconds") # Option 2: Use the streaming helper function that saves as it goes print("Generating audio using streaming API...") generate_streaming_audio( generator=generator, text="Me too, this is some cool stuff huh?", speaker=0, context=segments, output_file="audio_streamed.wav", max_audio_length_ms=10_000, play_audio=True # Set to True to play audio in real-time (requires sounddevice package) ) total_time = time.time() - start_time print(f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") print(f"Script completed at: {time.strftime('%H:%M:%S')}") ================================================ FILE: vad.py ================================================ import numpy as np import torch from typing import Callable, Dict, List from collections import deque class VoiceActivityDetector: def __init__( self, model, utils, sample_rate: int = 16000, threshold: float = 0.3, silence_duration: int = 45 ): self.model = model self.sample_rate = sample_rate self.threshold = threshold self.silence_duration = silence_duration # Get functions from utils self.get_speech_timestamps = utils[0] self.is_speaking = False self.silent_frames = 0 self.frame_size = 512 if sample_rate == 16000 else 256 # Required by Silero VAD print(f"VAD initialized with threshold {threshold}, frame size {self.frame_size}, silence duration {silence_duration}") def reset(self) -> None: self.is_speaking = False self.silent_frames = 0 if hasattr(self.model, "reset_states"): self.model.reset_states() elif hasattr(self.model, "reset_state"): self.model.reset_state() else: for buf in ("h", "c"): if hasattr(self.model, buf): getattr(self.model, buf).zero_() def process_audio_chunk(self, audio_chunk: np.ndarray) -> bool: # Prepare audio chunk if audio_chunk.ndim > 1: audio_chunk = np.mean(audio_chunk, axis=1) if audio_chunk.dtype != np.float32: audio_chunk = audio_chunk.astype(np.float32) # Process in chunks of the correct size speech_detected = False turn_ended = False speech_probs = [] # Process audio in correct sized chunks for Silero VAD for i in range(0, len(audio_chunk), self.frame_size): # Get chunk of correct size chunk = audio_chunk[i:i+self.frame_size] # If we don't have enough samples, pad with zeros if len(chunk) < self.frame_size: chunk = np.pad(chunk, (0, self.frame_size - len(chunk))) # Convert to tensor audio_tensor = torch.tensor(chunk).to('cpu') # Get speech probability speech_prob = self.model(audio_tensor, self.sample_rate).item() speech_probs.append(speech_prob) # Update speaking state if speech_prob >= self.threshold: speech_detected = True self.silent_frames = 0 else: if self.is_speaking: self.silent_frames += 1 # Print detailed speech detection information # print(f"Speech probabilities: {speech_probs}") # print(f"Speech detected: {speech_detected}, Current state: {self.is_speaking}") # print(f"Silent frames: {self.silent_frames}, Threshold: {self.silence_duration}") # Update speaking state based on all chunks if speech_detected: self.is_speaking = True self.silent_frames = 0 elif self.is_speaking and self.silent_frames >= self.silence_duration: # Transition to not speaking if we've had enough silent frames self.is_speaking = False turn_ended = True print(f"Turn ended after {self.silent_frames} silent frames") self.silent_frames = 0 return turn_ended class AudioStreamProcessor: def __init__( self, model, utils, sample_rate: int = 16000, chunk_size: int = 512, vad_threshold: float = 0.3, callbacks: Dict[str, Callable] = None, pre_speech_buffer_size: int = 10 ): self.sample_rate = sample_rate self.chunk_size = chunk_size self.pre_speech_buffer = deque(maxlen=pre_speech_buffer_size) # Ensure model is on CPU if hasattr(model, 'to'): model = model.to('cpu') self.vad = VoiceActivityDetector( model=model, utils=utils, sample_rate=sample_rate, threshold=vad_threshold, silence_duration=45 # Increased for better end detection ) self.audio_buffer = [] self.is_collecting = False self.callbacks = callbacks or {} self.silent_chunk_count = 0 self.max_silent_chunks = 30 # Force end after this many silent chunks print(f"AudioStreamProcessor initialized with threshold: {vad_threshold}") def process_audio(self, audio_chunk: np.ndarray): # Always add to pre-speech buffer self.pre_speech_buffer.append(audio_chunk) if self.is_collecting: self.audio_buffer.append(audio_chunk) # Process with VAD is_turn_end = self.vad.process_audio_chunk(audio_chunk) # Start collecting on speech detection if self.vad.is_speaking and not self.is_collecting: self.is_collecting = True self.silent_chunk_count = 0 # Include pre-speech buffer in the audio buffer self.audio_buffer = list(self.pre_speech_buffer) print(f"Speech started, beginning collection with {len(self.pre_speech_buffer)} pre-speech chunks") if "on_speech_start" in self.callbacks: self.callbacks["on_speech_start"]() # Count silent chunks when collecting but not speaking if self.is_collecting and not self.vad.is_speaking: self.silent_chunk_count += 1 print(f"Silent chunk count: {self.silent_chunk_count}, max: {self.max_silent_chunks}") # Force end after too many silent chunks if self.silent_chunk_count >= self.max_silent_chunks: is_turn_end = True print(f"Forcing speech end after {self.silent_chunk_count} silent chunks") else: self.silent_chunk_count = 0 # End collection on turn end if is_turn_end and self.is_collecting: print("Turn end detected, processing collected audio") self.is_collecting = False if self.audio_buffer: print(f"Audio buffer length: {len(self.audio_buffer)} chunks") print("Speech ended, processing collected audio") complete_audio = np.concatenate(self.audio_buffer) print(f"Complete audio length: {len(complete_audio)}") if "on_speech_end" in self.callbacks: try: print("Calling on_speech_end callback") self.callbacks["on_speech_end"](complete_audio, self.sample_rate) print("on_speech_end callback completed successfully") except Exception as e: print(f"Error in on_speech_end callback: {e}") # Clear buffer after processing self.audio_buffer = [] self.silent_chunk_count = 0 def reset(self): self.vad.reset() self.audio_buffer = [] self.is_collecting = False self.silent_chunk_count = 0 print("AudioStreamProcessor reset")