Full Code of davidbrowne17/csm-streaming for AI

main 121552fcd68e cached
25 files
262.0 KB
60.2k tokens
171 symbols
1 requests
Download .txt
Showing preview only (272K chars total). Download the full file or copy to clipboard to get everything.
Repository: davidbrowne17/csm-streaming
Branch: main
Commit: 121552fcd68e
Files: 25
Total size: 262.0 KB

Directory structure:
gitextract_5x1lml4p/

├── .github/
│   └── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── finetuned_model/
│   └── config.json
├── generator.py
├── llm_interface.py
├── loadandmergecheckpoint.py
├── lora.py
├── main.py
├── models.py
├── rag_system.py
├── requirements.txt
├── run_csm.py
├── setup.py
├── static/
│   ├── app.js
│   ├── chat.js
│   └── crud.js
├── templates/
│   ├── chat.html
│   ├── crud.html
│   ├── index.html
│   └── setup.html
├── test.py
└── vad.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms

github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: davidbrowne17
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']


================================================
FILE: .gitignore
================================================
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Virtual Environment
.env
.venv
env/
venv/
ENV/

# IDE
.idea/
.vscode/
*.swp
*.swo

# Project specific
.python-version
*.wav
output_*/
basic_audio.wav
full_conversation.wav
context_audio.wav

# Model files
*.pt
*.ckpt

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# CSM - Optimized Streaming/Finetuning Edition

---

CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.

Our fork adds **streaming audio generation**, **real-time playback**, and **performance optimizations** to the original implementation.

## Requirements

* A CUDA-compatible GPU
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
* Similarly, Python 3.10 is recommended, but newer versions may be fine
* For some audio operations, `ffmpeg` may be required
* For real-time audio playback: `pip install sounddevice`
* Access to the following Hugging Face models:
  * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
  * [CSM-1B](https://huggingface.co/sesame/csm-1b)

### Setup

```bash
sudo apt-get update && sudo apt-get install -y libportaudio2 libportaudio-dev
git clone git@github.com:davidbrowne17/csm-streaming.git
cd csm-streaming
python3.10 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

# Optional speedup
pip install flash-attn
# You will need access to CSM-1B and Llama-3.2-1B
huggingface-cli login
```

### Windows Setup

The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
The realtime demo uses VLLM for inference speed. This is currently not supported for windows but you can try with https://github.com/SystemPanic/vllm-windows until support is added.

## Quickstart

Generate a sentence with streaming (chunks are processed and output as they're generated):

```python
import time
from huggingface_hub import hf_hub_download
from generator import Generator, Segment, load_csm_1b, generate_streaming_audio
import torchaudio

# Load the model
generator = load_csm_1b("cuda")

# Generate audio with streaming and real-time playback
generate_streaming_audio(
    generator=generator,
    text="Hello, this is streaming audio generation in action!",
    speaker=0,
    context=[],  # No context needed for basic generation
    output_file="streaming_audio.wav",
    play_audio=True  # Enable real-time playback
)
```
## Finetuning
To finetune CSM all you need are some wav audio files with the speaker voice you want to train, just the raw wavs. Place them in a folder called audio_data and run lora.py.
You can configure the exact training params such as batch size, number of epochs and learning rate by modifying the values at the top of lora.py.
You will need a CUDA gpu with at least 12gb of vram depending on your dataset size and training params. You can monitor the training metrics via the dynamic png created in /finetuned_model/ folder. This contains various graphs to help you track the training progress. If you want to try a checkpoint you can use the loadandmergecheckpoint.py (make sure to set the same R and Alpha values as you used in the training)

## Realtime chat demo
To use the realtime demo run the setup.py to download the required models, and then run main.py. This will open up a setup page at http://localhost:8000 in which you can set the paths for your chosen LLM and setup the CSM paths and reference audio as well as select your headset and mic. When loaded you will be able to chat in realtime with the AI just like the CSM demo. Our demo includes a dynamic RAG system so the AI can remember previous conversations. The demo by default uses whisper-large-v3-turbo for STT and includes Automatic Voice Detection using Silero VAD.

## Usage

Our optimized version offers several ways to use CSM with streaming capabilities:

### 1. Basic Streaming Generation

Generate audio with streaming and save to a file:

```python
from generator import load_csm_1b, generate_streaming_audio

generator = load_csm_1b("cuda")

# Generate with streaming (writes to file as it generates)
generate_streaming_audio(
    generator=generator,
    text="This audio will be generated in chunks for faster response times.",
    speaker=0,
    context=[],
    output_file="streaming_output.wav"
)
```

### 2. Real-time Audio Playback

Generate and play audio in real-time as it's being generated:

```python
from generator import load_csm_1b, generate_streaming_audio

generator = load_csm_1b("cuda")

# Generate with streaming and play in real-time
generate_streaming_audio(
    generator=generator,
    text="You'll hear me speaking as I'm being generated!",
    speaker=0,
    context=[],
    output_file="streaming_output.wav",
    play_audio=True  # Enable real-time playback
)
```

### 3. Low-level Streaming API

For more control, use the low-level streaming API:

```python
from generator import load_csm_1b, Segment
import torchaudio

generator = load_csm_1b("cuda")

# Process audio chunks as they're generated
for audio_chunk in generator.generate_stream(
    text="This is generated chunk by chunk.",
    speaker=0,
    context=[]
):
    # Do something with each chunk as it's generated
    print(f"Received chunk of size: {audio_chunk.shape}")
    
    # You could process or play each chunk here
    # For example, write to a file incrementally
    # Or send over a network connection
```

### 4. Generate with Context

For best results, provide reference audio context:

```python
from generator import load_csm_1b, Segment, generate_streaming_audio
import torchaudio

generator = load_csm_1b("cuda")

# Load reference audio
def load_audio(audio_path):
    audio_tensor, sample_rate = torchaudio.load(audio_path)
    audio_tensor = torchaudio.functional.resample(
        audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
    )
    return audio_tensor

# Create context segments
segments = [
    Segment(
        text="I knew I could trust you.",
        speaker=0,
        audio=load_audio("reference.wav")
    )
]

# Generate with streaming using the context
generate_streaming_audio(
    generator=generator,
    text="Me too, this is some cool stuff huh?",
    speaker=0,
    context=segments,
    output_file="contextual_streaming.wav",
    play_audio=True
)
```

### 5. Regular Generation with Internal Streaming

Use the original API with streaming enabled internally:

```python
from generator import load_csm_1b, Segment
import torchaudio

generator = load_csm_1b("cuda")

# Regular generation but with internal streaming optimization
audio = generator.generate(
    text="This uses internal streaming for faster processing.",
    speaker=0,
    context=[],
    max_audio_length_ms=10_000,
    stream=True  # Enable internal streaming optimization
)

torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
## Performance Optimizations

Our optimized version includes several performance enhancements:

- **Streaming Generation**: Processes and outputs audio in chunks instead of waiting for the entire generation achieving a Real-time factor (RTF): 0.28x (target: <1.0) on a 4090 (10 seconds of audio takes 2.8 seconds to generate)
- **Frame Batching**: Processes multiple frames at once for better GPU utilization
- **Half-precision Inference**: Uses bfloat16/float16 for faster processing
- **CUDA Optimizations**: Enables cuDNN benchmarking and Flash Attention where available
- **Memory Management**: Clears GPU cache before generation to reduce memory pressure

## FAQ

**How much faster is the streaming version?**

The perceived response time is significantly faster since you get the first audio chunks in milliseconds instead of waiting for the entire generation to complete. The actual total generation time is also improved by 40-60% depending on your hardware.

**Does this model come with any voices?**

The model is a base generation model capable of producing a variety of voices but hasn't been fine-tuned on any specific voice. Provide reference audio for best results.

**Can I converse with the model?**

CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. Using a seperate LLM you can converse with the realtime demo via the web ui.

**Does it support other languages?**

The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.

## Misuse and abuse ⚠️

This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:

- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.

By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.

---

## Original Authors
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.

## Streaming, Realtime Demo and Finetuning Implementation
David Browne

## Support me
Support this project on Ko-fi: https://ko-fi.com/davidbrowne17

## Transformers streaming
If you want to use streaming with the Transformers implementation you can find it here: https://github.com/davidbrowne17/csm-streaming-tf


================================================
FILE: config.py
================================================
import os
import json
import logging
from typing import Dict, Any, Optional
from pydantic import BaseModel

logger = logging.getLogger(__name__)

class ConfigManager:
    """
    Manages configuration persistence for the AI Companion app.
    Saves and loads configuration to avoid re-entering model paths.
    """
    def __init__(self, config_path: str = "config/app_config.json"):
        """
        Initialize the configuration manager.
        
        Args:
            config_path: Path to store the configuration file
        """
        self.config_path = config_path
        self.config_dir = os.path.dirname(config_path)
        
        # Create config directory if it doesn't exist
        if not os.path.exists(self.config_dir):
            os.makedirs(self.config_dir, exist_ok=True)
            logger.info(f"Created configuration directory: {self.config_dir}")
    
    def save_config(self, config_data: Dict[str, Any]) -> bool:
        """
        Save configuration data to the config file.
        
        Args:
            config_data: Configuration data to save
            
        Returns:
            bool: True if successful, False otherwise
        """
        try:
            # Ensure directory exists
            os.makedirs(self.config_dir, exist_ok=True)
            print(config_data)
            # Verify all reference paths are included
            ref_paths = [
                "reference_audio_path", 
                "reference_audio_path2", 
                "reference_audio_path3"
            ]
            
            # Log which references are being saved
            for path_key in ref_paths:
                if path_key in config_data and config_data[path_key]:
                    logger.info(f"Saving reference path: {path_key}={config_data[path_key]}")
                else:
                    logger.info(f"No {path_key} provided in configuration")
            
            # Save configuration
            with open(self.config_path, 'w') as f:
                json.dump(config_data, f, indent=2)
            
            logger.info(f"Configuration saved to {self.config_path}")
            return True
        
        except Exception as e:
            logger.error(f"Failed to save configuration: {e}")
            return False
    
    def load_config(self) -> Optional[Dict[str, Any]]:
        """
        Load configuration data from the config file.
        
        Returns:
            Dict or None: Configuration data if successful, None otherwise
        """
        if not os.path.exists(self.config_path):
            logger.info(f"Configuration file does not exist at {self.config_path}")
            return None
        
        try:
            with open(self.config_path, 'r') as f:
                config_data = json.load(f)
            
            # Log which references are being loaded
            ref_paths = [
                "reference_audio_path", 
                "reference_audio_path2", 
                "reference_audio_path3"
            ]
            
            for path_key in ref_paths:
                if path_key in config_data and config_data[path_key]:
                    logger.info(f"Loaded reference path: {path_key}={config_data[path_key]}")
            
            logger.info(f"Configuration loaded from {self.config_path}")
            return config_data
        
        except Exception as e:
            logger.error(f"Failed to load configuration: {e}")
            return None

# Helper function to convert Pydantic model to dict
def model_to_dict(model: BaseModel) -> Dict[str, Any]:
    """Convert a Pydantic model to a dictionary suitable for JSON serialization"""
    return json.loads(model.json())

================================================
FILE: finetuned_model/config.json
================================================
{
  "audio_num_codebooks": 32,
  "audio_vocab_size": 2051,
  "backbone_flavor": "llama-1B",
  "decoder_flavor": "llama-100M",
  "text_vocab_size": 128256
}

================================================
FILE: generator.py
================================================
from dataclasses import dataclass
import math
import os
from typing import List, Tuple, Generator as PyGenerator, Optional, Callable
import time
import queue
import threading
import platform
from typing_extensions import OrderedDict
import wave
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from models import Model, ModelArgs
from moshi.models import loaders
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer
import logging

logger = logging.getLogger(__name__)

@dataclass
class Segment:
    speaker: int
    text: str
    sample_rate = 24_000
    audio: torch.Tensor


def load_llama3_tokenizer():
    """
    https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
    """
    tokenizer_name = "unsloth/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    bos = tokenizer.bos_token
    eos = tokenizer.eos_token
    tokenizer._tokenizer.post_processor = TemplateProcessing(
        single=f"{bos}:0 $A:0 {eos}:0",
        pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
        special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
    )

    return tokenizer


class Generator:
    def __init__(self, model: Model):
        self._model = model
        self._model.setup_caches(1)

        self._text_tokenizer = load_llama3_tokenizer()
        device = next(model.parameters()).device

        mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
        mimi = loaders.get_mimi(mimi_weight, device=device)
        
        num_codebooks = model.config.audio_num_codebooks
        mimi.set_num_codebooks(num_codebooks)
        self._num_codebooks = num_codebooks
        self._audio_tokenizer = mimi

        self.sample_rate = mimi.sample_rate
        self.device = device

        self._stream_buffer_size = 20
        self.max_seq_len = 2048
        self._cache = OrderedDict()
        self._text_token_cache = {}
        torch.set_num_threads(16)
        torch.cuda.set_per_process_memory_fraction(0.95)

    def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Tokenize text segment with caching optimization for reduced latency.
        """
        # Check cache first
        cache_key = f"{speaker}:{text}"
        if not hasattr(self, '_text_token_cache'):
            self._text_token_cache = {}
        
        if cache_key in self._text_token_cache:
            return self._text_token_cache[cache_key]

        text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
        text_frame = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.long, device=self.device)
        text_frame_mask = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.bool, device=self.device)
        text_frame[:, -1] = torch.tensor(text_tokens, device=self.device)
        text_frame_mask[:, -1] = True

        frame_tokens = [text_frame]
        frame_masks = [text_frame_mask]
        
        result = (torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0))
        
        self._text_token_cache[cache_key] = result
        
        return result


    def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        frame_tokens = []
        frame_masks = []

        # (K, T)
        audio = audio.to(self.device)
        audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
        
        # Limit to the number of codebooks set in MIMI
        audio_tokens = audio_tokens[:self._num_codebooks, :]
        
        # add EOS frame
        eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
        audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)

        audio_frame = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).long().to(self.device)
        audio_frame_mask = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).bool().to(self.device)
        audio_frame[:, :self._num_codebooks] = audio_tokens.transpose(0, 1)
        audio_frame_mask[:, :self._num_codebooks] = True

        frame_tokens.append(audio_frame)
        frame_masks.append(audio_frame_mask)

        return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)

    def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
        text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
        audio_tokens, audio_masks = self._tokenize_audio(segment.audio)

        total_len = text_tokens.size(0) + audio_tokens.size(0)

        if total_len > self.max_seq_len:
            overflow = total_len - self.max_seq_len

            if text_tokens.size(0) > overflow:
                text_tokens = text_tokens[overflow:]
                text_masks = text_masks[overflow:]
            else:
                audio_overflow = overflow - text_tokens.size(0)
                text_tokens = text_tokens[0:0]
                text_masks = text_masks[0:0]
                audio_tokens = audio_tokens[audio_overflow:]
                audio_masks = audio_masks[audio_overflow:]

        return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)

    @torch.inference_mode()
    def _decode_frames(self, frames):
        if not frames:
            return torch.tensor([])
        
        # Only use first N codebooks for faster decoding
        frames_reduced = [frame[:, :self._num_codebooks//2] for frame in frames]
        audio = self._audio_tokenizer.decode(torch.stack(frames_reduced).permute(1, 2, 0)).squeeze(0).squeeze(0)
        return audio

    @torch.inference_mode()
    def generate_stream(
        self,
        text: str,
        speaker: int,
        context: List[Segment],
        max_audio_length_ms: float = 90_000,
        temperature: float = 0.7,
        topk: int = 30,
        on_chunk_generated: Optional[Callable[[torch.Tensor], None]] = None,
    ):
        """
        Generate audio in a streaming fashion, optimized for lower latency to first chunk.
        """
        if torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.benchmark = True
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        self._model.reset_caches()

        max_generation_len = int(max_audio_length_ms / 80)

        tokens, tokens_mask = [], []

        initial_batch_size = 20
        normal_batch_size = 20  
        initial_buffer_size = 20
        normal_buffer_size = 20
        
        batch_size = initial_batch_size
        buffer_size = initial_buffer_size
        first_chunk_delivered = False

        context_start = time.time()
        if context:
            for segment in context:
                segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
                tokens.append(segment_tokens)
                tokens_mask.append(segment_tokens_mask)

        gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
        tokens.append(gen_segment_tokens)
        tokens_mask.append(gen_segment_tokens_mask)

        prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
        prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)

        max_seq_len = 2048
        if prompt_tokens.size(0) > max_seq_len:
            prompt_tokens = prompt_tokens[-max_seq_len:]
            prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]

        curr_tokens = prompt_tokens.unsqueeze(0)
        curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
        curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)

        expected_frame_count = buffer_size 
        frame_buffer = []

        zeros_1_1 = torch.zeros(1, 1).long().to(self.device)
        zeros_mask_1_1 = torch.zeros(1, 1).bool().to(self.device)

        def update_tokens(sample):
            nonlocal curr_tokens, curr_tokens_mask, curr_pos
            ones = torch.ones_like(sample).bool()
            curr_tokens = torch.cat([sample, zeros_1_1], dim=1).unsqueeze(1)
            curr_tokens_mask = torch.cat([ones, zeros_mask_1_1], dim=1).unsqueeze(1)
            curr_pos = curr_pos[:, -1:] + 1

        with self._audio_tokenizer.streaming(1):
            i = 0
            generation_start = time.time()

            while i < max_generation_len:
                batch_end = min(i + batch_size, max_generation_len)
                batch_size_actual = batch_end - i

                batch_samples = []

                for _ in range(batch_size_actual):
                    with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
                        sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
                        if torch.cuda.is_available() and hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available"):
                            try:
                                torch.cuda.synchronize()  # Force sync before checking
                                if sample.numel() == 0 or torch.isnan(sample).any():
                                    print("Warning: Generated empty or NaN sample, stopping generation")
                                    break
                            except:
                                print("Error checking tensor, stopping generation")
                                break
                    if torch.all(sample == 0):
                        break

                    batch_samples.append(sample)
                    update_tokens(sample)

                if not batch_samples:
                    break

                frame_buffer.extend(batch_samples)
                i += len(batch_samples)

                if len(frame_buffer) >= buffer_size:
                    frames_to_process = frame_buffer[:expected_frame_count]
                    
                    # If we don't have enough frames, pad with zeros to match expected shape
                    if len(frames_to_process) < expected_frame_count:
                        # Create padding frames (zeros)
                        padding_frames = [
                            torch.zeros_like(frames_to_process[0]) 
                            for _ in range(expected_frame_count - len(frames_to_process))
                        ]
                        
                        # Combine actual frames with padding
                        frames_to_process = frames_to_process + padding_frames
                    
                    frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)
                    audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)
                    
                    # Keep remaining frames for next iteration
                    frame_buffer = frame_buffer[expected_frame_count:]
                    
                    # Process and yield the chunk
                    cpu_chunk = audio_chunk.cpu()
                    if on_chunk_generated:
                        on_chunk_generated(cpu_chunk)
                    
                    # After first chunk is delivered, switch to normal batch and buffer sizes
                    if not first_chunk_delivered:
                        batch_size = normal_batch_size
                        buffer_size = normal_buffer_size
                        expected_frame_count = buffer_size
                        first_chunk_delivered = True
                    
                    yield cpu_chunk

                    # Occasionally print progress and sync GPU
                    if i >= 100 and (i % 100 == 0):
                        if torch.cuda.is_available():
                            torch.cuda.synchronize()
                        print(f"Generated {i} frames ({i * 0.08:.2f}s of audio)")

            # Process any remaining frames
            if frame_buffer:
                # Pad frame buffer if necessary
                if len(frame_buffer) < expected_frame_count:
                    padding_frames = [
                        torch.zeros_like(frame_buffer[0]) 
                        for _ in range(expected_frame_count - len(frame_buffer))
                    ]
                    frames_to_process = frame_buffer + padding_frames
                else:
                    # Otherwise take as many frames as possible that are a multiple of expected_frame_count
                    frames_multiple = (len(frame_buffer) // expected_frame_count) * expected_frame_count
                    frames_to_process = frame_buffer[:frames_multiple]
                    
                frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)
                audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)
                
                # Determine actual audio length (before padding)
                actual_frames_percentage = min(len(frame_buffer), expected_frame_count) / expected_frame_count
                actual_samples = int(audio_chunk.shape[0] * actual_frames_percentage)
                
                # Return only the non-padded portion of audio if we added padding
                if len(frame_buffer) < expected_frame_count:
                    audio_chunk = audio_chunk[:actual_samples]
                    
                cpu_chunk = audio_chunk.cpu()
                if on_chunk_generated:
                    on_chunk_generated(cpu_chunk)
                yield cpu_chunk

            # Print final performance metrics
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_time = time.time() - generation_start
            frames_generated = i
            audio_seconds = frames_generated * 0.08
            rtf = total_time / audio_seconds if audio_seconds > 0 else float('inf')
            print(f"Total time: {total_time:.2f}s")
            print(f"Generated {frames_generated} frames ({audio_seconds:.2f}s of audio)")
            print(f"Real-time factor: {rtf:.3f}x (target: <1.0)")

    @torch.inference_mode()
    def generate(
        self,
        text: str,
        speaker: int,
        context: List[Segment],
        max_audio_length_ms: float = 90_000,
        temperature: float = 0.8,
        topk: int = 40,
        stream: bool = False,
        output_file: Optional[str] = None,
    ):
        """
        Generate audio with optional streaming and file output.
        
        Args:
            text: Text to generate audio for
            speaker: Speaker ID
            context: List of context segments
            max_audio_length_ms: Maximum audio length in milliseconds
            temperature: Sampling temperature
            topk: Top-k sampling parameter
            stream: Whether to use streaming generation
            output_file: If provided and stream=True, output will be saved to this file
        
        Returns:
            torch.Tensor: Generated audio tensor
        """
        if stream:
            if output_file:
                # Setup streaming to file
                write_chunk, close_wav = stream_audio_to_wav(output_file, self.sample_rate)
                
                # Collect chunks while streaming to file
                audio_chunks = []
                t1 = time.time()
                
                for i, chunk in enumerate(self.generate_stream(
                    text, speaker, context, max_audio_length_ms, temperature, topk
                )):
                    # Write to file
                    write_chunk(chunk)
                    # Store for return value
                    audio_chunks.append(chunk)
                    
                    # Occasionally print progress
                    if i % 5 == 0:
                        print(f"Part {i+1} available after {time.time() - t1:.4f}s")
                        t1 = time.time()
                
                # Close file
                close_wav()
                print(f"Streaming complete, WAV file saved to {output_file}")
            else:
                # Just collect chunks without file output
                audio_chunks = []
                for chunk in self.generate_stream(text, speaker, context, max_audio_length_ms, temperature, topk):
                    audio_chunks.append(chunk)
            
            if not audio_chunks:
                return torch.tensor([])
            return torch.cat(audio_chunks)

        # Non-streaming generation remains unchanged
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self._model.reset_caches()

        max_generation_len = int(max_audio_length_ms / 80)
        tokens, tokens_mask = [], []

        for segment in context:
            segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
            tokens.append(segment_tokens)
            tokens_mask.append(segment_tokens_mask)

        gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
        tokens.append(gen_segment_tokens)
        tokens_mask.append(gen_segment_tokens_mask)

        prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
        prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)

        max_seq_len = 2048
        if prompt_tokens.size(0) > max_seq_len:
            prompt_tokens = prompt_tokens[-max_seq_len:]
            prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]

        curr_tokens = prompt_tokens.unsqueeze(0)
        curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
        curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)

        samples = []
        with self._audio_tokenizer.streaming(1):
            for _ in range(max_generation_len):
                sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
                if torch.all(sample == 0):
                    break
                samples.append(sample)

                curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
                curr_tokens_mask = torch.cat(
                    [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
                ).unsqueeze(1)
                curr_pos = curr_pos[:, -1:] + 1

        if not samples:
            return torch.tensor([])

        return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)

class AudioStreamWriter:
    """
    Helper class for writing streaming audio to a file.
    """
    def __init__(self, filename, sample_rate):
        self.filename = filename
        self.sample_rate = sample_rate
        self.audio_chunks = []
        self.lock = threading.Lock()
        self.queue = queue.Queue()
        self.running = True
        
        # Start background writer thread
        self.writer_thread = threading.Thread(target=self._writer_worker, daemon=True)
        self.writer_thread.start()
        
    def _writer_worker(self):
        """Background thread that handles audio chunk processing"""
        buffer_chunks = []
        last_flush_time = time.time()
        
        while self.running or not self.queue.empty():
            try:
                # Get chunk with timeout to allow for regular checks
                chunk = self.queue.get(timeout=0.2)
                buffer_chunks.append(chunk)
                
                # Periodically flush the buffer to the main list
                current_time = time.time()
                if len(buffer_chunks) >= 10 or (current_time - last_flush_time > 2.0 and buffer_chunks):
                    with self.lock:
                        self.audio_chunks.extend(buffer_chunks)
                    buffer_chunks = []
                    last_flush_time = current_time
                    
            except queue.Empty:
                # If queue is empty but we have pending chunks, add them
                if buffer_chunks:
                    with self.lock:
                        self.audio_chunks.extend(buffer_chunks)
                    buffer_chunks = []
                    last_flush_time = time.time()
        
        # Final flush of any remaining chunks
        if buffer_chunks:
            with self.lock:
                self.audio_chunks.extend(buffer_chunks)
        
    def add_chunk(self, chunk):
        """Add an audio chunk to the buffer queue without blocking"""
        try:
            self.queue.put(chunk, timeout=0.1)
        except queue.Full:
            # If queue is full, add directly to avoid losing data
            with self.lock:
                self.audio_chunks.append(chunk)
    
    def write_file(self):
        """Write all collected audio chunks to file and clean up"""
        # Signal the background thread to stop
        self.running = False
        # Wait for the thread to finish with a timeout
        self.writer_thread.join(timeout=3.0)
        
        with self.lock:
            if not self.audio_chunks:
                return
                
            # Concatenate all chunks
            audio = torch.cat(self.audio_chunks)
            # Save to file
            torchaudio.save(self.filename, audio.unsqueeze(0).cpu(), self.sample_rate)

from safetensors.torch import load_file
import os
import torch
from models import Model, ModelArgs
from generator import Generator

def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_codebooks: int = 32):
    """
    Load the CSM-1B model from a local checkpoint with extreme optimizations and warmup.
    """
    import torch
    import platform
    from functools import lru_cache
    from generator import Generator, Model, ModelArgs

    # Enable all CUDA optimizations
    torch.backends.cuda.matmul.allow_tf32 = True
    if hasattr(torch.backends.cuda, 'enable_flash_sdp'):
        torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True

    print(f"Loading CSM-1B model from local checkpoint '{model_path}' with extreme optimizations...")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    config = ModelArgs(
        backbone_flavor="llama-1B",
        decoder_flavor="llama-100M",
        text_vocab_size=128256,
        audio_vocab_size=2051,
        audio_num_codebooks=audio_num_codebooks,
    )

    model = Model.from_pretrained(model_path)
    model.eval()

    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')
    model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')

    model.to(device=device, dtype=dtype)

    print("Model compilation complete. Creating generator...")

    generator = Generator(model)
    generator._stream_buffer_size = 20

    # Setup tokenization caching
    generator._tokenization_cache = {}

    original_tokenize_text = generator._tokenize_text_segment

    @lru_cache(maxsize=2048)
    def cached_tokenize_text_segment(text_str, speaker_int):
        return original_tokenize_text(text_str, speaker_int)

    generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)

    # Perform warmup
    warmup_generator(generator)

    return generator

def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is a comprehensive warmup text that will exercise the model's generation capabilities.", speaker_id: int = 0):
    """
    Perform an extremely aggressive warmup to drastically reduce first-generation latency.
    """
    print("Starting maximum-intensity warmup sequence...")
    
    # Directly access and optimize the model's internal state
    if hasattr(gen._model, 'backbone') and hasattr(gen._model.backbone, 'positional_embedding'):
        # Force calculation of position embeddings to ensure they're cached
        with torch.inference_mode():
            positions = torch.arange(0, 2048).to(gen.device)
            _ = gen._model.backbone.positional_embedding(positions)
    
    # Pre-allocate CUDA memory to prevent fragmentation during generation
    if torch.cuda.is_available():
        print("Optimizing GPU memory allocation...")
        # Try to reserve a large chunk of memory
        try:
            import math
            reserved_memory = []
            # Reserve multiple blocks of different sizes
            for size_mb in [128, 256, 512, 256, 128, 64]:
                size = int(size_mb * 1024 * 1024 / 4)  # Convert MB to float32 elements
                tensor_size = int(math.sqrt(size))
                tensor = torch.ones((tensor_size, tensor_size), device=gen.device, dtype=torch.float32)
                tensor = tensor * 1.0  # Force allocation
                reserved_memory.append(tensor)
            torch.cuda.synchronize()
            
            # Now free the memory
            for tensor in reserved_memory:
                del tensor
            reserved_memory = []
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        except Exception as e:
            print(f"Memory pre-allocation: {e}")
    
    # Create multiple dummy audio segments with varying characteristics
    print("Creating diverse audio contexts...")
    audio_segments = []
    
    # Create 3 different audio patterns
    for i in range(3):
        length = 24000 * (i + 1)  # 1s, 2s, 3s
        audio = torch.zeros(length).to(gen.device)
        
        # Add different patterns to each segment
        if i == 0:
            # Sine wave pattern
            import math
            t = torch.linspace(0, 8 * math.pi, length).to(gen.device)
            audio = torch.sin(t) * 0.1
        elif i == 1:
            # Random noise pattern
            audio = torch.randn(length).to(gen.device) * 0.05
        else:
            # Pulse pattern
            audio[::800] = 0.2
            audio[::801] = -0.2
        
        segment = Segment(
            speaker=speaker_id,
            text=f"Warmup segment {i+1} with {length/24000:.1f}s of audio.",
            audio=audio
        )
        audio_segments.append(segment)
    
    # Force compilation of critical model components
    print("Forcing compilation of critical components...")
    
    # Directly exercise the audio tokenizer with real data
    with torch.inference_mode():
        for segment in audio_segments:
            # Force tokenization of both text and audio
            gen._tokenize_segment(segment)
    
    # Exercise the model's generation capabilities directly
    with torch.inference_mode():
        
        # Generate some sample frames to ensure model is compiled
        dummy_tokens = torch.ones(1, 10, gen._num_codebooks+1).long().to(gen.device)
        dummy_mask = torch.ones(1, 10, gen._num_codebooks+1).bool().to(gen.device)
        dummy_pos = torch.arange(0, 10).unsqueeze(0).to(gen.device)
        
        # Generate multiple frames with different parameters
        for temp in [0.6, 0.7, 0.8]:
            for topk in [20, 30, 40]:
                _ = gen._model.generate_frame(dummy_tokens, dummy_mask, dummy_pos, temp, topk)
    
    gen._text_token_cache.clear()
    
    print("Running final generation with exact same setup as a real request...")
    
    final_text = "This is the final warmup that exactly matches a real generation request."
    
    # First tokenize the text - to fill the cache
    gen._tokenize_text_segment(final_text, speaker_id)
    
    try:
        # Now run a complete generation with a single context segment
        generate_streaming_audio(
            generator=gen,
            text=final_text, 
            speaker=speaker_id,
            context=[audio_segments[0]],  # Just one context segment
            output_file="warmup_final.wav",
            max_audio_length_ms=6000,
            temperature=0.7,
            topk=30,
            play_audio=False
        )
    except Exception as e:
        print(f"Final warmup run exception (ignorable): {e}")
    
    # Force final synchronization and memory optimization
    if torch.cuda.is_available():
        print("Final GPU optimization...")
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        
        try:
            # Allocate a large tensor to force compaction
            large_tensor = torch.empty(int(1e9//4), dtype=torch.float, device=gen.device)
            # Immediately delete it
            del large_tensor
        except RuntimeError:
            # Expected if there's not enough memory
            pass
            
        # Final cleanup
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    print("Maximum-intensity warmup complete. First generation should now be MUCH faster.")

def load_csm_1b(device: str = "cuda") -> Generator:
    """
    Load the CSM-1B model with extreme optimizations for real-time performance.
    """
    # Enable all CUDA optimizations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    
    print("Loading CSM-1B model with extreme optimizations for real-time performance...")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    model = Model.from_pretrained("sesame/csm-1b")
    
    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')
    model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')

    model.to(device=device, dtype=dtype)
    
    print("Model compilation complete. Creating generator...")
    
    generator = Generator(model)
    
    generator._stream_buffer_size = 20
    
    
    generator._tokenization_cache = {}
    
    from functools import lru_cache

    # Patch the tokenize method with caching
    original_tokenize_text = generator._tokenize_text_segment

    @lru_cache(maxsize=2048)
    def cached_tokenize_text_segment(text_str, speaker_int):
        return original_tokenize_text(text_str, speaker_int)

    generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)
    
    warmup_generator(generator)

    return generator

def stream_audio_to_wav(filename, sample_rate):
    """
    Initialize a WAV writer for streaming audio chunks.
    
    Args:
        filename: Output WAV file path
        sample_rate: Audio sample rate in Hz
    
    Returns:
        tuple: (write_chunk, close) functions for writing audio data and closing the file
    """
    # Create a WAV file with the proper header
    wav_file = wave.open(filename, 'wb')
    wav_file.setnchannels(1)  # Mono
    wav_file.setsampwidth(2)  # 16-bit
    wav_file.setframerate(sample_rate)
    
    def write_chunk(audio_chunk):
        # Convert tensor to numpy and then to int16 PCM format
        if isinstance(audio_chunk, torch.Tensor):
            # Ensure it's on CPU and detached before converting to numpy
            audio_np = audio_chunk.detach().cpu().numpy()
        else:
            audio_np = audio_chunk
            
        # Normalize if needed (assuming audio is in [-1, 1] range)
        if audio_np.max() <= 1.0 and audio_np.min() >= -1.0:
            audio_int = (audio_np * 32767).astype(np.int16)
        else:
            audio_int = audio_np.astype(np.int16)
            
        # Write to WAV file
        wav_file.writeframes(audio_int.tobytes())
    
    def close():
        wav_file.close()
        
    return write_chunk, close

def generate_streaming_audio(
    generator: Generator,
    text: str,
    speaker: int,
    context: List[Segment],
    output_file: str,
    max_audio_length_ms: float = 90_000,
    temperature: float = 1.0,
    topk: int = 50,
    play_audio: bool = False,
):
    """
    Generate audio with streaming output and comprehensive timing metrics.
    Optimized for reduced first-chunk latency.
    """
    # Initialize the streaming WAV writer
    write_chunk, close_wav = stream_audio_to_wav(output_file, generator.sample_rate)
    
    # Set up audio playback if requested
    audio_queue = queue.Queue(maxsize=100) if play_audio else None
    stop_event = threading.Event()
    
    if play_audio:
        try:
            import sounddevice as sd
            
            # Get available sample rates for default output device to check compatibility
            device_info = sd.query_devices(kind='output')
            supported_rate = device_info.get('default_samplerate', 44100)
            need_resampling = abs(supported_rate - generator.sample_rate) > 100
            
            if need_resampling:
                try:
                    # Use resampling if sample rate doesn't match
                    import librosa
                    print(f"Resampling from {generator.sample_rate}Hz to {int(supported_rate)}Hz for playback")
                    
                    def audio_playback_worker():
                        while not stop_event.is_set() or not audio_queue.empty():
                            try:
                                chunk = audio_queue.get(timeout=0.5)
                                if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:
                                    audio_queue.task_done()
                                    continue
                                    
                                audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk
                                
                                # Skip very short chunks (likely noise)
                                if len(audio_np) < 100:
                                    audio_queue.task_done()
                                    continue
                                    
                                # Resample to device's supported rate
                                resampled = librosa.resample(
                                    audio_np, 
                                    orig_sr=generator.sample_rate, 
                                    target_sr=int(supported_rate)
                                )
                                sd.play(resampled, supported_rate, blocking=True)
                                # Add a small delay to ensure audio finishes playing
                                time.sleep(0.05)
                                audio_queue.task_done()
                            except queue.Empty:
                                # If queue empty but not stopping, keep trying
                                if not stop_event.is_set():
                                    continue
                                else:
                                    break
                            except Exception as e:
                                print(f"Playback error: {e}")
                                audio_queue.task_done()
                except ImportError:
                    print("Librosa not found. Using direct playback which may cause sample rate warnings.")
                    need_resampling = False
            
            if not need_resampling:
                def audio_playback_worker():
                    while not stop_event.is_set() or not audio_queue.empty():
                        try:
                            chunk = audio_queue.get(timeout=0.5)
                            if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:
                                audio_queue.task_done()
                                continue
                                
                            audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk
                            
                            # Skip very short chunks (likely noise)
                            if len(audio_np) < 100:
                                audio_queue.task_done()
                                continue
                                
                            sd.play(audio_np, generator.sample_rate, blocking=True)
                            # Add a small delay to ensure audio finishes playing
                            time.sleep(0.05)
                            audio_queue.task_done()
                        except queue.Empty:
                            # If queue empty but not stopping, keep trying
                            if not stop_event.is_set():
                                continue
                            else:
                                break
                        except Exception as e:
                            print(f"Playback error: {e}")
                            audio_queue.task_done()
            
            # Start playback thread
            playback_thread = threading.Thread(target=audio_playback_worker, daemon=False)
            playback_thread.start()
            
        except ImportError:
            print("sounddevice library not found. Install with 'pip install sounddevice' for real-time playback.")
            play_audio = False
    
    # Timing metrics
    chunk_times = []
    latency_to_first_chunk = None
    total_audio_duration = 0
    chunk_count = 0
    
    # Function to handle each generated chunk
    def on_chunk_generated(chunk):
        nonlocal chunk_count, latency_to_first_chunk, total_audio_duration
        
        current_time = time.time()
        if chunk_count == 0:
            latency_to_first_chunk = current_time - start_time
            print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms")
            
        # Save chunk to WAV file
        write_chunk(chunk)
        
        # Update metrics
        chunk_count += 1
        chunk_duration = len(chunk) / generator.sample_rate
        total_audio_duration += chunk_duration
        chunk_times.append(current_time)
        
        # Send to audio player if enabled
        if play_audio and audio_queue is not None:
            try:
                audio_queue.put(chunk, timeout=1.0)
            except queue.Full:
                pass  # Skip if queue is full to avoid blocking
    
    if torch.cuda.is_available():
        print("Preparing GPU for low-latency generation...")
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # Pre-allocate some GPU memory to avoid allocation during generation
        dummy_tensors = []
        for i in range(5):
            dummy = torch.ones((100, 100), device=generator.device)
            dummy = dummy + 1.0  # Force computation
            dummy_tensors.append(dummy)  # Keep reference to prevent deallocation
            
        torch.cuda.synchronize()
    
    # Set process priority to improve performance - use higher priority
    try:
        import psutil
        process = psutil.Process()
        if platform.system() == 'Windows':
            process.nice(psutil.HIGH_PRIORITY_CLASS)
        else:
            process.nice(-1)
    except (ImportError, PermissionError, psutil.AccessDenied):
        pass
    
    print(f"Starting audio generation for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
    start_time = time.time()
    
    # Generate audio in chunks, catching possible errors
    frame_count = 0
    audio_chunks = []  # Store all chunks for possible use at the end
    
    try:
        for audio_chunk in generator.generate_stream(
            text=text,
            speaker=speaker,
            context=context,
            max_audio_length_ms=max_audio_length_ms,
            temperature=temperature,
            topk=topk,
            on_chunk_generated=on_chunk_generated
        ):
            frame_count += 1
            audio_chunks.append(audio_chunk)  # Store the chunk
            
            # Print timing info less frequently to reduce overhead
            if frame_count % 10 == 0:
                current_time = time.time()
                elapsed = current_time - start_time
                if total_audio_duration > 0:
                    rtf = elapsed / total_audio_duration
                    remaining_time = (max_audio_length_ms/1000 - total_audio_duration) * rtf
                    print(f"Chunk {chunk_count}: {total_audio_duration:.1f}s audio in {elapsed:.1f}s "
                          f"(RTF: {rtf:.2f}x, Est. remaining: {remaining_time:.1f}s)")
    except Exception as e:
        print(f"Error during audio generation: {e}")
        import traceback
        traceback.print_exc()
    
    # Release dummy tensors to free memory
    if 'dummy_tensors' in locals():
        del dummy_tensors
    
    # Ensure all chunks are properly processed
    if play_audio and audio_queue is not None:
        print("Waiting for playback queue to finish...")
        try:
            timeout_start = time.time()
            while not audio_queue.empty() and time.time() - timeout_start < 5.0:
                time.sleep(0.1)
        except:
            pass
    
    # Add a small delay to ensure everything is processed
    time.sleep(0.5)
    
    # Signal audio worker that generation is complete
    stop_event.set()
    
    # Close WAV file
    close_wav()
    
    # Wait for audio playback to complete if enabled
    if play_audio and 'playback_thread' in locals():
        print("Waiting for audio playback to complete...")
        
        # First, ensure the queue is empty
        try:
            timeout_start = time.time()
            while not audio_queue.empty() and time.time() - timeout_start < 5.0:
                time.sleep(0.1)
        except:
            pass
            
        # Set a flag to indicate complete audio playback is needed
        if hasattr(sd, 'wait'):
            try:
                sd.wait()
            except:
                pass
                
        # Join the playback thread with timeout
        playback_thread.join(timeout=5.0)
        
        # Force sounddevice to stop if it's still playing
        try:
            sd.stop()
        except:
            pass
    
    # Calculate and print detailed performance metrics
    end_time = time.time()
    total_elapsed = end_time - start_time
    
    # Calculate inter-chunk latency
    if len(chunk_times) > 1:
        inter_chunk_latencies = [chunk_times[i] - chunk_times[i-1] for i in range(1, len(chunk_times))]
        avg_inter_chunk_latency = sum(inter_chunk_latencies) / len(inter_chunk_latencies)
        max_inter_chunk_latency = max(inter_chunk_latencies) if inter_chunk_latencies else 0
        min_inter_chunk_latency = min(inter_chunk_latencies) if inter_chunk_latencies else 0
    else:
        avg_inter_chunk_latency = max_inter_chunk_latency = min_inter_chunk_latency = 0
    
    rtf = total_elapsed / total_audio_duration if total_audio_duration > 0 else float('inf')
    
    print("\n" + "="*50)
    print("AUDIO GENERATION PERFORMANCE METRICS")
    print("="*50)
    print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms")
    print(f"Total generation time: {total_elapsed:.2f}s")
    print(f"Audio duration: {total_audio_duration:.2f}s")
    print(f"Real-time factor (RTF): {rtf:.3f}x (target: <1.0)")
    print(f"Number of chunks: {chunk_count}")
    print(f"Average chunk size: {(total_audio_duration/chunk_count)*1000:.1f}ms") if chunk_count > 0 else None
    print(f"Average inter-chunk latency: {avg_inter_chunk_latency*1000:.1f}ms")
    print(f"Min/Max inter-chunk latency: {min_inter_chunk_latency*1000:.1f}ms / {max_inter_chunk_latency*1000:.1f}ms")
    print(f"Chunks per second: {chunk_count/total_elapsed:.2f}")
    print(f"Output file: {output_file}")
    print("="*50)

================================================
FILE: llm_interface.py
================================================
import re
from typing import List, Dict, Any, Optional
import torch
from vllm import LLM, SamplingParams

class LLMInterface:
    def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1):
        """Initialize the LLM interface using VLLM with a given model.
        
        Args:
            model_path (str): Path to the model or HuggingFace model name
            max_tokens (int, optional): Maximum context length. Defaults to 8192.
            n_threads (int, optional): Number of CPU threads. Defaults to 8.
            gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility.
        """
        # VLLM configuration
        self.llm = LLM(
            model=model_path,
            tensor_parallel_size=1,  # Adjust based on number of GPUs available
            gpu_memory_utilization=0.6,
            max_model_len=max_tokens,
            swap_space=0,
            trust_remote_code=True,
            dtype=torch.float16,
        )
        
        # Store configuration for reference
        self.config = {
            "model_path": model_path,
            "max_tokens": max_tokens,
        }
        
    def trim_to_last_sentence(self, text: str) -> str:        
        """
        Return *text* truncated at the final full sentence boundary.
        A boundary is considered to be any '.', '!' or '?' followed by
        optional quotes/brackets, optional whitespace, and then end-of-string.

        If no sentence terminator exists, the original text is returned.
        """
        # Regex explanation:
        #   (.*?[.!?]["')\]]?)   any text lazily until a terminator
        #   \s*$                 followed only by whitespace till end-of-string
        m = re.match(r"^(.*?[.!?][\"')\]]?)\s*$", text, re.DOTALL)
        if m:
            return m.group(1).strip()
        # Fall back to manual search (handles cases with additional text)
        for i in range(len(text) - 1, -1, -1):
            if text[i] in ".!?":
                return text[: i + 1].strip()
        return text.strip()
    
    def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = "") -> str:
        """Generate a response from the LLM using chat-style prompt formatting.
        
        Args:
            system_prompt (str): The system prompt/instructions
            user_message (str): The user's input message
            conversation_history (str, optional): Any prior conversation context. Defaults to "".
            
        Returns:
            str: The generated response
        """
        # Format prompt following chat template structure
        prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>
        {conversation_history}
        <|start_header_id|>user<|end_header_id|>\n{user_message}<|eot_id|>
        <|start_header_id|>assistant<|end_header_id|>\n"""
        
        # Define sampling parameters (equivalent to the previous implementation)
        sampling_params = SamplingParams(
            temperature=1.0,
            top_p=0.95,
            max_tokens=100,
            repetition_penalty=1.2,
            top_k=200,
            stop=["</s>", "<|endoftext|>", "<<USR>>", "<</USR>>", "<</SYS>>", 
                  "<</USER>>", "<</ASSISTANT>>", "<|end_header_id|>", "<<ASSISTANT>>", 
                  "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"]
        )
        
        # Generate response using VLLM
        outputs = self.llm.generate(prompt, sampling_params)
        
        # Extract and return the generated text
        if outputs and len(outputs) > 0:
            text = outputs[0].outputs[0].text
            return self.trim_to_last_sentence(text)
        return ""
    
    def tokenize(self, text: str) -> List[int]:
        """Tokenize text using VLLM's tokenizer.
        
        Args:
            text (str): Text to tokenize
            
        Returns:
            List[int]: List of token IDs
        """
        # VLLM doesn't expose tokenizer directly in the same way
        # We can access the tokenizer through the LLM instance
        tokenizer = self.llm.get_tokenizer()
        return tokenizer.encode(text)
    
    def get_token_count(self, text: str) -> int:
        """Return token count of the input text.
        
        Args:
            text (str): Text to count tokens for
            
        Returns:
            int: Number of tokens
        """
        return len(self.tokenize(text))
    
    def batch_generate(self, prompts: List[Dict[str, str]], 
                       max_tokens: int = 512, 
                       temperature: float = 0.7) -> List[str]:
        """Generate responses for multiple prompts in a batch.        
        Args:
            prompts (List[Dict[str, str]]): List of prompt dictionaries, each with 
                                           'system', 'user' and optional 'history' keys
            max_tokens (int, optional): Maximum tokens to generate per response
            temperature (float, optional): Temperature for sampling
            
        Returns:
            List[str]: Generated responses
        """
        formatted_prompts = []
        
        # Format each prompt according to the chat template
        for p in prompts:
            system = p.get("system", "")
            user = p.get("user", "")
            history = p.get("history", "")
            
            formatted_prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system}<|eot_id|>
            {history}
            <|start_header_id|>user<|end_header_id|>\n{user}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>\n"""
            
            formatted_prompts.append(formatted_prompt)
        
        # Set up batch sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=0.95,
            max_tokens=max_tokens,
            repetition_penalty=1.2,
            top_k=400,
            stop=["</s>", "<|endoftext|>", "<<USR>>", "<</USR>>", "<</SYS>>", 
                  "<</USER>>", "<</ASSISTANT>>", "<|end_header_id|>", "<<ASSISTANT>>", 
                  "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"]
        )
        
        # Generate responses for all prompts in a batch
        outputs = self.llm.generate(formatted_prompts, sampling_params)
        
        # Extract and return the generated texts
        results = []
        for output in outputs:
            if output.outputs:
                results.append(output.outputs[0].text.strip())
            else:
                results.append("")
                
        return results

================================================
FILE: loadandmergecheckpoint.py
================================================
import os
import re
import torch
from models import Model
from safetensors.torch import save_file, load_file 

from lora import (
    remove_lora_modules,
    merge_lora_weights,
    strip_bias_keys,
    DEVICE,
    OUTPUT_DIR,
    replace_linear_with_lora,
)
MODEL_NAME = "sesame/csm-1b"
R=32
APLHA=32

def find_latest_checkpoint(dir_path):
    checkpoints = [
        (int(re.search(r"checkpoint-epoch-(\d+)", d).group(1)), os.path.join(dir_path, d))
        for d in os.listdir(dir_path)
        if os.path.isdir(os.path.join(dir_path, d)) and "checkpoint-epoch" in d
    ]
    if not checkpoints:
        raise FileNotFoundError("No checkpoints found.")
    latest_epoch, latest_path = max(checkpoints, key=lambda x: x[0])
    print(f"Latest checkpoint: epoch {latest_epoch} -> {latest_path}")
    return latest_path

def load_checkpoint_and_merge():
    print("Loading base model...")
    model = Model.from_pretrained(MODEL_NAME).to(DEVICE)

    print("Applying LoRA structure to the model...")
    target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj']

    model = replace_linear_with_lora(model, r=R, alpha=APLHA, dropout=0.0, target_linear_names = target_layers)
    checkpoint_path = find_latest_checkpoint(OUTPUT_DIR)
    
    print(f"Loading state dictionary from safetensors file...")
    state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"), device=DEVICE)

    print("Loading weights into the model...")
    model.load_state_dict(state_dict, strict=False)

    print("Merging LoRA weights into base model...")
    merge_lora_weights(model)

    print("Replacing LoRALinear modules with standard nn.Linear...")
    model = remove_lora_modules(model)

    print("Stripping bias keys for final clean model...")
    merged_state = strip_bias_keys(model.state_dict())

    final_path = os.path.join(OUTPUT_DIR, "model.safetensors")
    save_file(merged_state, final_path)
    print(f"Merged and cleaned model saved to: {final_path}")

if __name__ == "__main__":
    load_checkpoint_and_merge()


================================================
FILE: lora.py
================================================
import json
import os
import glob
import torch
import torchaudio
import logging
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_scheduler
import torch.nn.functional as F
from tqdm import tqdm
import wandb
from safetensors.torch import save_file
import csv
from models import Model
from moshi.models import loaders
from huggingface_hub import hf_hub_download
from tokenizers.processors import TemplateProcessing
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') 
import torch.nn as nn

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(), logging.FileHandler("finetune.log")]
)
logger = logging.getLogger(__name__)

AUDIO_DIR = "audio_data"
OUTPUT_DIR = "finetuned_model"
NUM_EPOCHS = 5
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 1e-6
MAX_GRAD_NORM = 0.1
NUM_CYCLES = 1.0
USE_WANDB = False
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIXED_PRECISION = True
WARMUP_STEPS = 50
SPEAKER_ID = 0
MODEL_NAME = "sesame/csm-1b"
TRANSCRIPTION_MODEL = "openai/whisper-large-v3-turbo"
MAX_AUDIO_FILES = 0
R=32
APLHA=32

class TrainingVisualizer:
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.epochs = []
        self.losses = []
        self.val_losses = []  # Added validation losses
        self.learning_rates = []
        self.steps = []
        
        self.fig, self.axes = plt.subplots(3, 1, figsize=(10, 15))
        self.fig.suptitle('CSM Finetuning Progress', fontsize=16)
        
        # Setup training loss plot
        self.axes[0].set_title('Training Loss')
        self.axes[0].set_xlabel('Epoch')
        self.axes[0].set_ylabel('Loss')
        self.axes[0].grid(True, linestyle='--', alpha=0.7)
        
        # Setup validation loss plot
        self.axes[1].set_title('Training vs Validation Loss')
        self.axes[1].set_xlabel('Epoch')
        self.axes[1].set_ylabel('Loss')
        self.axes[1].grid(True, linestyle='--', alpha=0.7)
        
        # Setup learning rate plot
        self.axes[2].set_title('Learning Rate')
        self.axes[2].set_xlabel('Epoch')
        self.axes[2].set_ylabel('Learning Rate')
        self.axes[2].grid(True, linestyle='--', alpha=0.7)
    
    def update(self, epoch, step, loss, lr, val_loss=None):
        """Update the metrics and redraw the plot"""
        self.epochs.append(epoch)
        self.steps.append(step)
        self.losses.append(loss)
        self.learning_rates.append(lr)
        
        # Add validation loss if provided, otherwise use None
        if val_loss is not None:
            self.val_losses.append(val_loss)
        elif len(self.val_losses) > 0:
            # If we have validation losses but none provided this time, use the last one
            self.val_losses.append(self.val_losses[-1])
        else:
            # If we've never had validation losses, use None
            self.val_losses.append(None)
        
        # Update training loss plot
        self.axes[0].clear()
        self.axes[0].plot(self.epochs, self.losses, 'b-')
        self.axes[0].set_title('Training Loss')
        self.axes[0].set_xlabel('Epoch')
        self.axes[0].set_ylabel('Loss')
        self.axes[0].grid(True, linestyle='--', alpha=0.7)
        
        # Update validation loss plot
        self.axes[1].clear()
        self.axes[1].plot(self.epochs, self.losses, 'b-', label='Training')
        
        # If we have validation losses, plot them
        if any(x is not None for x in self.val_losses):
            # Filter out None values
            val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]
            val_loss_values = [v for v in self.val_losses if v is not None]
            if val_epochs:
                self.axes[1].plot(val_epochs, val_loss_values, 'r-', label='Validation')
                self.axes[1].legend()
        
        self.axes[1].set_title('Training vs Validation Loss')
        self.axes[1].set_xlabel('Epoch')
        self.axes[1].set_ylabel('Loss')
        self.axes[1].grid(True, linestyle='--', alpha=0.7)
        
        # Update learning rate plot
        self.axes[2].clear()
        self.axes[2].plot(self.epochs, self.learning_rates, 'g-')
        self.axes[2].set_title('Learning Rate')
        self.axes[2].set_xlabel('Epoch')
        self.axes[2].set_ylabel('Learning Rate')
        self.axes[2].grid(True, linestyle='--', alpha=0.7)
        
        # Calculate convergence metrics
        min_loss = min(self.losses)
        min_loss_epoch = self.epochs[self.losses.index(min_loss)]
        
        # Check for potential convergence stall
        recent_window = 10  # Look at last 10 steps
        if len(self.losses) > recent_window:
            recent_losses = self.losses[-recent_window:]
            loss_std = np.std(recent_losses)
            loss_change = (recent_losses[0] - recent_losses[-1]) / recent_losses[0] if recent_losses[0] != 0 else 0
            
            convergence_status = ""
            if loss_std < 0.001 and loss_change < 0.01:
                convergence_status = "STALLED: Loss not improving significantly"
            elif min_loss == self.losses[-1]:
                convergence_status = "IMPROVING: New best loss!"
            elif self.losses[-1] < self.losses[-2]:
                convergence_status = "IMPROVING: Loss decreasing"
            else:
                convergence_status = "FLUCTUATING: Loss increased"
            
            # Add convergence status to title
            self.fig.suptitle(f'CSM Finetuning Progress - {convergence_status}\n' + 
                            f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' + 
                            f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)
        else:
            self.fig.suptitle(f'CSM Finetuning Progress\n' + 
                            f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\n' + 
                            f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.92])  # Adjust for the larger title
        
        # Save the figure
        plot_path = os.path.join(self.output_dir, 'training_progress.png')
        self.fig.savefig(plot_path)
        
    def finalize(self):
        """Create a final, more detailed visualization when training completes"""
        # Create a new figure for the final plot
        final_fig = plt.figure(figsize=(12, 16))
        gs = plt.GridSpec(4, 2, figure=final_fig)
        
        # Plot 1: Loss vs Steps
        ax1 = final_fig.add_subplot(gs[0, :])
        ax1.plot(self.steps, self.losses, 'b-', linewidth=2)
        ax1.set_title('Training Loss vs Steps', fontsize=14)
        ax1.set_xlabel('Steps')
        ax1.set_ylabel('Loss')
        ax1.grid(True, linestyle='--', alpha=0.7)
        
        # Plot 2: Loss vs Epochs
        ax2 = final_fig.add_subplot(gs[1, 0])
        ax2.plot(self.epochs, self.losses, 'r-', linewidth=2)
        ax2.set_title('Training Loss vs Epochs', fontsize=14)
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Loss')
        ax2.grid(True, linestyle='--', alpha=0.7)
        
        # Plot 3: Learning Rate vs Steps
        ax3 = final_fig.add_subplot(gs[1, 1])
        ax3.plot(self.steps, self.learning_rates, 'g-', linewidth=2)
        ax3.set_title('Learning Rate Schedule', fontsize=14)
        ax3.set_xlabel('Steps')
        ax3.set_ylabel('Learning Rate')
        ax3.grid(True, linestyle='--', alpha=0.7)
        
        # Plot 4: Training vs Validation Loss
        ax4 = final_fig.add_subplot(gs[2, :])
        ax4.plot(self.epochs, self.losses, 'b-', linewidth=2, label='Training')
        
        if any(x is not None for x in self.val_losses):
            # Filter out None values
            val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]
            val_loss_values = [v for v in self.val_losses if v is not None]
            if val_epochs:
                ax4.plot(val_epochs, val_loss_values, 'r-', linewidth=2, label='Validation')
                ax4.legend()
                
        ax4.set_title('Training vs Validation Loss', fontsize=14)
        ax4.set_xlabel('Epochs')
        ax4.set_ylabel('Loss')
        ax4.grid(True, linestyle='--', alpha=0.7)
        
        # Plot 5: Combined plot with two y-axes
        ax5 = final_fig.add_subplot(gs[3, :])
        color1, color2 = 'blue', 'green'
        
        # Plot loss on left axis
        line1 = ax5.plot(self.epochs, self.losses, color=color1, linewidth=2.5, label='Loss')
        ax5.set_xlabel('Epochs')
        ax5.set_ylabel('Loss', color=color1)
        ax5.tick_params(axis='y', labelcolor=color1)
        
        # Plot learning rate on right axis
        ax6 = ax5.twinx()
        line2 = ax6.plot(self.epochs, self.learning_rates, color=color2, linewidth=2.5, label='Learning Rate')
        ax6.set_ylabel('Learning Rate', color=color2)
        ax6.tick_params(axis='y', labelcolor=color2)
        
        # Combine legends
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax5.legend(lines, labels, loc='upper right')
        ax5.set_title('Loss and Learning Rate vs Epochs', fontsize=14)
        ax5.grid(True, linestyle='--', alpha=0.7)
        
        # Add training summary
        if self.epochs:
            epoch_count = max(self.epochs)
            step_count = max(self.steps)
            min_loss = min(self.losses)
            min_loss_epoch = self.epochs[self.losses.index(min_loss)]
            min_loss_step = self.steps[self.losses.index(min_loss)]
            
            # Calculate convergence indicators
            recent_epochs = min(10, len(self.losses))
            recent_losses = self.losses[-recent_epochs:]
            loss_change_pct = ((recent_losses[0] - recent_losses[-1]) / recent_losses[0]) * 100 if recent_losses[0] != 0 else 0
            
            summary_text = (
                f"Training Summary\n"
                f"Total Epochs: {epoch_count:.2f}\n"
                f"Total Steps: {step_count}\n"
                f"Min Loss: {min_loss:.6f} (Epoch {min_loss_epoch:.2f}, Step {min_loss_step})\n"
                f"Recent {recent_epochs} epochs loss change: {loss_change_pct:.2f}%\n"
            )
            
            if len(self.losses) > 20:
                # Add convergence assessment
                last_20_losses = self.losses[-20:]
                std_last_20 = np.std(last_20_losses)
                converged = std_last_20 < 0.01 and loss_change_pct < 1.0
                
                summary_text += f"Convergence status: {'CONVERGED' if converged else 'NOT CONVERGED'}\n"
                if converged:
                    summary_text += f"Loss stabilized with std dev {std_last_20:.6f}"
                else:
                    summary_text += f"Loss still changing significantly (std dev: {std_last_20:.6f})"
            
            plt.figtext(0.02, 0.02, summary_text, fontsize=10, 
                        bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))
        
        plt.tight_layout(rect=[0, 0.05, 1, 0.97])
        final_fig.suptitle('CSM Model Finetuning Metrics', fontsize=16, fontweight='bold')
        plt.subplots_adjust(top=0.93)
        
        # Save the final detailed plot
        final_plot_path = os.path.join(self.output_dir, 'training_metrics_final.png')
        final_fig.savefig(final_plot_path, dpi=300, bbox_inches='tight')
        plt.close(final_fig)
        plt.close(self.fig)
        
        logger.info(f"Final training visualization saved to {final_plot_path}")
        
        return final_plot_path

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=32, alpha=64, dropout=0.0, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # The base linear (frozen).
        self.weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False)
        nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
        
        self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=bias)
        
        # LoRA trainable matrices
        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        
        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # normal forward with frozen weight
        result = F.linear(x, self.weight, self.bias)

        # LoRA forward with trainable A and B
        lora_out = F.linear(self.dropout(x), self.lora_A)  # [*, r]
        lora_out = F.linear(lora_out, self.lora_B)         # [*, out_features]
        return result + self.scaling * lora_out

def replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout=0.0, target_linear_names: List[str] = None):
    """
    Replaces specified nn.Linear layers with LoRALinear layers within a model, ensuring device consistency.
    """
    if target_linear_names is None:
        logger.warning("No target layer names specified for LoRA replacement. No layers will be replaced.")
        return model

    for name, module in list(model.named_modules()):
        if isinstance(module, nn.Linear) and any(target_name in name for target_name in target_linear_names):
            parent_name, child_name = name.rsplit('.', 1)
            
            parent_module = model
            for part in parent_name.split('.'):
                parent_module = getattr(parent_module, part)

            original_device = module.weight.device
            original_dtype = module.weight.dtype

            # Create the new LoRA layer
            lora_linear = LoRALinear(
                in_features=module.in_features,
                out_features=module.out_features,
                r=r,
                alpha=alpha,
                dropout=dropout,
                bias=(module.bias is not None)
            )
            # Copy the original weights and bias
            with torch.no_grad():
                lora_linear.weight.copy_(module.weight.data)
                if module.bias is not None:
                    lora_linear.bias.copy_(module.bias.data)
            
            lora_linear.to(device=original_device, dtype=original_dtype)
            
            setattr(parent_module, child_name, lora_linear)
            logger.info(f"Replaced layer: {name} with LoRALinear on device {original_device}")
            
    return model

def load_llama3_tokenizer():
    tokenizer_name = "unsloth/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    bos = tokenizer.bos_token
    eos = tokenizer.eos_token
    tokenizer._tokenizer.post_processor = TemplateProcessing(
        single=f"{bos}:0 $A:0 {eos}:0",
        pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
        special_tokens=[(bos, tokenizer.bos_token_id), (eos, tokenizer.eos_token_id)],
    )
    return tokenizer

@dataclass
class AudioTextPair:
    audio_path: str
    text: str
    speaker_id: int
    processed_audio: Optional[torch.Tensor] = None
    
    def load_audio(self, sample_rate=24000) -> torch.Tensor:
        if self.processed_audio is not None:
            return self.processed_audio

        waveform, sr = torchaudio.load(self.audio_path)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        if sr != sample_rate:
            resampler = torchaudio.transforms.Resample(sr, sample_rate)
            waveform = resampler(waveform)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)

        self.processed_audio = waveform.squeeze(0)
        return self.processed_audio

class CSMDataset(Dataset):
    def __init__(self, data_items, text_tokenizer, audio_tokenizer, device):
        self.data_items = data_items
        self.text_tokenizer = text_tokenizer
        self.audio_tokenizer = audio_tokenizer
        self.device = device
        self.sample_rate = audio_tokenizer.sample_rate
        
    def __len__(self):
        return len(self.data_items)
        
    def tokenize_text_segment(self, text: str, speaker: int):
        text_tokens = self.text_tokenizer.encode(f"[{speaker}]{text}")
        text_frame = torch.zeros(len(text_tokens), 33).long()
        text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
        text_frame[:, -1] = torch.tensor(text_tokens)
        text_frame_mask[:, -1] = True
        return text_frame, text_frame_mask

    def tokenize_audio(self, audio: torch.Tensor):
        assert audio.ndim == 1, "Audio must be single channel"
        audio_device = next(self.audio_tokenizer.parameters()).device
        audio = audio.to(audio_device)
        
        try:
            audio_tokens = self.audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
            eos_frame = torch.zeros(audio_tokens.size(0), 1, device=audio_device)
            audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)

            audio_frame = torch.zeros(audio_tokens.size(1), 33, device=audio_device).long()
            audio_frame_mask = torch.zeros(audio_tokens.size(1), 33, device=audio_device).bool()
            audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
            audio_frame_mask[:, :-1] = True
        except RuntimeError as e:
            logger.warning(f"Error encoding audio: {e}, using empty frames")
            audio_frame = torch.zeros(1, 33, device=audio_device).long()
            audio_frame_mask = torch.zeros(1, 33, device=audio_device).bool()

        return audio_frame, audio_frame_mask
    
    def __getitem__(self, idx: int):
        item = self.data_items[idx]
        audio = item.load_audio(self.sample_rate)
        
        text_tokens, text_masks = self.tokenize_text_segment(item.text, item.speaker_id)
        audio_tokens, audio_masks = self.tokenize_audio(audio)
        
        device = audio_tokens.device
        text_tokens = text_tokens.to(device)
        text_masks = text_masks.to(device)
        
        input_tokens = text_tokens
        input_masks = text_masks
        
        target_tokens = torch.cat([text_tokens, audio_tokens], dim=0)
        target_masks = torch.cat([text_masks, audio_masks], dim=0)
        
        if device != self.device:
            input_tokens = input_tokens.to(self.device)
            input_masks = input_masks.to(self.device)
            target_tokens = target_tokens.to(self.device)
            target_masks = target_masks.to(self.device)
        
        return {
            "input_tokens": input_tokens,
            "input_masks": input_masks,
            "target_tokens": target_tokens,
            "target_masks": target_masks,
        }

def collate_fn(batch):
    max_seq_len = 1024
    device = batch[0]["input_tokens"].device
    
    max_input_len = min(max(item["input_tokens"].size(0) for item in batch), max_seq_len)
    max_target_len = min(max(item["target_tokens"].size(0) for item in batch), max_seq_len)

    batch_input_tokens = []
    batch_input_masks = []
    batch_target_tokens = []
    batch_target_masks = []
    
    for item in batch:
        input_tokens = item["input_tokens"][:max_input_len]
        input_masks = item["input_masks"][:max_input_len]
        target_tokens = item["target_tokens"][:max_target_len]
        target_masks = item["target_masks"][:max_target_len]
        
        input_tokens = F.pad(input_tokens, (0,0,0, max_input_len - input_tokens.size(0)), "constant", 0)
        input_masks = F.pad(input_masks, (0,0,0, max_input_len - input_masks.size(0)), "constant", False)
        
        target_tokens = F.pad(target_tokens, (0,0,0, max_target_len - target_tokens.size(0)), "constant", 0)
        target_masks = F.pad(target_masks, (0,0,0, max_target_len - target_masks.size(0)), "constant", False)
        
        batch_input_tokens.append(input_tokens)
        batch_input_masks.append(input_masks)
        batch_target_tokens.append(target_tokens)
        batch_target_masks.append(target_masks)
    
    return {
        "input_tokens": torch.stack(batch_input_tokens),
        "input_masks": torch.stack(batch_input_masks),
        "target_tokens": torch.stack(batch_target_tokens),
        "target_masks": torch.stack(batch_target_masks),
        "positions": torch.arange(0, max_target_len).unsqueeze(0).repeat(len(batch), 1).to(device)
    }

def transcribe_audio_files():
    from transformers import pipeline
    
    # Cache file path
    cache_file = os.path.join(AUDIO_DIR, "transcription_cache.json")
    
    # Load existing cache
    cache = {}
    if os.path.exists(cache_file):
        try:
            with open(cache_file, 'r', encoding='utf-8') as f:
                cache = json.load(f)
            logger.info(f"Loaded transcription cache with {len(cache)} entries")
        except Exception as e:
            logger.warning(f"Could not load cache file: {e}")
            cache = {}
    
    logger.info(f"Transcribing audio files in: {AUDIO_DIR}")
    transcriber = pipeline("automatic-speech-recognition", model=TRANSCRIPTION_MODEL)
    audio_text_pairs = []
    
    audio_files = glob.glob(os.path.join(AUDIO_DIR, "*.wav")) \
        + glob.glob(os.path.join(AUDIO_DIR, "*.mp3")) \
        + glob.glob(os.path.join(AUDIO_DIR, "*.flac"))
    
    if MAX_AUDIO_FILES > 0 and len(audio_files) > MAX_AUDIO_FILES:
        logger.info(f"Found {len(audio_files)} files, limiting to {MAX_AUDIO_FILES}")
        audio_files = audio_files[:MAX_AUDIO_FILES]
    
    cache_hits = 0
    cache_misses = 0
    
    for audio_file in tqdm(audio_files, desc="Processing audio files"):
        try:
            # Create cache key using file path and modification time
            file_stat = os.stat(audio_file)
            cache_key = f"{audio_file}_{file_stat.st_mtime}_{file_stat.st_size}"
            
            # Check if transcription exists in cache
            if cache_key in cache:
                transcription = cache[cache_key]
                cache_hits += 1
                logger.debug(f"Cache hit: {os.path.basename(audio_file)}")
            else:
                # Transcribe the file
                result = transcriber(audio_file, return_timestamps=True, chunk_length_s=30,
                                   stride_length_s=[6, 0], batch_size=32,
                                   generate_kwargs={"language": "<|en|>", "task": "transcribe"})
                transcription = result["text"].strip()
                
                # Save to cache
                cache[cache_key] = transcription
                cache_misses += 1
                logger.info(f"Transcribed: {os.path.basename(audio_file)} -> {transcription}")
            
            audio_text_pairs.append(
                AudioTextPair(audio_path=audio_file, text=transcription, speaker_id=0)
            )
            
        except Exception as e:
            logger.error(f"Error processing {audio_file}: {e}")
    
    # Save updated cache
    try:
        with open(cache_file, 'w', encoding='utf-8') as f:
            json.dump(cache, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved transcription cache with {len(cache)} entries")
    except Exception as e:
        logger.error(f"Could not save cache file: {e}")
    
    logger.info(f"Processed {len(audio_text_pairs)} audio files (Cache hits: {cache_hits}, Cache misses: {cache_misses})")
    return audio_text_pairs

def prepare_csm_model_for_training():
    logger.info(f"Loading CSM model: {MODEL_NAME}")
    model = Model.from_pretrained(MODEL_NAME).to(DEVICE)

    text_tokenizer = load_llama3_tokenizer()
    mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
    mimi = loaders.get_mimi(mimi_weight, device=DEVICE)
    mimi.set_num_codebooks(32)
    audio_tokenizer = mimi
    try:

        codebook_0_centroids = mimi.quantizer.rvq_first.layers[0].codebook.weight.data
        
        num_codebook_0_tokens, embedding_dim = codebook_0_centroids.shape
        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
        model.codebook_embedding.weight.data.copy_(codebook_0_centroids)
        logger.info(f"Successfully initialized codebook_embedding with shape: {codebook_0_centroids.shape}")

    except AttributeError:
        num_codebook_0_tokens, embedding_dim = 1024, 1024
        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
        nn.init.xavier_uniform_(model.codebook_embedding.weight)
        
    except Exception as e:
        num_codebook_0_tokens, embedding_dim = 1024, 1024
        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)
        nn.init.xavier_uniform_(model.codebook_embedding.weight)

    # Some fallback logic for config
    if not hasattr(model.config, 'get'):
        def get_method(self, key, default=None):
            if hasattr(self, key):
                return getattr(self, key)
            return default
        model.config.__class__.get = get_method
    if not hasattr(model.config, 'tie_word_embeddings'):
        model.config.tie_word_embeddings = False
    target_layers = [
        "q_proj", 
        "k_proj", 
        "v_proj", 
        "output_proj",
        "w1",           
        "w2",           
        "w3"            
    ]
    logger.info("Applying LoRA to model...")
    model = replace_linear_with_lora(
        model,
        r=R,
        alpha=APLHA,
        dropout=0.01,
        target_linear_names=target_layers
    )
    model.cuda()
    

    # First, freeze all parameters of the base model
    for param in model.parameters():
        param.requires_grad = False

    # Then, unfreeze only the newly added LoRA parameters.
    # It is also common practice to train the bias parameters.
    for name, param in model.named_parameters():
        if "lora_A" in name or "lora_B" in name or "bias" in name:
            param.requires_grad = True

    return model, text_tokenizer, audio_tokenizer

def setup_model_caches(model, batch_size):
    try:
        with torch.no_grad():
            model.reset_caches()
            model.backbone.reset_caches()
            model.decoder.reset_caches()
    except Exception as e:
        logger.debug(f"No caches to reset or error: {e}")
    return True

class BridgingModule(nn.Module):
    """For a 2048->1024 bridging if needed."""
    def __init__(self, in_dim=2048, out_dim=1024):
        super().__init__()
        self.bridge = nn.Linear(in_dim, out_dim, bias=False)
        nn.init.xavier_uniform_(self.bridge.weight)
    def forward(self, x):
        return self.bridge(x)

def compute_loss_for_codebooks_single_pass(
    backbone_out,  # [b, seq_len, 2048]
    decoder_out,   # [b, seq_len, 1024]
    model, 
    target_tokens, # [b, seq_len, codebooks]
    target_masks,  # [b, seq_len, codebooks bool]
    device
):
    bsz, seq_len = target_tokens.size()[:2]
    num_codebooks = model.config.audio_num_codebooks

    c0_logits = model.codebook0_head(backbone_out)
    audio_positions = target_masks[..., :-1].any(dim=-1)  # [b, seq_len] for audio

    total_loss = torch.tensor(0.0, device=device)
    count = 0

    # codebook0
    for b in range(bsz):
        for s in range(seq_len):
            if audio_positions[b, s]:
                token_logits = c0_logits[b, s]
                target_token = target_tokens[b, s, 0]
                if target_token > 0:
                    ce = F.cross_entropy(token_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')
                    total_loss += ce
                    count += 1

    # codebooks [1..N-1] from decoder_out
    for i in range(1, num_codebooks):
        weight_i = model.audio_head[i-1]
        flat_dec = decoder_out.view(bsz * seq_len, -1)
        token_logits_all = flat_dec.mm(weight_i)
        
        for b in range(bsz):
            for s in range(seq_len):
                if audio_positions[b, s]:
                    target_token = target_tokens[b, s, i]
                    if target_token > 0:
                        row_idx = b*seq_len + s
                        row_logits = token_logits_all[row_idx]
                        ce = F.cross_entropy(row_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')
                        total_loss += ce
                        count += 1

    if count > 0:
        total_loss = total_loss / count
    return total_loss

def single_pass_forward(model, bridging_module, target_tokens, target_masks, positions):
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    
    embed = model._embed_tokens(target_tokens)
    masked_embed = embed * target_masks.unsqueeze(-1)
    h = masked_embed.sum(dim=2)
    
    backbone_out = model.backbone(h, input_pos=positions, mask=None).to(dtype)
    bridging_out = bridging_module(backbone_out)
    
    codebook0_logits = model.codebook0_head(backbone_out)
    codebook0_tokens = torch.argmax(codebook0_logits, dim=-1).clamp(0, model.codebook_embedding.num_embeddings - 1)
    c0_embed = model.codebook_embedding(codebook0_tokens)
    
    # Get the last hidden state from bridging module
    last_h = bridging_out[:, -1, :].unsqueeze(1)
    
    # Concatenate the last hidden state with the codebook embeddings
    decoder_input = torch.cat([last_h, c0_embed], dim=1)
    
    # Process decoder inputs in parallel
    B, S, D = decoder_input.shape  # Batch, Sequence length, Dimension
    
    # Reshape to (B*S, D) to process all tokens in parallel
    decoder_input_flat = decoder_input.view(-1, D).unsqueeze(1)  # [B*S, 1, D]
    
    # Run decoder on all inputs in parallel
    decoder_out_flat = model.decoder(decoder_input_flat).to(dtype)  # [B*S, 1, output_dim]
    
    # Reshape back to original batch and sequence dimensions
    decoder_out = decoder_out_flat.view(B, S, -1)  # [B, S, output_dim]
    
    # Remove the first token (corresponding to last_h) as in original code
    decoder_out = decoder_out[:, 1:, :]  # [B, T, 1024]
    
    # Safety check: handle empty sequences
    if decoder_out.size(1) == 0:
        return torch.tensor(0.0, requires_grad=True, device=device)
    
    loss = compute_loss_for_codebooks_single_pass(
        backbone_out=backbone_out,
        decoder_out=decoder_out,
        model=model,
        target_tokens=target_tokens[..., 1:],  # Drop codebook 0
        target_masks=target_masks[..., 1:],
        device=device
    )
    
    return loss

def calculate_validation_loss(model, bridging_module, dataset, device, max_samples=50):
    """
    Calculate validation loss on a subset of the dataset
    """
    # Create a small validation dataloader with a subset of data
    val_indices = torch.randperm(len(dataset))[:max_samples].tolist()
    val_samples = [dataset[i] for i in val_indices]
    
    val_loader = DataLoader(
        val_samples, 
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=False
    )
    
    model.eval()
    bridging_module.eval()
    
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            setup_model_caches(model, batch["target_tokens"].size(0))
            
            loss = forward_and_loss(model, bridging_module, batch, device)
            
            total_loss += loss.item()
            num_batches += 1
    
    model.train()
    bridging_module.train()
    
    # Return average loss
    return total_loss / num_batches if num_batches > 0 else 0.0

def strip_bias_keys(state_dict: dict) -> dict:
    new_sd = {}
    for k, v in state_dict.items():
        if k == "codebook_embedding.weight":
            print(f"Stripping {k} from checkpoint (training-only layer)")
            continue
        if not k.endswith(".bias"):
            new_sd[k] = v
        else:
            print(f"Stripping {k} from checkpoint")
    return new_sd

def remove_lora_modules(module: nn.Module) -> nn.Module:
    for name, child in list(module.named_children()):
        new_child = remove_lora_modules(child)
        setattr(module, name, new_child)

    if isinstance(module, LoRALinear):
        out_features, in_features = module.out_features, module.in_features

        # Determine if we actually need a bias
        has_bias = (module.bias is not None)
        new_linear = nn.Linear(
            in_features=in_features,
            out_features=out_features,
            bias=has_bias
        )

        # Copy over the merged weight
        new_linear.weight.data.copy_(module.weight.data)

        # If we had a bias in LoRALinear, copy it too
        if has_bias:
            new_linear.bias.data.copy_(module.bias.data)

        return new_linear

    return module

def merge_lora_layer(lora_module: LoRALinear):
    """
    Merge the LoRA params (lora_A, lora_B) into the base weight in-place.
    This transforms the LoRALinear into a standard Linear equivalent.
    """
    # W = W + (alpha/r) * (lora_B @ lora_A)
    merged_delta = lora_module.scaling * (lora_module.lora_B @ lora_module.lora_A)
    lora_module.weight.data += merged_delta

    # Optionally zero out LoRA parameters so they no longer affect anything
    lora_module.lora_A.data.zero_()
    lora_module.lora_B.data.zero_()

def merge_lora_weights(model: nn.Module):
    for module in model.modules():
        if isinstance(module, LoRALinear):
            merge_lora_layer(module)
    return model

def finetune(model, dataset):
    logger.info("Starting finetuning process")
    csv_file = os.path.join(OUTPUT_DIR, "training_metrics.csv")
    with open(csv_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "step", "global_step", "loss", "learning_rate", "val_loss"])
    
    def log_metrics(epoch, step, global_step, loss, learning_rate, val_loss=None):
        with open(csv_file, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch, step, global_step, loss, learning_rate, val_loss if val_loss is not None else ""])
        visualizer.update(epoch, global_step, loss, learning_rate, val_loss)

    visualizer = TrainingVisualizer(OUTPUT_DIR)
    bridging_module = BridgingModule(in_dim=2048, out_dim=1024).to(DEVICE)
    for param in bridging_module.parameters():
        param.requires_grad = True

    dataloader = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn, num_workers=0, pin_memory=False
    )

    trainable_params = [p for p in model.parameters() if p.requires_grad] + list(bridging_module.parameters())
    optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE)
    
    num_training_steps = len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
    lr_scheduler = get_scheduler(
        "cosine", optimizer=optimizer,
        num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps
    )
    
    if USE_WANDB:
        wandb.init(project="csm-finetuning", name="csm-lora-finetune-fixed")
    
    scaler = torch.amp.GradScaler() if MIXED_PRECISION else None
    global_step = 0
    validation_frequency = max(1, len(dataloader) // (2 * GRADIENT_ACCUMULATION_STEPS))
    
    model.train()
    bridging_module.train()
    
    logger.info("Calculating initial validation loss...")
    initial_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
    logger.info(f"Initial validation loss: {initial_val_loss:.6f}")
    
    current_loss = 0.0
    current_lr = LEARNING_RATE

    for epoch in range(NUM_EPOCHS):
        logger.info(f"Starting epoch {epoch+1}/{NUM_EPOCHS}")
        progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
        
        for step, batch in enumerate(dataloader):
            try:
                setup_model_caches(model, batch["target_tokens"].size(0))
                
                with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16, enabled=MIXED_PRECISION):
                    loss = forward_and_loss(model, bridging_module, batch, DEVICE)
                    if GRADIENT_ACCUMULATION_STEPS > 1:
                        loss = loss / GRADIENT_ACCUMULATION_STEPS

                if torch.isnan(loss) or torch.isinf(loss):
                    logger.warning(f"NaN or Inf loss detected at step {step}. Skipping batch.")
                    optimizer.zero_grad()
                    progress_bar.update(1)
                    continue

                if MIXED_PRECISION:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(dataloader):
                    if MIXED_PRECISION:
                        scaler.unscale_(optimizer)
                    
                    torch.nn.utils.clip_grad_norm_(trainable_params, MAX_GRAD_NORM)
                    
                    if MIXED_PRECISION:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    current_lr = optimizer.param_groups[0]["lr"]
                    current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS if GRADIENT_ACCUMULATION_STEPS > 1 else loss.item()
                    current_epoch = epoch + (step + 1) / len(dataloader)
                    
                    current_val_loss = None
                    if global_step > 0 and global_step % validation_frequency == 0:
                        logger.info(f"Calculating validation loss at global step {global_step}...")
                        current_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
                        logger.info(f"Validation loss: {current_val_loss:.6f}")
                    
                    log_metrics(current_epoch, step, global_step, current_loss, current_lr, current_val_loss)
                    global_step += 1
                    
                    if USE_WANDB:
                        wandb.log({"loss": current_loss, "learning_rate": current_lr, "epoch": current_epoch, "global_step": global_step, "val_loss": current_val_loss})
                    
                    progress_bar.set_postfix({"loss": f"{current_loss:.4f}", "lr": f"{current_lr:.2e}"})
                
                progress_bar.update(1)
            
            except Exception as e:
                logger.error(f"Error in batch {step}: {e}")
                import traceback
                logger.error(traceback.format_exc())
                try:
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()
                except: pass
                progress_bar.update(1)
                continue
        
        logger.info(f"Calculating validation loss at end of epoch {epoch+1}...")
        epoch_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)
        logger.info(f"Epoch {epoch+1} validation loss: {epoch_val_loss:.6f}")
        log_metrics(epoch + 1.0, len(dataloader), global_step, current_loss, current_lr, epoch_val_loss)
        
        checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-epoch-{epoch+1}")
        os.makedirs(checkpoint_dir, exist_ok=True)

        checkpoint_tensors = {
            **model.state_dict(),
            **bridging_module.state_dict()
        }
        save_file(checkpoint_tensors, os.path.join(checkpoint_dir, "model.safetensors"))

        logger.info(f"Saved checkpoint to {checkpoint_dir}")
    
    final_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE, max_samples=100)
    logger.info(f"Final validation loss: {final_val_loss:.6f}")
    
    logger.info("Merging LoRA weights into the base model...")
    merge_lora_weights(model)
    model = remove_lora_modules(model)
    merged_state = strip_bias_keys(model.state_dict())

    final_merged_path = os.path.join(OUTPUT_DIR, "model.safetensors")
    save_file(merged_state, final_merged_path)
    logger.info(f"LoRA-merged & replaced model saved to {final_merged_path}")
    
    visualizer.finalize()
    if USE_WANDB:
        wandb.finish()
    
    return model

def forward_and_loss(model, bridging_module, batch, device):
    target_tokens = batch["target_tokens"].to(device)
    target_masks = batch["target_masks"].to(device)
    positions = batch["positions"].to(device)

    input_tokens = target_tokens[:, :-1]
    input_masks = target_masks[:, :-1]
    input_positions = positions[:, :-1]
    labels = target_tokens[:, 1:]
    label_masks = target_masks[:, 1:]

    if input_tokens.size(1) == 0:
        return torch.tensor(0.0, requires_grad=True, device=device)

    # 1. Embed tokens and apply mask
    embed = model._embed_tokens(input_tokens)
    masked_embed = embed * input_masks.unsqueeze(-1)
    h = masked_embed.sum(dim=2)

    # 2. Pass through the backbone
    backbone_out = model.backbone(h, input_pos=input_positions, mask=None)

    # 3. Calculate loss for all codebooks
    loss_fct = nn.CrossEntropyLoss(ignore_index=0)
    total_loss = 0.0
    num_codebooks_with_loss = 0

    c0_logits = model.codebook0_head(backbone_out)
    c0_labels = labels[..., 0]
    
    active_mask = label_masks[..., 0].view(-1)
    if active_mask.sum() > 0:
        active_logits = c0_logits.view(-1, c0_logits.size(-1))[active_mask]
        active_labels = c0_labels.view(-1)[active_mask]
        c0_loss = loss_fct(active_logits, active_labels)
        total_loss += c0_loss
        num_codebooks_with_loss += 1

    decoder_states = bridging_module(backbone_out)
    
    num_codebooks = model.config.audio_num_codebooks
    for i in range(1, num_codebooks):
        if hasattr(model, 'audio_head') and len(model.audio_head) >= i:
            weight_i = model.audio_head[i-1] 
            
            logits_i = decoder_states @ weight_i

            labels_i = labels[..., i]
            active_mask_i = label_masks[..., i].view(-1)

            if active_mask_i.sum() > 0:
                active_logits_i = logits_i.view(-1, logits_i.size(-1))[active_mask_i]
                active_labels_i = labels_i.view(-1)[active_mask_i]
                loss_i = loss_fct(active_logits_i, active_labels_i)
                total_loss += loss_i
                num_codebooks_with_loss += 1

    if num_codebooks_with_loss > 0:
        return total_loss / num_codebooks_with_loss
    else:
        return torch.tensor(0.0, requires_grad=True, device=device)

def main():
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    torch.backends.cuda.enable_flash_sdp(True)
    if DEVICE == "cuda":
        torch.backends.cudnn.benchmark = True
    
    model, text_tokenizer, audio_tokenizer = prepare_csm_model_for_training()
    audio_text_pairs = transcribe_audio_files()
    if not audio_text_pairs:
        logger.error(f"No audio files found or transcribed in {AUDIO_DIR}")
        return
    
    dataset = CSMDataset(
        audio_text_pairs,
        text_tokenizer=text_tokenizer,
        audio_tokenizer=audio_tokenizer,
        device=DEVICE
    )
    
    logger.info(f"Dataset created with {len(dataset)} samples")
    
    try:
        finetune(model, dataset)
        logger.info("Finetuning completed successfully!")
    except Exception as e:
        logger.error(f"Error during finetuning: {e}")
        import traceback
        logger.error(traceback.format_exc())
        
        try:
            # If there's an error, at least save a partial state
            partial_path = os.path.join(OUTPUT_DIR, "model_partial.safetensors")
            torch.save(model.state_dict(), partial_path)
            logger.info(f"Saved partial model to {partial_path} despite errors")
        except Exception as save_error:
            logger.error(f"Could not save partial model: {save_error}")

if __name__ == "__main__":
    main()


================================================
FILE: main.py
================================================
import asyncio
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"  
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 
os.environ["PYTORCH_DISABLE_CUDA_GRAPHS"] = "1"  
import platform
import sqlite3
import time
import threading
import json
import queue
from fastapi.websockets import WebSocketState
import torch
import torchaudio
import sounddevice as sd
import numpy as np
import whisper
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy import create_engine, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from typing import Optional
from generator import Segment, load_csm_1b_local
from llm_interface import LLMInterface
from rag_system import RAGSystem 
from vad import AudioStreamProcessor
from pydantic import BaseModel
import logging
from config import ConfigManager
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import re
speaking_start_time = 0.0          # set every time the AI begins a new turn
MIN_BARGE_LATENCY   = 0.9   
speaker_counters = {
    0: 0,  # AI
    1: 0   # User
}
current_generation_id = 1
pending_user_inputs = []
user_input_lock = threading.Lock()
audio_fade_duration = 0.3  # seconds for fade-out
last_interrupt_time = 0
interrupt_cooldown = 6.0  # seconds between allowed interrupts
audio_chunk_buffer = []  # Buffer to store the most recent audio chunks for fade-out
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
model_thread = None
model_queue = queue.Queue()
model_result_queue = queue.Queue()
model_thread_running = threading.Event()
llm_lock = threading.Lock()
audio_gen_lock = threading.Lock()
# Database
Base = declarative_base()
engine = create_engine("sqlite:///companion.db")
SessionLocal = sessionmaker(bind=engine)

class Conversation(Base):
    __tablename__ = "conversations"
    id = Column(Integer, primary_key=True, index=True)
    session_id = Column(String, index=True)
    timestamp = Column(String)
    user_message = Column(Text)
    ai_message = Column(Text)
    audio_path = Column(String)

Base.metadata.create_all(bind=engine)

# Pydantic config schema
class CompanionConfig(BaseModel):
    system_prompt: str
    reference_audio_path: str
    reference_text: str
    reference_audio_path2: Optional[str] = None  # optional field
    reference_text2: Optional[str] = None  # optional field
    reference_audio_path3: Optional[str] = None  # optional field
    reference_text3: Optional[str] = None  # optional field
    model_path: str
    llm_path: str
    max_tokens: int = 8192
    voice_speaker_id: int = 0
    vad_enabled: bool = True
    vad_threshold: float = 0.5
    embedding_model: str = "all-MiniLM-L6-v2"

# Global state
conversation_history = []
config = None
audio_queue = queue.Queue()
is_speaking = False
interrupt_flag = threading.Event()
generator = None
llm = None
rag = None
vad_processor = None
reference_segments = []
active_connections = []
message_queue = asyncio.Queue()

# Async event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# FastAPI
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
config_manager = ConfigManager()
model_id = "openai/whisper-large-v3-turbo"
# Whisper
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True
)            
whisper_model.to("cuda")
processor = AutoProcessor.from_pretrained(model_id)
whisper_pipe = pipeline(
    "automatic-speech-recognition",
    model=whisper_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch.float16,
    device='cuda',
)
# Background queue
async def process_message_queue():
    while True:
        message = await message_queue.get()
        for client in active_connections[:]:
            try:
                if client.client_state == WebSocketState.CONNECTED:
                    await client.send_json(message)
            except Exception as e:
                logger.error(f"Error in message queue for client: {e}")
                if client in active_connections:
                    active_connections.remove(client)
        message_queue.task_done()

def load_reference_segments(config_data: CompanionConfig):
    """Load multiple reference clips for voice‑cloning."""
    global reference_segments
    reference_segments = []
    
    # Load primary reference (required)
    if os.path.isfile(config_data.reference_audio_path):
        logger.info(f"Loading primary reference audio: {config_data.reference_audio_path}")
        wav, sr = torchaudio.load(config_data.reference_audio_path)
        wav = torchaudio.functional.resample(wav.squeeze(0),
                                         orig_freq=sr,
                                         new_freq=24_000)
        reference_segments.append(Segment(text=config_data.reference_text,
                                  speaker=config_data.voice_speaker_id,
                                  audio=wav))
    else:
        logger.warning(f"Primary reference audio '{config_data.reference_audio_path}' not found.")
    
    # Load second reference (optional)
    if config_data.reference_audio_path2 and os.path.isfile(config_data.reference_audio_path2):
        logger.info(f"Loading second reference audio: {config_data.reference_audio_path2}")
        wav, sr = torchaudio.load(config_data.reference_audio_path2)
        wav = torchaudio.functional.resample(wav.squeeze(0),
                                         orig_freq=sr,
                                         new_freq=24_000)
        reference_segments.append(Segment(text=config_data.reference_text2,
                                  speaker=config_data.voice_speaker_id,
                                  audio=wav))
    
    # Load third reference (optional)
    if config_data.reference_audio_path3 and os.path.isfile(config_data.reference_audio_path3):
        logger.info(f"Loading third reference audio: {config_data.reference_audio_path3}")
        wav, sr = torchaudio.load(config_data.reference_audio_path3)
        wav = torchaudio.functional.resample(wav.squeeze(0),
                                         orig_freq=sr,
                                         new_freq=24_000)
        reference_segments.append(Segment(text=config_data.reference_text3,
                                  speaker=config_data.voice_speaker_id,
                                  audio=wav))
    
    logger.info(f"Loaded {len(reference_segments)} reference audio segments.")

def transcribe_audio(audio_data, sample_rate):
    global whisper_model
    audio_np = np.array(audio_data).astype(np.float32)
    if sample_rate != 16000:
        try:
            audio_tensor = torch.tensor(audio_np).unsqueeze(0)
            audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000)
            audio_np = audio_tensor.squeeze(0).numpy()
        except: pass
    try:
        with torch.jit.optimized_execution(False):
            result = whisper_pipe(audio_np, generate_kwargs={"language": "english"}) 
            return result["text"]
    except:
        return "[Transcription error]"

def initialize_models(config_data: CompanionConfig):
    global generator, llm, rag, vad_processor, config
    config = config_data                         

    logger.info("Loading LLM …")
    llm = LLMInterface(config_data.llm_path,
                       config_data.max_tokens)

    logger.info("Loading RAG …")
    rag = RAGSystem("companion.db",
                    model_name=config_data.embedding_model)

    vad_model, vad_utils = torch.hub.load('snakers4/silero-vad',
                                          model='silero_vad',
                                          force_reload=False)
    vad_processor = AudioStreamProcessor(
        model=vad_model,
        utils=vad_utils,
        sample_rate=16_000,
        vad_threshold=config_data.vad_threshold,
        callbacks={"on_speech_start": on_speech_start,
                   "on_speech_end":   on_speech_end},
    )

    load_reference_segments(config_data)

    start_model_thread()

    logger.info("Compiling / warming‑up voice model …")
    t0 = time.time()

    # send a dummy request; max 0.5 s of audio, result discarded
    model_queue.put((
        "warm‑up.",                          # text
        config_data.voice_speaker_id,        # speaker
        [],                                  # no context
        500,                                 # max_ms
        0.7,                                 # temperature
        40,                                  # top‑k
    ))

    # block until worker signals EOS (None marker)
    while True:
        r = model_result_queue.get()
        if r is None:
            break

    logger.info(f"Voice model ready in {time.time() - t0:.1f}s")


def on_speech_start():
    asyncio.run_coroutine_threadsafe(
        message_queue.put(
            {
                "type":   "vad_status",
                "status": "speech_started",
                "should_interrupt": False,  # always False – UI never barges-in here
            }
        ),
        loop,
    )

def on_speech_end(audio_data, sample_rate):
    try:
        logger.info("Transcription starting")
        user_text = transcribe_audio(audio_data, sample_rate)
        logger.info(f"Transcription completed: '{user_text}'")

        session_id = "default"
        speaker_id = 1
        index = speaker_counters[speaker_id]
        user_audio_path = f"audio/user/{session_id}_user_{index}.wav"
        os.makedirs(os.path.dirname(user_audio_path), exist_ok=True)

        audio_tensor = torch.tensor(audio_data).unsqueeze(0)
        save_audio_and_trim(user_audio_path, session_id, speaker_id, audio_tensor.squeeze(0), sample_rate)
        add_segment(user_text, speaker_id, audio_tensor.squeeze(0))

        logger.info(f"User audio saved and segment appended: {user_audio_path}")

        speaker_counters[speaker_id] += 1

        # Send transcription to clients
        asyncio.run_coroutine_threadsafe(
            message_queue.put({"type": "transcription", "text": user_text}),
            loop
        )
        
        threading.Thread(target=lambda: process_user_input(user_text, session_id), daemon=True).start()
    except Exception as e:
        logger.error(f"VAD callback failed: {e}")

def process_pending_inputs():
    """Process only the latest user input after an interruption"""
    global pending_user_inputs, is_speaking, interrupt_flag
    time.sleep(0.2)
    is_speaking = False
    interrupt_flag.clear()
    
    with user_input_lock:
        if not pending_user_inputs:
            logger.info("No pending user inputs to process")
            return
        
        # Only take the most recent input and ignore others
        latest_input = pending_user_inputs[-1]
        logger.info(f"Processing only latest input: '{latest_input[0]}'")
        
        # Clear all pending inputs
        pending_user_inputs = []
        
        # Process only the latest input
        user_text, session_id = latest_input
        process_user_input(user_text, session_id)

def process_user_input(user_text, session_id="default"):
    global config, is_speaking, pending_user_inputs, interrupt_flag
    
    # Skip empty messages
    if not user_text or user_text.strip() == "":
        logger.warning("Empty user input received, ignoring")
        return
    
    interrupt_flag.clear()
    is_speaking = False
    
    # Check if we're currently supposed to be speaking
    if is_speaking:
        logger.info(f"AI is currently speaking, adding input to pending queue: '{user_text}'")
        
        with user_input_lock:
            # Only keep the most recent input, replacing any existing ones
            pending_user_inputs = [(user_text, session_id)]
            logger.info(f"Added user input as the only pending input: '{user_text}'")
        
        # Request interruption if not already interrupted
        if not interrupt_flag.is_set():
            logger.info("Automatically interrupting current speech for new input")
            interrupt_flag.set()
            # Notify clients of interruption
            asyncio.run_coroutine_threadsafe(
                message_queue.put({"type": "audio_status", "status": "interrupted"}),
                loop
            )
            
            # Allow a short delay before processing the new input
            time.sleep(0.3)
            
            # Process the pending input after interruption
            process_pending_inputs()
        
        return
    
    interrupt_flag.clear()
    
    # Normal processing continues...
    logger.info(f"Processing user input: '{user_text}'")
    context = "\n".join([f"User: {msg['user']}\nAI: {msg['ai']}" for msg in conversation_history[-5:]])
    rag_context = rag.query(user_text)
    system_prompt = config.system_prompt
    if rag_context:
        system_prompt += f"\n\nRelevant context:\n{rag_context}"

    # Notify clients that we're thinking
    asyncio.run_coroutine_threadsafe(
        message_queue.put({"type": "status", "message": "Thinking..."}),
        loop
    )
    
    try:
        with llm_lock: 
            ai_response = llm.generate_response(system_prompt, user_text, context)
        
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
        conversation_history.append({
            "timestamp": timestamp,
            "user": user_text,
            "ai": ai_response
        })
        
        try:
            db = SessionLocal()
            conv = Conversation(
                session_id=session_id,
                timestamp=timestamp,
                user_message=user_text,
                ai_message=ai_response,
                audio_path=""
            )
            db.add(conv)
            db.commit()
            index = speaker_counters[0]
            output_file = f"audio/ai/{session_id}_response_{index}.wav"
            speaker_counters[0] += 1
            conv.audio_path = output_file
            db.commit()
            db.close()
        except Exception as e:
            logger.error(f"Database error: {e}")
        
        threading.Thread(target=lambda: rag.add_conversation(user_text, ai_response), daemon=True).start()
        
        asyncio.run_coroutine_threadsafe(
            message_queue.put({"type": "audio_status", "status": "preparing"}),
            loop
        )
        
        # Small delay to ensure client is ready
        time.sleep(0.2)
        
        # Send the response to clients
        asyncio.run_coroutine_threadsafe(
            message_queue.put({"type": "response", "text": ai_response}),
            loop
        )

        time.sleep(0.5)
        
        if is_speaking:
            logger.warning("Still speaking when trying to start new audio - forcing interrupt")
            interrupt_flag.set()
            is_speaking = False
            time.sleep(0.5)  # Give time for cleanup
        
        interrupt_flag.clear()  # Make absolutely sure
        is_speaking = False    # Reset for audio thread to take over
        
        # Start audio generation in a new thread
        threading.Thread(target=audio_generation_thread, args=(ai_response, output_file), daemon=True).start()

    except Exception as e:
        logger.error(f"Error generating response: {e}")
        asyncio.run_coroutine_threadsafe(
            message_queue.put({"type": "error", "message": "Failed to generate response"}),
            loop
        )

def model_worker(cfg: CompanionConfig):
    global generator, model_thread_running

    logger.info("Model worker thread started")

    if generator is None:
        torch._inductor.config.triton.cudagraphs = False  # Disable cudagraphs
        torch._inductor.config.fx_graph_cache = False  # Disable graph caching
        logger.info("Loading voice model inside worker thread …")
        generator = load_csm_1b_local(cfg.model_path, "cuda")
        logger.info("Voice model ready (compiled with cudagraphs)")

    while model_thread_running.is_set():
        try:
            request = model_queue.get(timeout=0.1)
            if request is None:
                break

            text, speaker_id, context, max_ms, temperature, topk = request

            for chunk in generator.generate_stream(
                    text=text,
                    speaker=speaker_id,
                    context=context,
                    max_audio_length_ms=max_ms,
                    temperature=temperature,
                    topk=topk):
                model_result_queue.put(chunk)

                if not model_thread_running.is_set():
                    break

            model_result_queue.put(None) # EOS marker

        except queue.Empty:
            continue
        except Exception as e:
            import traceback
            logger.error(f"Error in model worker: {e}\n{traceback.format_exc()}")
            model_result_queue.put(Exception(f"Generation error: {e}"))

    logger.info("Model worker thread exiting")

def start_model_thread():
    global model_thread, model_thread_running

    if model_thread is not None and model_thread.is_alive():
        return                        

    model_thread_running.set()
    model_thread = threading.Thread(target=model_worker,
                                    args=(config,),
                                    daemon=True,
                                    name="model_worker")
    model_thread.start()
    logger.info("Started dedicated model worker thread")

async def run_audio_generation(text, output_file):
    """Async wrapper for audio generation that runs in the event loop thread"""
    audio_generation_thread(text, output_file)

def send_to_all_clients(message: dict):
    """Send a message to all connected WebSocket clients"""
    for client in active_connections[:]:
        try:
            if client.client_state == WebSocketState.CONNECTED:
                asyncio.run_coroutine_threadsafe(client.send_json(message), loop)
                logger.info(f"Sent message to client: {message}")
            else:
                logger.warning("Detected non-connected client; removing from active_connections")
                active_connections.remove(client)
        except Exception as e:
            logger.error(f"Error sending message to client: {e}")
            if client in active_connections:
                active_connections.remove(client)

saved_audio_paths = {
    "default": {
        0: [],  # AI
        1: []   # User
    }
}
MAX_AUDIO_FILES = 8

def save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate):
    """
    Save audio file and trim old audio files for both AI and user to maintain storage limits.
    
    Args:
        path: Path to save the audio file
        session_id: Conversation session ID
        speaker_id: 0 for AI, 1 for user
        tensor: Audio tensor to save
        sample_rate: Audio sample rate
    """
    torchaudio.save(path, tensor.unsqueeze(0), sample_rate)
    
    saved_audio_paths.setdefault(session_id, {}).setdefault(speaker_id, []).append(path)
    
    paths = saved_audio_paths[session_id][speaker_id]
    while len(paths) > MAX_AUDIO_FILES:
        old_path = paths.pop(0)
        if os.path.exists(old_path):
            os.remove(old_path)
            logger.info(f"Removed old audio file: {old_path}")
    
    other_speaker_id = 1 if speaker_id == 0 else 0
    if other_speaker_id in saved_audio_paths[session_id]:
        other_paths = saved_audio_paths[session_id][other_speaker_id]
        while len(other_paths) > MAX_AUDIO_FILES:
            old_path = other_paths.pop(0)
            if os.path.exists(old_path):
                os.remove(old_path)
                logger.info(f"Removed old audio file from other speaker: {old_path}")

MAX_SEGMENTS = 8

def add_segment(text, speaker_id, audio_tensor):
    """
    Add a new segment and ensure the total context stays within token limits.
    This version correctly separates protected and dynamic segments, performs trimming
    on the dynamic list, and rebuilds the global context list at the end.

    Args:
        text: Text content of the segment
        speaker_id: ID of the speaker (0 for AI, 1 for user)
        audio_tensor: Audio data as a tensor
    """
    global reference_segments, generator, config

    # Determine the number of protected, initial reference segments based on what was actually loaded.
    num_protected_segments = 0
    if config.reference_audio_path and os.path.exists(config.reference_audio_path):
        num_protected_segments += 1
    if config.reference_audio_path2 and os.path.exists(config.reference_audio_path2):
        num_protected_segments += 1
    if config.reference_audio_path3 and os.path.exists(config.reference_audio_path3):
        num_protected_segments += 1

    # Separate protected from dynamic segments from the current global state
    protected_segments = reference_segments[:num_protected_segments]
    dynamic_segments = reference_segments[num_protected_segments:]

    # Add the new segment to the dynamic list
    new_segment = Segment(text=text, speaker=speaker_id, audio=audio_tensor)
    dynamic_segments.append(new_segment)

    # First, trim by MAX_SEGMENTS count. The oldest dynamic segments are removed.
    max_dynamic_allowed = MAX_SEGMENTS - len(protected_segments)
    if len(dynamic_segments) > max_dynamic_allowed:
        # Keep only the most recent dynamic segments
        dynamic_segments = dynamic_segments[-max_dynamic_allowed:]

    # Then, check and trim by token count if necessary.
    # This loop will trim the oldest dynamic segments until the token count is acceptable.
    if hasattr(generator, '_text_tokenizer'):
        while dynamic_segments:
            # Tentatively combine for token calculation
            temp_full_list = protected_segments + dynamic_segments
            total_tokens = 0

            # Calculate total tokens for the current combination
            for segment in temp_full_list:
                tokens = generator._text_tokenizer.encode(f"[{segment.speaker}]{segment.text}")
                total_tokens += len(tokens)
                if segment.audio is not None:
                    # Approximate frame count to token conversion
                    audio_frames = segment.audio.size(0) // 6094
                    total_tokens += audio_frames

            # If we are within limits, the trimming is done.
            if total_tokens <= 4096:
                break

            # Otherwise, remove the oldest dynamic segment and re-check in the next loop iteration.
            dynamic_segments.pop(0)

    else:
        # Fallback if tokenizer is not available
        logger.warning("Unable to access tokenizer - falling back to word-based estimation for context trimming")

        def estimate_tokens(segment):
            words = segment.text.split()
            punctuation = sum(1 for char in segment.text if char in ".,!?;:\"'()[]{}")
            text_tokens = len(words) + punctuation
            audio_tokens = 0
            if segment.audio is not None:
                audio_frames = segment.audio.size(0) // 6094
                audio_tokens = audio_frames
            return text_tokens + audio_tokens

        while dynamic_segments:
            total_estimated_tokens = sum(estimate_tokens(s) for s in protected_segments) + \
                                     sum(estimate_tokens(s) for s in dynamic_segments)
            if total_estimated_tokens <= 2048:
                break
            dynamic_segments.pop(0)

    # Finally, overwrite the global variable with the new, correctly-trimmed list.
    # This is the single source of truth for the update.
    reference_segments = protected_segments + dynamic_segments

    # Log the final state for debugging
    logger.info(f"Context updated. Segments: {len(reference_segments)} total " +
                f"({len(protected_segments)} protected, {len(dynamic_segments)} dynamic).")

def preprocess_text_for_tts(text):
    """
    Removes all punctuation except periods, commas, exclamation points, and question marks
    from the input text to create cleaner speech output while preserving intonation.
    Args:
    text (str): Input text with potential punctuation
    Returns:
    str: Cleaned text with only allowed punctuation
    """
    # Define a regex pattern that matches all punctuation except periods, commas, exclamation points, and question marks
    # This includes: ; : " '  ~ @ # $ % ^ & * ( ) _ - + = [ ] { } \ | / < >
    pattern = r'[^\w\s.,!?\']'
    # Replace matched punctuation with empty string
    cleaned_text = re.sub(pattern, '', text)
    # normalize multiple spaces to single space
    cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
    # ensure there's a space after punctuation for better speech pacing
    cleaned_text = re.sub(r'([.,!?])(\S)', r'\1 \2', cleaned_text)
    return cleaned_text.strip()

def audio_generation_thread(text, output_file):
    global is_speaking, interrupt_flag, audio_queue, model_thread_running, current_generation_id, speaking_start_time
    
    current_generation_id += 1
    this_id = current_generation_id
    
    interrupt_flag.clear()
    
    # Log the start of generation
    logger.info(f"Starting audio generation for ID: {this_id}")
    
    # Try to acquire the lock, but don't block if it's busy
    if not audio_gen_lock.acquire(blocking=False):
        logger.warning(f"Audio generation {this_id} - lock acquisition failed, another generation is in progress")
        asyncio.run_coroutine_threadsafe(
            message_queue.put({
                "type": "error", 
                "message": "Audio generation busy, skipping synthesis",
                "gen_id": this_id
            }),
            loop
        )
        return
    
    try:
        # Start the model thread if it's not already running
        start_model_thread()
        
        interrupt_flag.clear()
        is_speaking = True
        speaking_start_time = time.time()
        
        # Create output directory
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        all_audio_chunks = []
        
        # Prepare text
        text_lower = text.lower()
        text_lower = preprocess_text_for_tts(text_lower)
        
        asyncio.run_coroutine_threadsafe(
            message_queue.put({
                "type": "audio_status", 
                "status": "preparing_generation",
                "gen_id": this_id
            }),
            loop
        )
        
        # Give client a moment to process
        time.sleep(0.2)
        
        logger.info(f"Sending generating status with ID {this_id}")
        asyncio.run_coroutine_threadsafe(
            message_queue.put({
                "type": "audio_status", 
                "status": "generating",
                "gen_id": this_id  # Include generation ID
            }),
            loop
        )
        
        # Small delay to ensure client gets the signal
        time.sleep(0.2)
        
        # Estimate audio length
        words = text.split()
        avg_wpm = 100
        words_per_second = avg_wpm / 60
        estimated_seconds = len(words) / words_per_second
        max_audio_length_ms = int(estimated_seconds * 1000)
        
        # Send request to model thread
        logger.info(f"Audio generation {this_id} - sending request to model thread")
        model_queue.put((
            text_lower,
            config.voice_speaker_id,
            reference_segments,
            max_audio_length_ms,
            0.8,  # temperature
            50    # topk
        ))
        
        # Start timing
        generation_start = time.time()
        chunk_counter = 0
        
        # Process results as they come
        while True:
            try:
                # Check for interruption FIRST before getting more results
                if interrupt_flag.is_set():
                    logger.info(f"Audio generation {this_id} - interrupt detected, stopping")
                    
                    # Signal model thread to exit and restart
                    model_thread_running.clear()
                    time.sleep(0.1)
                    model_thread_running.set()
                    start_model_thread()
                    
                    # Clear any remaining items in the result queue
                    while not model_result_queue.empty():
                        try:
                            model_result_queue.get_nowait()
                        except queue.Empty:
                            pass
                    
                    # Break out of the processing loop
                    break
                
                # Get result with timeout to allow checking interrupt
                result = model_result_queue.get(timeout=0.1)
                
                # Check for end of generation or error
                if result is None:
                    logger.info(f"Audio generation {this_id} - complete")
                    break
                    
                if isinstance(result, Exception):
                    logger.error(f"Audio generation {this_id} - error: {result}")
                    raise result
                
                # Track timing for first chunk
                if chunk_counter == 0:
                    first_chunk_time = time.time() - generation_start
                    logger.info(f"Audio generation {this_id} - first chunk latency: {first_chunk_time*1000:.1f}ms")
                
                chunk_counter += 1
                
                # One more interrupt check before processing chunk
                if interrupt_flag.is_set():
                    logger.info(f"Audio generation {this_id} - interrupt flag set during chunk processing")
                    break
                
                # Process this audio chunk
                audio_chunk = result
                all_audio_chunks.append(audio_chunk)
                
                # Convert to numpy and send to audio queue
                chunk_array = audio_chunk.cpu().numpy().astype(np.float32)
                audio_queue.put(chunk_array)
                
                if chunk_counter == 1:
                    logger.info(f"Sending first audio chunk with ID {this_id}")
                    # Notify client we're sending the first chunk
                    asyncio.run_coroutine_threadsafe(
                        message_queue.put({
                            "type": "audio_status", 
                            "status": "first_chunk",
                            "gen_id": this_id
                        }),
                        loop
                    )
                    # Small delay
                    time.sleep(0.1)
                
                # Send chunk with generation ID
                asyncio.run_coroutine_threadsafe(
                    message_queue.put({
                        "type": "audio_chunk",
                        "audio": chunk_array.tolist(),
                        "sample_rate": generator.sample_rate,
                        "gen_id": this_id,
                        "chunk_num": chunk_counter  # Include chunk number
                    }),
                    loop
                )
                
            except queue.Empty:
                # No results yet, keep checking
                continue
            except Exception as e:
                logger.error(f"Audio generation {this_id} - error processing result: {e}")
                break
        
        # Save complete audio if available
        if all_audio_chunks and not interrupt_flag.is_set():
            try:
                complete_audio = torch.cat(all_audio_chunks)
                save_audio_and_trim(output_file, "default", config.voice_speaker_id, complete_audio, generator.sample_rate)
                add_segment(text.lower(), config.voice_speaker_id, complete_audio)
                
                # Log statistics
                total_time = time.time() - generation_start
                total_audio_seconds = complete_audio.size(0) / generator.sample_rate
                rtf = total_time / total_audio_seconds
                logger.info(f"Audio generation {this_id} - completed in {total_time:.2f}s, RTF: {rtf:.2f}x")
            except Exception as e:
                logger.error(f"Audio generation {this_id} - error saving complete audio: {e}")
                
    except Exception as e:
        import traceback
        logger.error(f"Audio generation {this_id} - unexpected error: {e}\n{traceback.format_exc()}")
    finally:
        is_speaking = False
        
        # Signal end of audio
        audio_queue.put(None)
        
        try:
            logger.info(f"Audio generation {this_id} - sending completion status")
            asyncio.run_coroutine_threadsafe(
                message_queue.put({
                    "type": "audio_status", 
                    "status": "complete",
                    "gen_id": this_id
                }),
                loop
            )
        except Exception as e:
            logger.error(f"Audio generation {this_id} - failed to send completion status: {e}")
            
        # Process any pending inputs
        with user_input_lock:
            if pending_user_inputs:
                # Process pending inputs
                logger.info(f"Audio generation {this_id} - processing pending inputs")
                process_pending_inputs()
            
        # Release the lock
        logger.info(f"Audio generation {this_id} - releasing lock")
        audio_gen_lock.release()
    
def handle_interrupt(websocket):
    global is_speaking, last_interrupt_time, interrupt_flag, model_thread_running, speaking_start_time
    
    # Log the current state
    logger.info(f"Interrupt requested. Current state: is_speaking={is_speaking}")
    
    current_time = time.time()
    time_since_speech_start = current_time - speaking_start_time if speaking_start_time > 0 else 999
    time_since_last_interrupt = current_time - last_interrupt_time
    
    # Only apply cooldown for established speech, not for new speech
    if time_since_last_interrupt < interrupt_cooldown and time_since_speech_start > 3.0:
        logger.info(f"Ignoring interrupt: too soon after previous interrupt ({time_since_last_interrupt:.1f}s < {interrupt_cooldown}s)")
        # Let the client know we're not interrupting
        asyncio.run_coroutine_threadsafe(
           websocket.send_json({
               "type": "audio_status",
               "status": "interrupt_acknowledged",
               "success": False,
               "reason": "cooldown"
           }),
           loop
        )
        return False
    
    # Update the last interrupt time
    last_interrupt_time = current_time
    
    # We should interrupt if we're speaking OR if model generation is in progress
    if is_speaking or not model_result_queue.empty():
        logger.info("Interruption processing: we are speaking or generating")
        
        interrupt_flag.set()
        
        # Notify clients
        asyncio.run_coroutine_threadsafe(
            message_queue.put({"type": "audio_status", "status": "interrupted"}),
            loop
        )
        
        asyncio.run_coroutine_threadsafe(
           websocket.send_json({
               "type": "audio_status",
               "status": "interrupt_acknowledged"
           }),
           loop
        )
        
        # Clear the audio queue to stop additional audio from being processed
        try:
            # Drain the existing queue
            while not audio_queue.empty():
                try:
                    audio_queue.get_nowait()
                except queue.Empty:
                    break
                    
            # Add end signal
            audio_queue.put(None)
            logger.info("Audio queue cleared")
        except Exception as e:
            logger.error(f"Error clearing audio queue: {e}")
        
        # Reset VAD to prepare for new input
        if vad_processor:
            try:
                vad_processor.reset()
                logger.info("VAD processor reset")
            except Exception as e:
                logger.error(f"Error resetting VAD: {e}")
        
        # Stop current model worker if needed
        if model_thread and model_thread.is_alive():
            try:
                # Clear the thread running flag to stop generation
                model_thread_running.clear()
                
                # Wait a brief moment for thread to notice and exit
                time.sleep(0.1)
                
                # Now restart the thread state flag
                model_thread_running.set()
                
                # And restart the thread
                start_model_thread()
                logger.info("Model thread restarted")
            except Exception as e:
                logger.error(f"Error restarting model thread: {e}")
        
        return True
    
    logger.info("No active speech to interrupt")
    return False

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    global is_speaking, audio_queue
    
    await websocket.accept()
    active_connections.append(websocket)
    
    saved = config_manager.load_config()
    if saved:
        await websocket.send_json({"type": "saved_config", "config": saved})
        
    try:
        while True:
            data = await websocket.receive_json()
            
            if data["type"] == "config":
                # Config handling
                try:
                    config_data = data["config"]
                    
                    logger.info(f"Received config data keys: {config_data.keys()}")

                    for key in ["reference_audio_path", "reference_audio_path2", "reference_audio_path3",
                               "reference_text", "reference_text2", "reference_text3"]:
                        if key in config_data:
                            logger.info(f"Config includes {key}: {config_data[key]}")
                        else:
                            logger.warning(f"Config missing {key}")
                    
                    conf = CompanionConfig(**config_data)
                    
                    saved = config_manager.save_config(config_data)
                    
                    if saved:
                        initialize_models(conf)
                        await websocket.send_json({"type": "status", "message": "Models initialized and configuration saved"})
                    else:
                        await websocket.send_json({"type": "error", "message": "Failed to save configuration"})
                        
                except Exception as e:
                    logger.error(f"Error processing config: {str(e)}")
                    await websocket.send_json({"type": "error", "message": f"Configuration error: {str(e)}"})
                
                
            elif data["type"] == "request_saved_config":
                saved = config_manager.load_config()
                await websocket.send_json({"type": "saved_config", "config": saved})
            
            elif data["type"] == "text_message":
                user_text   = data["text"]
                session_id  = data.get("session_id", "default")
                logger.info(f"TEXT-MSG from client: {user_text!r}")

                # If the model is already talking, queue the request but
                if is_speaking:
                    with user_input_lock:
                        if len(pending_user_inputs) >= 3:
                            pending_user_inputs = pending_user_inputs[-2:]
                        pending_user_inputs.append((user_text, session_id))
                    await websocket.send_json(
                        {"type":"status","message":"Queued – I’ll answer in a moment"})
                    continue                         

                await message_queue.put({"type":"transcription","text":user_text})
                threading.Thread(
                    target=lambda: process_user_input(user_text, session_id),
                    daemon=True).start()
                
            elif data["type"] == "audio":
                audio_data = np.asarray(data["audio"], dtype=np.float32)
                sample_rate = data["sample_rate"]

                if sample_rate != 16000:
                    audio_tensor = torch.tensor(audio_data).unsqueeze(0)
                    audio_tensor = torchaudio.functional.resample(
                        audio_tensor, orig_freq=sample_rate, new_freq=16000
                    )
                    audio_data  = audio_tensor.squeeze(0).numpy()
                    sample_rate = 16000

                if config and config.vad_enabled:
                    vad_processor.process_audio(audio_data)  
                else:
                    text = transcribe_audio(audio_data, sample_rate)
                    await websocket.send_json({"type": "transcription", "text": text})
                    await message_queue.put({"type": "transcription", "text": text})

                    if is_speaking:
                        with user_input_lock:
                            pending_user_inputs.append((text, "default"))
                    else:
                        process_user_input(text)

                        
            elif data["type"] == "interrupt":
                logger.info("Explicit interrupt request received")
                
                # Always acknowledge receipt of interrupt request
                await websocket.send_json({
                    "type": "audio_status", 
                    "status": "interrupt_acknowledged"
                })
                
                # Then try to handle the actual interrupt
                success = handle_interrupt(websocket)
                
                # If successful, allow a brief delay for clearing everything
                if success:
                    await asyncio.sleep(0.3)  # Short delay to allow complete clearing
                    
                    # Force process pending inputs after interrupt
                    with user_input_lock:
                        if pending_user_inputs:
                            user_text, session_id = pending_user_inputs.pop(0)
                            pending_user_inputs.clear()  # Clear any backup to avoid multiple responses
                            
                            # Process in a new thread to avoid blocking
                            threading.Thread(
                                target=lambda: process_user_input(user_text, session_id),
                                daemon=True
                            ).start()
                
                # Send final status update about the interrupt
                await websocket.send_json({
                    "type": "audio_status", 
                    "status": "interrupted",
                    "success": success
                })
                
            elif data["type"] == "mute":
                await websocket.send_json({"type": "mute_status", "muted": data["muted"]})
                if not data["muted"] and config and config.vad_enabled:
                    vad_processor.reset()
                    
    except WebSocketDisconnect:
        if websocket in active_connections:
            active_connections.remove(websocket)

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.get("/setup", response_class=HTMLResponse)
async def setup_page(request: Request):
    return templates.TemplateResponse("setup.html", {"request": request})

@app.get("/chat", response_class=HTMLResponse)
async def chat_page(request: Request):
    return templates.TemplateResponse("chat.html", {"request": request})

@app.on_event("startup")
async def startup_event():
    os.makedirs("static", exist_ok=True)
    os.makedirs("audio/user", exist_ok=True)
    os.makedirs("audio/ai", exist_ok=True)
    os.makedirs("embeddings_cache", exist_ok=True)
    os.makedirs("templates", exist_ok=True)
    with open("templates/index.html", "w") as f:
        f.write("""<meta http-equiv="refresh" content="0; url=/setup" />""")
    try:
        torch.hub.load('snakers4/silero-vad', model='silero_vad', force_reload=False)
    except: pass
    asyncio.create_task(process_message_queue())

@app.on_event("shutdown")
async def shutdown_event():
    logger.info("Server shutting down...")

from flask import Flask, jsonify, request, send_file

@app.get("/api/conversations")
async def get_conversations(request: Request):
    conn = sqlite3.connect("companion.db")
    cur = conn.cursor()
    cur.execute("SELECT id, user_message, ai_message FROM conversations ORDER BY id DESC")
    data = [{"id": row[0], "user_message": row[1], "ai_message": row[2]} for row in cur.fetchall()]
    conn.close()
    return JSONResponse(content=data)

@app.route("/api/conversations/<int:conv_id>", methods=["PUT"])
def update_conversation(conv_id):
    data = request.get_json()
    conn = sqlite3.connect("companion.db")
    cur = conn.cursor()
    cur.execute("UPDATE conversations SET user_message=?, ai_message=? WHERE id=?",
                (data["user_message"], data["ai_message"], conv_id))
    conn.commit()
    conn.close()
    return "", 204

@app.delete("/api/conversations")
async def delete_all_conversations():
    try:
        conn = sqlite3.connect("companion.db")
        cur = conn.cursor()
        cur.execute("DELETE FROM conversations")
        conn.commit()
        conn.close()
        return {"status": "all deleted"}
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)

@app.delete("/api/conversations/{conv_id}")
async def delete_conversation(conv_id: int):
    try:
        conn = sqlite3.connect("companion.db")
        cur = conn.cursor()
        cur.execute("DELETE FROM conversations WHERE id = ?", (conv_id,))
        conn.commit()
        conn.close()
        return JSONResponse(content={"status": "deleted", "id": conv_id})
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)

app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

@app.get("/crud", response_class=HTMLResponse)
async def crud_ui(request: Request):
    return templates.TemplateResponse("crud.html", {"request": request})

if __name__ == "__main__":
    import uvicorn
    threading.Thread(target=lambda: asyncio.run(loop.run_forever()), daemon=True).start()
    uvicorn.run(app, host="0.0.0.0", port=8000)


================================================
FILE: models.py
================================================
import logging
from dataclasses import dataclass

import torch
import torch.nn as nn
import torchtune
from huggingface_hub import PyTorchModelHubMixin
from torchtune.models import llama3_2

logger = logging.getLogger(__name__)

def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
    return llama3_2.llama3_2(
        vocab_size=128_256,
        num_layers=16,
        num_heads=32,
        num_kv_heads=8,
        embed_dim=2048,
        max_seq_len=2048,
        intermediate_dim=8192,
        attn_dropout=0.0,
        norm_eps=1e-5,
        rope_base=500_000,
        scale_factor=32,
    )

def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
    return llama3_2.llama3_2(
        vocab_size=128_256,
        num_layers=4,
        num_heads=8,
        num_kv_heads=2,
        embed_dim=1024,
        max_seq_len=2048,
        intermediate_dim=8192,
        attn_dropout=0.0,
        norm_eps=1e-5,
        rope_base=500_000,
        scale_factor=32,
    )

FLAVORS = {
    "llama-1B": llama3_2_1B,
    "llama-100M": llama3_2_100M,
}

def _prepare_transformer(model):
    embed_dim = model.tok_embeddings.embedding_dim
    model.tok_embeddings = nn.Identity()
    model.output = nn.Identity()
    return model, embed_dim

def _create_causal_mask(seq_len: int, device: torch.device):
    return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))

def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
    """
    Args:
        mask: (max_seq_len, max_seq_len)
        input_pos: (batch_size, seq_len)

    Returns:
        (batch_size, seq_len, max_seq_len)
    """
    r = mask[input_pos, :]
    return r

def _multinomial_sample_one_no_sync(probs):  # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs).exponential_(1)
    return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)

def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
    logits = logits / temperature

    filter_value: float = -float("Inf")
    indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
    scores_processed = logits.masked_fill(indices_to_remove, filter_value)
    scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
    probs = torch.nn.functional.softmax(scores_processed, dim=-1)

    sample_token = _multinomial_sample_one_no_sync(probs)
    return sample_token

@dataclass
class ModelArgs:
    backbone_flavor: str
    decoder_flavor: str
    text_vocab_size: int
    audio_vocab_size: int
    audio_num_codebooks: int


class Model(
    nn.Module,
    PyTorchModelHubMixin,
    repo_url="https://github.com/SesameAILabs/csm",
    pipeline_tag="text-to-speech",
    license="apache-2.0",
):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config

        self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
        self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())

        self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
        self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)

        self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
        self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
        self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))

    def setup_caches(self, max_batch_size: int) -> torch.Tensor:
        """Setup KV caches and return a causal mask."""
        dtype = next(self.parameters()).dtype
        device = next(self.parameters()).device

        with device:
            self.backbone.setup_caches(max_batch_size, dtype)
            self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)

        self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
        self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))

    def generate_frame(
        self,
        tokens: torch.Tensor,
        tokens_mask: torch.Tensor,
        input_pos: torch.Tensor,
        temperature: float,
        topk: int,
    ) -> torch.Tensor:
        """
        Args:
            tokens: (batch_size, seq_len, audio_num_codebooks+1)
            tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
            input_pos: (batch_size, seq_len) positions for each token
            mask: (batch_size, seq_len, max_seq_len

        Returns:
            (batch_size, audio_num_codebooks) sampled tokens
        """
        dtype = next(self.parameters()).dtype
        b, s, _ = tokens.size()

        assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
        curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
        embeds = self._embed_tokens(tokens)
        masked_embeds = embeds * tokens_mask.unsqueeze(-1)
        h = masked_embeds.sum(dim=2)
        h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)

        last_h = h[:, -1, :]
        c0_logits = self.codebook0_head(last_h)
        c0_sample = sample_topk(c0_logits, topk, temperature)
        c0_embed = self._embed_audio(0, c0_sample)

        curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
        curr_sample = c0_sample.clone()
        curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)

        # Decoder caches must be reset every frame.
        self.decoder.reset_caches()
        for i in range(1, self.config.audio_num_codebooks):
            curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
            decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
                dtype=dtype
            )
            ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
            ci_sample = sample_topk(ci_logits, topk, temperature)
            ci_embed = self._embed_audio(i, ci_sample)

            curr_h = ci_embed
            curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
            curr_pos = curr_pos[:, -1:] + 1

        return curr_sample

    def reset_caches(self):
        self.backbone.reset_caches()
        self.decoder.reset_caches()

    def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
        return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)

    def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
        text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)

        audio_tokens = tokens[:, :, :-1] + (
            self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
        )
        audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
            tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
        )

        return torch.cat([audio_embeds, text_embeds], dim=-2)
    


================================================
FILE: rag_system.py
================================================
import sqlite3
import numpy as np
import json
from pathlib import Path
import time
from typing import List, Dict, Any, Tuple, Optional
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch

class RAGSystem:
    def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = "./embeddings_cache"):
        """
        Initialize the enhanced RAG system with embeddings.
        
        Args:
            db_path: Path to the SQLite database
            model_name: Name of the sentence-transformer model to use
            cache_dir: Directory to cache embeddings
        """
        self.db_path = db_path
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        
        # Load embedding model
        print(f"Loading embedding model: {model_name}")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_name, device=self.device)
        print(f"Embedding model loaded on {self.device}")
        
        # Cache for embeddings
        self.embedding_cache = self._load_embedding_cache()
        
        # Initialize database tables if needed
        self._initialize_db()
        
        # Load existing conversations and cache embeddings
        self._load_conversations()
    
    def _initialize_db(self):
        """Create necessary tables if they don't exist."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Create conversations table if it doesn't exist
        cursor.execute("""
        CREATE TABLE IF NOT EXISTS conversations (
            id INTEGER PRIMARY KEY,
            user_message TEXT,
            ai_message TEXT,
            timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
        )
        """)
        
        # Create embeddings table if it doesn't exist
        cursor.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id INTEGER PRIMARY KEY,
            conversation_id INTEGER,
            text TEXT,
            embedding_file TEXT,
            chunk_id TEXT,
            FOREIGN KEY (conversation_id) REFERENCES conversations(id)
        )
        """)
        
        conn.commit()
        conn.close()
    
    def _load_embedding_cache(self) -> Dict[str, np.ndarray]:
        """Load cached embeddings from disk."""
        cache = {}
        
        for cache_file in self.cache_dir.glob("*.json"):
            try:
                with open(cache_file, "r") as f:
                    cache_data = json.load(f)
                    for chunk_id, embedding_data in cache_data.items():
                        cache[chunk_id] = np.array(embedding_data)
            except Exception as e:
                print(f"Error loading cache file {cache_file}: {e}")
        
        print(f"Loaded {len(cache)} cached embeddings")
        return cache
    
    def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):
        """Save an embedding to the cache."""
        cache_file = self.cache_dir / f"{chunk_id[:2]}.json"
        
        # Load existing cache file or create new one
        if cache_file.exists():
            try:
                with open(cache_file, "r") as f:
                    cache_data = json.load(f)
            except:
                cache_data = {}
        else:
            cache_data = {}
        
        # Add new embedding
        cache_data[chunk_id] = embedding.tolist()
        
        # Save cache file
        with open(cache_file, "w") as f:
            json.dump(cache_data, f)
    
    def _load_conversations(self):
        """Load existing conversations from the database and cache their embeddings."""
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            # First check if the conversations table exists
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'")
            if not cursor.fetchone():
                print("Conversations table does not exist yet")
                conn.close()
                return
            
            # Get all conversations not yet in the embeddings table
            cursor.execute("""
            SELECT c.id, c.user_message, c.ai_message 
            FROM conversations c
            LEFT JOIN embeddings e ON c.id = e.conversation_id
            WHERE e.id IS NULL
            """)
            
            conversations = cursor.fetchall()
            if not conversations:
                conn.close()
                return
            
            print(f"Processing embeddings for {len(conversations)} new conversations")
            
            for conv_id, user_message, ai_message in conversations:
                # Create chunks for indexing
                if user_message is not None and ai_message is not None:  # Ensure neither is None
                    self._process_conversation(conv_id, user_message, ai_message, conn)
            
            conn.close()
            print("Finished processing conversation embeddings")
        except Exception as e:
            print(f"Error loading conversations: {e}")
    
    def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection):
        """Process a conversation and store its embeddings."""
        try:
            cursor = conn.cursor()
            
            # Combine user and AI messages
            full_text = f"User: {user_message}\nAI: {ai_message}"
            
            # For simplicity, we're using the entire message as a chunk
            # In a more sophisticated system, you might split long messages into smaller chunks
            chunk_id = f"conv_{conv_id}"
            
            # Check if we already have this embedding cached
            if chunk_id not in self.embedding_cache:
                # Generate embedding
                embedding = self.model.encode(full_text)
                self.embedding_cache[chunk_id] = embedding
                
                # Save to cache
                self._save_embedding_to_cache(chunk_id, embedding)
            else:
                embedding = self.embedding_cache[chunk_id]
            
            # Store reference in database
            embedding_file = f"{chunk_id[:2]}.json"
            cursor.execute(
                "INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)",
                (conv_id, full_text, embedding_file, chunk_id)
            )
            
            conn.commit()
        except Exception as e:
            print(f"Error processing conversation {conv_id}: {e}")
    
    def add_conversation(self, user_message: str, ai_message: str) -> int:
        """
        Add a new conversation to the RAG system.
        
        Returns:
            The id of the newly added conversation
        """
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            # Insert the conversation first
            cursor.execute(
                "INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)",
                (user_message, ai_message)
            )
            
            # Get the ID of the new conversation
            conv_id = cursor.lastrowid
            
            # Process the conversation for embeddings
            self._process_conversation(conv_id, user_message, ai_message, conn)
            
            conn.commit()
            conn.close()
            
            return conv_id
        except Exception as e:
            print(f"Error adding conversation: {e}")
            return -1
    
    def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]:
        """
        Query the RAG system for relevant context.
        
        Args:
            query_text: The query text
            top_k: Number of top results to return
            
        Returns:
            List of tuples with (text, similarity_score)
        """
        if query_text is None or query_text.strip() == "":
            print("Error: Empty query text")
            return []
            
        try:
            # Generate query embedding
            query_embedding = self.model.encode(query_text)
            
            # Find most similar conversations
            results = self._find_similar(query_embedding, top_k)
            
            return results
        except Exception as e:
            print(f"Error during query: {e}")
            return []
    
    def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str:
        """
        Get formatted context from the RAG system.
        
        Args:
            query_text: The query text
            top_k: Number of top results to return
            threshold: Minimum similarity score to include
            
        Returns:
            String with relevant context
        """
        results = self.query(query_text, top_k)
        
        if not results:
            return ""
        
        # Format results
        context_parts = []
        for text, score in results:
            # Only include really relevant results
            if score < threshold:  # Threshold for relevance
                continue
            context_parts.append(f"Relevance: {score:.2f}\n{text}")
        
        return "\n---\n".join(context_parts)
    
    def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]:
        """Find the most similar conversations to the query."""
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            # Check if the embeddings table exists
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'")
            if not cursor.fetchone():
                print("Embeddings table does not exist yet")
                conn.close()
                return []
            
            # Get all embeddings from the database
            cursor.execute("SELECT id, text, embedding_file, chunk_id FROM embeddings")
            results = cursor.fetchall()
            
            if not results:
                conn.close()
                return []
            
            # Calculate similarities
            similarities = []
            for db_id, text, embedding_file, chunk_id in results:
                # Get embedding from cache
                if chunk_id in self.embedding_cache:
                    embedding = self.embedding_cache[chunk_id]
                else:
                    # This should not happen, but just in case
                    # We'll reload from the cache file
                    cache_file = self.cache_dir / embedding_file
                    if cache_file.exists():
                        with open(cache_file, "r") as f:
                            cache_data = json.load(f)
                            if chunk_id in cache_data:
                                embedding = np.array(cache_data[chunk_id])
                                self.embedding_cache[chunk_id] = embedding
                            else:
                                continue
                    else:
                        continue
                
                # Calculate similarity
                similarity = cosine_similarity(
                    query_embedding.reshape(1, -1),
                    embedding.reshape(1, -1)
                )[0][0]
                
                similarities.append((text, similarity))
            
            conn.close()
            
            # Sort by similarity and return top_k
            similarities.sort(key=lambda x: x[1], reverse=True)
            return similarities[:top_k]
        except Exception as e:
            print(f"Error finding similar documents: {e}")
            return []
    
    def refresh(self):
        """Refresh embeddings from the database."""
        self._load_conversations()

# Example usage
if __name__ == "__main__":
    # Initialize the RAG system
    rag = RAGSystem("conversations.db")

================================================
FILE: requirements.txt
================================================
--extra-index-url=https://download.pytorch.org/whl/cu128
vllm==0.8.0
torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
tokenizers==0.21.0
transformers==4.49.0
huggingface_hub==0.28.1
moshi==0.2.2
sounddevice
torchtune==0.4.0
torchao==0.9.0
bitsandbytes
peft
wandb
silero_vad
python-multipart>=0.0.6
aiofiles>=23.1.0
sentence-transformers>=2.2.2
ctransformers>=0.2.24
python-multipart>=0.0.6
sqlalchemy>=2.0.0
pydantic>=2.0.0
fastapi>=0.95.0
uvicorn>=0.22.0
websockets>=11.0.3
jinja2>=3.0.0
speechbrain>=0.5.15
matplotlib
whisper-openai
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
numpy==1.26.0

================================================
FILE: run_csm.py
================================================
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass


# Default prompts are available at https://hf.co/sesame/csm-1b
prompt_filepath_conversational_a = hf_hub_download(
    repo_id="sesame/csm-1b",
    filename="prompts/conversational_a.wav"
)
prompt_filepath_conversational_b = hf_hub_download(
    repo_id="sesame/csm-1b",
    filename="prompts/conversational_b.wav"
)

SPEAKER_PROMPTS = {
    "conversational_a": {
        "text": (
            "like revising for an exam I'd have to try and like keep up the momentum because I'd "
            "start really early I'd be like okay I'm gonna start revising now and then like "
            "you're revising for ages and then I just like start losing steam I didn't do that "
            "for the exam we had recently to be fair that was a more of a last minute scenario "
            "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
            "sort of start the day with this not like a panic but like a"
        ),
        "audio": prompt_filepath_conversational_a
    },
    "conversational_b": {
        "text": (
            "like a super Mario level. Like it's very like high detail. And like, once you get "
            "into the park, it just like, everything looks like a computer game and they have all "
            "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
            "will have like a question block. And if you like, you know, punch it, a coin will "
            "come out. So like everyone, when they come into the park, they get like this little "
            "bracelet and then you can go punching question blocks around."
        ),
        "audio": prompt_filepath_conversational_b
    }
}

def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
    audio_tensor, sample_rate = torchaudio.load(audio_path)
    audio_tensor = audio_tensor.squeeze(0)
    # Resample is lazy so we can always call it
    audio_tensor = torchaudio.functional.resample(
        audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
    )
    return audio_tensor

def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
    audio_tensor = load_prompt_audio(audio_path, sample_rate)
    return Segment(text=text, speaker=speaker, audio=audio_tensor)

def main():
    # Select the best available device, skipping MPS due to float64 limitations
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using device: {device}")

    # Load model
    generator = load_csm_1b(device)

    # Prepare prompts
    prompt_a = prepare_prompt(
        SPEAKER_PROMPTS["conversational_a"]["text"],
        0,
        SPEAKER_PROMPTS["conversational_a"]["audio"],
        generator.sample_rate
    )

    prompt_b = prepare_prompt(
        SPEAKER_PROMPTS["conversational_b"]["text"],
        1,
        SPEAKER_PROMPTS["conversational_b"]["audio"],
        generator.sample_rate
    )

    # Generate conversation
    conversation = [
        {"text": "Hey how are you doing?", "speaker_id": 0},
        {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
        {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
        {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
    ]

    # Generate each utterance
    generated_segments = []
    prompt_segments = [prompt_a, prompt_b]

    for utterance in conversation:
        print(f"Generating: {utterance['text']}")
        audio_tensor = generator.generate(
            text=utterance['text'],
            speaker=utterance['speaker_id'],
            context=prompt_segments + generated_segments,
            max_audio_length_ms=10_000,
        )
        generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))

    # Concatenate all generations
    all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
    torchaudio.save(
        "full_conversation.wav",
        all_audio.unsqueeze(0).cpu(),
        generator.sample_rate
    )
    print("Successfully generated full_conversation.wav")

if __name__ == "__main__":
    main() 

================================================
FILE: setup.py
================================================
import os
import sys
import subprocess
import logging
import urllib.request
import torch
import time
import shutil
from pathlib import Path

# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def check_requirements():
    """Check if all required Python packages are installed"""
    logger.info("Checking requirements...")
    
    requirements = [
        "torch", "torchaudio", "fastapi", "uvicorn", "websockets", "numpy",
        "scikit-learn", "sqlalchemy", "pydantic", "jinja2", "whisper",
        "sounddevice", "soundfile", "sentence_transformers", "ctransformers"
    ]
    
    missing = []
    for req in requirements:
        try:
            __import__(req)
        except ImportError:
            missing.append(req)
    
    if missing:
        logger.warning(f"Missing required packages: {', '.join(missing)}")
        logger.info("Installing missing requirements...")
        subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
        logger.info("Requirements installed successfully")
    else:
        logger.info("All requirements are satisfied")

def download_vad_model():
    """Download the Silero VAD model using PyTorch Hub instead of direct URL"""
    model_path = "silero_vad.jit"
    
    if os.path.exists(model_path):
        logger.info(f"Silero VAD model already exists at {model_path}")
        return
    
    logger.info("Downloading Silero VAD model using PyTorch Hub...")
    try:
        # Use torch.hub to download the model instead of direct URL
        torch.hub.set_dir("./models")
        model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad",
                                     model="silero_vad",
                                     force_reload=True,
                                     onnx=False)
        
        # Save the model
        torch.jit.save(model, model_path)
        logger.info(f"Model downloaded and saved to {model_path}")
        
    except Exception as e:
        logger.error(f"Failed to download Silero VAD model using PyTorch Hub: {e}")
        logger.info("Falling back to energy-based VAD - the system will still work but with simpler voice detection")

def download_embedding_models():
    """Download the sentence transformer models for RAG"""
    logger.info("Setting up sentence transformer models...")
    
    try:
        from sentence_transformers import SentenceTransformer
        
        # Download lightweight model for embeddings
        logger.info("Downloading embedding models (this may take a few minutes)...
Download .txt
gitextract_5x1lml4p/

├── .github/
│   └── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── finetuned_model/
│   └── config.json
├── generator.py
├── llm_interface.py
├── loadandmergecheckpoint.py
├── lora.py
├── main.py
├── models.py
├── rag_system.py
├── requirements.txt
├── run_csm.py
├── setup.py
├── static/
│   ├── app.js
│   ├── chat.js
│   └── crud.js
├── templates/
│   ├── chat.html
│   ├── crud.html
│   ├── index.html
│   └── setup.html
├── test.py
└── vad.py
Download .txt
SYMBOL INDEX (171 symbols across 15 files)

FILE: config.py
  class ConfigManager (line 9) | class ConfigManager:
    method __init__ (line 14) | def __init__(self, config_path: str = "config/app_config.json"):
    method save_config (line 29) | def save_config(self, config_data: Dict[str, Any]) -> bool:
    method load_config (line 68) | def load_config(self) -> Optional[Dict[str, Any]]:
  function model_to_dict (line 102) | def model_to_dict(model: BaseModel) -> Dict[str, Any]:

FILE: generator.py
  class Segment (line 24) | class Segment:
  function load_llama3_tokenizer (line 31) | def load_llama3_tokenizer():
  class Generator (line 48) | class Generator:
    method __init__ (line 49) | def __init__(self, model: Model):
    method _tokenize_text_segment (line 74) | def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[tor...
    method _tokenize_audio (line 102) | def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, ...
    method _tokenize_segment (line 128) | def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, t...
    method _decode_frames (line 150) | def _decode_frames(self, frames):
    method generate_stream (line 160) | def generate_stream(
    method generate (line 346) | def generate(
  class AudioStreamWriter (line 457) | class AudioStreamWriter:
    method __init__ (line 461) | def __init__(self, filename, sample_rate):
    method _writer_worker (line 473) | def _writer_worker(self):
    method add_chunk (line 505) | def add_chunk(self, chunk):
    method write_file (line 514) | def write_file(self):
  function load_csm_1b_local (line 536) | def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_c...
  function warmup_generator (line 597) | def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is ...
  function load_csm_1b (line 733) | def load_csm_1b(device: str = "cuda") -> Generator:
  function stream_audio_to_wav (line 781) | def stream_audio_to_wav(filename, sample_rate):
  function generate_streaming_audio (line 820) | def generate_streaming_audio(

FILE: llm_interface.py
  class LLMInterface (line 6) | class LLMInterface:
    method __init__ (line 7) | def __init__(self, model_path: str, max_tokens: int = 8192, n_threads:...
    method trim_to_last_sentence (line 33) | def trim_to_last_sentence(self, text: str) -> str:
    method generate_response (line 53) | def generate_response(self, system_prompt: str, user_message: str, con...
    method tokenize (line 91) | def tokenize(self, text: str) -> List[int]:
    method get_token_count (line 105) | def get_token_count(self, text: str) -> int:
    method batch_generate (line 116) | def batch_generate(self, prompts: List[Dict[str, str]],

FILE: loadandmergecheckpoint.py
  function find_latest_checkpoint (line 19) | def find_latest_checkpoint(dir_path):
  function load_checkpoint_and_merge (line 31) | def load_checkpoint_and_merge():

FILE: lora.py
  class TrainingVisualizer (line 54) | class TrainingVisualizer:
    method __init__ (line 55) | def __init__(self, output_dir):
    method update (line 84) | def update(self, epoch, step, loss, lr, val_loss=None):
    method finalize (line 171) | def finalize(self):
  class LoRALinear (line 291) | class LoRALinear(nn.Module):
    method __init__ (line 292) | def __init__(self, in_features, out_features, r=32, alpha=64, dropout=...
    method forward (line 314) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function replace_linear_with_lora (line 323) | def replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout...
  function load_llama3_tokenizer (line 364) | def load_llama3_tokenizer():
  class AudioTextPair (line 377) | class AudioTextPair:
    method load_audio (line 383) | def load_audio(self, sample_rate=24000) -> torch.Tensor:
  class CSMDataset (line 398) | class CSMDataset(Dataset):
    method __init__ (line 399) | def __init__(self, data_items, text_tokenizer, audio_tokenizer, device):
    method __len__ (line 406) | def __len__(self):
    method tokenize_text_segment (line 409) | def tokenize_text_segment(self, text: str, speaker: int):
    method tokenize_audio (line 417) | def tokenize_audio(self, audio: torch.Tensor):
    method __getitem__ (line 438) | def __getitem__(self, idx: int):
  function collate_fn (line 468) | def collate_fn(batch):
  function transcribe_audio_files (line 505) | def transcribe_audio_files():
  function prepare_csm_model_for_training (line 578) | def prepare_csm_model_for_training():
  function setup_model_caches (line 647) | def setup_model_caches(model, batch_size):
  class BridgingModule (line 657) | class BridgingModule(nn.Module):
    method __init__ (line 659) | def __init__(self, in_dim=2048, out_dim=1024):
    method forward (line 663) | def forward(self, x):
  function compute_loss_for_codebooks_single_pass (line 666) | def compute_loss_for_codebooks_single_pass(
  function single_pass_forward (line 715) | def single_pass_forward(model, bridging_module, target_tokens, target_ma...
  function calculate_validation_loss (line 766) | def calculate_validation_loss(model, bridging_module, dataset, device, m...
  function strip_bias_keys (line 804) | def strip_bias_keys(state_dict: dict) -> dict:
  function remove_lora_modules (line 816) | def remove_lora_modules(module: nn.Module) -> nn.Module:
  function merge_lora_layer (line 843) | def merge_lora_layer(lora_module: LoRALinear):
  function merge_lora_weights (line 856) | def merge_lora_weights(model: nn.Module):
  function finetune (line 862) | def finetune(model, dataset):
  function forward_and_loss (line 1015) | def forward_and_loss(model, bridging_module, batch, device):
  function main (line 1077) | def main():

FILE: main.py
  class Conversation (line 63) | class Conversation(Base):
  class CompanionConfig (line 75) | class CompanionConfig(BaseModel):
  function process_message_queue (line 130) | async def process_message_queue():
  function load_reference_segments (line 143) | def load_reference_segments(config_data: CompanionConfig):
  function transcribe_audio (line 185) | def transcribe_audio(audio_data, sample_rate):
  function initialize_models (line 201) | def initialize_models(config_data: CompanionConfig):
  function on_speech_start (line 251) | def on_speech_start():
  function on_speech_end (line 263) | def on_speech_end(audio_data, sample_rate):
  function process_pending_inputs (line 293) | def process_pending_inputs():
  function process_user_input (line 316) | def process_user_input(user_text, session_id="default"):
  function model_worker (line 438) | def model_worker(cfg: CompanionConfig):
  function start_model_thread (line 481) | def start_model_thread():
  function run_audio_generation (line 495) | async def run_audio_generation(text, output_file):
  function send_to_all_clients (line 499) | def send_to_all_clients(message: dict):
  function save_audio_and_trim (line 522) | def save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate):
  function add_segment (line 555) | def add_segment(text, speaker_id, audio_tensor):
  function preprocess_text_for_tts (line 644) | def preprocess_text_for_tts(text):
  function audio_generation_thread (line 664) | def audio_generation_thread(text, output_file):
  function handle_interrupt (line 887) | def handle_interrupt(websocket):
  function websocket_endpoint (line 982) | async def websocket_endpoint(websocket: WebSocket):
  function index (line 1120) | async def index(request: Request):
  function setup_page (line 1124) | async def setup_page(request: Request):
  function chat_page (line 1128) | async def chat_page(request: Request):
  function startup_event (line 1132) | async def startup_event():
  function shutdown_event (line 1146) | async def shutdown_event():
  function get_conversations (line 1152) | async def get_conversations(request: Request):
  function update_conversation (line 1161) | def update_conversation(conv_id):
  function delete_all_conversations (line 1172) | async def delete_all_conversations():
  function delete_conversation (line 1184) | async def delete_conversation(conv_id: int):
  function crud_ui (line 1199) | async def crud_ui(request: Request):

FILE: models.py
  function llama3_2_1B (line 12) | def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
  function llama3_2_100M (line 27) | def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
  function _prepare_transformer (line 47) | def _prepare_transformer(model):
  function _create_causal_mask (line 53) | def _create_causal_mask(seq_len: int, device: torch.device):
  function _index_causal_mask (line 56) | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
  function _multinomial_sample_one_no_sync (line 68) | def _multinomial_sample_one_no_sync(probs):  # Does multinomial sampling...
  function sample_topk (line 72) | def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
  class ModelArgs (line 85) | class ModelArgs:
  class Model (line 93) | class Model(
    method __init__ (line 100) | def __init__(self, config: ModelArgs):
    method setup_caches (line 114) | def setup_caches(self, max_batch_size: int) -> torch.Tensor:
    method generate_frame (line 126) | def generate_frame(
    method reset_caches (line 180) | def reset_caches(self):
    method _embed_audio (line 184) | def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.T...
    method _embed_tokens (line 187) | def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:

FILE: rag_system.py
  class RAGSystem (line 11) | class RAGSystem:
    method __init__ (line 12) | def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2",...
    method _initialize_db (line 40) | def _initialize_db(self):
    method _load_embedding_cache (line 70) | def _load_embedding_cache(self) -> Dict[str, np.ndarray]:
    method _save_embedding_to_cache (line 86) | def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):
    method _load_conversations (line 107) | def _load_conversations(self):
    method _process_conversation (line 145) | def _process_conversation(self, conv_id: int, user_message: str, ai_me...
    method add_conversation (line 179) | def add_conversation(self, user_message: str, ai_message: str) -> int:
    method query (line 210) | def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, fl...
    method get_context (line 237) | def get_context(self, query_text: str, top_k: int = 3, threshold: floa...
    method _find_similar (line 264) | def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> Li...
    method refresh (line 323) | def refresh(self):

FILE: run_csm.py
  function load_prompt_audio (line 44) | def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch...
  function prepare_prompt (line 53) | def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate...
  function main (line 57) | def main():

FILE: setup.py
  function check_requirements (line 16) | def check_requirements():
  function download_vad_model (line 41) | def download_vad_model():
  function download_embedding_models (line 66) | def download_embedding_models():
  function setup_directories (line 90) | def setup_directories():
  function setup_database (line 116) | def setup_database():
  function check_cuda (line 144) | def check_cuda():
  function main (line 154) | def main():

FILE: static/app.js
  function showLoading (line 128) | function showLoading(show, message) {
  function getSelectedMic (line 154) | function getSelectedMic() {
  function getSelectedOutput (line 158) | function getSelectedOutput() {
  function populateAudioDevices (line 162) | async function populateAudioDevices() {
  function visualizeMic (line 193) | function visualizeMic(analyser, canvasId) {

FILE: static/chat.js
  constant SESSION_ID (line 17) | const SESSION_ID = "default";
  function createPermanentVoiceCircle (line 29) | function createPermanentVoiceCircle() {
  function showVoiceCircle (line 56) | function showVoiceCircle() {
  function hideVoiceCircle (line 61) | function hideVoiceCircle() {
  function showNotification (line 66) | function showNotification(msg, type='info'){
  function addMessageToConversation (line 77) | function addMessageToConversation(sender,text){
  function connectWebSocket (line 102) | function connectWebSocket() {
  function sendTextMessage (line 185) | function sendTextMessage(txt) {
  function resetAudioState (line 251) | function resetAudioState() {
  function clearAudioPlayback (line 271) | function clearAudioPlayback() {
  function requestInterrupt (line 344) | function requestInterrupt() {
  function handleWebSocketMessage (line 403) | function handleWebSocketMessage(d) {
  function queueAudioForPlayback (line 557) | function queueAudioForPlayback(arr, sr, genId = 0) {
  function queueAudioForPlayback (line 578) | function queueAudioForPlayback(arr, sr, genId = 0) {
  function processAudioPlaybackQueue (line 615) | function processAudioPlaybackQueue() {
  function playAudioChunk (line 684) | async function playAudioChunk(audioArr, sampleRate) {
  function startRecording (line 763) | async function startRecording() {
  function stopRecording (line 803) | function stopRecording() {
  function setupChatUI (line 826) | async function setupChatUI() {
  function initAudioLevelsChart (line 988) | function initAudioLevelsChart() {

FILE: static/crud.js
  function loadConversations (line 22) | async function loadConversations() {
  function renderConversations (line 28) | function renderConversations(list) {

FILE: test.py
  function load_audio (line 31) | def load_audio(audio_path):

FILE: vad.py
  class VoiceActivityDetector (line 5) | class VoiceActivityDetector:
    method __init__ (line 6) | def __init__(
    method reset (line 28) | def reset(self) -> None:
    method process_audio_chunk (line 41) | def process_audio_chunk(self, audio_chunk: np.ndarray) -> bool:
  class AudioStreamProcessor (line 99) | class AudioStreamProcessor:
    method __init__ (line 100) | def __init__(
    method process_audio (line 133) | def process_audio(self, audio_chunk: np.ndarray):
    method reset (line 186) | def reset(self):
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (280K chars).
[
  {
    "path": ".github/FUNDING.yml",
    "chars": 896,
    "preview": "# These are supported funding model platforms\n\ngithub: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [u"
  },
  {
    "path": ".gitignore",
    "chars": 398,
    "preview": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\np"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 9667,
    "preview": "# CSM - Optimized Streaming/Finetuning Edition\n\n---\n\nCSM (Conversational Speech Model) is a speech generation model from"
  },
  {
    "path": "config.py",
    "chars": 3713,
    "preview": "import os\nimport json\nimport logging\nfrom typing import Dict, Any, Optional\nfrom pydantic import BaseModel\n\nlogger = log"
  },
  {
    "path": "finetuned_model/config.json",
    "chars": 155,
    "preview": "{\n  \"audio_num_codebooks\": 32,\n  \"audio_vocab_size\": 2051,\n  \"backbone_flavor\": \"llama-1B\",\n  \"decoder_flavor\": \"llama-1"
  },
  {
    "path": "generator.py",
    "chars": 44034,
    "preview": "from dataclasses import dataclass\nimport math\nimport os\nfrom typing import List, Tuple, Generator as PyGenerator, Option"
  },
  {
    "path": "llm_interface.py",
    "chars": 6717,
    "preview": "import re\nfrom typing import List, Dict, Any, Optional\nimport torch\nfrom vllm import LLM, SamplingParams\n\nclass LLMInter"
  },
  {
    "path": "loadandmergecheckpoint.py",
    "chars": 2028,
    "preview": "import os\nimport re\nimport torch\nfrom models import Model\nfrom safetensors.torch import save_file, load_file \n\nfrom lora"
  },
  {
    "path": "lora.py",
    "chars": 44719,
    "preview": "import json\nimport os\nimport glob\nimport torch\nimport torchaudio\nimport logging\nimport numpy as np\nfrom dataclasses impo"
  },
  {
    "path": "main.py",
    "chars": 47270,
    "preview": "import asyncio\nimport os\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"  \nos.environ[\"CUDA_LAUN"
  },
  {
    "path": "models.py",
    "chars": 7230,
    "preview": "import logging\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torchtune\nfrom huggingface_h"
  },
  {
    "path": "rag_system.py",
    "chars": 12234,
    "preview": "import sqlite3\nimport numpy as np\nimport json\nfrom pathlib import Path\nimport time\nfrom typing import List, Dict, Any, T"
  },
  {
    "path": "requirements.txt",
    "chars": 621,
    "preview": "--extra-index-url=https://download.pytorch.org/whl/cu128\nvllm==0.8.0\ntorch==2.6.0\ntorchaudio==2.6.0\ntorchvision==0.21.0\n"
  },
  {
    "path": "run_csm.py",
    "chars": 4371,
    "preview": "import os\nimport torch\nimport torchaudio\nfrom huggingface_hub import hf_hub_download\nfrom generator import load_csm_1b, "
  },
  {
    "path": "setup.py",
    "chars": 6066,
    "preview": "import os\nimport sys\nimport subprocess\nimport logging\nimport urllib.request\nimport torch\nimport time\nimport shutil\nfrom "
  },
  {
    "path": "static/app.js",
    "chars": 7899,
    "preview": "let ws;\nlet micAnalyser, micContext, micSource, micStream;\nlet outputAnalyser, outputAudioCtx;\nlet lastConfig = null;\nle"
  },
  {
    "path": "static/chat.js",
    "chars": 31614,
    "preview": "let ws;\nlet sessionStartTime = null;\nlet messageCount = 0;\nlet audioLevelsChart = null;\nlet isRecording = false;\nlet isA"
  },
  {
    "path": "static/crud.js",
    "chars": 2666,
    "preview": "let allConversations = [];\n\ndocument.addEventListener('DOMContentLoaded', async () => {\n  await loadConversations();\n\n  "
  },
  {
    "path": "templates/chat.html",
    "chars": 9043,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n  <meta charset=\"UTF-8\">\n  <meta name=\"viewport\" content=\"width=device-width, in"
  },
  {
    "path": "templates/crud.html",
    "chars": 1014,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n  <meta charset=\"UTF-8\" />\n  <title>Conversation Manager</title>\n  "
  },
  {
    "path": "templates/index.html",
    "chars": 53,
    "preview": "<meta http-equiv=\"refresh\" content=\"0; url=/setup\" />"
  },
  {
    "path": "templates/setup.html",
    "chars": 4444,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n<meta charset=\"UTF-8\" />\n<meta name=\"viewport\" content=\"width=devic"
  },
  {
    "path": "test.py",
    "chars": 2735,
    "preview": "import time\nfrom generator import Segment, load_csm_1b, generate_streaming_audio\nimport torchaudio\n\nprint(f\"Starting scr"
  },
  {
    "path": "vad.py",
    "chars": 7386,
    "preview": "import numpy as np\nimport torch\nfrom typing import Callable, Dict, List\nfrom collections import deque\nclass VoiceActivit"
  }
]

About this extraction

This page contains the full source code of the davidbrowne17/csm-streaming GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (262.0 KB), approximately 60.2k tokens, and a symbol index with 171 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!