[
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with a single Open Collective username\nko_fi: davidbrowne17\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\nlfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry\npolar: # Replace with a single Polar username\nbuy_me_a_coffee: # Replace with a single Buy Me a Coffee username\nthanks_dev: # Replace with a single thanks.dev username\ncustom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']\n"
  },
  {
    "path": ".gitignore",
    "content": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# Virtual Environment\n.env\n.venv\nenv/\nvenv/\nENV/\n\n# IDE\n.idea/\n.vscode/\n*.swp\n*.swo\n\n# Project specific\n.python-version\n*.wav\noutput_*/\nbasic_audio.wav\nfull_conversation.wav\ncontext_audio.wav\n\n# Model files\n*.pt\n*.ckpt"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# CSM - Optimized Streaming/Finetuning Edition\n\n---\n\nCSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.\n\nOur fork adds **streaming audio generation**, **real-time playback**, and **performance optimizations** to the original implementation.\n\n## Requirements\n\n* A CUDA-compatible GPU\n* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions\n* Similarly, Python 3.10 is recommended, but newer versions may be fine\n* For some audio operations, `ffmpeg` may be required\n* For real-time audio playback: `pip install sounddevice`\n* Access to the following Hugging Face models:\n  * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)\n  * [CSM-1B](https://huggingface.co/sesame/csm-1b)\n\n### Setup\n\n```bash\nsudo apt-get update && sudo apt-get install -y libportaudio2 libportaudio-dev\ngit clone git@github.com:davidbrowne17/csm-streaming.git\ncd csm-streaming\npython3.10 -m venv .venv\nsource .venv/bin/activate\npip install -r requirements.txt\n\n# Optional speedup\npip install flash-attn\n# You will need access to CSM-1B and Llama-3.2-1B\nhuggingface-cli login\n```\n\n### Windows Setup\n\nThe `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.\nThe realtime demo uses VLLM for inference speed. This is currently not supported for windows but you can try with https://github.com/SystemPanic/vllm-windows until support is added.\n\n## Quickstart\n\nGenerate a sentence with streaming (chunks are processed and output as they're generated):\n\n```python\nimport time\nfrom huggingface_hub import hf_hub_download\nfrom generator import Generator, Segment, load_csm_1b, generate_streaming_audio\nimport torchaudio\n\n# Load the model\ngenerator = load_csm_1b(\"cuda\")\n\n# Generate audio with streaming and real-time playback\ngenerate_streaming_audio(\n    generator=generator,\n    text=\"Hello, this is streaming audio generation in action!\",\n    speaker=0,\n    context=[],  # No context needed for basic generation\n    output_file=\"streaming_audio.wav\",\n    play_audio=True  # Enable real-time playback\n)\n```\n## Finetuning\nTo finetune CSM all you need are some wav audio files with the speaker voice you want to train, just the raw wavs. Place them in a folder called audio_data and run lora.py.\nYou can configure the exact training params such as batch size, number of epochs and learning rate by modifying the values at the top of lora.py.\nYou will need a CUDA gpu with at least 12gb of vram depending on your dataset size and training params. You can monitor the training metrics via the dynamic png created in /finetuned_model/ folder. This contains various graphs to help you track the training progress. If you want to try a checkpoint you can use the loadandmergecheckpoint.py (make sure to set the same R and Alpha values as you used in the training)\n\n## Realtime chat demo\nTo use the realtime demo run the setup.py to download the required models, and then run main.py. This will open up a setup page at http://localhost:8000 in which you can set the paths for your chosen LLM and setup the CSM paths and reference audio as well as select your headset and mic. When loaded you will be able to chat in realtime with the AI just like the CSM demo. Our demo includes a dynamic RAG system so the AI can remember previous conversations. The demo by default uses whisper-large-v3-turbo for STT and includes Automatic Voice Detection using Silero VAD.\n\n## Usage\n\nOur optimized version offers several ways to use CSM with streaming capabilities:\n\n### 1. Basic Streaming Generation\n\nGenerate audio with streaming and save to a file:\n\n```python\nfrom generator import load_csm_1b, generate_streaming_audio\n\ngenerator = load_csm_1b(\"cuda\")\n\n# Generate with streaming (writes to file as it generates)\ngenerate_streaming_audio(\n    generator=generator,\n    text=\"This audio will be generated in chunks for faster response times.\",\n    speaker=0,\n    context=[],\n    output_file=\"streaming_output.wav\"\n)\n```\n\n### 2. Real-time Audio Playback\n\nGenerate and play audio in real-time as it's being generated:\n\n```python\nfrom generator import load_csm_1b, generate_streaming_audio\n\ngenerator = load_csm_1b(\"cuda\")\n\n# Generate with streaming and play in real-time\ngenerate_streaming_audio(\n    generator=generator,\n    text=\"You'll hear me speaking as I'm being generated!\",\n    speaker=0,\n    context=[],\n    output_file=\"streaming_output.wav\",\n    play_audio=True  # Enable real-time playback\n)\n```\n\n### 3. Low-level Streaming API\n\nFor more control, use the low-level streaming API:\n\n```python\nfrom generator import load_csm_1b, Segment\nimport torchaudio\n\ngenerator = load_csm_1b(\"cuda\")\n\n# Process audio chunks as they're generated\nfor audio_chunk in generator.generate_stream(\n    text=\"This is generated chunk by chunk.\",\n    speaker=0,\n    context=[]\n):\n    # Do something with each chunk as it's generated\n    print(f\"Received chunk of size: {audio_chunk.shape}\")\n    \n    # You could process or play each chunk here\n    # For example, write to a file incrementally\n    # Or send over a network connection\n```\n\n### 4. Generate with Context\n\nFor best results, provide reference audio context:\n\n```python\nfrom generator import load_csm_1b, Segment, generate_streaming_audio\nimport torchaudio\n\ngenerator = load_csm_1b(\"cuda\")\n\n# Load reference audio\ndef load_audio(audio_path):\n    audio_tensor, sample_rate = torchaudio.load(audio_path)\n    audio_tensor = torchaudio.functional.resample(\n        audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate\n    )\n    return audio_tensor\n\n# Create context segments\nsegments = [\n    Segment(\n        text=\"I knew I could trust you.\",\n        speaker=0,\n        audio=load_audio(\"reference.wav\")\n    )\n]\n\n# Generate with streaming using the context\ngenerate_streaming_audio(\n    generator=generator,\n    text=\"Me too, this is some cool stuff huh?\",\n    speaker=0,\n    context=segments,\n    output_file=\"contextual_streaming.wav\",\n    play_audio=True\n)\n```\n\n### 5. Regular Generation with Internal Streaming\n\nUse the original API with streaming enabled internally:\n\n```python\nfrom generator import load_csm_1b, Segment\nimport torchaudio\n\ngenerator = load_csm_1b(\"cuda\")\n\n# Regular generation but with internal streaming optimization\naudio = generator.generate(\n    text=\"This uses internal streaming for faster processing.\",\n    speaker=0,\n    context=[],\n    max_audio_length_ms=10_000,\n    stream=True  # Enable internal streaming optimization\n)\n\ntorchaudio.save(\"audio.wav\", audio.unsqueeze(0).cpu(), generator.sample_rate)\n```\n## Performance Optimizations\n\nOur optimized version includes several performance enhancements:\n\n- **Streaming Generation**: Processes and outputs audio in chunks instead of waiting for the entire generation achieving a Real-time factor (RTF): 0.28x (target: <1.0) on a 4090 (10 seconds of audio takes 2.8 seconds to generate)\n- **Frame Batching**: Processes multiple frames at once for better GPU utilization\n- **Half-precision Inference**: Uses bfloat16/float16 for faster processing\n- **CUDA Optimizations**: Enables cuDNN benchmarking and Flash Attention where available\n- **Memory Management**: Clears GPU cache before generation to reduce memory pressure\n\n## FAQ\n\n**How much faster is the streaming version?**\n\nThe perceived response time is significantly faster since you get the first audio chunks in milliseconds instead of waiting for the entire generation to complete. The actual total generation time is also improved by 40-60% depending on your hardware.\n\n**Does this model come with any voices?**\n\nThe model is a base generation model capable of producing a variety of voices but hasn't been fine-tuned on any specific voice. Provide reference audio for best results.\n\n**Can I converse with the model?**\n\nCSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. Using a seperate LLM you can converse with the realtime demo via the web ui.\n\n**Does it support other languages?**\n\nThe model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.\n\n## Misuse and abuse ⚠️\n\nThis project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:\n\n- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.\n- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.\n- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.\n\nBy using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.\n\n---\n\n## Original Authors\nJohan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.\n\n## Streaming, Realtime Demo and Finetuning Implementation\nDavid Browne\n\n## Support me\nSupport this project on Ko-fi: https://ko-fi.com/davidbrowne17\n\n## Transformers streaming\nIf you want to use streaming with the Transformers implementation you can find it here: https://github.com/davidbrowne17/csm-streaming-tf\n"
  },
  {
    "path": "config.py",
    "content": "import os\nimport json\nimport logging\nfrom typing import Dict, Any, Optional\nfrom pydantic import BaseModel\n\nlogger = logging.getLogger(__name__)\n\nclass ConfigManager:\n    \"\"\"\n    Manages configuration persistence for the AI Companion app.\n    Saves and loads configuration to avoid re-entering model paths.\n    \"\"\"\n    def __init__(self, config_path: str = \"config/app_config.json\"):\n        \"\"\"\n        Initialize the configuration manager.\n        \n        Args:\n            config_path: Path to store the configuration file\n        \"\"\"\n        self.config_path = config_path\n        self.config_dir = os.path.dirname(config_path)\n        \n        # Create config directory if it doesn't exist\n        if not os.path.exists(self.config_dir):\n            os.makedirs(self.config_dir, exist_ok=True)\n            logger.info(f\"Created configuration directory: {self.config_dir}\")\n    \n    def save_config(self, config_data: Dict[str, Any]) -> bool:\n        \"\"\"\n        Save configuration data to the config file.\n        \n        Args:\n            config_data: Configuration data to save\n            \n        Returns:\n            bool: True if successful, False otherwise\n        \"\"\"\n        try:\n            # Ensure directory exists\n            os.makedirs(self.config_dir, exist_ok=True)\n            print(config_data)\n            # Verify all reference paths are included\n            ref_paths = [\n                \"reference_audio_path\", \n                \"reference_audio_path2\", \n                \"reference_audio_path3\"\n            ]\n            \n            # Log which references are being saved\n            for path_key in ref_paths:\n                if path_key in config_data and config_data[path_key]:\n                    logger.info(f\"Saving reference path: {path_key}={config_data[path_key]}\")\n                else:\n                    logger.info(f\"No {path_key} provided in configuration\")\n            \n            # Save configuration\n            with open(self.config_path, 'w') as f:\n                json.dump(config_data, f, indent=2)\n            \n            logger.info(f\"Configuration saved to {self.config_path}\")\n            return True\n        \n        except Exception as e:\n            logger.error(f\"Failed to save configuration: {e}\")\n            return False\n    \n    def load_config(self) -> Optional[Dict[str, Any]]:\n        \"\"\"\n        Load configuration data from the config file.\n        \n        Returns:\n            Dict or None: Configuration data if successful, None otherwise\n        \"\"\"\n        if not os.path.exists(self.config_path):\n            logger.info(f\"Configuration file does not exist at {self.config_path}\")\n            return None\n        \n        try:\n            with open(self.config_path, 'r') as f:\n                config_data = json.load(f)\n            \n            # Log which references are being loaded\n            ref_paths = [\n                \"reference_audio_path\", \n                \"reference_audio_path2\", \n                \"reference_audio_path3\"\n            ]\n            \n            for path_key in ref_paths:\n                if path_key in config_data and config_data[path_key]:\n                    logger.info(f\"Loaded reference path: {path_key}={config_data[path_key]}\")\n            \n            logger.info(f\"Configuration loaded from {self.config_path}\")\n            return config_data\n        \n        except Exception as e:\n            logger.error(f\"Failed to load configuration: {e}\")\n            return None\n\n# Helper function to convert Pydantic model to dict\ndef model_to_dict(model: BaseModel) -> Dict[str, Any]:\n    \"\"\"Convert a Pydantic model to a dictionary suitable for JSON serialization\"\"\"\n    return json.loads(model.json())"
  },
  {
    "path": "finetuned_model/config.json",
    "content": "{\n  \"audio_num_codebooks\": 32,\n  \"audio_vocab_size\": 2051,\n  \"backbone_flavor\": \"llama-1B\",\n  \"decoder_flavor\": \"llama-100M\",\n  \"text_vocab_size\": 128256\n}"
  },
  {
    "path": "generator.py",
    "content": "from dataclasses import dataclass\nimport math\nimport os\nfrom typing import List, Tuple, Generator as PyGenerator, Optional, Callable\nimport time\nimport queue\nimport threading\nimport platform\nfrom typing_extensions import OrderedDict\nimport wave\nimport numpy as np\nimport torch\nimport torchaudio\nfrom huggingface_hub import hf_hub_download\nfrom models import Model, ModelArgs\nfrom moshi.models import loaders\nfrom tokenizers.processors import TemplateProcessing\nfrom transformers import AutoTokenizer\nimport logging\n\nlogger = logging.getLogger(__name__)\n\n@dataclass\nclass Segment:\n    speaker: int\n    text: str\n    sample_rate = 24_000\n    audio: torch.Tensor\n\n\ndef load_llama3_tokenizer():\n    \"\"\"\n    https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992\n    \"\"\"\n    tokenizer_name = \"unsloth/Llama-3.2-1B\"\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n    bos = tokenizer.bos_token\n    eos = tokenizer.eos_token\n    tokenizer._tokenizer.post_processor = TemplateProcessing(\n        single=f\"{bos}:0 $A:0 {eos}:0\",\n        pair=f\"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1\",\n        special_tokens=[(f\"{bos}\", tokenizer.bos_token_id), (f\"{eos}\", tokenizer.eos_token_id)],\n    )\n\n    return tokenizer\n\n\nclass Generator:\n    def __init__(self, model: Model):\n        self._model = model\n        self._model.setup_caches(1)\n\n        self._text_tokenizer = load_llama3_tokenizer()\n        device = next(model.parameters()).device\n\n        mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)\n        mimi = loaders.get_mimi(mimi_weight, device=device)\n        \n        num_codebooks = model.config.audio_num_codebooks\n        mimi.set_num_codebooks(num_codebooks)\n        self._num_codebooks = num_codebooks\n        self._audio_tokenizer = mimi\n\n        self.sample_rate = mimi.sample_rate\n        self.device = device\n\n        self._stream_buffer_size = 20\n        self.max_seq_len = 2048\n        self._cache = OrderedDict()\n        self._text_token_cache = {}\n        torch.set_num_threads(16)\n        torch.cuda.set_per_process_memory_fraction(0.95)\n\n    def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Tokenize text segment with caching optimization for reduced latency.\n        \"\"\"\n        # Check cache first\n        cache_key = f\"{speaker}:{text}\"\n        if not hasattr(self, '_text_token_cache'):\n            self._text_token_cache = {}\n        \n        if cache_key in self._text_token_cache:\n            return self._text_token_cache[cache_key]\n\n        text_tokens = self._text_tokenizer.encode(f\"[{speaker}]{text}\")\n        text_frame = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.long, device=self.device)\n        text_frame_mask = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.bool, device=self.device)\n        text_frame[:, -1] = torch.tensor(text_tokens, device=self.device)\n        text_frame_mask[:, -1] = True\n\n        frame_tokens = [text_frame]\n        frame_masks = [text_frame_mask]\n        \n        result = (torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0))\n        \n        self._text_token_cache[cache_key] = result\n        \n        return result\n\n\n    def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n\n        frame_tokens = []\n        frame_masks = []\n\n        # (K, T)\n        audio = audio.to(self.device)\n        audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]\n        \n        # Limit to the number of codebooks set in MIMI\n        audio_tokens = audio_tokens[:self._num_codebooks, :]\n        \n        # add EOS frame\n        eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)\n        audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)\n\n        audio_frame = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).long().to(self.device)\n        audio_frame_mask = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).bool().to(self.device)\n        audio_frame[:, :self._num_codebooks] = audio_tokens.transpose(0, 1)\n        audio_frame_mask[:, :self._num_codebooks] = True\n\n        frame_tokens.append(audio_frame)\n        frame_masks.append(audio_frame_mask)\n\n        return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)\n\n    def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:\n        text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)\n        audio_tokens, audio_masks = self._tokenize_audio(segment.audio)\n\n        total_len = text_tokens.size(0) + audio_tokens.size(0)\n\n        if total_len > self.max_seq_len:\n            overflow = total_len - self.max_seq_len\n\n            if text_tokens.size(0) > overflow:\n                text_tokens = text_tokens[overflow:]\n                text_masks = text_masks[overflow:]\n            else:\n                audio_overflow = overflow - text_tokens.size(0)\n                text_tokens = text_tokens[0:0]\n                text_masks = text_masks[0:0]\n                audio_tokens = audio_tokens[audio_overflow:]\n                audio_masks = audio_masks[audio_overflow:]\n\n        return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)\n\n    @torch.inference_mode()\n    def _decode_frames(self, frames):\n        if not frames:\n            return torch.tensor([])\n        \n        # Only use first N codebooks for faster decoding\n        frames_reduced = [frame[:, :self._num_codebooks//2] for frame in frames]\n        audio = self._audio_tokenizer.decode(torch.stack(frames_reduced).permute(1, 2, 0)).squeeze(0).squeeze(0)\n        return audio\n\n    @torch.inference_mode()\n    def generate_stream(\n        self,\n        text: str,\n        speaker: int,\n        context: List[Segment],\n        max_audio_length_ms: float = 90_000,\n        temperature: float = 0.7,\n        topk: int = 30,\n        on_chunk_generated: Optional[Callable[[torch.Tensor], None]] = None,\n    ):\n        \"\"\"\n        Generate audio in a streaming fashion, optimized for lower latency to first chunk.\n        \"\"\"\n        if torch.cuda.is_available():\n            torch.backends.cuda.matmul.allow_tf32 = True\n            torch.backends.cudnn.benchmark = True\n            torch.cuda.empty_cache()\n            torch.cuda.synchronize()\n\n        self._model.reset_caches()\n\n        max_generation_len = int(max_audio_length_ms / 80)\n\n        tokens, tokens_mask = [], []\n\n        initial_batch_size = 20\n        normal_batch_size = 20  \n        initial_buffer_size = 20\n        normal_buffer_size = 20\n        \n        batch_size = initial_batch_size\n        buffer_size = initial_buffer_size\n        first_chunk_delivered = False\n\n        context_start = time.time()\n        if context:\n            for segment in context:\n                segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)\n                tokens.append(segment_tokens)\n                tokens_mask.append(segment_tokens_mask)\n\n        gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)\n        tokens.append(gen_segment_tokens)\n        tokens_mask.append(gen_segment_tokens_mask)\n\n        prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)\n        prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)\n\n        max_seq_len = 2048\n        if prompt_tokens.size(0) > max_seq_len:\n            prompt_tokens = prompt_tokens[-max_seq_len:]\n            prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]\n\n        curr_tokens = prompt_tokens.unsqueeze(0)\n        curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)\n        curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)\n\n        expected_frame_count = buffer_size \n        frame_buffer = []\n\n        zeros_1_1 = torch.zeros(1, 1).long().to(self.device)\n        zeros_mask_1_1 = torch.zeros(1, 1).bool().to(self.device)\n\n        def update_tokens(sample):\n            nonlocal curr_tokens, curr_tokens_mask, curr_pos\n            ones = torch.ones_like(sample).bool()\n            curr_tokens = torch.cat([sample, zeros_1_1], dim=1).unsqueeze(1)\n            curr_tokens_mask = torch.cat([ones, zeros_mask_1_1], dim=1).unsqueeze(1)\n            curr_pos = curr_pos[:, -1:] + 1\n\n        with self._audio_tokenizer.streaming(1):\n            i = 0\n            generation_start = time.time()\n\n            while i < max_generation_len:\n                batch_end = min(i + batch_size, max_generation_len)\n                batch_size_actual = batch_end - i\n\n                batch_samples = []\n\n                for _ in range(batch_size_actual):\n                    with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):\n                        sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)\n                        if torch.cuda.is_available() and hasattr(torch, \"cuda\") and hasattr(torch.cuda, \"is_available\"):\n                            try:\n                                torch.cuda.synchronize()  # Force sync before checking\n                                if sample.numel() == 0 or torch.isnan(sample).any():\n                                    print(\"Warning: Generated empty or NaN sample, stopping generation\")\n                                    break\n                            except:\n                                print(\"Error checking tensor, stopping generation\")\n                                break\n                    if torch.all(sample == 0):\n                        break\n\n                    batch_samples.append(sample)\n                    update_tokens(sample)\n\n                if not batch_samples:\n                    break\n\n                frame_buffer.extend(batch_samples)\n                i += len(batch_samples)\n\n                if len(frame_buffer) >= buffer_size:\n                    frames_to_process = frame_buffer[:expected_frame_count]\n                    \n                    # If we don't have enough frames, pad with zeros to match expected shape\n                    if len(frames_to_process) < expected_frame_count:\n                        # Create padding frames (zeros)\n                        padding_frames = [\n                            torch.zeros_like(frames_to_process[0]) \n                            for _ in range(expected_frame_count - len(frames_to_process))\n                        ]\n                        \n                        # Combine actual frames with padding\n                        frames_to_process = frames_to_process + padding_frames\n                    \n                    frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)\n                    audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)\n                    \n                    # Keep remaining frames for next iteration\n                    frame_buffer = frame_buffer[expected_frame_count:]\n                    \n                    # Process and yield the chunk\n                    cpu_chunk = audio_chunk.cpu()\n                    if on_chunk_generated:\n                        on_chunk_generated(cpu_chunk)\n                    \n                    # After first chunk is delivered, switch to normal batch and buffer sizes\n                    if not first_chunk_delivered:\n                        batch_size = normal_batch_size\n                        buffer_size = normal_buffer_size\n                        expected_frame_count = buffer_size\n                        first_chunk_delivered = True\n                    \n                    yield cpu_chunk\n\n                    # Occasionally print progress and sync GPU\n                    if i >= 100 and (i % 100 == 0):\n                        if torch.cuda.is_available():\n                            torch.cuda.synchronize()\n                        print(f\"Generated {i} frames ({i * 0.08:.2f}s of audio)\")\n\n            # Process any remaining frames\n            if frame_buffer:\n                # Pad frame buffer if necessary\n                if len(frame_buffer) < expected_frame_count:\n                    padding_frames = [\n                        torch.zeros_like(frame_buffer[0]) \n                        for _ in range(expected_frame_count - len(frame_buffer))\n                    ]\n                    frames_to_process = frame_buffer + padding_frames\n                else:\n                    # Otherwise take as many frames as possible that are a multiple of expected_frame_count\n                    frames_multiple = (len(frame_buffer) // expected_frame_count) * expected_frame_count\n                    frames_to_process = frame_buffer[:frames_multiple]\n                    \n                frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0)\n                audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0)\n                \n                # Determine actual audio length (before padding)\n                actual_frames_percentage = min(len(frame_buffer), expected_frame_count) / expected_frame_count\n                actual_samples = int(audio_chunk.shape[0] * actual_frames_percentage)\n                \n                # Return only the non-padded portion of audio if we added padding\n                if len(frame_buffer) < expected_frame_count:\n                    audio_chunk = audio_chunk[:actual_samples]\n                    \n                cpu_chunk = audio_chunk.cpu()\n                if on_chunk_generated:\n                    on_chunk_generated(cpu_chunk)\n                yield cpu_chunk\n\n            # Print final performance metrics\n            if torch.cuda.is_available():\n                torch.cuda.synchronize()\n            total_time = time.time() - generation_start\n            frames_generated = i\n            audio_seconds = frames_generated * 0.08\n            rtf = total_time / audio_seconds if audio_seconds > 0 else float('inf')\n            print(f\"Total time: {total_time:.2f}s\")\n            print(f\"Generated {frames_generated} frames ({audio_seconds:.2f}s of audio)\")\n            print(f\"Real-time factor: {rtf:.3f}x (target: <1.0)\")\n\n    @torch.inference_mode()\n    def generate(\n        self,\n        text: str,\n        speaker: int,\n        context: List[Segment],\n        max_audio_length_ms: float = 90_000,\n        temperature: float = 0.8,\n        topk: int = 40,\n        stream: bool = False,\n        output_file: Optional[str] = None,\n    ):\n        \"\"\"\n        Generate audio with optional streaming and file output.\n        \n        Args:\n            text: Text to generate audio for\n            speaker: Speaker ID\n            context: List of context segments\n            max_audio_length_ms: Maximum audio length in milliseconds\n            temperature: Sampling temperature\n            topk: Top-k sampling parameter\n            stream: Whether to use streaming generation\n            output_file: If provided and stream=True, output will be saved to this file\n        \n        Returns:\n            torch.Tensor: Generated audio tensor\n        \"\"\"\n        if stream:\n            if output_file:\n                # Setup streaming to file\n                write_chunk, close_wav = stream_audio_to_wav(output_file, self.sample_rate)\n                \n                # Collect chunks while streaming to file\n                audio_chunks = []\n                t1 = time.time()\n                \n                for i, chunk in enumerate(self.generate_stream(\n                    text, speaker, context, max_audio_length_ms, temperature, topk\n                )):\n                    # Write to file\n                    write_chunk(chunk)\n                    # Store for return value\n                    audio_chunks.append(chunk)\n                    \n                    # Occasionally print progress\n                    if i % 5 == 0:\n                        print(f\"Part {i+1} available after {time.time() - t1:.4f}s\")\n                        t1 = time.time()\n                \n                # Close file\n                close_wav()\n                print(f\"Streaming complete, WAV file saved to {output_file}\")\n            else:\n                # Just collect chunks without file output\n                audio_chunks = []\n                for chunk in self.generate_stream(text, speaker, context, max_audio_length_ms, temperature, topk):\n                    audio_chunks.append(chunk)\n            \n            if not audio_chunks:\n                return torch.tensor([])\n            return torch.cat(audio_chunks)\n\n        # Non-streaming generation remains unchanged\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n        self._model.reset_caches()\n\n        max_generation_len = int(max_audio_length_ms / 80)\n        tokens, tokens_mask = [], []\n\n        for segment in context:\n            segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)\n            tokens.append(segment_tokens)\n            tokens_mask.append(segment_tokens_mask)\n\n        gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)\n        tokens.append(gen_segment_tokens)\n        tokens_mask.append(gen_segment_tokens_mask)\n\n        prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)\n        prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)\n\n        max_seq_len = 2048\n        if prompt_tokens.size(0) > max_seq_len:\n            prompt_tokens = prompt_tokens[-max_seq_len:]\n            prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:]\n\n        curr_tokens = prompt_tokens.unsqueeze(0)\n        curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)\n        curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)\n\n        samples = []\n        with self._audio_tokenizer.streaming(1):\n            for _ in range(max_generation_len):\n                sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)\n                if torch.all(sample == 0):\n                    break\n                samples.append(sample)\n\n                curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)\n                curr_tokens_mask = torch.cat(\n                    [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1\n                ).unsqueeze(1)\n                curr_pos = curr_pos[:, -1:] + 1\n\n        if not samples:\n            return torch.tensor([])\n\n        return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)\n\nclass AudioStreamWriter:\n    \"\"\"\n    Helper class for writing streaming audio to a file.\n    \"\"\"\n    def __init__(self, filename, sample_rate):\n        self.filename = filename\n        self.sample_rate = sample_rate\n        self.audio_chunks = []\n        self.lock = threading.Lock()\n        self.queue = queue.Queue()\n        self.running = True\n        \n        # Start background writer thread\n        self.writer_thread = threading.Thread(target=self._writer_worker, daemon=True)\n        self.writer_thread.start()\n        \n    def _writer_worker(self):\n        \"\"\"Background thread that handles audio chunk processing\"\"\"\n        buffer_chunks = []\n        last_flush_time = time.time()\n        \n        while self.running or not self.queue.empty():\n            try:\n                # Get chunk with timeout to allow for regular checks\n                chunk = self.queue.get(timeout=0.2)\n                buffer_chunks.append(chunk)\n                \n                # Periodically flush the buffer to the main list\n                current_time = time.time()\n                if len(buffer_chunks) >= 10 or (current_time - last_flush_time > 2.0 and buffer_chunks):\n                    with self.lock:\n                        self.audio_chunks.extend(buffer_chunks)\n                    buffer_chunks = []\n                    last_flush_time = current_time\n                    \n            except queue.Empty:\n                # If queue is empty but we have pending chunks, add them\n                if buffer_chunks:\n                    with self.lock:\n                        self.audio_chunks.extend(buffer_chunks)\n                    buffer_chunks = []\n                    last_flush_time = time.time()\n        \n        # Final flush of any remaining chunks\n        if buffer_chunks:\n            with self.lock:\n                self.audio_chunks.extend(buffer_chunks)\n        \n    def add_chunk(self, chunk):\n        \"\"\"Add an audio chunk to the buffer queue without blocking\"\"\"\n        try:\n            self.queue.put(chunk, timeout=0.1)\n        except queue.Full:\n            # If queue is full, add directly to avoid losing data\n            with self.lock:\n                self.audio_chunks.append(chunk)\n    \n    def write_file(self):\n        \"\"\"Write all collected audio chunks to file and clean up\"\"\"\n        # Signal the background thread to stop\n        self.running = False\n        # Wait for the thread to finish with a timeout\n        self.writer_thread.join(timeout=3.0)\n        \n        with self.lock:\n            if not self.audio_chunks:\n                return\n                \n            # Concatenate all chunks\n            audio = torch.cat(self.audio_chunks)\n            # Save to file\n            torchaudio.save(self.filename, audio.unsqueeze(0).cpu(), self.sample_rate)\n\nfrom safetensors.torch import load_file\nimport os\nimport torch\nfrom models import Model, ModelArgs\nfrom generator import Generator\n\ndef load_csm_1b_local(model_path: str, device: str = \"cuda\", audio_num_codebooks: int = 32):\n    \"\"\"\n    Load the CSM-1B model from a local checkpoint with extreme optimizations and warmup.\n    \"\"\"\n    import torch\n    import platform\n    from functools import lru_cache\n    from generator import Generator, Model, ModelArgs\n\n    # Enable all CUDA optimizations\n    torch.backends.cuda.matmul.allow_tf32 = True\n    if hasattr(torch.backends.cuda, 'enable_flash_sdp'):\n        torch.backends.cuda.enable_flash_sdp(True)\n    torch.backends.cudnn.benchmark = True\n    torch.backends.cudnn.enabled = True\n\n    print(f\"Loading CSM-1B model from local checkpoint '{model_path}' with extreme optimizations...\")\n\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n\n    config = ModelArgs(\n        backbone_flavor=\"llama-1B\",\n        decoder_flavor=\"llama-100M\",\n        text_vocab_size=128256,\n        audio_vocab_size=2051,\n        audio_num_codebooks=audio_num_codebooks,\n    )\n\n    model = Model.from_pretrained(model_path)\n    model.eval()\n\n    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16\n\n    model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')\n    model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')\n\n    model.to(device=device, dtype=dtype)\n\n    print(\"Model compilation complete. Creating generator...\")\n\n    generator = Generator(model)\n    generator._stream_buffer_size = 20\n\n    # Setup tokenization caching\n    generator._tokenization_cache = {}\n\n    original_tokenize_text = generator._tokenize_text_segment\n\n    @lru_cache(maxsize=2048)\n    def cached_tokenize_text_segment(text_str, speaker_int):\n        return original_tokenize_text(text_str, speaker_int)\n\n    generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)\n\n    # Perform warmup\n    warmup_generator(generator)\n\n    return generator\n\ndef warmup_generator(gen: Generator, warmup_text: str = \"Hello, this is a comprehensive warmup text that will exercise the model's generation capabilities.\", speaker_id: int = 0):\n    \"\"\"\n    Perform an extremely aggressive warmup to drastically reduce first-generation latency.\n    \"\"\"\n    print(\"Starting maximum-intensity warmup sequence...\")\n    \n    # Directly access and optimize the model's internal state\n    if hasattr(gen._model, 'backbone') and hasattr(gen._model.backbone, 'positional_embedding'):\n        # Force calculation of position embeddings to ensure they're cached\n        with torch.inference_mode():\n            positions = torch.arange(0, 2048).to(gen.device)\n            _ = gen._model.backbone.positional_embedding(positions)\n    \n    # Pre-allocate CUDA memory to prevent fragmentation during generation\n    if torch.cuda.is_available():\n        print(\"Optimizing GPU memory allocation...\")\n        # Try to reserve a large chunk of memory\n        try:\n            import math\n            reserved_memory = []\n            # Reserve multiple blocks of different sizes\n            for size_mb in [128, 256, 512, 256, 128, 64]:\n                size = int(size_mb * 1024 * 1024 / 4)  # Convert MB to float32 elements\n                tensor_size = int(math.sqrt(size))\n                tensor = torch.ones((tensor_size, tensor_size), device=gen.device, dtype=torch.float32)\n                tensor = tensor * 1.0  # Force allocation\n                reserved_memory.append(tensor)\n            torch.cuda.synchronize()\n            \n            # Now free the memory\n            for tensor in reserved_memory:\n                del tensor\n            reserved_memory = []\n            torch.cuda.empty_cache()\n            torch.cuda.synchronize()\n        except Exception as e:\n            print(f\"Memory pre-allocation: {e}\")\n    \n    # Create multiple dummy audio segments with varying characteristics\n    print(\"Creating diverse audio contexts...\")\n    audio_segments = []\n    \n    # Create 3 different audio patterns\n    for i in range(3):\n        length = 24000 * (i + 1)  # 1s, 2s, 3s\n        audio = torch.zeros(length).to(gen.device)\n        \n        # Add different patterns to each segment\n        if i == 0:\n            # Sine wave pattern\n            import math\n            t = torch.linspace(0, 8 * math.pi, length).to(gen.device)\n            audio = torch.sin(t) * 0.1\n        elif i == 1:\n            # Random noise pattern\n            audio = torch.randn(length).to(gen.device) * 0.05\n        else:\n            # Pulse pattern\n            audio[::800] = 0.2\n            audio[::801] = -0.2\n        \n        segment = Segment(\n            speaker=speaker_id,\n            text=f\"Warmup segment {i+1} with {length/24000:.1f}s of audio.\",\n            audio=audio\n        )\n        audio_segments.append(segment)\n    \n    # Force compilation of critical model components\n    print(\"Forcing compilation of critical components...\")\n    \n    # Directly exercise the audio tokenizer with real data\n    with torch.inference_mode():\n        for segment in audio_segments:\n            # Force tokenization of both text and audio\n            gen._tokenize_segment(segment)\n    \n    # Exercise the model's generation capabilities directly\n    with torch.inference_mode():\n        \n        # Generate some sample frames to ensure model is compiled\n        dummy_tokens = torch.ones(1, 10, gen._num_codebooks+1).long().to(gen.device)\n        dummy_mask = torch.ones(1, 10, gen._num_codebooks+1).bool().to(gen.device)\n        dummy_pos = torch.arange(0, 10).unsqueeze(0).to(gen.device)\n        \n        # Generate multiple frames with different parameters\n        for temp in [0.6, 0.7, 0.8]:\n            for topk in [20, 30, 40]:\n                _ = gen._model.generate_frame(dummy_tokens, dummy_mask, dummy_pos, temp, topk)\n    \n    gen._text_token_cache.clear()\n    \n    print(\"Running final generation with exact same setup as a real request...\")\n    \n    final_text = \"This is the final warmup that exactly matches a real generation request.\"\n    \n    # First tokenize the text - to fill the cache\n    gen._tokenize_text_segment(final_text, speaker_id)\n    \n    try:\n        # Now run a complete generation with a single context segment\n        generate_streaming_audio(\n            generator=gen,\n            text=final_text, \n            speaker=speaker_id,\n            context=[audio_segments[0]],  # Just one context segment\n            output_file=\"warmup_final.wav\",\n            max_audio_length_ms=6000,\n            temperature=0.7,\n            topk=30,\n            play_audio=False\n        )\n    except Exception as e:\n        print(f\"Final warmup run exception (ignorable): {e}\")\n    \n    # Force final synchronization and memory optimization\n    if torch.cuda.is_available():\n        print(\"Final GPU optimization...\")\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        \n        try:\n            # Allocate a large tensor to force compaction\n            large_tensor = torch.empty(int(1e9//4), dtype=torch.float, device=gen.device)\n            # Immediately delete it\n            del large_tensor\n        except RuntimeError:\n            # Expected if there's not enough memory\n            pass\n            \n        # Final cleanup\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n    \n    print(\"Maximum-intensity warmup complete. First generation should now be MUCH faster.\")\n\ndef load_csm_1b(device: str = \"cuda\") -> Generator:\n    \"\"\"\n    Load the CSM-1B model with extreme optimizations for real-time performance.\n    \"\"\"\n    # Enable all CUDA optimizations\n    torch.backends.cuda.matmul.allow_tf32 = True\n    torch.backends.cuda.enable_flash_sdp(True)\n    torch.backends.cudnn.benchmark = True\n    torch.backends.cudnn.enabled = True\n    \n    print(\"Loading CSM-1B model with extreme optimizations for real-time performance...\")\n    \n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n    \n    model = Model.from_pretrained(\"sesame/csm-1b\")\n    \n    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16\n    model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor')\n    model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor')\n\n    model.to(device=device, dtype=dtype)\n    \n    print(\"Model compilation complete. Creating generator...\")\n    \n    generator = Generator(model)\n    \n    generator._stream_buffer_size = 20\n    \n    \n    generator._tokenization_cache = {}\n    \n    from functools import lru_cache\n\n    # Patch the tokenize method with caching\n    original_tokenize_text = generator._tokenize_text_segment\n\n    @lru_cache(maxsize=2048)\n    def cached_tokenize_text_segment(text_str, speaker_int):\n        return original_tokenize_text(text_str, speaker_int)\n\n    generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker)\n    \n    warmup_generator(generator)\n\n    return generator\n\ndef stream_audio_to_wav(filename, sample_rate):\n    \"\"\"\n    Initialize a WAV writer for streaming audio chunks.\n    \n    Args:\n        filename: Output WAV file path\n        sample_rate: Audio sample rate in Hz\n    \n    Returns:\n        tuple: (write_chunk, close) functions for writing audio data and closing the file\n    \"\"\"\n    # Create a WAV file with the proper header\n    wav_file = wave.open(filename, 'wb')\n    wav_file.setnchannels(1)  # Mono\n    wav_file.setsampwidth(2)  # 16-bit\n    wav_file.setframerate(sample_rate)\n    \n    def write_chunk(audio_chunk):\n        # Convert tensor to numpy and then to int16 PCM format\n        if isinstance(audio_chunk, torch.Tensor):\n            # Ensure it's on CPU and detached before converting to numpy\n            audio_np = audio_chunk.detach().cpu().numpy()\n        else:\n            audio_np = audio_chunk\n            \n        # Normalize if needed (assuming audio is in [-1, 1] range)\n        if audio_np.max() <= 1.0 and audio_np.min() >= -1.0:\n            audio_int = (audio_np * 32767).astype(np.int16)\n        else:\n            audio_int = audio_np.astype(np.int16)\n            \n        # Write to WAV file\n        wav_file.writeframes(audio_int.tobytes())\n    \n    def close():\n        wav_file.close()\n        \n    return write_chunk, close\n\ndef generate_streaming_audio(\n    generator: Generator,\n    text: str,\n    speaker: int,\n    context: List[Segment],\n    output_file: str,\n    max_audio_length_ms: float = 90_000,\n    temperature: float = 1.0,\n    topk: int = 50,\n    play_audio: bool = False,\n):\n    \"\"\"\n    Generate audio with streaming output and comprehensive timing metrics.\n    Optimized for reduced first-chunk latency.\n    \"\"\"\n    # Initialize the streaming WAV writer\n    write_chunk, close_wav = stream_audio_to_wav(output_file, generator.sample_rate)\n    \n    # Set up audio playback if requested\n    audio_queue = queue.Queue(maxsize=100) if play_audio else None\n    stop_event = threading.Event()\n    \n    if play_audio:\n        try:\n            import sounddevice as sd\n            \n            # Get available sample rates for default output device to check compatibility\n            device_info = sd.query_devices(kind='output')\n            supported_rate = device_info.get('default_samplerate', 44100)\n            need_resampling = abs(supported_rate - generator.sample_rate) > 100\n            \n            if need_resampling:\n                try:\n                    # Use resampling if sample rate doesn't match\n                    import librosa\n                    print(f\"Resampling from {generator.sample_rate}Hz to {int(supported_rate)}Hz for playback\")\n                    \n                    def audio_playback_worker():\n                        while not stop_event.is_set() or not audio_queue.empty():\n                            try:\n                                chunk = audio_queue.get(timeout=0.5)\n                                if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:\n                                    audio_queue.task_done()\n                                    continue\n                                    \n                                audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk\n                                \n                                # Skip very short chunks (likely noise)\n                                if len(audio_np) < 100:\n                                    audio_queue.task_done()\n                                    continue\n                                    \n                                # Resample to device's supported rate\n                                resampled = librosa.resample(\n                                    audio_np, \n                                    orig_sr=generator.sample_rate, \n                                    target_sr=int(supported_rate)\n                                )\n                                sd.play(resampled, supported_rate, blocking=True)\n                                # Add a small delay to ensure audio finishes playing\n                                time.sleep(0.05)\n                                audio_queue.task_done()\n                            except queue.Empty:\n                                # If queue empty but not stopping, keep trying\n                                if not stop_event.is_set():\n                                    continue\n                                else:\n                                    break\n                            except Exception as e:\n                                print(f\"Playback error: {e}\")\n                                audio_queue.task_done()\n                except ImportError:\n                    print(\"Librosa not found. Using direct playback which may cause sample rate warnings.\")\n                    need_resampling = False\n            \n            if not need_resampling:\n                def audio_playback_worker():\n                    while not stop_event.is_set() or not audio_queue.empty():\n                        try:\n                            chunk = audio_queue.get(timeout=0.5)\n                            if isinstance(chunk, torch.Tensor) and chunk.numel() == 0:\n                                audio_queue.task_done()\n                                continue\n                                \n                            audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk\n                            \n                            # Skip very short chunks (likely noise)\n                            if len(audio_np) < 100:\n                                audio_queue.task_done()\n                                continue\n                                \n                            sd.play(audio_np, generator.sample_rate, blocking=True)\n                            # Add a small delay to ensure audio finishes playing\n                            time.sleep(0.05)\n                            audio_queue.task_done()\n                        except queue.Empty:\n                            # If queue empty but not stopping, keep trying\n                            if not stop_event.is_set():\n                                continue\n                            else:\n                                break\n                        except Exception as e:\n                            print(f\"Playback error: {e}\")\n                            audio_queue.task_done()\n            \n            # Start playback thread\n            playback_thread = threading.Thread(target=audio_playback_worker, daemon=False)\n            playback_thread.start()\n            \n        except ImportError:\n            print(\"sounddevice library not found. Install with 'pip install sounddevice' for real-time playback.\")\n            play_audio = False\n    \n    # Timing metrics\n    chunk_times = []\n    latency_to_first_chunk = None\n    total_audio_duration = 0\n    chunk_count = 0\n    \n    # Function to handle each generated chunk\n    def on_chunk_generated(chunk):\n        nonlocal chunk_count, latency_to_first_chunk, total_audio_duration\n        \n        current_time = time.time()\n        if chunk_count == 0:\n            latency_to_first_chunk = current_time - start_time\n            print(f\"First chunk latency: {latency_to_first_chunk*1000:.1f}ms\")\n            \n        # Save chunk to WAV file\n        write_chunk(chunk)\n        \n        # Update metrics\n        chunk_count += 1\n        chunk_duration = len(chunk) / generator.sample_rate\n        total_audio_duration += chunk_duration\n        chunk_times.append(current_time)\n        \n        # Send to audio player if enabled\n        if play_audio and audio_queue is not None:\n            try:\n                audio_queue.put(chunk, timeout=1.0)\n            except queue.Full:\n                pass  # Skip if queue is full to avoid blocking\n    \n    if torch.cuda.is_available():\n        print(\"Preparing GPU for low-latency generation...\")\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n        \n        # Pre-allocate some GPU memory to avoid allocation during generation\n        dummy_tensors = []\n        for i in range(5):\n            dummy = torch.ones((100, 100), device=generator.device)\n            dummy = dummy + 1.0  # Force computation\n            dummy_tensors.append(dummy)  # Keep reference to prevent deallocation\n            \n        torch.cuda.synchronize()\n    \n    # Set process priority to improve performance - use higher priority\n    try:\n        import psutil\n        process = psutil.Process()\n        if platform.system() == 'Windows':\n            process.nice(psutil.HIGH_PRIORITY_CLASS)\n        else:\n            process.nice(-1)\n    except (ImportError, PermissionError, psutil.AccessDenied):\n        pass\n    \n    print(f\"Starting audio generation for: '{text[:50]}{'...' if len(text) > 50 else ''}'\")\n    start_time = time.time()\n    \n    # Generate audio in chunks, catching possible errors\n    frame_count = 0\n    audio_chunks = []  # Store all chunks for possible use at the end\n    \n    try:\n        for audio_chunk in generator.generate_stream(\n            text=text,\n            speaker=speaker,\n            context=context,\n            max_audio_length_ms=max_audio_length_ms,\n            temperature=temperature,\n            topk=topk,\n            on_chunk_generated=on_chunk_generated\n        ):\n            frame_count += 1\n            audio_chunks.append(audio_chunk)  # Store the chunk\n            \n            # Print timing info less frequently to reduce overhead\n            if frame_count % 10 == 0:\n                current_time = time.time()\n                elapsed = current_time - start_time\n                if total_audio_duration > 0:\n                    rtf = elapsed / total_audio_duration\n                    remaining_time = (max_audio_length_ms/1000 - total_audio_duration) * rtf\n                    print(f\"Chunk {chunk_count}: {total_audio_duration:.1f}s audio in {elapsed:.1f}s \"\n                          f\"(RTF: {rtf:.2f}x, Est. remaining: {remaining_time:.1f}s)\")\n    except Exception as e:\n        print(f\"Error during audio generation: {e}\")\n        import traceback\n        traceback.print_exc()\n    \n    # Release dummy tensors to free memory\n    if 'dummy_tensors' in locals():\n        del dummy_tensors\n    \n    # Ensure all chunks are properly processed\n    if play_audio and audio_queue is not None:\n        print(\"Waiting for playback queue to finish...\")\n        try:\n            timeout_start = time.time()\n            while not audio_queue.empty() and time.time() - timeout_start < 5.0:\n                time.sleep(0.1)\n        except:\n            pass\n    \n    # Add a small delay to ensure everything is processed\n    time.sleep(0.5)\n    \n    # Signal audio worker that generation is complete\n    stop_event.set()\n    \n    # Close WAV file\n    close_wav()\n    \n    # Wait for audio playback to complete if enabled\n    if play_audio and 'playback_thread' in locals():\n        print(\"Waiting for audio playback to complete...\")\n        \n        # First, ensure the queue is empty\n        try:\n            timeout_start = time.time()\n            while not audio_queue.empty() and time.time() - timeout_start < 5.0:\n                time.sleep(0.1)\n        except:\n            pass\n            \n        # Set a flag to indicate complete audio playback is needed\n        if hasattr(sd, 'wait'):\n            try:\n                sd.wait()\n            except:\n                pass\n                \n        # Join the playback thread with timeout\n        playback_thread.join(timeout=5.0)\n        \n        # Force sounddevice to stop if it's still playing\n        try:\n            sd.stop()\n        except:\n            pass\n    \n    # Calculate and print detailed performance metrics\n    end_time = time.time()\n    total_elapsed = end_time - start_time\n    \n    # Calculate inter-chunk latency\n    if len(chunk_times) > 1:\n        inter_chunk_latencies = [chunk_times[i] - chunk_times[i-1] for i in range(1, len(chunk_times))]\n        avg_inter_chunk_latency = sum(inter_chunk_latencies) / len(inter_chunk_latencies)\n        max_inter_chunk_latency = max(inter_chunk_latencies) if inter_chunk_latencies else 0\n        min_inter_chunk_latency = min(inter_chunk_latencies) if inter_chunk_latencies else 0\n    else:\n        avg_inter_chunk_latency = max_inter_chunk_latency = min_inter_chunk_latency = 0\n    \n    rtf = total_elapsed / total_audio_duration if total_audio_duration > 0 else float('inf')\n    \n    print(\"\\n\" + \"=\"*50)\n    print(\"AUDIO GENERATION PERFORMANCE METRICS\")\n    print(\"=\"*50)\n    print(f\"First chunk latency: {latency_to_first_chunk*1000:.1f}ms\")\n    print(f\"Total generation time: {total_elapsed:.2f}s\")\n    print(f\"Audio duration: {total_audio_duration:.2f}s\")\n    print(f\"Real-time factor (RTF): {rtf:.3f}x (target: <1.0)\")\n    print(f\"Number of chunks: {chunk_count}\")\n    print(f\"Average chunk size: {(total_audio_duration/chunk_count)*1000:.1f}ms\") if chunk_count > 0 else None\n    print(f\"Average inter-chunk latency: {avg_inter_chunk_latency*1000:.1f}ms\")\n    print(f\"Min/Max inter-chunk latency: {min_inter_chunk_latency*1000:.1f}ms / {max_inter_chunk_latency*1000:.1f}ms\")\n    print(f\"Chunks per second: {chunk_count/total_elapsed:.2f}\")\n    print(f\"Output file: {output_file}\")\n    print(\"=\"*50)"
  },
  {
    "path": "llm_interface.py",
    "content": "import re\nfrom typing import List, Dict, Any, Optional\nimport torch\nfrom vllm import LLM, SamplingParams\n\nclass LLMInterface:\n    def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1):\n        \"\"\"Initialize the LLM interface using VLLM with a given model.\n        \n        Args:\n            model_path (str): Path to the model or HuggingFace model name\n            max_tokens (int, optional): Maximum context length. Defaults to 8192.\n            n_threads (int, optional): Number of CPU threads. Defaults to 8.\n            gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility.\n        \"\"\"\n        # VLLM configuration\n        self.llm = LLM(\n            model=model_path,\n            tensor_parallel_size=1,  # Adjust based on number of GPUs available\n            gpu_memory_utilization=0.6,\n            max_model_len=max_tokens,\n            swap_space=0,\n            trust_remote_code=True,\n            dtype=torch.float16,\n        )\n        \n        # Store configuration for reference\n        self.config = {\n            \"model_path\": model_path,\n            \"max_tokens\": max_tokens,\n        }\n        \n    def trim_to_last_sentence(self, text: str) -> str:        \n        \"\"\"\n        Return *text* truncated at the final full sentence boundary.\n        A boundary is considered to be any '.', '!' or '?' followed by\n        optional quotes/brackets, optional whitespace, and then end-of-string.\n\n        If no sentence terminator exists, the original text is returned.\n        \"\"\"\n        # Regex explanation:\n        #   (.*?[.!?][\"')\\]]?)   any text lazily until a terminator\n        #   \\s*$                 followed only by whitespace till end-of-string\n        m = re.match(r\"^(.*?[.!?][\\\"')\\]]?)\\s*$\", text, re.DOTALL)\n        if m:\n            return m.group(1).strip()\n        # Fall back to manual search (handles cases with additional text)\n        for i in range(len(text) - 1, -1, -1):\n            if text[i] in \".!?\":\n                return text[: i + 1].strip()\n        return text.strip()\n    \n    def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = \"\") -> str:\n        \"\"\"Generate a response from the LLM using chat-style prompt formatting.\n        \n        Args:\n            system_prompt (str): The system prompt/instructions\n            user_message (str): The user's input message\n            conversation_history (str, optional): Any prior conversation context. Defaults to \"\".\n            \n        Returns:\n            str: The generated response\n        \"\"\"\n        # Format prompt following chat template structure\n        prompt = f\"\"\"<|start_header_id|>system<|end_header_id|>\\n{system_prompt}<|eot_id|>\n        {conversation_history}\n        <|start_header_id|>user<|end_header_id|>\\n{user_message}<|eot_id|>\n        <|start_header_id|>assistant<|end_header_id|>\\n\"\"\"\n        \n        # Define sampling parameters (equivalent to the previous implementation)\n        sampling_params = SamplingParams(\n            temperature=1.0,\n            top_p=0.95,\n            max_tokens=100,\n            repetition_penalty=1.2,\n            top_k=200,\n            stop=[\"</s>\", \"<|endoftext|>\", \"<<USR>>\", \"<</USR>>\", \"<</SYS>>\", \n                  \"<</USER>>\", \"<</ASSISTANT>>\", \"<|end_header_id|>\", \"<<ASSISTANT>>\", \n                  \"<|eot_id|>\", \"<|im_end|>\", \"user:\", \"User:\", \"user :\", \"User :\"]\n        )\n        \n        # Generate response using VLLM\n        outputs = self.llm.generate(prompt, sampling_params)\n        \n        # Extract and return the generated text\n        if outputs and len(outputs) > 0:\n            text = outputs[0].outputs[0].text\n            return self.trim_to_last_sentence(text)\n        return \"\"\n    \n    def tokenize(self, text: str) -> List[int]:\n        \"\"\"Tokenize text using VLLM's tokenizer.\n        \n        Args:\n            text (str): Text to tokenize\n            \n        Returns:\n            List[int]: List of token IDs\n        \"\"\"\n        # VLLM doesn't expose tokenizer directly in the same way\n        # We can access the tokenizer through the LLM instance\n        tokenizer = self.llm.get_tokenizer()\n        return tokenizer.encode(text)\n    \n    def get_token_count(self, text: str) -> int:\n        \"\"\"Return token count of the input text.\n        \n        Args:\n            text (str): Text to count tokens for\n            \n        Returns:\n            int: Number of tokens\n        \"\"\"\n        return len(self.tokenize(text))\n    \n    def batch_generate(self, prompts: List[Dict[str, str]], \n                       max_tokens: int = 512, \n                       temperature: float = 0.7) -> List[str]:\n        \"\"\"Generate responses for multiple prompts in a batch.        \n        Args:\n            prompts (List[Dict[str, str]]): List of prompt dictionaries, each with \n                                           'system', 'user' and optional 'history' keys\n            max_tokens (int, optional): Maximum tokens to generate per response\n            temperature (float, optional): Temperature for sampling\n            \n        Returns:\n            List[str]: Generated responses\n        \"\"\"\n        formatted_prompts = []\n        \n        # Format each prompt according to the chat template\n        for p in prompts:\n            system = p.get(\"system\", \"\")\n            user = p.get(\"user\", \"\")\n            history = p.get(\"history\", \"\")\n            \n            formatted_prompt = f\"\"\"<|start_header_id|>system<|end_header_id|>\\n{system}<|eot_id|>\n            {history}\n            <|start_header_id|>user<|end_header_id|>\\n{user}<|eot_id|>\n            <|start_header_id|>assistant<|end_header_id|>\\n\"\"\"\n            \n            formatted_prompts.append(formatted_prompt)\n        \n        # Set up batch sampling parameters\n        sampling_params = SamplingParams(\n            temperature=temperature,\n            top_p=0.95,\n            max_tokens=max_tokens,\n            repetition_penalty=1.2,\n            top_k=400,\n            stop=[\"</s>\", \"<|endoftext|>\", \"<<USR>>\", \"<</USR>>\", \"<</SYS>>\", \n                  \"<</USER>>\", \"<</ASSISTANT>>\", \"<|end_header_id|>\", \"<<ASSISTANT>>\", \n                  \"<|eot_id|>\", \"<|im_end|>\", \"user:\", \"User:\", \"user :\", \"User :\"]\n        )\n        \n        # Generate responses for all prompts in a batch\n        outputs = self.llm.generate(formatted_prompts, sampling_params)\n        \n        # Extract and return the generated texts\n        results = []\n        for output in outputs:\n            if output.outputs:\n                results.append(output.outputs[0].text.strip())\n            else:\n                results.append(\"\")\n                \n        return results"
  },
  {
    "path": "loadandmergecheckpoint.py",
    "content": "import os\nimport re\nimport torch\nfrom models import Model\nfrom safetensors.torch import save_file, load_file \n\nfrom lora import (\n    remove_lora_modules,\n    merge_lora_weights,\n    strip_bias_keys,\n    DEVICE,\n    OUTPUT_DIR,\n    replace_linear_with_lora,\n)\nMODEL_NAME = \"sesame/csm-1b\"\nR=32\nAPLHA=32\n\ndef find_latest_checkpoint(dir_path):\n    checkpoints = [\n        (int(re.search(r\"checkpoint-epoch-(\\d+)\", d).group(1)), os.path.join(dir_path, d))\n        for d in os.listdir(dir_path)\n        if os.path.isdir(os.path.join(dir_path, d)) and \"checkpoint-epoch\" in d\n    ]\n    if not checkpoints:\n        raise FileNotFoundError(\"No checkpoints found.\")\n    latest_epoch, latest_path = max(checkpoints, key=lambda x: x[0])\n    print(f\"Latest checkpoint: epoch {latest_epoch} -> {latest_path}\")\n    return latest_path\n\ndef load_checkpoint_and_merge():\n    print(\"Loading base model...\")\n    model = Model.from_pretrained(MODEL_NAME).to(DEVICE)\n\n    print(\"Applying LoRA structure to the model...\")\n    target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj']\n\n    model = replace_linear_with_lora(model, r=R, alpha=APLHA, dropout=0.0, target_linear_names = target_layers)\n    checkpoint_path = find_latest_checkpoint(OUTPUT_DIR)\n    \n    print(f\"Loading state dictionary from safetensors file...\")\n    state_dict = load_file(os.path.join(checkpoint_path, \"model.safetensors\"), device=DEVICE)\n\n    print(\"Loading weights into the model...\")\n    model.load_state_dict(state_dict, strict=False)\n\n    print(\"Merging LoRA weights into base model...\")\n    merge_lora_weights(model)\n\n    print(\"Replacing LoRALinear modules with standard nn.Linear...\")\n    model = remove_lora_modules(model)\n\n    print(\"Stripping bias keys for final clean model...\")\n    merged_state = strip_bias_keys(model.state_dict())\n\n    final_path = os.path.join(OUTPUT_DIR, \"model.safetensors\")\n    save_file(merged_state, final_path)\n    print(f\"Merged and cleaned model saved to: {final_path}\")\n\nif __name__ == \"__main__\":\n    load_checkpoint_and_merge()\n"
  },
  {
    "path": "lora.py",
    "content": "import json\nimport os\nimport glob\nimport torch\nimport torchaudio\nimport logging\nimport numpy as np\nfrom dataclasses import dataclass\nfrom typing import List, Dict, Optional, Tuple\nfrom torch.utils.data import Dataset, DataLoader\nfrom transformers import AutoTokenizer, get_scheduler\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nimport wandb\nfrom safetensors.torch import save_file\nimport csv\nfrom models import Model\nfrom moshi.models import loaders\nfrom huggingface_hub import hf_hub_download\nfrom tokenizers.processors import TemplateProcessing\nimport matplotlib.pyplot as plt\nimport matplotlib\nmatplotlib.use('Agg') \nimport torch.nn as nn\n\n# Setup logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler(), logging.FileHandler(\"finetune.log\")]\n)\nlogger = logging.getLogger(__name__)\n\nAUDIO_DIR = \"audio_data\"\nOUTPUT_DIR = \"finetuned_model\"\nNUM_EPOCHS = 5\nBATCH_SIZE = 1\nGRADIENT_ACCUMULATION_STEPS = 8\nLEARNING_RATE = 1e-6\nMAX_GRAD_NORM = 0.1\nNUM_CYCLES = 1.0\nUSE_WANDB = False\nSEED = 42\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nMIXED_PRECISION = True\nWARMUP_STEPS = 50\nSPEAKER_ID = 0\nMODEL_NAME = \"sesame/csm-1b\"\nTRANSCRIPTION_MODEL = \"openai/whisper-large-v3-turbo\"\nMAX_AUDIO_FILES = 0\nR=32\nAPLHA=32\n\nclass TrainingVisualizer:\n    def __init__(self, output_dir):\n        self.output_dir = output_dir\n        self.epochs = []\n        self.losses = []\n        self.val_losses = []  # Added validation losses\n        self.learning_rates = []\n        self.steps = []\n        \n        self.fig, self.axes = plt.subplots(3, 1, figsize=(10, 15))\n        self.fig.suptitle('CSM Finetuning Progress', fontsize=16)\n        \n        # Setup training loss plot\n        self.axes[0].set_title('Training Loss')\n        self.axes[0].set_xlabel('Epoch')\n        self.axes[0].set_ylabel('Loss')\n        self.axes[0].grid(True, linestyle='--', alpha=0.7)\n        \n        # Setup validation loss plot\n        self.axes[1].set_title('Training vs Validation Loss')\n        self.axes[1].set_xlabel('Epoch')\n        self.axes[1].set_ylabel('Loss')\n        self.axes[1].grid(True, linestyle='--', alpha=0.7)\n        \n        # Setup learning rate plot\n        self.axes[2].set_title('Learning Rate')\n        self.axes[2].set_xlabel('Epoch')\n        self.axes[2].set_ylabel('Learning Rate')\n        self.axes[2].grid(True, linestyle='--', alpha=0.7)\n    \n    def update(self, epoch, step, loss, lr, val_loss=None):\n        \"\"\"Update the metrics and redraw the plot\"\"\"\n        self.epochs.append(epoch)\n        self.steps.append(step)\n        self.losses.append(loss)\n        self.learning_rates.append(lr)\n        \n        # Add validation loss if provided, otherwise use None\n        if val_loss is not None:\n            self.val_losses.append(val_loss)\n        elif len(self.val_losses) > 0:\n            # If we have validation losses but none provided this time, use the last one\n            self.val_losses.append(self.val_losses[-1])\n        else:\n            # If we've never had validation losses, use None\n            self.val_losses.append(None)\n        \n        # Update training loss plot\n        self.axes[0].clear()\n        self.axes[0].plot(self.epochs, self.losses, 'b-')\n        self.axes[0].set_title('Training Loss')\n        self.axes[0].set_xlabel('Epoch')\n        self.axes[0].set_ylabel('Loss')\n        self.axes[0].grid(True, linestyle='--', alpha=0.7)\n        \n        # Update validation loss plot\n        self.axes[1].clear()\n        self.axes[1].plot(self.epochs, self.losses, 'b-', label='Training')\n        \n        # If we have validation losses, plot them\n        if any(x is not None for x in self.val_losses):\n            # Filter out None values\n            val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]\n            val_loss_values = [v for v in self.val_losses if v is not None]\n            if val_epochs:\n                self.axes[1].plot(val_epochs, val_loss_values, 'r-', label='Validation')\n                self.axes[1].legend()\n        \n        self.axes[1].set_title('Training vs Validation Loss')\n        self.axes[1].set_xlabel('Epoch')\n        self.axes[1].set_ylabel('Loss')\n        self.axes[1].grid(True, linestyle='--', alpha=0.7)\n        \n        # Update learning rate plot\n        self.axes[2].clear()\n        self.axes[2].plot(self.epochs, self.learning_rates, 'g-')\n        self.axes[2].set_title('Learning Rate')\n        self.axes[2].set_xlabel('Epoch')\n        self.axes[2].set_ylabel('Learning Rate')\n        self.axes[2].grid(True, linestyle='--', alpha=0.7)\n        \n        # Calculate convergence metrics\n        min_loss = min(self.losses)\n        min_loss_epoch = self.epochs[self.losses.index(min_loss)]\n        \n        # Check for potential convergence stall\n        recent_window = 10  # Look at last 10 steps\n        if len(self.losses) > recent_window:\n            recent_losses = self.losses[-recent_window:]\n            loss_std = np.std(recent_losses)\n            loss_change = (recent_losses[0] - recent_losses[-1]) / recent_losses[0] if recent_losses[0] != 0 else 0\n            \n            convergence_status = \"\"\n            if loss_std < 0.001 and loss_change < 0.01:\n                convergence_status = \"STALLED: Loss not improving significantly\"\n            elif min_loss == self.losses[-1]:\n                convergence_status = \"IMPROVING: New best loss!\"\n            elif self.losses[-1] < self.losses[-2]:\n                convergence_status = \"IMPROVING: Loss decreasing\"\n            else:\n                convergence_status = \"FLUCTUATING: Loss increased\"\n            \n            # Add convergence status to title\n            self.fig.suptitle(f'CSM Finetuning Progress - {convergence_status}\\n' + \n                            f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\\n' + \n                            f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)\n        else:\n            self.fig.suptitle(f'CSM Finetuning Progress\\n' + \n                            f'Epoch: {epoch:.2f}, Loss: {loss:.4f}, LR: {lr:.8f}\\n' + \n                            f'Best: {min_loss:.4f} at epoch {min_loss_epoch:.2f}', fontsize=12)\n        \n        plt.tight_layout(rect=[0, 0.03, 1, 0.92])  # Adjust for the larger title\n        \n        # Save the figure\n        plot_path = os.path.join(self.output_dir, 'training_progress.png')\n        self.fig.savefig(plot_path)\n        \n    def finalize(self):\n        \"\"\"Create a final, more detailed visualization when training completes\"\"\"\n        # Create a new figure for the final plot\n        final_fig = plt.figure(figsize=(12, 16))\n        gs = plt.GridSpec(4, 2, figure=final_fig)\n        \n        # Plot 1: Loss vs Steps\n        ax1 = final_fig.add_subplot(gs[0, :])\n        ax1.plot(self.steps, self.losses, 'b-', linewidth=2)\n        ax1.set_title('Training Loss vs Steps', fontsize=14)\n        ax1.set_xlabel('Steps')\n        ax1.set_ylabel('Loss')\n        ax1.grid(True, linestyle='--', alpha=0.7)\n        \n        # Plot 2: Loss vs Epochs\n        ax2 = final_fig.add_subplot(gs[1, 0])\n        ax2.plot(self.epochs, self.losses, 'r-', linewidth=2)\n        ax2.set_title('Training Loss vs Epochs', fontsize=14)\n        ax2.set_xlabel('Epochs')\n        ax2.set_ylabel('Loss')\n        ax2.grid(True, linestyle='--', alpha=0.7)\n        \n        # Plot 3: Learning Rate vs Steps\n        ax3 = final_fig.add_subplot(gs[1, 1])\n        ax3.plot(self.steps, self.learning_rates, 'g-', linewidth=2)\n        ax3.set_title('Learning Rate Schedule', fontsize=14)\n        ax3.set_xlabel('Steps')\n        ax3.set_ylabel('Learning Rate')\n        ax3.grid(True, linestyle='--', alpha=0.7)\n        \n        # Plot 4: Training vs Validation Loss\n        ax4 = final_fig.add_subplot(gs[2, :])\n        ax4.plot(self.epochs, self.losses, 'b-', linewidth=2, label='Training')\n        \n        if any(x is not None for x in self.val_losses):\n            # Filter out None values\n            val_epochs = [e for e, v in zip(self.epochs, self.val_losses) if v is not None]\n            val_loss_values = [v for v in self.val_losses if v is not None]\n            if val_epochs:\n                ax4.plot(val_epochs, val_loss_values, 'r-', linewidth=2, label='Validation')\n                ax4.legend()\n                \n        ax4.set_title('Training vs Validation Loss', fontsize=14)\n        ax4.set_xlabel('Epochs')\n        ax4.set_ylabel('Loss')\n        ax4.grid(True, linestyle='--', alpha=0.7)\n        \n        # Plot 5: Combined plot with two y-axes\n        ax5 = final_fig.add_subplot(gs[3, :])\n        color1, color2 = 'blue', 'green'\n        \n        # Plot loss on left axis\n        line1 = ax5.plot(self.epochs, self.losses, color=color1, linewidth=2.5, label='Loss')\n        ax5.set_xlabel('Epochs')\n        ax5.set_ylabel('Loss', color=color1)\n        ax5.tick_params(axis='y', labelcolor=color1)\n        \n        # Plot learning rate on right axis\n        ax6 = ax5.twinx()\n        line2 = ax6.plot(self.epochs, self.learning_rates, color=color2, linewidth=2.5, label='Learning Rate')\n        ax6.set_ylabel('Learning Rate', color=color2)\n        ax6.tick_params(axis='y', labelcolor=color2)\n        \n        # Combine legends\n        lines = line1 + line2\n        labels = [l.get_label() for l in lines]\n        ax5.legend(lines, labels, loc='upper right')\n        ax5.set_title('Loss and Learning Rate vs Epochs', fontsize=14)\n        ax5.grid(True, linestyle='--', alpha=0.7)\n        \n        # Add training summary\n        if self.epochs:\n            epoch_count = max(self.epochs)\n            step_count = max(self.steps)\n            min_loss = min(self.losses)\n            min_loss_epoch = self.epochs[self.losses.index(min_loss)]\n            min_loss_step = self.steps[self.losses.index(min_loss)]\n            \n            # Calculate convergence indicators\n            recent_epochs = min(10, len(self.losses))\n            recent_losses = self.losses[-recent_epochs:]\n            loss_change_pct = ((recent_losses[0] - recent_losses[-1]) / recent_losses[0]) * 100 if recent_losses[0] != 0 else 0\n            \n            summary_text = (\n                f\"Training Summary\\n\"\n                f\"Total Epochs: {epoch_count:.2f}\\n\"\n                f\"Total Steps: {step_count}\\n\"\n                f\"Min Loss: {min_loss:.6f} (Epoch {min_loss_epoch:.2f}, Step {min_loss_step})\\n\"\n                f\"Recent {recent_epochs} epochs loss change: {loss_change_pct:.2f}%\\n\"\n            )\n            \n            if len(self.losses) > 20:\n                # Add convergence assessment\n                last_20_losses = self.losses[-20:]\n                std_last_20 = np.std(last_20_losses)\n                converged = std_last_20 < 0.01 and loss_change_pct < 1.0\n                \n                summary_text += f\"Convergence status: {'CONVERGED' if converged else 'NOT CONVERGED'}\\n\"\n                if converged:\n                    summary_text += f\"Loss stabilized with std dev {std_last_20:.6f}\"\n                else:\n                    summary_text += f\"Loss still changing significantly (std dev: {std_last_20:.6f})\"\n            \n            plt.figtext(0.02, 0.02, summary_text, fontsize=10, \n                        bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))\n        \n        plt.tight_layout(rect=[0, 0.05, 1, 0.97])\n        final_fig.suptitle('CSM Model Finetuning Metrics', fontsize=16, fontweight='bold')\n        plt.subplots_adjust(top=0.93)\n        \n        # Save the final detailed plot\n        final_plot_path = os.path.join(self.output_dir, 'training_metrics_final.png')\n        final_fig.savefig(final_plot_path, dpi=300, bbox_inches='tight')\n        plt.close(final_fig)\n        plt.close(self.fig)\n        \n        logger.info(f\"Final training visualization saved to {final_plot_path}\")\n        \n        return final_plot_path\n\nclass LoRALinear(nn.Module):\n    def __init__(self, in_features, out_features, r=32, alpha=64, dropout=0.0, bias=True):\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.r = r\n        self.alpha = alpha\n        self.scaling = alpha / r\n        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n\n        # The base linear (frozen).\n        self.weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False)\n        nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))\n        \n        self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=bias)\n        \n        # LoRA trainable matrices\n        self.lora_A = nn.Parameter(torch.zeros(r, in_features))\n        self.lora_B = nn.Parameter(torch.zeros(out_features, r))\n        \n        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))\n        nn.init.zeros_(self.lora_B)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # normal forward with frozen weight\n        result = F.linear(x, self.weight, self.bias)\n\n        # LoRA forward with trainable A and B\n        lora_out = F.linear(self.dropout(x), self.lora_A)  # [*, r]\n        lora_out = F.linear(lora_out, self.lora_B)         # [*, out_features]\n        return result + self.scaling * lora_out\n\ndef replace_linear_with_lora(model: nn.Module, r=R, alpha=APLHA, dropout=0.0, target_linear_names: List[str] = None):\n    \"\"\"\n    Replaces specified nn.Linear layers with LoRALinear layers within a model, ensuring device consistency.\n    \"\"\"\n    if target_linear_names is None:\n        logger.warning(\"No target layer names specified for LoRA replacement. No layers will be replaced.\")\n        return model\n\n    for name, module in list(model.named_modules()):\n        if isinstance(module, nn.Linear) and any(target_name in name for target_name in target_linear_names):\n            parent_name, child_name = name.rsplit('.', 1)\n            \n            parent_module = model\n            for part in parent_name.split('.'):\n                parent_module = getattr(parent_module, part)\n\n            original_device = module.weight.device\n            original_dtype = module.weight.dtype\n\n            # Create the new LoRA layer\n            lora_linear = LoRALinear(\n                in_features=module.in_features,\n                out_features=module.out_features,\n                r=r,\n                alpha=alpha,\n                dropout=dropout,\n                bias=(module.bias is not None)\n            )\n            # Copy the original weights and bias\n            with torch.no_grad():\n                lora_linear.weight.copy_(module.weight.data)\n                if module.bias is not None:\n                    lora_linear.bias.copy_(module.bias.data)\n            \n            lora_linear.to(device=original_device, dtype=original_dtype)\n            \n            setattr(parent_module, child_name, lora_linear)\n            logger.info(f\"Replaced layer: {name} with LoRALinear on device {original_device}\")\n            \n    return model\n\ndef load_llama3_tokenizer():\n    tokenizer_name = \"unsloth/Llama-3.2-1B\"\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n    bos = tokenizer.bos_token\n    eos = tokenizer.eos_token\n    tokenizer._tokenizer.post_processor = TemplateProcessing(\n        single=f\"{bos}:0 $A:0 {eos}:0\",\n        pair=f\"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1\",\n        special_tokens=[(bos, tokenizer.bos_token_id), (eos, tokenizer.eos_token_id)],\n    )\n    return tokenizer\n\n@dataclass\nclass AudioTextPair:\n    audio_path: str\n    text: str\n    speaker_id: int\n    processed_audio: Optional[torch.Tensor] = None\n    \n    def load_audio(self, sample_rate=24000) -> torch.Tensor:\n        if self.processed_audio is not None:\n            return self.processed_audio\n\n        waveform, sr = torchaudio.load(self.audio_path)\n        if waveform.shape[0] > 1:\n            waveform = torch.mean(waveform, dim=0, keepdim=True)\n        if sr != sample_rate:\n            resampler = torchaudio.transforms.Resample(sr, sample_rate)\n            waveform = resampler(waveform)\n        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)\n\n        self.processed_audio = waveform.squeeze(0)\n        return self.processed_audio\n\nclass CSMDataset(Dataset):\n    def __init__(self, data_items, text_tokenizer, audio_tokenizer, device):\n        self.data_items = data_items\n        self.text_tokenizer = text_tokenizer\n        self.audio_tokenizer = audio_tokenizer\n        self.device = device\n        self.sample_rate = audio_tokenizer.sample_rate\n        \n    def __len__(self):\n        return len(self.data_items)\n        \n    def tokenize_text_segment(self, text: str, speaker: int):\n        text_tokens = self.text_tokenizer.encode(f\"[{speaker}]{text}\")\n        text_frame = torch.zeros(len(text_tokens), 33).long()\n        text_frame_mask = torch.zeros(len(text_tokens), 33).bool()\n        text_frame[:, -1] = torch.tensor(text_tokens)\n        text_frame_mask[:, -1] = True\n        return text_frame, text_frame_mask\n\n    def tokenize_audio(self, audio: torch.Tensor):\n        assert audio.ndim == 1, \"Audio must be single channel\"\n        audio_device = next(self.audio_tokenizer.parameters()).device\n        audio = audio.to(audio_device)\n        \n        try:\n            audio_tokens = self.audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]\n            eos_frame = torch.zeros(audio_tokens.size(0), 1, device=audio_device)\n            audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)\n\n            audio_frame = torch.zeros(audio_tokens.size(1), 33, device=audio_device).long()\n            audio_frame_mask = torch.zeros(audio_tokens.size(1), 33, device=audio_device).bool()\n            audio_frame[:, :-1] = audio_tokens.transpose(0, 1)\n            audio_frame_mask[:, :-1] = True\n        except RuntimeError as e:\n            logger.warning(f\"Error encoding audio: {e}, using empty frames\")\n            audio_frame = torch.zeros(1, 33, device=audio_device).long()\n            audio_frame_mask = torch.zeros(1, 33, device=audio_device).bool()\n\n        return audio_frame, audio_frame_mask\n    \n    def __getitem__(self, idx: int):\n        item = self.data_items[idx]\n        audio = item.load_audio(self.sample_rate)\n        \n        text_tokens, text_masks = self.tokenize_text_segment(item.text, item.speaker_id)\n        audio_tokens, audio_masks = self.tokenize_audio(audio)\n        \n        device = audio_tokens.device\n        text_tokens = text_tokens.to(device)\n        text_masks = text_masks.to(device)\n        \n        input_tokens = text_tokens\n        input_masks = text_masks\n        \n        target_tokens = torch.cat([text_tokens, audio_tokens], dim=0)\n        target_masks = torch.cat([text_masks, audio_masks], dim=0)\n        \n        if device != self.device:\n            input_tokens = input_tokens.to(self.device)\n            input_masks = input_masks.to(self.device)\n            target_tokens = target_tokens.to(self.device)\n            target_masks = target_masks.to(self.device)\n        \n        return {\n            \"input_tokens\": input_tokens,\n            \"input_masks\": input_masks,\n            \"target_tokens\": target_tokens,\n            \"target_masks\": target_masks,\n        }\n\ndef collate_fn(batch):\n    max_seq_len = 1024\n    device = batch[0][\"input_tokens\"].device\n    \n    max_input_len = min(max(item[\"input_tokens\"].size(0) for item in batch), max_seq_len)\n    max_target_len = min(max(item[\"target_tokens\"].size(0) for item in batch), max_seq_len)\n\n    batch_input_tokens = []\n    batch_input_masks = []\n    batch_target_tokens = []\n    batch_target_masks = []\n    \n    for item in batch:\n        input_tokens = item[\"input_tokens\"][:max_input_len]\n        input_masks = item[\"input_masks\"][:max_input_len]\n        target_tokens = item[\"target_tokens\"][:max_target_len]\n        target_masks = item[\"target_masks\"][:max_target_len]\n        \n        input_tokens = F.pad(input_tokens, (0,0,0, max_input_len - input_tokens.size(0)), \"constant\", 0)\n        input_masks = F.pad(input_masks, (0,0,0, max_input_len - input_masks.size(0)), \"constant\", False)\n        \n        target_tokens = F.pad(target_tokens, (0,0,0, max_target_len - target_tokens.size(0)), \"constant\", 0)\n        target_masks = F.pad(target_masks, (0,0,0, max_target_len - target_masks.size(0)), \"constant\", False)\n        \n        batch_input_tokens.append(input_tokens)\n        batch_input_masks.append(input_masks)\n        batch_target_tokens.append(target_tokens)\n        batch_target_masks.append(target_masks)\n    \n    return {\n        \"input_tokens\": torch.stack(batch_input_tokens),\n        \"input_masks\": torch.stack(batch_input_masks),\n        \"target_tokens\": torch.stack(batch_target_tokens),\n        \"target_masks\": torch.stack(batch_target_masks),\n        \"positions\": torch.arange(0, max_target_len).unsqueeze(0).repeat(len(batch), 1).to(device)\n    }\n\ndef transcribe_audio_files():\n    from transformers import pipeline\n    \n    # Cache file path\n    cache_file = os.path.join(AUDIO_DIR, \"transcription_cache.json\")\n    \n    # Load existing cache\n    cache = {}\n    if os.path.exists(cache_file):\n        try:\n            with open(cache_file, 'r', encoding='utf-8') as f:\n                cache = json.load(f)\n            logger.info(f\"Loaded transcription cache with {len(cache)} entries\")\n        except Exception as e:\n            logger.warning(f\"Could not load cache file: {e}\")\n            cache = {}\n    \n    logger.info(f\"Transcribing audio files in: {AUDIO_DIR}\")\n    transcriber = pipeline(\"automatic-speech-recognition\", model=TRANSCRIPTION_MODEL)\n    audio_text_pairs = []\n    \n    audio_files = glob.glob(os.path.join(AUDIO_DIR, \"*.wav\")) \\\n        + glob.glob(os.path.join(AUDIO_DIR, \"*.mp3\")) \\\n        + glob.glob(os.path.join(AUDIO_DIR, \"*.flac\"))\n    \n    if MAX_AUDIO_FILES > 0 and len(audio_files) > MAX_AUDIO_FILES:\n        logger.info(f\"Found {len(audio_files)} files, limiting to {MAX_AUDIO_FILES}\")\n        audio_files = audio_files[:MAX_AUDIO_FILES]\n    \n    cache_hits = 0\n    cache_misses = 0\n    \n    for audio_file in tqdm(audio_files, desc=\"Processing audio files\"):\n        try:\n            # Create cache key using file path and modification time\n            file_stat = os.stat(audio_file)\n            cache_key = f\"{audio_file}_{file_stat.st_mtime}_{file_stat.st_size}\"\n            \n            # Check if transcription exists in cache\n            if cache_key in cache:\n                transcription = cache[cache_key]\n                cache_hits += 1\n                logger.debug(f\"Cache hit: {os.path.basename(audio_file)}\")\n            else:\n                # Transcribe the file\n                result = transcriber(audio_file, return_timestamps=True, chunk_length_s=30,\n                                   stride_length_s=[6, 0], batch_size=32,\n                                   generate_kwargs={\"language\": \"<|en|>\", \"task\": \"transcribe\"})\n                transcription = result[\"text\"].strip()\n                \n                # Save to cache\n                cache[cache_key] = transcription\n                cache_misses += 1\n                logger.info(f\"Transcribed: {os.path.basename(audio_file)} -> {transcription}\")\n            \n            audio_text_pairs.append(\n                AudioTextPair(audio_path=audio_file, text=transcription, speaker_id=0)\n            )\n            \n        except Exception as e:\n            logger.error(f\"Error processing {audio_file}: {e}\")\n    \n    # Save updated cache\n    try:\n        with open(cache_file, 'w', encoding='utf-8') as f:\n            json.dump(cache, f, ensure_ascii=False, indent=2)\n        logger.info(f\"Saved transcription cache with {len(cache)} entries\")\n    except Exception as e:\n        logger.error(f\"Could not save cache file: {e}\")\n    \n    logger.info(f\"Processed {len(audio_text_pairs)} audio files (Cache hits: {cache_hits}, Cache misses: {cache_misses})\")\n    return audio_text_pairs\n\ndef prepare_csm_model_for_training():\n    logger.info(f\"Loading CSM model: {MODEL_NAME}\")\n    model = Model.from_pretrained(MODEL_NAME).to(DEVICE)\n\n    text_tokenizer = load_llama3_tokenizer()\n    mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)\n    mimi = loaders.get_mimi(mimi_weight, device=DEVICE)\n    mimi.set_num_codebooks(32)\n    audio_tokenizer = mimi\n    try:\n\n        codebook_0_centroids = mimi.quantizer.rvq_first.layers[0].codebook.weight.data\n        \n        num_codebook_0_tokens, embedding_dim = codebook_0_centroids.shape\n        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)\n        model.codebook_embedding.weight.data.copy_(codebook_0_centroids)\n        logger.info(f\"Successfully initialized codebook_embedding with shape: {codebook_0_centroids.shape}\")\n\n    except AttributeError:\n        num_codebook_0_tokens, embedding_dim = 1024, 1024\n        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)\n        nn.init.xavier_uniform_(model.codebook_embedding.weight)\n        \n    except Exception as e:\n        num_codebook_0_tokens, embedding_dim = 1024, 1024\n        model.codebook_embedding = nn.Embedding(num_codebook_0_tokens, embedding_dim).to(DEVICE)\n        nn.init.xavier_uniform_(model.codebook_embedding.weight)\n\n    # Some fallback logic for config\n    if not hasattr(model.config, 'get'):\n        def get_method(self, key, default=None):\n            if hasattr(self, key):\n                return getattr(self, key)\n            return default\n        model.config.__class__.get = get_method\n    if not hasattr(model.config, 'tie_word_embeddings'):\n        model.config.tie_word_embeddings = False\n    target_layers = [\n        \"q_proj\", \n        \"k_proj\", \n        \"v_proj\", \n        \"output_proj\",\n        \"w1\",           \n        \"w2\",           \n        \"w3\"            \n    ]\n    logger.info(\"Applying LoRA to model...\")\n    model = replace_linear_with_lora(\n        model,\n        r=R,\n        alpha=APLHA,\n        dropout=0.01,\n        target_linear_names=target_layers\n    )\n    model.cuda()\n    \n\n    # First, freeze all parameters of the base model\n    for param in model.parameters():\n        param.requires_grad = False\n\n    # Then, unfreeze only the newly added LoRA parameters.\n    # It is also common practice to train the bias parameters.\n    for name, param in model.named_parameters():\n        if \"lora_A\" in name or \"lora_B\" in name or \"bias\" in name:\n            param.requires_grad = True\n\n    return model, text_tokenizer, audio_tokenizer\n\ndef setup_model_caches(model, batch_size):\n    try:\n        with torch.no_grad():\n            model.reset_caches()\n            model.backbone.reset_caches()\n            model.decoder.reset_caches()\n    except Exception as e:\n        logger.debug(f\"No caches to reset or error: {e}\")\n    return True\n\nclass BridgingModule(nn.Module):\n    \"\"\"For a 2048->1024 bridging if needed.\"\"\"\n    def __init__(self, in_dim=2048, out_dim=1024):\n        super().__init__()\n        self.bridge = nn.Linear(in_dim, out_dim, bias=False)\n        nn.init.xavier_uniform_(self.bridge.weight)\n    def forward(self, x):\n        return self.bridge(x)\n\ndef compute_loss_for_codebooks_single_pass(\n    backbone_out,  # [b, seq_len, 2048]\n    decoder_out,   # [b, seq_len, 1024]\n    model, \n    target_tokens, # [b, seq_len, codebooks]\n    target_masks,  # [b, seq_len, codebooks bool]\n    device\n):\n    bsz, seq_len = target_tokens.size()[:2]\n    num_codebooks = model.config.audio_num_codebooks\n\n    c0_logits = model.codebook0_head(backbone_out)\n    audio_positions = target_masks[..., :-1].any(dim=-1)  # [b, seq_len] for audio\n\n    total_loss = torch.tensor(0.0, device=device)\n    count = 0\n\n    # codebook0\n    for b in range(bsz):\n        for s in range(seq_len):\n            if audio_positions[b, s]:\n                token_logits = c0_logits[b, s]\n                target_token = target_tokens[b, s, 0]\n                if target_token > 0:\n                    ce = F.cross_entropy(token_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')\n                    total_loss += ce\n                    count += 1\n\n    # codebooks [1..N-1] from decoder_out\n    for i in range(1, num_codebooks):\n        weight_i = model.audio_head[i-1]\n        flat_dec = decoder_out.view(bsz * seq_len, -1)\n        token_logits_all = flat_dec.mm(weight_i)\n        \n        for b in range(bsz):\n            for s in range(seq_len):\n                if audio_positions[b, s]:\n                    target_token = target_tokens[b, s, i]\n                    if target_token > 0:\n                        row_idx = b*seq_len + s\n                        row_logits = token_logits_all[row_idx]\n                        ce = F.cross_entropy(row_logits.unsqueeze(0), target_token.unsqueeze(0), reduction='sum')\n                        total_loss += ce\n                        count += 1\n\n    if count > 0:\n        total_loss = total_loss / count\n    return total_loss\n\ndef single_pass_forward(model, bridging_module, target_tokens, target_masks, positions):\n    device = next(model.parameters()).device\n    dtype = next(model.parameters()).dtype\n    \n    embed = model._embed_tokens(target_tokens)\n    masked_embed = embed * target_masks.unsqueeze(-1)\n    h = masked_embed.sum(dim=2)\n    \n    backbone_out = model.backbone(h, input_pos=positions, mask=None).to(dtype)\n    bridging_out = bridging_module(backbone_out)\n    \n    codebook0_logits = model.codebook0_head(backbone_out)\n    codebook0_tokens = torch.argmax(codebook0_logits, dim=-1).clamp(0, model.codebook_embedding.num_embeddings - 1)\n    c0_embed = model.codebook_embedding(codebook0_tokens)\n    \n    # Get the last hidden state from bridging module\n    last_h = bridging_out[:, -1, :].unsqueeze(1)\n    \n    # Concatenate the last hidden state with the codebook embeddings\n    decoder_input = torch.cat([last_h, c0_embed], dim=1)\n    \n    # Process decoder inputs in parallel\n    B, S, D = decoder_input.shape  # Batch, Sequence length, Dimension\n    \n    # Reshape to (B*S, D) to process all tokens in parallel\n    decoder_input_flat = decoder_input.view(-1, D).unsqueeze(1)  # [B*S, 1, D]\n    \n    # Run decoder on all inputs in parallel\n    decoder_out_flat = model.decoder(decoder_input_flat).to(dtype)  # [B*S, 1, output_dim]\n    \n    # Reshape back to original batch and sequence dimensions\n    decoder_out = decoder_out_flat.view(B, S, -1)  # [B, S, output_dim]\n    \n    # Remove the first token (corresponding to last_h) as in original code\n    decoder_out = decoder_out[:, 1:, :]  # [B, T, 1024]\n    \n    # Safety check: handle empty sequences\n    if decoder_out.size(1) == 0:\n        return torch.tensor(0.0, requires_grad=True, device=device)\n    \n    loss = compute_loss_for_codebooks_single_pass(\n        backbone_out=backbone_out,\n        decoder_out=decoder_out,\n        model=model,\n        target_tokens=target_tokens[..., 1:],  # Drop codebook 0\n        target_masks=target_masks[..., 1:],\n        device=device\n    )\n    \n    return loss\n\ndef calculate_validation_loss(model, bridging_module, dataset, device, max_samples=50):\n    \"\"\"\n    Calculate validation loss on a subset of the dataset\n    \"\"\"\n    # Create a small validation dataloader with a subset of data\n    val_indices = torch.randperm(len(dataset))[:max_samples].tolist()\n    val_samples = [dataset[i] for i in val_indices]\n    \n    val_loader = DataLoader(\n        val_samples, \n        batch_size=1,\n        shuffle=False,\n        collate_fn=collate_fn,\n        num_workers=0,\n        pin_memory=False\n    )\n    \n    model.eval()\n    bridging_module.eval()\n    \n    total_loss = 0.0\n    num_batches = 0\n    \n    with torch.no_grad():\n        for batch in val_loader:\n            setup_model_caches(model, batch[\"target_tokens\"].size(0))\n            \n            loss = forward_and_loss(model, bridging_module, batch, device)\n            \n            total_loss += loss.item()\n            num_batches += 1\n    \n    model.train()\n    bridging_module.train()\n    \n    # Return average loss\n    return total_loss / num_batches if num_batches > 0 else 0.0\n\ndef strip_bias_keys(state_dict: dict) -> dict:\n    new_sd = {}\n    for k, v in state_dict.items():\n        if k == \"codebook_embedding.weight\":\n            print(f\"Stripping {k} from checkpoint (training-only layer)\")\n            continue\n        if not k.endswith(\".bias\"):\n            new_sd[k] = v\n        else:\n            print(f\"Stripping {k} from checkpoint\")\n    return new_sd\n\ndef remove_lora_modules(module: nn.Module) -> nn.Module:\n    for name, child in list(module.named_children()):\n        new_child = remove_lora_modules(child)\n        setattr(module, name, new_child)\n\n    if isinstance(module, LoRALinear):\n        out_features, in_features = module.out_features, module.in_features\n\n        # Determine if we actually need a bias\n        has_bias = (module.bias is not None)\n        new_linear = nn.Linear(\n            in_features=in_features,\n            out_features=out_features,\n            bias=has_bias\n        )\n\n        # Copy over the merged weight\n        new_linear.weight.data.copy_(module.weight.data)\n\n        # If we had a bias in LoRALinear, copy it too\n        if has_bias:\n            new_linear.bias.data.copy_(module.bias.data)\n\n        return new_linear\n\n    return module\n\ndef merge_lora_layer(lora_module: LoRALinear):\n    \"\"\"\n    Merge the LoRA params (lora_A, lora_B) into the base weight in-place.\n    This transforms the LoRALinear into a standard Linear equivalent.\n    \"\"\"\n    # W = W + (alpha/r) * (lora_B @ lora_A)\n    merged_delta = lora_module.scaling * (lora_module.lora_B @ lora_module.lora_A)\n    lora_module.weight.data += merged_delta\n\n    # Optionally zero out LoRA parameters so they no longer affect anything\n    lora_module.lora_A.data.zero_()\n    lora_module.lora_B.data.zero_()\n\ndef merge_lora_weights(model: nn.Module):\n    for module in model.modules():\n        if isinstance(module, LoRALinear):\n            merge_lora_layer(module)\n    return model\n\ndef finetune(model, dataset):\n    logger.info(\"Starting finetuning process\")\n    csv_file = os.path.join(OUTPUT_DIR, \"training_metrics.csv\")\n    with open(csv_file, \"w\", newline=\"\") as f:\n        writer = csv.writer(f)\n        writer.writerow([\"epoch\", \"step\", \"global_step\", \"loss\", \"learning_rate\", \"val_loss\"])\n    \n    def log_metrics(epoch, step, global_step, loss, learning_rate, val_loss=None):\n        with open(csv_file, \"a\", newline=\"\") as f:\n            writer = csv.writer(f)\n            writer.writerow([epoch, step, global_step, loss, learning_rate, val_loss if val_loss is not None else \"\"])\n        visualizer.update(epoch, global_step, loss, learning_rate, val_loss)\n\n    visualizer = TrainingVisualizer(OUTPUT_DIR)\n    bridging_module = BridgingModule(in_dim=2048, out_dim=1024).to(DEVICE)\n    for param in bridging_module.parameters():\n        param.requires_grad = True\n\n    dataloader = DataLoader(\n        dataset, batch_size=BATCH_SIZE, shuffle=True,\n        collate_fn=collate_fn, num_workers=0, pin_memory=False\n    )\n\n    trainable_params = [p for p in model.parameters() if p.requires_grad] + list(bridging_module.parameters())\n    optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE)\n    \n    num_training_steps = len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS\n    lr_scheduler = get_scheduler(\n        \"cosine\", optimizer=optimizer,\n        num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps\n    )\n    \n    if USE_WANDB:\n        wandb.init(project=\"csm-finetuning\", name=\"csm-lora-finetune-fixed\")\n    \n    scaler = torch.amp.GradScaler() if MIXED_PRECISION else None\n    global_step = 0\n    validation_frequency = max(1, len(dataloader) // (2 * GRADIENT_ACCUMULATION_STEPS))\n    \n    model.train()\n    bridging_module.train()\n    \n    logger.info(\"Calculating initial validation loss...\")\n    initial_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)\n    logger.info(f\"Initial validation loss: {initial_val_loss:.6f}\")\n    \n    current_loss = 0.0\n    current_lr = LEARNING_RATE\n\n    for epoch in range(NUM_EPOCHS):\n        logger.info(f\"Starting epoch {epoch+1}/{NUM_EPOCHS}\")\n        progress_bar = tqdm(total=len(dataloader), desc=f\"Epoch {epoch+1}\")\n        \n        for step, batch in enumerate(dataloader):\n            try:\n                setup_model_caches(model, batch[\"target_tokens\"].size(0))\n                \n                with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16, enabled=MIXED_PRECISION):\n                    loss = forward_and_loss(model, bridging_module, batch, DEVICE)\n                    if GRADIENT_ACCUMULATION_STEPS > 1:\n                        loss = loss / GRADIENT_ACCUMULATION_STEPS\n\n                if torch.isnan(loss) or torch.isinf(loss):\n                    logger.warning(f\"NaN or Inf loss detected at step {step}. Skipping batch.\")\n                    optimizer.zero_grad()\n                    progress_bar.update(1)\n                    continue\n\n                if MIXED_PRECISION:\n                    scaler.scale(loss).backward()\n                else:\n                    loss.backward()\n                \n                if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(dataloader):\n                    if MIXED_PRECISION:\n                        scaler.unscale_(optimizer)\n                    \n                    torch.nn.utils.clip_grad_norm_(trainable_params, MAX_GRAD_NORM)\n                    \n                    if MIXED_PRECISION:\n                        scaler.step(optimizer)\n                        scaler.update()\n                    else:\n                        optimizer.step()\n                    \n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n                    \n                    current_lr = optimizer.param_groups[0][\"lr\"]\n                    current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS if GRADIENT_ACCUMULATION_STEPS > 1 else loss.item()\n                    current_epoch = epoch + (step + 1) / len(dataloader)\n                    \n                    current_val_loss = None\n                    if global_step > 0 and global_step % validation_frequency == 0:\n                        logger.info(f\"Calculating validation loss at global step {global_step}...\")\n                        current_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)\n                        logger.info(f\"Validation loss: {current_val_loss:.6f}\")\n                    \n                    log_metrics(current_epoch, step, global_step, current_loss, current_lr, current_val_loss)\n                    global_step += 1\n                    \n                    if USE_WANDB:\n                        wandb.log({\"loss\": current_loss, \"learning_rate\": current_lr, \"epoch\": current_epoch, \"global_step\": global_step, \"val_loss\": current_val_loss})\n                    \n                    progress_bar.set_postfix({\"loss\": f\"{current_loss:.4f}\", \"lr\": f\"{current_lr:.2e}\"})\n                \n                progress_bar.update(1)\n            \n            except Exception as e:\n                logger.error(f\"Error in batch {step}: {e}\")\n                import traceback\n                logger.error(traceback.format_exc())\n                try:\n                    optimizer.zero_grad()\n                    torch.cuda.empty_cache()\n                except: pass\n                progress_bar.update(1)\n                continue\n        \n        logger.info(f\"Calculating validation loss at end of epoch {epoch+1}...\")\n        epoch_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE)\n        logger.info(f\"Epoch {epoch+1} validation loss: {epoch_val_loss:.6f}\")\n        log_metrics(epoch + 1.0, len(dataloader), global_step, current_loss, current_lr, epoch_val_loss)\n        \n        checkpoint_dir = os.path.join(OUTPUT_DIR, f\"checkpoint-epoch-{epoch+1}\")\n        os.makedirs(checkpoint_dir, exist_ok=True)\n\n        checkpoint_tensors = {\n            **model.state_dict(),\n            **bridging_module.state_dict()\n        }\n        save_file(checkpoint_tensors, os.path.join(checkpoint_dir, \"model.safetensors\"))\n\n        logger.info(f\"Saved checkpoint to {checkpoint_dir}\")\n    \n    final_val_loss = calculate_validation_loss(model, bridging_module, dataset, DEVICE, max_samples=100)\n    logger.info(f\"Final validation loss: {final_val_loss:.6f}\")\n    \n    logger.info(\"Merging LoRA weights into the base model...\")\n    merge_lora_weights(model)\n    model = remove_lora_modules(model)\n    merged_state = strip_bias_keys(model.state_dict())\n\n    final_merged_path = os.path.join(OUTPUT_DIR, \"model.safetensors\")\n    save_file(merged_state, final_merged_path)\n    logger.info(f\"LoRA-merged & replaced model saved to {final_merged_path}\")\n    \n    visualizer.finalize()\n    if USE_WANDB:\n        wandb.finish()\n    \n    return model\n\ndef forward_and_loss(model, bridging_module, batch, device):\n    target_tokens = batch[\"target_tokens\"].to(device)\n    target_masks = batch[\"target_masks\"].to(device)\n    positions = batch[\"positions\"].to(device)\n\n    input_tokens = target_tokens[:, :-1]\n    input_masks = target_masks[:, :-1]\n    input_positions = positions[:, :-1]\n    labels = target_tokens[:, 1:]\n    label_masks = target_masks[:, 1:]\n\n    if input_tokens.size(1) == 0:\n        return torch.tensor(0.0, requires_grad=True, device=device)\n\n    # 1. Embed tokens and apply mask\n    embed = model._embed_tokens(input_tokens)\n    masked_embed = embed * input_masks.unsqueeze(-1)\n    h = masked_embed.sum(dim=2)\n\n    # 2. Pass through the backbone\n    backbone_out = model.backbone(h, input_pos=input_positions, mask=None)\n\n    # 3. Calculate loss for all codebooks\n    loss_fct = nn.CrossEntropyLoss(ignore_index=0)\n    total_loss = 0.0\n    num_codebooks_with_loss = 0\n\n    c0_logits = model.codebook0_head(backbone_out)\n    c0_labels = labels[..., 0]\n    \n    active_mask = label_masks[..., 0].view(-1)\n    if active_mask.sum() > 0:\n        active_logits = c0_logits.view(-1, c0_logits.size(-1))[active_mask]\n        active_labels = c0_labels.view(-1)[active_mask]\n        c0_loss = loss_fct(active_logits, active_labels)\n        total_loss += c0_loss\n        num_codebooks_with_loss += 1\n\n    decoder_states = bridging_module(backbone_out)\n    \n    num_codebooks = model.config.audio_num_codebooks\n    for i in range(1, num_codebooks):\n        if hasattr(model, 'audio_head') and len(model.audio_head) >= i:\n            weight_i = model.audio_head[i-1] \n            \n            logits_i = decoder_states @ weight_i\n\n            labels_i = labels[..., i]\n            active_mask_i = label_masks[..., i].view(-1)\n\n            if active_mask_i.sum() > 0:\n                active_logits_i = logits_i.view(-1, logits_i.size(-1))[active_mask_i]\n                active_labels_i = labels_i.view(-1)[active_mask_i]\n                loss_i = loss_fct(active_logits_i, active_labels_i)\n                total_loss += loss_i\n                num_codebooks_with_loss += 1\n\n    if num_codebooks_with_loss > 0:\n        return total_loss / num_codebooks_with_loss\n    else:\n        return torch.tensor(0.0, requires_grad=True, device=device)\n\ndef main():\n    torch.manual_seed(SEED)\n    np.random.seed(SEED)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(SEED)\n    \n    os.makedirs(OUTPUT_DIR, exist_ok=True)\n    torch.backends.cuda.enable_flash_sdp(True)\n    if DEVICE == \"cuda\":\n        torch.backends.cudnn.benchmark = True\n    \n    model, text_tokenizer, audio_tokenizer = prepare_csm_model_for_training()\n    audio_text_pairs = transcribe_audio_files()\n    if not audio_text_pairs:\n        logger.error(f\"No audio files found or transcribed in {AUDIO_DIR}\")\n        return\n    \n    dataset = CSMDataset(\n        audio_text_pairs,\n        text_tokenizer=text_tokenizer,\n        audio_tokenizer=audio_tokenizer,\n        device=DEVICE\n    )\n    \n    logger.info(f\"Dataset created with {len(dataset)} samples\")\n    \n    try:\n        finetune(model, dataset)\n        logger.info(\"Finetuning completed successfully!\")\n    except Exception as e:\n        logger.error(f\"Error during finetuning: {e}\")\n        import traceback\n        logger.error(traceback.format_exc())\n        \n        try:\n            # If there's an error, at least save a partial state\n            partial_path = os.path.join(OUTPUT_DIR, \"model_partial.safetensors\")\n            torch.save(model.state_dict(), partial_path)\n            logger.info(f\"Saved partial model to {partial_path} despite errors\")\n        except Exception as save_error:\n            logger.error(f\"Could not save partial model: {save_error}\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "main.py",
    "content": "import asyncio\nimport os\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"  \nos.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\" \nos.environ[\"PYTORCH_DISABLE_CUDA_GRAPHS\"] = \"1\"  \nimport platform\nimport sqlite3\nimport time\nimport threading\nimport json\nimport queue\nfrom fastapi.websockets import WebSocketState\nimport torch\nimport torchaudio\nimport sounddevice as sd\nimport numpy as np\nimport whisper\nfrom fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request\nfrom fastapi.staticfiles import StaticFiles\nfrom fastapi.responses import HTMLResponse, JSONResponse\nfrom fastapi.templating import Jinja2Templates\nfrom sqlalchemy import create_engine, Column, Integer, String, Text\nfrom sqlalchemy.ext.declarative import declarative_base\nfrom sqlalchemy.orm import sessionmaker\nfrom typing import Optional\nfrom generator import Segment, load_csm_1b_local\nfrom llm_interface import LLMInterface\nfrom rag_system import RAGSystem \nfrom vad import AudioStreamProcessor\nfrom pydantic import BaseModel\nimport logging\nfrom config import ConfigManager\nfrom transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\nimport re\nspeaking_start_time = 0.0          # set every time the AI begins a new turn\nMIN_BARGE_LATENCY   = 0.9   \nspeaker_counters = {\n    0: 0,  # AI\n    1: 0   # User\n}\ncurrent_generation_id = 1\npending_user_inputs = []\nuser_input_lock = threading.Lock()\naudio_fade_duration = 0.3  # seconds for fade-out\nlast_interrupt_time = 0\ninterrupt_cooldown = 6.0  # seconds between allowed interrupts\naudio_chunk_buffer = []  # Buffer to store the most recent audio chunks for fade-out\n# Setup logging\nlogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')\nlogger = logging.getLogger(__name__)\nmodel_thread = None\nmodel_queue = queue.Queue()\nmodel_result_queue = queue.Queue()\nmodel_thread_running = threading.Event()\nllm_lock = threading.Lock()\naudio_gen_lock = threading.Lock()\n# Database\nBase = declarative_base()\nengine = create_engine(\"sqlite:///companion.db\")\nSessionLocal = sessionmaker(bind=engine)\n\nclass Conversation(Base):\n    __tablename__ = \"conversations\"\n    id = Column(Integer, primary_key=True, index=True)\n    session_id = Column(String, index=True)\n    timestamp = Column(String)\n    user_message = Column(Text)\n    ai_message = Column(Text)\n    audio_path = Column(String)\n\nBase.metadata.create_all(bind=engine)\n\n# Pydantic config schema\nclass CompanionConfig(BaseModel):\n    system_prompt: str\n    reference_audio_path: str\n    reference_text: str\n    reference_audio_path2: Optional[str] = None  # optional field\n    reference_text2: Optional[str] = None  # optional field\n    reference_audio_path3: Optional[str] = None  # optional field\n    reference_text3: Optional[str] = None  # optional field\n    model_path: str\n    llm_path: str\n    max_tokens: int = 8192\n    voice_speaker_id: int = 0\n    vad_enabled: bool = True\n    vad_threshold: float = 0.5\n    embedding_model: str = \"all-MiniLM-L6-v2\"\n\n# Global state\nconversation_history = []\nconfig = None\naudio_queue = queue.Queue()\nis_speaking = False\ninterrupt_flag = threading.Event()\ngenerator = None\nllm = None\nrag = None\nvad_processor = None\nreference_segments = []\nactive_connections = []\nmessage_queue = asyncio.Queue()\n\n# Async event loop\nloop = asyncio.new_event_loop()\nasyncio.set_event_loop(loop)\n\n# FastAPI\napp = FastAPI()\napp.mount(\"/static\", StaticFiles(directory=\"static\"), name=\"static\")\ntemplates = Jinja2Templates(directory=\"templates\")\nconfig_manager = ConfigManager()\nmodel_id = \"openai/whisper-large-v3-turbo\"\n# Whisper\nwhisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(\n    model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True\n)            \nwhisper_model.to(\"cuda\")\nprocessor = AutoProcessor.from_pretrained(model_id)\nwhisper_pipe = pipeline(\n    \"automatic-speech-recognition\",\n    model=whisper_model,\n    tokenizer=processor.tokenizer,\n    feature_extractor=processor.feature_extractor,\n    torch_dtype=torch.float16,\n    device='cuda',\n)\n# Background queue\nasync def process_message_queue():\n    while True:\n        message = await message_queue.get()\n        for client in active_connections[:]:\n            try:\n                if client.client_state == WebSocketState.CONNECTED:\n                    await client.send_json(message)\n            except Exception as e:\n                logger.error(f\"Error in message queue for client: {e}\")\n                if client in active_connections:\n                    active_connections.remove(client)\n        message_queue.task_done()\n\ndef load_reference_segments(config_data: CompanionConfig):\n    \"\"\"Load multiple reference clips for voice‑cloning.\"\"\"\n    global reference_segments\n    reference_segments = []\n    \n    # Load primary reference (required)\n    if os.path.isfile(config_data.reference_audio_path):\n        logger.info(f\"Loading primary reference audio: {config_data.reference_audio_path}\")\n        wav, sr = torchaudio.load(config_data.reference_audio_path)\n        wav = torchaudio.functional.resample(wav.squeeze(0),\n                                         orig_freq=sr,\n                                         new_freq=24_000)\n        reference_segments.append(Segment(text=config_data.reference_text,\n                                  speaker=config_data.voice_speaker_id,\n                                  audio=wav))\n    else:\n        logger.warning(f\"Primary reference audio '{config_data.reference_audio_path}' not found.\")\n    \n    # Load second reference (optional)\n    if config_data.reference_audio_path2 and os.path.isfile(config_data.reference_audio_path2):\n        logger.info(f\"Loading second reference audio: {config_data.reference_audio_path2}\")\n        wav, sr = torchaudio.load(config_data.reference_audio_path2)\n        wav = torchaudio.functional.resample(wav.squeeze(0),\n                                         orig_freq=sr,\n                                         new_freq=24_000)\n        reference_segments.append(Segment(text=config_data.reference_text2,\n                                  speaker=config_data.voice_speaker_id,\n                                  audio=wav))\n    \n    # Load third reference (optional)\n    if config_data.reference_audio_path3 and os.path.isfile(config_data.reference_audio_path3):\n        logger.info(f\"Loading third reference audio: {config_data.reference_audio_path3}\")\n        wav, sr = torchaudio.load(config_data.reference_audio_path3)\n        wav = torchaudio.functional.resample(wav.squeeze(0),\n                                         orig_freq=sr,\n                                         new_freq=24_000)\n        reference_segments.append(Segment(text=config_data.reference_text3,\n                                  speaker=config_data.voice_speaker_id,\n                                  audio=wav))\n    \n    logger.info(f\"Loaded {len(reference_segments)} reference audio segments.\")\n\ndef transcribe_audio(audio_data, sample_rate):\n    global whisper_model\n    audio_np = np.array(audio_data).astype(np.float32)\n    if sample_rate != 16000:\n        try:\n            audio_tensor = torch.tensor(audio_np).unsqueeze(0)\n            audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000)\n            audio_np = audio_tensor.squeeze(0).numpy()\n        except: pass\n    try:\n        with torch.jit.optimized_execution(False):\n            result = whisper_pipe(audio_np, generate_kwargs={\"language\": \"english\"}) \n            return result[\"text\"]\n    except:\n        return \"[Transcription error]\"\n\ndef initialize_models(config_data: CompanionConfig):\n    global generator, llm, rag, vad_processor, config\n    config = config_data                         \n\n    logger.info(\"Loading LLM …\")\n    llm = LLMInterface(config_data.llm_path,\n                       config_data.max_tokens)\n\n    logger.info(\"Loading RAG …\")\n    rag = RAGSystem(\"companion.db\",\n                    model_name=config_data.embedding_model)\n\n    vad_model, vad_utils = torch.hub.load('snakers4/silero-vad',\n                                          model='silero_vad',\n                                          force_reload=False)\n    vad_processor = AudioStreamProcessor(\n        model=vad_model,\n        utils=vad_utils,\n        sample_rate=16_000,\n        vad_threshold=config_data.vad_threshold,\n        callbacks={\"on_speech_start\": on_speech_start,\n                   \"on_speech_end\":   on_speech_end},\n    )\n\n    load_reference_segments(config_data)\n\n    start_model_thread()\n\n    logger.info(\"Compiling / warming‑up voice model …\")\n    t0 = time.time()\n\n    # send a dummy request; max 0.5 s of audio, result discarded\n    model_queue.put((\n        \"warm‑up.\",                          # text\n        config_data.voice_speaker_id,        # speaker\n        [],                                  # no context\n        500,                                 # max_ms\n        0.7,                                 # temperature\n        40,                                  # top‑k\n    ))\n\n    # block until worker signals EOS (None marker)\n    while True:\n        r = model_result_queue.get()\n        if r is None:\n            break\n\n    logger.info(f\"Voice model ready in {time.time() - t0:.1f}s\")\n\n\ndef on_speech_start():\n    asyncio.run_coroutine_threadsafe(\n        message_queue.put(\n            {\n                \"type\":   \"vad_status\",\n                \"status\": \"speech_started\",\n                \"should_interrupt\": False,  # always False – UI never barges-in here\n            }\n        ),\n        loop,\n    )\n\ndef on_speech_end(audio_data, sample_rate):\n    try:\n        logger.info(\"Transcription starting\")\n        user_text = transcribe_audio(audio_data, sample_rate)\n        logger.info(f\"Transcription completed: '{user_text}'\")\n\n        session_id = \"default\"\n        speaker_id = 1\n        index = speaker_counters[speaker_id]\n        user_audio_path = f\"audio/user/{session_id}_user_{index}.wav\"\n        os.makedirs(os.path.dirname(user_audio_path), exist_ok=True)\n\n        audio_tensor = torch.tensor(audio_data).unsqueeze(0)\n        save_audio_and_trim(user_audio_path, session_id, speaker_id, audio_tensor.squeeze(0), sample_rate)\n        add_segment(user_text, speaker_id, audio_tensor.squeeze(0))\n\n        logger.info(f\"User audio saved and segment appended: {user_audio_path}\")\n\n        speaker_counters[speaker_id] += 1\n\n        # Send transcription to clients\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\"type\": \"transcription\", \"text\": user_text}),\n            loop\n        )\n        \n        threading.Thread(target=lambda: process_user_input(user_text, session_id), daemon=True).start()\n    except Exception as e:\n        logger.error(f\"VAD callback failed: {e}\")\n\ndef process_pending_inputs():\n    \"\"\"Process only the latest user input after an interruption\"\"\"\n    global pending_user_inputs, is_speaking, interrupt_flag\n    time.sleep(0.2)\n    is_speaking = False\n    interrupt_flag.clear()\n    \n    with user_input_lock:\n        if not pending_user_inputs:\n            logger.info(\"No pending user inputs to process\")\n            return\n        \n        # Only take the most recent input and ignore others\n        latest_input = pending_user_inputs[-1]\n        logger.info(f\"Processing only latest input: '{latest_input[0]}'\")\n        \n        # Clear all pending inputs\n        pending_user_inputs = []\n        \n        # Process only the latest input\n        user_text, session_id = latest_input\n        process_user_input(user_text, session_id)\n\ndef process_user_input(user_text, session_id=\"default\"):\n    global config, is_speaking, pending_user_inputs, interrupt_flag\n    \n    # Skip empty messages\n    if not user_text or user_text.strip() == \"\":\n        logger.warning(\"Empty user input received, ignoring\")\n        return\n    \n    interrupt_flag.clear()\n    is_speaking = False\n    \n    # Check if we're currently supposed to be speaking\n    if is_speaking:\n        logger.info(f\"AI is currently speaking, adding input to pending queue: '{user_text}'\")\n        \n        with user_input_lock:\n            # Only keep the most recent input, replacing any existing ones\n            pending_user_inputs = [(user_text, session_id)]\n            logger.info(f\"Added user input as the only pending input: '{user_text}'\")\n        \n        # Request interruption if not already interrupted\n        if not interrupt_flag.is_set():\n            logger.info(\"Automatically interrupting current speech for new input\")\n            interrupt_flag.set()\n            # Notify clients of interruption\n            asyncio.run_coroutine_threadsafe(\n                message_queue.put({\"type\": \"audio_status\", \"status\": \"interrupted\"}),\n                loop\n            )\n            \n            # Allow a short delay before processing the new input\n            time.sleep(0.3)\n            \n            # Process the pending input after interruption\n            process_pending_inputs()\n        \n        return\n    \n    interrupt_flag.clear()\n    \n    # Normal processing continues...\n    logger.info(f\"Processing user input: '{user_text}'\")\n    context = \"\\n\".join([f\"User: {msg['user']}\\nAI: {msg['ai']}\" for msg in conversation_history[-5:]])\n    rag_context = rag.query(user_text)\n    system_prompt = config.system_prompt\n    if rag_context:\n        system_prompt += f\"\\n\\nRelevant context:\\n{rag_context}\"\n\n    # Notify clients that we're thinking\n    asyncio.run_coroutine_threadsafe(\n        message_queue.put({\"type\": \"status\", \"message\": \"Thinking...\"}),\n        loop\n    )\n    \n    try:\n        with llm_lock: \n            ai_response = llm.generate_response(system_prompt, user_text, context)\n        \n        timestamp = time.strftime(\"%Y-%m-%d %H:%M:%S\")\n        conversation_history.append({\n            \"timestamp\": timestamp,\n            \"user\": user_text,\n            \"ai\": ai_response\n        })\n        \n        try:\n            db = SessionLocal()\n            conv = Conversation(\n                session_id=session_id,\n                timestamp=timestamp,\n                user_message=user_text,\n                ai_message=ai_response,\n                audio_path=\"\"\n            )\n            db.add(conv)\n            db.commit()\n            index = speaker_counters[0]\n            output_file = f\"audio/ai/{session_id}_response_{index}.wav\"\n            speaker_counters[0] += 1\n            conv.audio_path = output_file\n            db.commit()\n            db.close()\n        except Exception as e:\n            logger.error(f\"Database error: {e}\")\n        \n        threading.Thread(target=lambda: rag.add_conversation(user_text, ai_response), daemon=True).start()\n        \n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\"type\": \"audio_status\", \"status\": \"preparing\"}),\n            loop\n        )\n        \n        # Small delay to ensure client is ready\n        time.sleep(0.2)\n        \n        # Send the response to clients\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\"type\": \"response\", \"text\": ai_response}),\n            loop\n        )\n\n        time.sleep(0.5)\n        \n        if is_speaking:\n            logger.warning(\"Still speaking when trying to start new audio - forcing interrupt\")\n            interrupt_flag.set()\n            is_speaking = False\n            time.sleep(0.5)  # Give time for cleanup\n        \n        interrupt_flag.clear()  # Make absolutely sure\n        is_speaking = False    # Reset for audio thread to take over\n        \n        # Start audio generation in a new thread\n        threading.Thread(target=audio_generation_thread, args=(ai_response, output_file), daemon=True).start()\n\n    except Exception as e:\n        logger.error(f\"Error generating response: {e}\")\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\"type\": \"error\", \"message\": \"Failed to generate response\"}),\n            loop\n        )\n\ndef model_worker(cfg: CompanionConfig):\n    global generator, model_thread_running\n\n    logger.info(\"Model worker thread started\")\n\n    if generator is None:\n        torch._inductor.config.triton.cudagraphs = False  # Disable cudagraphs\n        torch._inductor.config.fx_graph_cache = False  # Disable graph caching\n        logger.info(\"Loading voice model inside worker thread …\")\n        generator = load_csm_1b_local(cfg.model_path, \"cuda\")\n        logger.info(\"Voice model ready (compiled with cudagraphs)\")\n\n    while model_thread_running.is_set():\n        try:\n            request = model_queue.get(timeout=0.1)\n            if request is None:\n                break\n\n            text, speaker_id, context, max_ms, temperature, topk = request\n\n            for chunk in generator.generate_stream(\n                    text=text,\n                    speaker=speaker_id,\n                    context=context,\n                    max_audio_length_ms=max_ms,\n                    temperature=temperature,\n                    topk=topk):\n                model_result_queue.put(chunk)\n\n                if not model_thread_running.is_set():\n                    break\n\n            model_result_queue.put(None) # EOS marker\n\n        except queue.Empty:\n            continue\n        except Exception as e:\n            import traceback\n            logger.error(f\"Error in model worker: {e}\\n{traceback.format_exc()}\")\n            model_result_queue.put(Exception(f\"Generation error: {e}\"))\n\n    logger.info(\"Model worker thread exiting\")\n\ndef start_model_thread():\n    global model_thread, model_thread_running\n\n    if model_thread is not None and model_thread.is_alive():\n        return                        \n\n    model_thread_running.set()\n    model_thread = threading.Thread(target=model_worker,\n                                    args=(config,),\n                                    daemon=True,\n                                    name=\"model_worker\")\n    model_thread.start()\n    logger.info(\"Started dedicated model worker thread\")\n\nasync def run_audio_generation(text, output_file):\n    \"\"\"Async wrapper for audio generation that runs in the event loop thread\"\"\"\n    audio_generation_thread(text, output_file)\n\ndef send_to_all_clients(message: dict):\n    \"\"\"Send a message to all connected WebSocket clients\"\"\"\n    for client in active_connections[:]:\n        try:\n            if client.client_state == WebSocketState.CONNECTED:\n                asyncio.run_coroutine_threadsafe(client.send_json(message), loop)\n                logger.info(f\"Sent message to client: {message}\")\n            else:\n                logger.warning(\"Detected non-connected client; removing from active_connections\")\n                active_connections.remove(client)\n        except Exception as e:\n            logger.error(f\"Error sending message to client: {e}\")\n            if client in active_connections:\n                active_connections.remove(client)\n\nsaved_audio_paths = {\n    \"default\": {\n        0: [],  # AI\n        1: []   # User\n    }\n}\nMAX_AUDIO_FILES = 8\n\ndef save_audio_and_trim(path, session_id, speaker_id, tensor, sample_rate):\n    \"\"\"\n    Save audio file and trim old audio files for both AI and user to maintain storage limits.\n    \n    Args:\n        path: Path to save the audio file\n        session_id: Conversation session ID\n        speaker_id: 0 for AI, 1 for user\n        tensor: Audio tensor to save\n        sample_rate: Audio sample rate\n    \"\"\"\n    torchaudio.save(path, tensor.unsqueeze(0), sample_rate)\n    \n    saved_audio_paths.setdefault(session_id, {}).setdefault(speaker_id, []).append(path)\n    \n    paths = saved_audio_paths[session_id][speaker_id]\n    while len(paths) > MAX_AUDIO_FILES:\n        old_path = paths.pop(0)\n        if os.path.exists(old_path):\n            os.remove(old_path)\n            logger.info(f\"Removed old audio file: {old_path}\")\n    \n    other_speaker_id = 1 if speaker_id == 0 else 0\n    if other_speaker_id in saved_audio_paths[session_id]:\n        other_paths = saved_audio_paths[session_id][other_speaker_id]\n        while len(other_paths) > MAX_AUDIO_FILES:\n            old_path = other_paths.pop(0)\n            if os.path.exists(old_path):\n                os.remove(old_path)\n                logger.info(f\"Removed old audio file from other speaker: {old_path}\")\n\nMAX_SEGMENTS = 8\n\ndef add_segment(text, speaker_id, audio_tensor):\n    \"\"\"\n    Add a new segment and ensure the total context stays within token limits.\n    This version correctly separates protected and dynamic segments, performs trimming\n    on the dynamic list, and rebuilds the global context list at the end.\n\n    Args:\n        text: Text content of the segment\n        speaker_id: ID of the speaker (0 for AI, 1 for user)\n        audio_tensor: Audio data as a tensor\n    \"\"\"\n    global reference_segments, generator, config\n\n    # Determine the number of protected, initial reference segments based on what was actually loaded.\n    num_protected_segments = 0\n    if config.reference_audio_path and os.path.exists(config.reference_audio_path):\n        num_protected_segments += 1\n    if config.reference_audio_path2 and os.path.exists(config.reference_audio_path2):\n        num_protected_segments += 1\n    if config.reference_audio_path3 and os.path.exists(config.reference_audio_path3):\n        num_protected_segments += 1\n\n    # Separate protected from dynamic segments from the current global state\n    protected_segments = reference_segments[:num_protected_segments]\n    dynamic_segments = reference_segments[num_protected_segments:]\n\n    # Add the new segment to the dynamic list\n    new_segment = Segment(text=text, speaker=speaker_id, audio=audio_tensor)\n    dynamic_segments.append(new_segment)\n\n    # First, trim by MAX_SEGMENTS count. The oldest dynamic segments are removed.\n    max_dynamic_allowed = MAX_SEGMENTS - len(protected_segments)\n    if len(dynamic_segments) > max_dynamic_allowed:\n        # Keep only the most recent dynamic segments\n        dynamic_segments = dynamic_segments[-max_dynamic_allowed:]\n\n    # Then, check and trim by token count if necessary.\n    # This loop will trim the oldest dynamic segments until the token count is acceptable.\n    if hasattr(generator, '_text_tokenizer'):\n        while dynamic_segments:\n            # Tentatively combine for token calculation\n            temp_full_list = protected_segments + dynamic_segments\n            total_tokens = 0\n\n            # Calculate total tokens for the current combination\n            for segment in temp_full_list:\n                tokens = generator._text_tokenizer.encode(f\"[{segment.speaker}]{segment.text}\")\n                total_tokens += len(tokens)\n                if segment.audio is not None:\n                    # Approximate frame count to token conversion\n                    audio_frames = segment.audio.size(0) // 6094\n                    total_tokens += audio_frames\n\n            # If we are within limits, the trimming is done.\n            if total_tokens <= 4096:\n                break\n\n            # Otherwise, remove the oldest dynamic segment and re-check in the next loop iteration.\n            dynamic_segments.pop(0)\n\n    else:\n        # Fallback if tokenizer is not available\n        logger.warning(\"Unable to access tokenizer - falling back to word-based estimation for context trimming\")\n\n        def estimate_tokens(segment):\n            words = segment.text.split()\n            punctuation = sum(1 for char in segment.text if char in \".,!?;:\\\"'()[]{}\")\n            text_tokens = len(words) + punctuation\n            audio_tokens = 0\n            if segment.audio is not None:\n                audio_frames = segment.audio.size(0) // 6094\n                audio_tokens = audio_frames\n            return text_tokens + audio_tokens\n\n        while dynamic_segments:\n            total_estimated_tokens = sum(estimate_tokens(s) for s in protected_segments) + \\\n                                     sum(estimate_tokens(s) for s in dynamic_segments)\n            if total_estimated_tokens <= 2048:\n                break\n            dynamic_segments.pop(0)\n\n    # Finally, overwrite the global variable with the new, correctly-trimmed list.\n    # This is the single source of truth for the update.\n    reference_segments = protected_segments + dynamic_segments\n\n    # Log the final state for debugging\n    logger.info(f\"Context updated. Segments: {len(reference_segments)} total \" +\n                f\"({len(protected_segments)} protected, {len(dynamic_segments)} dynamic).\")\n\ndef preprocess_text_for_tts(text):\n    \"\"\"\n    Removes all punctuation except periods, commas, exclamation points, and question marks\n    from the input text to create cleaner speech output while preserving intonation.\n    Args:\n    text (str): Input text with potential punctuation\n    Returns:\n    str: Cleaned text with only allowed punctuation\n    \"\"\"\n    # Define a regex pattern that matches all punctuation except periods, commas, exclamation points, and question marks\n    # This includes: ; : \" '  ~ @ # $ % ^ & * ( ) _ - + = [ ] { } \\ | / < >\n    pattern = r'[^\\w\\s.,!?\\']'\n    # Replace matched punctuation with empty string\n    cleaned_text = re.sub(pattern, '', text)\n    # normalize multiple spaces to single space\n    cleaned_text = re.sub(r'\\s+', ' ', cleaned_text)\n    # ensure there's a space after punctuation for better speech pacing\n    cleaned_text = re.sub(r'([.,!?])(\\S)', r'\\1 \\2', cleaned_text)\n    return cleaned_text.strip()\n\ndef audio_generation_thread(text, output_file):\n    global is_speaking, interrupt_flag, audio_queue, model_thread_running, current_generation_id, speaking_start_time\n    \n    current_generation_id += 1\n    this_id = current_generation_id\n    \n    interrupt_flag.clear()\n    \n    # Log the start of generation\n    logger.info(f\"Starting audio generation for ID: {this_id}\")\n    \n    # Try to acquire the lock, but don't block if it's busy\n    if not audio_gen_lock.acquire(blocking=False):\n        logger.warning(f\"Audio generation {this_id} - lock acquisition failed, another generation is in progress\")\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\n                \"type\": \"error\", \n                \"message\": \"Audio generation busy, skipping synthesis\",\n                \"gen_id\": this_id\n            }),\n            loop\n        )\n        return\n    \n    try:\n        # Start the model thread if it's not already running\n        start_model_thread()\n        \n        interrupt_flag.clear()\n        is_speaking = True\n        speaking_start_time = time.time()\n        \n        # Create output directory\n        os.makedirs(os.path.dirname(output_file), exist_ok=True)\n        all_audio_chunks = []\n        \n        # Prepare text\n        text_lower = text.lower()\n        text_lower = preprocess_text_for_tts(text_lower)\n        \n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\n                \"type\": \"audio_status\", \n                \"status\": \"preparing_generation\",\n                \"gen_id\": this_id\n            }),\n            loop\n        )\n        \n        # Give client a moment to process\n        time.sleep(0.2)\n        \n        logger.info(f\"Sending generating status with ID {this_id}\")\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\n                \"type\": \"audio_status\", \n                \"status\": \"generating\",\n                \"gen_id\": this_id  # Include generation ID\n            }),\n            loop\n        )\n        \n        # Small delay to ensure client gets the signal\n        time.sleep(0.2)\n        \n        # Estimate audio length\n        words = text.split()\n        avg_wpm = 100\n        words_per_second = avg_wpm / 60\n        estimated_seconds = len(words) / words_per_second\n        max_audio_length_ms = int(estimated_seconds * 1000)\n        \n        # Send request to model thread\n        logger.info(f\"Audio generation {this_id} - sending request to model thread\")\n        model_queue.put((\n            text_lower,\n            config.voice_speaker_id,\n            reference_segments,\n            max_audio_length_ms,\n            0.8,  # temperature\n            50    # topk\n        ))\n        \n        # Start timing\n        generation_start = time.time()\n        chunk_counter = 0\n        \n        # Process results as they come\n        while True:\n            try:\n                # Check for interruption FIRST before getting more results\n                if interrupt_flag.is_set():\n                    logger.info(f\"Audio generation {this_id} - interrupt detected, stopping\")\n                    \n                    # Signal model thread to exit and restart\n                    model_thread_running.clear()\n                    time.sleep(0.1)\n                    model_thread_running.set()\n                    start_model_thread()\n                    \n                    # Clear any remaining items in the result queue\n                    while not model_result_queue.empty():\n                        try:\n                            model_result_queue.get_nowait()\n                        except queue.Empty:\n                            pass\n                    \n                    # Break out of the processing loop\n                    break\n                \n                # Get result with timeout to allow checking interrupt\n                result = model_result_queue.get(timeout=0.1)\n                \n                # Check for end of generation or error\n                if result is None:\n                    logger.info(f\"Audio generation {this_id} - complete\")\n                    break\n                    \n                if isinstance(result, Exception):\n                    logger.error(f\"Audio generation {this_id} - error: {result}\")\n                    raise result\n                \n                # Track timing for first chunk\n                if chunk_counter == 0:\n                    first_chunk_time = time.time() - generation_start\n                    logger.info(f\"Audio generation {this_id} - first chunk latency: {first_chunk_time*1000:.1f}ms\")\n                \n                chunk_counter += 1\n                \n                # One more interrupt check before processing chunk\n                if interrupt_flag.is_set():\n                    logger.info(f\"Audio generation {this_id} - interrupt flag set during chunk processing\")\n                    break\n                \n                # Process this audio chunk\n                audio_chunk = result\n                all_audio_chunks.append(audio_chunk)\n                \n                # Convert to numpy and send to audio queue\n                chunk_array = audio_chunk.cpu().numpy().astype(np.float32)\n                audio_queue.put(chunk_array)\n                \n                if chunk_counter == 1:\n                    logger.info(f\"Sending first audio chunk with ID {this_id}\")\n                    # Notify client we're sending the first chunk\n                    asyncio.run_coroutine_threadsafe(\n                        message_queue.put({\n                            \"type\": \"audio_status\", \n                            \"status\": \"first_chunk\",\n                            \"gen_id\": this_id\n                        }),\n                        loop\n                    )\n                    # Small delay\n                    time.sleep(0.1)\n                \n                # Send chunk with generation ID\n                asyncio.run_coroutine_threadsafe(\n                    message_queue.put({\n                        \"type\": \"audio_chunk\",\n                        \"audio\": chunk_array.tolist(),\n                        \"sample_rate\": generator.sample_rate,\n                        \"gen_id\": this_id,\n                        \"chunk_num\": chunk_counter  # Include chunk number\n                    }),\n                    loop\n                )\n                \n            except queue.Empty:\n                # No results yet, keep checking\n                continue\n            except Exception as e:\n                logger.error(f\"Audio generation {this_id} - error processing result: {e}\")\n                break\n        \n        # Save complete audio if available\n        if all_audio_chunks and not interrupt_flag.is_set():\n            try:\n                complete_audio = torch.cat(all_audio_chunks)\n                save_audio_and_trim(output_file, \"default\", config.voice_speaker_id, complete_audio, generator.sample_rate)\n                add_segment(text.lower(), config.voice_speaker_id, complete_audio)\n                \n                # Log statistics\n                total_time = time.time() - generation_start\n                total_audio_seconds = complete_audio.size(0) / generator.sample_rate\n                rtf = total_time / total_audio_seconds\n                logger.info(f\"Audio generation {this_id} - completed in {total_time:.2f}s, RTF: {rtf:.2f}x\")\n            except Exception as e:\n                logger.error(f\"Audio generation {this_id} - error saving complete audio: {e}\")\n                \n    except Exception as e:\n        import traceback\n        logger.error(f\"Audio generation {this_id} - unexpected error: {e}\\n{traceback.format_exc()}\")\n    finally:\n        is_speaking = False\n        \n        # Signal end of audio\n        audio_queue.put(None)\n        \n        try:\n            logger.info(f\"Audio generation {this_id} - sending completion status\")\n            asyncio.run_coroutine_threadsafe(\n                message_queue.put({\n                    \"type\": \"audio_status\", \n                    \"status\": \"complete\",\n                    \"gen_id\": this_id\n                }),\n                loop\n            )\n        except Exception as e:\n            logger.error(f\"Audio generation {this_id} - failed to send completion status: {e}\")\n            \n        # Process any pending inputs\n        with user_input_lock:\n            if pending_user_inputs:\n                # Process pending inputs\n                logger.info(f\"Audio generation {this_id} - processing pending inputs\")\n                process_pending_inputs()\n            \n        # Release the lock\n        logger.info(f\"Audio generation {this_id} - releasing lock\")\n        audio_gen_lock.release()\n    \ndef handle_interrupt(websocket):\n    global is_speaking, last_interrupt_time, interrupt_flag, model_thread_running, speaking_start_time\n    \n    # Log the current state\n    logger.info(f\"Interrupt requested. Current state: is_speaking={is_speaking}\")\n    \n    current_time = time.time()\n    time_since_speech_start = current_time - speaking_start_time if speaking_start_time > 0 else 999\n    time_since_last_interrupt = current_time - last_interrupt_time\n    \n    # Only apply cooldown for established speech, not for new speech\n    if time_since_last_interrupt < interrupt_cooldown and time_since_speech_start > 3.0:\n        logger.info(f\"Ignoring interrupt: too soon after previous interrupt ({time_since_last_interrupt:.1f}s < {interrupt_cooldown}s)\")\n        # Let the client know we're not interrupting\n        asyncio.run_coroutine_threadsafe(\n           websocket.send_json({\n               \"type\": \"audio_status\",\n               \"status\": \"interrupt_acknowledged\",\n               \"success\": False,\n               \"reason\": \"cooldown\"\n           }),\n           loop\n        )\n        return False\n    \n    # Update the last interrupt time\n    last_interrupt_time = current_time\n    \n    # We should interrupt if we're speaking OR if model generation is in progress\n    if is_speaking or not model_result_queue.empty():\n        logger.info(\"Interruption processing: we are speaking or generating\")\n        \n        interrupt_flag.set()\n        \n        # Notify clients\n        asyncio.run_coroutine_threadsafe(\n            message_queue.put({\"type\": \"audio_status\", \"status\": \"interrupted\"}),\n            loop\n        )\n        \n        asyncio.run_coroutine_threadsafe(\n           websocket.send_json({\n               \"type\": \"audio_status\",\n               \"status\": \"interrupt_acknowledged\"\n           }),\n           loop\n        )\n        \n        # Clear the audio queue to stop additional audio from being processed\n        try:\n            # Drain the existing queue\n            while not audio_queue.empty():\n                try:\n                    audio_queue.get_nowait()\n                except queue.Empty:\n                    break\n                    \n            # Add end signal\n            audio_queue.put(None)\n            logger.info(\"Audio queue cleared\")\n        except Exception as e:\n            logger.error(f\"Error clearing audio queue: {e}\")\n        \n        # Reset VAD to prepare for new input\n        if vad_processor:\n            try:\n                vad_processor.reset()\n                logger.info(\"VAD processor reset\")\n            except Exception as e:\n                logger.error(f\"Error resetting VAD: {e}\")\n        \n        # Stop current model worker if needed\n        if model_thread and model_thread.is_alive():\n            try:\n                # Clear the thread running flag to stop generation\n                model_thread_running.clear()\n                \n                # Wait a brief moment for thread to notice and exit\n                time.sleep(0.1)\n                \n                # Now restart the thread state flag\n                model_thread_running.set()\n                \n                # And restart the thread\n                start_model_thread()\n                logger.info(\"Model thread restarted\")\n            except Exception as e:\n                logger.error(f\"Error restarting model thread: {e}\")\n        \n        return True\n    \n    logger.info(\"No active speech to interrupt\")\n    return False\n\n@app.websocket(\"/ws\")\nasync def websocket_endpoint(websocket: WebSocket):\n    global is_speaking, audio_queue\n    \n    await websocket.accept()\n    active_connections.append(websocket)\n    \n    saved = config_manager.load_config()\n    if saved:\n        await websocket.send_json({\"type\": \"saved_config\", \"config\": saved})\n        \n    try:\n        while True:\n            data = await websocket.receive_json()\n            \n            if data[\"type\"] == \"config\":\n                # Config handling\n                try:\n                    config_data = data[\"config\"]\n                    \n                    logger.info(f\"Received config data keys: {config_data.keys()}\")\n\n                    for key in [\"reference_audio_path\", \"reference_audio_path2\", \"reference_audio_path3\",\n                               \"reference_text\", \"reference_text2\", \"reference_text3\"]:\n                        if key in config_data:\n                            logger.info(f\"Config includes {key}: {config_data[key]}\")\n                        else:\n                            logger.warning(f\"Config missing {key}\")\n                    \n                    conf = CompanionConfig(**config_data)\n                    \n                    saved = config_manager.save_config(config_data)\n                    \n                    if saved:\n                        initialize_models(conf)\n                        await websocket.send_json({\"type\": \"status\", \"message\": \"Models initialized and configuration saved\"})\n                    else:\n                        await websocket.send_json({\"type\": \"error\", \"message\": \"Failed to save configuration\"})\n                        \n                except Exception as e:\n                    logger.error(f\"Error processing config: {str(e)}\")\n                    await websocket.send_json({\"type\": \"error\", \"message\": f\"Configuration error: {str(e)}\"})\n                \n                \n            elif data[\"type\"] == \"request_saved_config\":\n                saved = config_manager.load_config()\n                await websocket.send_json({\"type\": \"saved_config\", \"config\": saved})\n            \n            elif data[\"type\"] == \"text_message\":\n                user_text   = data[\"text\"]\n                session_id  = data.get(\"session_id\", \"default\")\n                logger.info(f\"TEXT-MSG from client: {user_text!r}\")\n\n                # If the model is already talking, queue the request but\n                if is_speaking:\n                    with user_input_lock:\n                        if len(pending_user_inputs) >= 3:\n                            pending_user_inputs = pending_user_inputs[-2:]\n                        pending_user_inputs.append((user_text, session_id))\n                    await websocket.send_json(\n                        {\"type\":\"status\",\"message\":\"Queued – I’ll answer in a moment\"})\n                    continue                         \n\n                await message_queue.put({\"type\":\"transcription\",\"text\":user_text})\n                threading.Thread(\n                    target=lambda: process_user_input(user_text, session_id),\n                    daemon=True).start()\n                \n            elif data[\"type\"] == \"audio\":\n                audio_data = np.asarray(data[\"audio\"], dtype=np.float32)\n                sample_rate = data[\"sample_rate\"]\n\n                if sample_rate != 16000:\n                    audio_tensor = torch.tensor(audio_data).unsqueeze(0)\n                    audio_tensor = torchaudio.functional.resample(\n                        audio_tensor, orig_freq=sample_rate, new_freq=16000\n                    )\n                    audio_data  = audio_tensor.squeeze(0).numpy()\n                    sample_rate = 16000\n\n                if config and config.vad_enabled:\n                    vad_processor.process_audio(audio_data)  \n                else:\n                    text = transcribe_audio(audio_data, sample_rate)\n                    await websocket.send_json({\"type\": \"transcription\", \"text\": text})\n                    await message_queue.put({\"type\": \"transcription\", \"text\": text})\n\n                    if is_speaking:\n                        with user_input_lock:\n                            pending_user_inputs.append((text, \"default\"))\n                    else:\n                        process_user_input(text)\n\n                        \n            elif data[\"type\"] == \"interrupt\":\n                logger.info(\"Explicit interrupt request received\")\n                \n                # Always acknowledge receipt of interrupt request\n                await websocket.send_json({\n                    \"type\": \"audio_status\", \n                    \"status\": \"interrupt_acknowledged\"\n                })\n                \n                # Then try to handle the actual interrupt\n                success = handle_interrupt(websocket)\n                \n                # If successful, allow a brief delay for clearing everything\n                if success:\n                    await asyncio.sleep(0.3)  # Short delay to allow complete clearing\n                    \n                    # Force process pending inputs after interrupt\n                    with user_input_lock:\n                        if pending_user_inputs:\n                            user_text, session_id = pending_user_inputs.pop(0)\n                            pending_user_inputs.clear()  # Clear any backup to avoid multiple responses\n                            \n                            # Process in a new thread to avoid blocking\n                            threading.Thread(\n                                target=lambda: process_user_input(user_text, session_id),\n                                daemon=True\n                            ).start()\n                \n                # Send final status update about the interrupt\n                await websocket.send_json({\n                    \"type\": \"audio_status\", \n                    \"status\": \"interrupted\",\n                    \"success\": success\n                })\n                \n            elif data[\"type\"] == \"mute\":\n                await websocket.send_json({\"type\": \"mute_status\", \"muted\": data[\"muted\"]})\n                if not data[\"muted\"] and config and config.vad_enabled:\n                    vad_processor.reset()\n                    \n    except WebSocketDisconnect:\n        if websocket in active_connections:\n            active_connections.remove(websocket)\n\n@app.get(\"/\", response_class=HTMLResponse)\nasync def index(request: Request):\n    return templates.TemplateResponse(\"index.html\", {\"request\": request})\n\n@app.get(\"/setup\", response_class=HTMLResponse)\nasync def setup_page(request: Request):\n    return templates.TemplateResponse(\"setup.html\", {\"request\": request})\n\n@app.get(\"/chat\", response_class=HTMLResponse)\nasync def chat_page(request: Request):\n    return templates.TemplateResponse(\"chat.html\", {\"request\": request})\n\n@app.on_event(\"startup\")\nasync def startup_event():\n    os.makedirs(\"static\", exist_ok=True)\n    os.makedirs(\"audio/user\", exist_ok=True)\n    os.makedirs(\"audio/ai\", exist_ok=True)\n    os.makedirs(\"embeddings_cache\", exist_ok=True)\n    os.makedirs(\"templates\", exist_ok=True)\n    with open(\"templates/index.html\", \"w\") as f:\n        f.write(\"\"\"<meta http-equiv=\"refresh\" content=\"0; url=/setup\" />\"\"\")\n    try:\n        torch.hub.load('snakers4/silero-vad', model='silero_vad', force_reload=False)\n    except: pass\n    asyncio.create_task(process_message_queue())\n\n@app.on_event(\"shutdown\")\nasync def shutdown_event():\n    logger.info(\"Server shutting down...\")\n\nfrom flask import Flask, jsonify, request, send_file\n\n@app.get(\"/api/conversations\")\nasync def get_conversations(request: Request):\n    conn = sqlite3.connect(\"companion.db\")\n    cur = conn.cursor()\n    cur.execute(\"SELECT id, user_message, ai_message FROM conversations ORDER BY id DESC\")\n    data = [{\"id\": row[0], \"user_message\": row[1], \"ai_message\": row[2]} for row in cur.fetchall()]\n    conn.close()\n    return JSONResponse(content=data)\n\n@app.route(\"/api/conversations/<int:conv_id>\", methods=[\"PUT\"])\ndef update_conversation(conv_id):\n    data = request.get_json()\n    conn = sqlite3.connect(\"companion.db\")\n    cur = conn.cursor()\n    cur.execute(\"UPDATE conversations SET user_message=?, ai_message=? WHERE id=?\",\n                (data[\"user_message\"], data[\"ai_message\"], conv_id))\n    conn.commit()\n    conn.close()\n    return \"\", 204\n\n@app.delete(\"/api/conversations\")\nasync def delete_all_conversations():\n    try:\n        conn = sqlite3.connect(\"companion.db\")\n        cur = conn.cursor()\n        cur.execute(\"DELETE FROM conversations\")\n        conn.commit()\n        conn.close()\n        return {\"status\": \"all deleted\"}\n    except Exception as e:\n        return JSONResponse(content={\"error\": str(e)}, status_code=500)\n\n@app.delete(\"/api/conversations/{conv_id}\")\nasync def delete_conversation(conv_id: int):\n    try:\n        conn = sqlite3.connect(\"companion.db\")\n        cur = conn.cursor()\n        cur.execute(\"DELETE FROM conversations WHERE id = ?\", (conv_id,))\n        conn.commit()\n        conn.close()\n        return JSONResponse(content={\"status\": \"deleted\", \"id\": conv_id})\n    except Exception as e:\n        return JSONResponse(content={\"error\": str(e)}, status_code=500)\n\napp.mount(\"/static\", StaticFiles(directory=\"static\"), name=\"static\")\ntemplates = Jinja2Templates(directory=\"templates\")\n\n@app.get(\"/crud\", response_class=HTMLResponse)\nasync def crud_ui(request: Request):\n    return templates.TemplateResponse(\"crud.html\", {\"request\": request})\n\nif __name__ == \"__main__\":\n    import uvicorn\n    threading.Thread(target=lambda: asyncio.run(loop.run_forever()), daemon=True).start()\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "models.py",
    "content": "import logging\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torchtune\nfrom huggingface_hub import PyTorchModelHubMixin\nfrom torchtune.models import llama3_2\n\nlogger = logging.getLogger(__name__)\n\ndef llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:\n    return llama3_2.llama3_2(\n        vocab_size=128_256,\n        num_layers=16,\n        num_heads=32,\n        num_kv_heads=8,\n        embed_dim=2048,\n        max_seq_len=2048,\n        intermediate_dim=8192,\n        attn_dropout=0.0,\n        norm_eps=1e-5,\n        rope_base=500_000,\n        scale_factor=32,\n    )\n\ndef llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:\n    return llama3_2.llama3_2(\n        vocab_size=128_256,\n        num_layers=4,\n        num_heads=8,\n        num_kv_heads=2,\n        embed_dim=1024,\n        max_seq_len=2048,\n        intermediate_dim=8192,\n        attn_dropout=0.0,\n        norm_eps=1e-5,\n        rope_base=500_000,\n        scale_factor=32,\n    )\n\nFLAVORS = {\n    \"llama-1B\": llama3_2_1B,\n    \"llama-100M\": llama3_2_100M,\n}\n\ndef _prepare_transformer(model):\n    embed_dim = model.tok_embeddings.embedding_dim\n    model.tok_embeddings = nn.Identity()\n    model.output = nn.Identity()\n    return model, embed_dim\n\ndef _create_causal_mask(seq_len: int, device: torch.device):\n    return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))\n\ndef _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):\n    \"\"\"\n    Args:\n        mask: (max_seq_len, max_seq_len)\n        input_pos: (batch_size, seq_len)\n\n    Returns:\n        (batch_size, seq_len, max_seq_len)\n    \"\"\"\n    r = mask[input_pos, :]\n    return r\n\ndef _multinomial_sample_one_no_sync(probs):  # Does multinomial sampling without a cuda synchronization\n    q = torch.empty_like(probs).exponential_(1)\n    return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)\n\ndef sample_topk(logits: torch.Tensor, topk: int, temperature: float):\n    logits = logits / temperature\n\n    filter_value: float = -float(\"Inf\")\n    indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]\n    scores_processed = logits.masked_fill(indices_to_remove, filter_value)\n    scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)\n    probs = torch.nn.functional.softmax(scores_processed, dim=-1)\n\n    sample_token = _multinomial_sample_one_no_sync(probs)\n    return sample_token\n\n@dataclass\nclass ModelArgs:\n    backbone_flavor: str\n    decoder_flavor: str\n    text_vocab_size: int\n    audio_vocab_size: int\n    audio_num_codebooks: int\n\n\nclass Model(\n    nn.Module,\n    PyTorchModelHubMixin,\n    repo_url=\"https://github.com/SesameAILabs/csm\",\n    pipeline_tag=\"text-to-speech\",\n    license=\"apache-2.0\",\n):\n    def __init__(self, config: ModelArgs):\n        super().__init__()\n        self.config = config\n\n        self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())\n        self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())\n\n        self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)\n        self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)\n\n        self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)\n        self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)\n        self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))\n\n    def setup_caches(self, max_batch_size: int) -> torch.Tensor:\n        \"\"\"Setup KV caches and return a causal mask.\"\"\"\n        dtype = next(self.parameters()).dtype\n        device = next(self.parameters()).device\n\n        with device:\n            self.backbone.setup_caches(max_batch_size, dtype)\n            self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)\n\n        self.register_buffer(\"backbone_causal_mask\", _create_causal_mask(self.backbone.max_seq_len, device))\n        self.register_buffer(\"decoder_causal_mask\", _create_causal_mask(self.config.audio_num_codebooks, device))\n\n    def generate_frame(\n        self,\n        tokens: torch.Tensor,\n        tokens_mask: torch.Tensor,\n        input_pos: torch.Tensor,\n        temperature: float,\n        topk: int,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            tokens: (batch_size, seq_len, audio_num_codebooks+1)\n            tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)\n            input_pos: (batch_size, seq_len) positions for each token\n            mask: (batch_size, seq_len, max_seq_len\n\n        Returns:\n            (batch_size, audio_num_codebooks) sampled tokens\n        \"\"\"\n        dtype = next(self.parameters()).dtype\n        b, s, _ = tokens.size()\n\n        assert self.backbone.caches_are_enabled(), \"backbone caches are not enabled\"\n        curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)\n        embeds = self._embed_tokens(tokens)\n        masked_embeds = embeds * tokens_mask.unsqueeze(-1)\n        h = masked_embeds.sum(dim=2)\n        h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)\n\n        last_h = h[:, -1, :]\n        c0_logits = self.codebook0_head(last_h)\n        c0_sample = sample_topk(c0_logits, topk, temperature)\n        c0_embed = self._embed_audio(0, c0_sample)\n\n        curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)\n        curr_sample = c0_sample.clone()\n        curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)\n\n        # Decoder caches must be reset every frame.\n        self.decoder.reset_caches()\n        for i in range(1, self.config.audio_num_codebooks):\n            curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)\n            decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(\n                dtype=dtype\n            )\n            ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])\n            ci_sample = sample_topk(ci_logits, topk, temperature)\n            ci_embed = self._embed_audio(i, ci_sample)\n\n            curr_h = ci_embed\n            curr_sample = torch.cat([curr_sample, ci_sample], dim=1)\n            curr_pos = curr_pos[:, -1:] + 1\n\n        return curr_sample\n\n    def reset_caches(self):\n        self.backbone.reset_caches()\n        self.decoder.reset_caches()\n\n    def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:\n        return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)\n\n    def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:\n        text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)\n\n        audio_tokens = tokens[:, :, :-1] + (\n            self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)\n        )\n        audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(\n            tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1\n        )\n\n        return torch.cat([audio_embeds, text_embeds], dim=-2)\n    \n"
  },
  {
    "path": "rag_system.py",
    "content": "import sqlite3\nimport numpy as np\nimport json\nfrom pathlib import Path\nimport time\nfrom typing import List, Dict, Any, Tuple, Optional\nfrom sentence_transformers import SentenceTransformer\nfrom sklearn.metrics.pairwise import cosine_similarity\nimport torch\n\nclass RAGSystem:\n    def __init__(self, db_path: str, model_name: str = \"all-MiniLM-L6-v2\", cache_dir: str = \"./embeddings_cache\"):\n        \"\"\"\n        Initialize the enhanced RAG system with embeddings.\n        \n        Args:\n            db_path: Path to the SQLite database\n            model_name: Name of the sentence-transformer model to use\n            cache_dir: Directory to cache embeddings\n        \"\"\"\n        self.db_path = db_path\n        self.cache_dir = Path(cache_dir)\n        self.cache_dir.mkdir(exist_ok=True)\n        \n        # Load embedding model\n        print(f\"Loading embedding model: {model_name}\")\n        self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        self.model = SentenceTransformer(model_name, device=self.device)\n        print(f\"Embedding model loaded on {self.device}\")\n        \n        # Cache for embeddings\n        self.embedding_cache = self._load_embedding_cache()\n        \n        # Initialize database tables if needed\n        self._initialize_db()\n        \n        # Load existing conversations and cache embeddings\n        self._load_conversations()\n    \n    def _initialize_db(self):\n        \"\"\"Create necessary tables if they don't exist.\"\"\"\n        conn = sqlite3.connect(self.db_path)\n        cursor = conn.cursor()\n        \n        # Create conversations table if it doesn't exist\n        cursor.execute(\"\"\"\n        CREATE TABLE IF NOT EXISTS conversations (\n            id INTEGER PRIMARY KEY,\n            user_message TEXT,\n            ai_message TEXT,\n            timestamp DATETIME DEFAULT CURRENT_TIMESTAMP\n        )\n        \"\"\")\n        \n        # Create embeddings table if it doesn't exist\n        cursor.execute(\"\"\"\n        CREATE TABLE IF NOT EXISTS embeddings (\n            id INTEGER PRIMARY KEY,\n            conversation_id INTEGER,\n            text TEXT,\n            embedding_file TEXT,\n            chunk_id TEXT,\n            FOREIGN KEY (conversation_id) REFERENCES conversations(id)\n        )\n        \"\"\")\n        \n        conn.commit()\n        conn.close()\n    \n    def _load_embedding_cache(self) -> Dict[str, np.ndarray]:\n        \"\"\"Load cached embeddings from disk.\"\"\"\n        cache = {}\n        \n        for cache_file in self.cache_dir.glob(\"*.json\"):\n            try:\n                with open(cache_file, \"r\") as f:\n                    cache_data = json.load(f)\n                    for chunk_id, embedding_data in cache_data.items():\n                        cache[chunk_id] = np.array(embedding_data)\n            except Exception as e:\n                print(f\"Error loading cache file {cache_file}: {e}\")\n        \n        print(f\"Loaded {len(cache)} cached embeddings\")\n        return cache\n    \n    def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):\n        \"\"\"Save an embedding to the cache.\"\"\"\n        cache_file = self.cache_dir / f\"{chunk_id[:2]}.json\"\n        \n        # Load existing cache file or create new one\n        if cache_file.exists():\n            try:\n                with open(cache_file, \"r\") as f:\n                    cache_data = json.load(f)\n            except:\n                cache_data = {}\n        else:\n            cache_data = {}\n        \n        # Add new embedding\n        cache_data[chunk_id] = embedding.tolist()\n        \n        # Save cache file\n        with open(cache_file, \"w\") as f:\n            json.dump(cache_data, f)\n    \n    def _load_conversations(self):\n        \"\"\"Load existing conversations from the database and cache their embeddings.\"\"\"\n        try:\n            conn = sqlite3.connect(self.db_path)\n            cursor = conn.cursor()\n            \n            # First check if the conversations table exists\n            cursor.execute(\"SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'\")\n            if not cursor.fetchone():\n                print(\"Conversations table does not exist yet\")\n                conn.close()\n                return\n            \n            # Get all conversations not yet in the embeddings table\n            cursor.execute(\"\"\"\n            SELECT c.id, c.user_message, c.ai_message \n            FROM conversations c\n            LEFT JOIN embeddings e ON c.id = e.conversation_id\n            WHERE e.id IS NULL\n            \"\"\")\n            \n            conversations = cursor.fetchall()\n            if not conversations:\n                conn.close()\n                return\n            \n            print(f\"Processing embeddings for {len(conversations)} new conversations\")\n            \n            for conv_id, user_message, ai_message in conversations:\n                # Create chunks for indexing\n                if user_message is not None and ai_message is not None:  # Ensure neither is None\n                    self._process_conversation(conv_id, user_message, ai_message, conn)\n            \n            conn.close()\n            print(\"Finished processing conversation embeddings\")\n        except Exception as e:\n            print(f\"Error loading conversations: {e}\")\n    \n    def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection):\n        \"\"\"Process a conversation and store its embeddings.\"\"\"\n        try:\n            cursor = conn.cursor()\n            \n            # Combine user and AI messages\n            full_text = f\"User: {user_message}\\nAI: {ai_message}\"\n            \n            # For simplicity, we're using the entire message as a chunk\n            # In a more sophisticated system, you might split long messages into smaller chunks\n            chunk_id = f\"conv_{conv_id}\"\n            \n            # Check if we already have this embedding cached\n            if chunk_id not in self.embedding_cache:\n                # Generate embedding\n                embedding = self.model.encode(full_text)\n                self.embedding_cache[chunk_id] = embedding\n                \n                # Save to cache\n                self._save_embedding_to_cache(chunk_id, embedding)\n            else:\n                embedding = self.embedding_cache[chunk_id]\n            \n            # Store reference in database\n            embedding_file = f\"{chunk_id[:2]}.json\"\n            cursor.execute(\n                \"INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)\",\n                (conv_id, full_text, embedding_file, chunk_id)\n            )\n            \n            conn.commit()\n        except Exception as e:\n            print(f\"Error processing conversation {conv_id}: {e}\")\n    \n    def add_conversation(self, user_message: str, ai_message: str) -> int:\n        \"\"\"\n        Add a new conversation to the RAG system.\n        \n        Returns:\n            The id of the newly added conversation\n        \"\"\"\n        try:\n            conn = sqlite3.connect(self.db_path)\n            cursor = conn.cursor()\n            \n            # Insert the conversation first\n            cursor.execute(\n                \"INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)\",\n                (user_message, ai_message)\n            )\n            \n            # Get the ID of the new conversation\n            conv_id = cursor.lastrowid\n            \n            # Process the conversation for embeddings\n            self._process_conversation(conv_id, user_message, ai_message, conn)\n            \n            conn.commit()\n            conn.close()\n            \n            return conv_id\n        except Exception as e:\n            print(f\"Error adding conversation: {e}\")\n            return -1\n    \n    def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]:\n        \"\"\"\n        Query the RAG system for relevant context.\n        \n        Args:\n            query_text: The query text\n            top_k: Number of top results to return\n            \n        Returns:\n            List of tuples with (text, similarity_score)\n        \"\"\"\n        if query_text is None or query_text.strip() == \"\":\n            print(\"Error: Empty query text\")\n            return []\n            \n        try:\n            # Generate query embedding\n            query_embedding = self.model.encode(query_text)\n            \n            # Find most similar conversations\n            results = self._find_similar(query_embedding, top_k)\n            \n            return results\n        except Exception as e:\n            print(f\"Error during query: {e}\")\n            return []\n    \n    def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str:\n        \"\"\"\n        Get formatted context from the RAG system.\n        \n        Args:\n            query_text: The query text\n            top_k: Number of top results to return\n            threshold: Minimum similarity score to include\n            \n        Returns:\n            String with relevant context\n        \"\"\"\n        results = self.query(query_text, top_k)\n        \n        if not results:\n            return \"\"\n        \n        # Format results\n        context_parts = []\n        for text, score in results:\n            # Only include really relevant results\n            if score < threshold:  # Threshold for relevance\n                continue\n            context_parts.append(f\"Relevance: {score:.2f}\\n{text}\")\n        \n        return \"\\n---\\n\".join(context_parts)\n    \n    def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]:\n        \"\"\"Find the most similar conversations to the query.\"\"\"\n        try:\n            conn = sqlite3.connect(self.db_path)\n            cursor = conn.cursor()\n            \n            # Check if the embeddings table exists\n            cursor.execute(\"SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'\")\n            if not cursor.fetchone():\n                print(\"Embeddings table does not exist yet\")\n                conn.close()\n                return []\n            \n            # Get all embeddings from the database\n            cursor.execute(\"SELECT id, text, embedding_file, chunk_id FROM embeddings\")\n            results = cursor.fetchall()\n            \n            if not results:\n                conn.close()\n                return []\n            \n            # Calculate similarities\n            similarities = []\n            for db_id, text, embedding_file, chunk_id in results:\n                # Get embedding from cache\n                if chunk_id in self.embedding_cache:\n                    embedding = self.embedding_cache[chunk_id]\n                else:\n                    # This should not happen, but just in case\n                    # We'll reload from the cache file\n                    cache_file = self.cache_dir / embedding_file\n                    if cache_file.exists():\n                        with open(cache_file, \"r\") as f:\n                            cache_data = json.load(f)\n                            if chunk_id in cache_data:\n                                embedding = np.array(cache_data[chunk_id])\n                                self.embedding_cache[chunk_id] = embedding\n                            else:\n                                continue\n                    else:\n                        continue\n                \n                # Calculate similarity\n                similarity = cosine_similarity(\n                    query_embedding.reshape(1, -1),\n                    embedding.reshape(1, -1)\n                )[0][0]\n                \n                similarities.append((text, similarity))\n            \n            conn.close()\n            \n            # Sort by similarity and return top_k\n            similarities.sort(key=lambda x: x[1], reverse=True)\n            return similarities[:top_k]\n        except Exception as e:\n            print(f\"Error finding similar documents: {e}\")\n            return []\n    \n    def refresh(self):\n        \"\"\"Refresh embeddings from the database.\"\"\"\n        self._load_conversations()\n\n# Example usage\nif __name__ == \"__main__\":\n    # Initialize the RAG system\n    rag = RAGSystem(\"conversations.db\")"
  },
  {
    "path": "requirements.txt",
    "content": "--extra-index-url=https://download.pytorch.org/whl/cu128\nvllm==0.8.0\ntorch==2.6.0\ntorchaudio==2.6.0\ntorchvision==0.21.0\ntokenizers==0.21.0\ntransformers==4.49.0\nhuggingface_hub==0.28.1\nmoshi==0.2.2\nsounddevice\ntorchtune==0.4.0\ntorchao==0.9.0\nbitsandbytes\npeft\nwandb\nsilero_vad\npython-multipart>=0.0.6\naiofiles>=23.1.0\nsentence-transformers>=2.2.2\nctransformers>=0.2.24\npython-multipart>=0.0.6\nsqlalchemy>=2.0.0\npydantic>=2.0.0\nfastapi>=0.95.0\nuvicorn>=0.22.0\nwebsockets>=11.0.3\njinja2>=3.0.0\nspeechbrain>=0.5.15\nmatplotlib\nwhisper-openai\nsilentcipher @ git+https://github.com/SesameAILabs/silentcipher@master\nnumpy==1.26.0"
  },
  {
    "path": "run_csm.py",
    "content": "import os\nimport torch\nimport torchaudio\nfrom huggingface_hub import hf_hub_download\nfrom generator import load_csm_1b, Segment\nfrom dataclasses import dataclass\n\n\n# Default prompts are available at https://hf.co/sesame/csm-1b\nprompt_filepath_conversational_a = hf_hub_download(\n    repo_id=\"sesame/csm-1b\",\n    filename=\"prompts/conversational_a.wav\"\n)\nprompt_filepath_conversational_b = hf_hub_download(\n    repo_id=\"sesame/csm-1b\",\n    filename=\"prompts/conversational_b.wav\"\n)\n\nSPEAKER_PROMPTS = {\n    \"conversational_a\": {\n        \"text\": (\n            \"like revising for an exam I'd have to try and like keep up the momentum because I'd \"\n            \"start really early I'd be like okay I'm gonna start revising now and then like \"\n            \"you're revising for ages and then I just like start losing steam I didn't do that \"\n            \"for the exam we had recently to be fair that was a more of a last minute scenario \"\n            \"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I \"\n            \"sort of start the day with this not like a panic but like a\"\n        ),\n        \"audio\": prompt_filepath_conversational_a\n    },\n    \"conversational_b\": {\n        \"text\": (\n            \"like a super Mario level. Like it's very like high detail. And like, once you get \"\n            \"into the park, it just like, everything looks like a computer game and they have all \"\n            \"these, like, you know, if, if there's like a, you know, like in a Mario game, they \"\n            \"will have like a question block. And if you like, you know, punch it, a coin will \"\n            \"come out. So like everyone, when they come into the park, they get like this little \"\n            \"bracelet and then you can go punching question blocks around.\"\n        ),\n        \"audio\": prompt_filepath_conversational_b\n    }\n}\n\ndef load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:\n    audio_tensor, sample_rate = torchaudio.load(audio_path)\n    audio_tensor = audio_tensor.squeeze(0)\n    # Resample is lazy so we can always call it\n    audio_tensor = torchaudio.functional.resample(\n        audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate\n    )\n    return audio_tensor\n\ndef prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:\n    audio_tensor = load_prompt_audio(audio_path, sample_rate)\n    return Segment(text=text, speaker=speaker, audio=audio_tensor)\n\ndef main():\n    # Select the best available device, skipping MPS due to float64 limitations\n    if torch.cuda.is_available():\n        device = \"cuda\"\n    else:\n        device = \"cpu\"\n    print(f\"Using device: {device}\")\n\n    # Load model\n    generator = load_csm_1b(device)\n\n    # Prepare prompts\n    prompt_a = prepare_prompt(\n        SPEAKER_PROMPTS[\"conversational_a\"][\"text\"],\n        0,\n        SPEAKER_PROMPTS[\"conversational_a\"][\"audio\"],\n        generator.sample_rate\n    )\n\n    prompt_b = prepare_prompt(\n        SPEAKER_PROMPTS[\"conversational_b\"][\"text\"],\n        1,\n        SPEAKER_PROMPTS[\"conversational_b\"][\"audio\"],\n        generator.sample_rate\n    )\n\n    # Generate conversation\n    conversation = [\n        {\"text\": \"Hey how are you doing?\", \"speaker_id\": 0},\n        {\"text\": \"Pretty good, pretty good. How about you?\", \"speaker_id\": 1},\n        {\"text\": \"I'm great! So happy to be speaking with you today.\", \"speaker_id\": 0},\n        {\"text\": \"Me too! This is some cool stuff, isn't it?\", \"speaker_id\": 1}\n    ]\n\n    # Generate each utterance\n    generated_segments = []\n    prompt_segments = [prompt_a, prompt_b]\n\n    for utterance in conversation:\n        print(f\"Generating: {utterance['text']}\")\n        audio_tensor = generator.generate(\n            text=utterance['text'],\n            speaker=utterance['speaker_id'],\n            context=prompt_segments + generated_segments,\n            max_audio_length_ms=10_000,\n        )\n        generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))\n\n    # Concatenate all generations\n    all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)\n    torchaudio.save(\n        \"full_conversation.wav\",\n        all_audio.unsqueeze(0).cpu(),\n        generator.sample_rate\n    )\n    print(\"Successfully generated full_conversation.wav\")\n\nif __name__ == \"__main__\":\n    main() "
  },
  {
    "path": "setup.py",
    "content": "import os\nimport sys\nimport subprocess\nimport logging\nimport urllib.request\nimport torch\nimport time\nimport shutil\nfrom pathlib import Path\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO,\n                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')\nlogger = logging.getLogger(__name__)\n\ndef check_requirements():\n    \"\"\"Check if all required Python packages are installed\"\"\"\n    logger.info(\"Checking requirements...\")\n    \n    requirements = [\n        \"torch\", \"torchaudio\", \"fastapi\", \"uvicorn\", \"websockets\", \"numpy\",\n        \"scikit-learn\", \"sqlalchemy\", \"pydantic\", \"jinja2\", \"whisper\",\n        \"sounddevice\", \"soundfile\", \"sentence_transformers\", \"ctransformers\"\n    ]\n    \n    missing = []\n    for req in requirements:\n        try:\n            __import__(req)\n        except ImportError:\n            missing.append(req)\n    \n    if missing:\n        logger.warning(f\"Missing required packages: {', '.join(missing)}\")\n        logger.info(\"Installing missing requirements...\")\n        subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-r\", \"requirements.txt\"])\n        logger.info(\"Requirements installed successfully\")\n    else:\n        logger.info(\"All requirements are satisfied\")\n\ndef download_vad_model():\n    \"\"\"Download the Silero VAD model using PyTorch Hub instead of direct URL\"\"\"\n    model_path = \"silero_vad.jit\"\n    \n    if os.path.exists(model_path):\n        logger.info(f\"Silero VAD model already exists at {model_path}\")\n        return\n    \n    logger.info(\"Downloading Silero VAD model using PyTorch Hub...\")\n    try:\n        # Use torch.hub to download the model instead of direct URL\n        torch.hub.set_dir(\"./models\")\n        model, utils = torch.hub.load(repo_or_dir=\"snakers4/silero-vad\",\n                                     model=\"silero_vad\",\n                                     force_reload=True,\n                                     onnx=False)\n        \n        # Save the model\n        torch.jit.save(model, model_path)\n        logger.info(f\"Model downloaded and saved to {model_path}\")\n        \n    except Exception as e:\n        logger.error(f\"Failed to download Silero VAD model using PyTorch Hub: {e}\")\n        logger.info(\"Falling back to energy-based VAD - the system will still work but with simpler voice detection\")\n\ndef download_embedding_models():\n    \"\"\"Download the sentence transformer models for RAG\"\"\"\n    logger.info(\"Setting up sentence transformer models...\")\n    \n    try:\n        from sentence_transformers import SentenceTransformer\n        \n        # Download lightweight model for embeddings\n        logger.info(\"Downloading embedding models (this may take a few minutes)...\")\n        models = [\n            \"all-MiniLM-L6-v2\",  # Fast\n            \"all-mpnet-base-v2\",  # Balanced\n            \"multi-qa-mpnet-base-dot-v1\"  # Best for Q&A\n        ]\n        \n        for model_name in models:\n            logger.info(f\"Setting up model: {model_name}\")\n            _ = SentenceTransformer(model_name)\n            logger.info(f\"Model {model_name} is ready\")\n            \n    except Exception as e:\n        logger.error(f\"Failed to download embedding models: {e}\")\n        logger.error(\"Please try running the script again or download models manually\")\n\ndef setup_directories():\n    \"\"\"Create necessary directories for the application\"\"\"\n    directories = [\"static\", \"responses\", \"embeddings_cache\", \"templates\"]\n    \n    for directory in directories:\n        os.makedirs(directory, exist_ok=True)\n        logger.info(f\"Directory {directory} is ready\")\n    \n    # Create template redirect file\n    template_dir = Path(\"templates\")\n    index_html = template_dir / \"index.html\"\n    \n    with open(index_html, \"w\") as f:\n        f.write(\"\"\"\n<!DOCTYPE html>\n<html>\n<head>\n    <meta http-equiv=\"refresh\" content=\"0; url=/static/index.html\">\n</head>\n<body>\n    <p>Redirecting to <a href=\"/static/index.html\">AI Companion</a>...</p>\n</body>\n</html>\n        \"\"\")\n    logger.info(\"Created index template for redirection\")\n\ndef setup_database():\n    \"\"\"Initialize the SQLite database\"\"\"\n    logger.info(\"Setting up database...\")\n    \n    try:\n        from sqlalchemy import create_engine, Column, Integer, String, Text\n        from sqlalchemy.ext.declarative import declarative_base\n        from sqlalchemy.orm import sessionmaker\n        \n        Base = declarative_base()\n        engine = create_engine(\"sqlite:///companion.db\")\n        \n        class Conversation(Base):\n            __tablename__ = \"conversations\"\n            id = Column(Integer, primary_key=True, index=True)\n            session_id = Column(String, index=True)\n            timestamp = Column(String)\n            user_message = Column(Text)\n            ai_message = Column(Text)\n            audio_path = Column(String)\n        \n        # Create tables\n        Base.metadata.create_all(bind=engine)\n        logger.info(\"Database initialized successfully\")\n        \n    except Exception as e:\n        logger.error(f\"Failed to set up database: {e}\")\n\ndef check_cuda():\n    \"\"\"Check if CUDA is available for PyTorch\"\"\"\n    if torch.cuda.is_available():\n        device_name = torch.cuda.get_device_name(0)\n        logger.info(f\"CUDA is available: {device_name}\")\n        logger.info(f\"CUDA version: {torch.version.cuda}\")\n    else:\n        logger.warning(\"CUDA is not available. The application will run on CPU, which may be very slow\")\n        logger.warning(\"For optimal performance, a CUDA-capable GPU is recommended\")\n\ndef main():\n    \"\"\"Main setup function\"\"\"\n    logger.info(\"Starting AI Companion setup...\")\n    \n    # Check for CUDA availability\n    check_cuda()\n    \n    # Check and install requirements\n    #check_requirements()\n    \n    # Create directories\n    setup_directories()\n    \n    # Set up database\n    setup_database()\n    \n    # Download models\n    download_vad_model()\n    download_embedding_models()\n    \n    logger.info(\"Setup completed successfully!\")\n    logger.info(\"You can now start the application with:\")\n    logger.info(\"   python main.py\")\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "static/app.js",
    "content": "let ws;\nlet micAnalyser, micContext, micSource, micStream;\nlet outputAnalyser, outputAudioCtx;\nlet lastConfig = null;\nlet isLoading = false;\n\ndocument.addEventListener('DOMContentLoaded', async () => {\n  await populateAudioDevices();\n\n  ws = new WebSocket(`ws://${window.location.host}/ws`);\n\n  ws.onopen = () => {\n    console.log(\"WebSocket connected, requesting saved config...\");\n    ws.send(JSON.stringify({ type: \"request_saved_config\" }));\n  };\n\n  ws.onmessage = async (event) => {\n    const data = JSON.parse(event.data);\n\n    if (data.type === \"saved_config\" && data.config) {\n      document.getElementById('systemPrompt').value = data.config.system_prompt || \"\";\n      document.getElementById('modelPath').value = data.config.model_path || \"\";\n      document.getElementById('llmPath').value = data.config.llm_path || \"\";\n      document.getElementById('referenceAudio').value = data.config.reference_audio_path || \"\";\n      document.getElementById('referenceText').value = data.config.reference_text || \"\";\n      document.getElementById('referenceAudio2').value = data.config.reference_audio_path2 || \"\";\n      document.getElementById('referenceText2').value  = data.config.reference_text2  || \"\";\n      document.getElementById('referenceAudio3').value = data.config.reference_audio_path3 || \"\";\n      document.getElementById('referenceText3').value  = data.config.reference_text3  || \"\";\n\n      setTimeout(() => {\n        if (data.config.mic_id) document.getElementById('micSelect').value = data.config.mic_id;\n        if (data.config.output_id) document.getElementById('outputSelect').value = data.config.output_id;\n      }, 500);\n    }\n\n    if (data.type === \"status\") {\n      if (data.message.includes(\"Models initialized\")) {\n        console.log(\"Model initialization confirmed. Redirecting...\");\n      \n        // Save config again just to be safe\n        localStorage.setItem('ai_config', JSON.stringify(lastConfig));\n      \n        // Close WebSocket before navigating\n        if (ws && ws.readyState === WebSocket.OPEN) {\n          ws.close();\n        }\n      \n        // Wait briefly to let server clean up, then redirect\n        setTimeout(() => {\n          window.location.href = \"/chat\";\n        }, 100);\n      } else if (data.message.includes(\"Initializing\") || data.message.includes(\"Loading\")) {\n        // Show that models are being loaded\n        showLoading(true, data.message);\n      }\n    }\n  };\n\n  document.getElementById('testMicBtn').addEventListener('click', async () => {\n    const micId = getSelectedMic();\n    micStream = await navigator.mediaDevices.getUserMedia({ audio: { deviceId: micId } });\n\n    micContext = new AudioContext();\n    micSource = micContext.createMediaStreamSource(micStream);\n    micAnalyser = micContext.createAnalyser();\n    micSource.connect(micAnalyser);\n    visualizeMic(micAnalyser, 'micCanvas');\n\n    const recorder = new MediaRecorder(micStream);\n    const chunks = [];\n\n    recorder.ondataavailable = e => {\n      if (e.data.size > 0) chunks.push(e.data);\n    };\n\n    recorder.onstop = () => {\n      const blob = new Blob(chunks, { type: 'audio/webm' });\n      const url = URL.createObjectURL(blob);\n      const audio = new Audio(url);\n      audio.play();\n\n      micStream.getTracks().forEach(track => track.stop());\n      micContext.close();\n    };\n\n    recorder.start();\n    setTimeout(() => recorder.stop(), 3000);\n  });\n\n  document.getElementById('testOutputBtn').addEventListener('click', () => {\n    const audio = new Audio('/static/test.mp3');\n    audio.setSinkId(getSelectedOutput()).then(() => {\n      outputAudioCtx = new AudioContext();\n      const outputSource = outputAudioCtx.createMediaElementSource(audio);\n      outputAnalyser = outputAudioCtx.createAnalyser();\n      outputSource.connect(outputAnalyser);\n      outputAnalyser.connect(outputAudioCtx.destination);\n      visualizeMic(outputAnalyser, 'outputCanvas');\n      audio.play();\n    }).catch(err => {\n      console.warn(\"Failed to route output:\", err);\n    });\n  });\n\n  document.getElementById('saveAndStart').addEventListener('click', () => {\n    lastConfig = {\n      system_prompt: document.getElementById('systemPrompt').value,\n      model_path: document.getElementById('modelPath').value,\n      llm_path: document.getElementById('llmPath').value,\n      reference_audio_path: document.getElementById('referenceAudio').value,\n      reference_text: document.getElementById('referenceText').value,\n      reference_audio_path2: document.getElementById('referenceAudio2').value,\n      reference_text2: document.getElementById('referenceText2').value,\n      reference_audio_path3: document.getElementById('referenceAudio3').value,\n      reference_text3: document.getElementById('referenceText3').value,\n      mic_id: getSelectedMic(),\n      output_id: getSelectedOutput(),\n    };\n    console.log(\"Sending config to backend...\");\n    console.log(lastConfig)\n    showLoading(true, \"Initializing models, please wait...\");\n    ws.send(JSON.stringify({ type: \"config\", config: lastConfig }));\n    // we wait for the backend to reply with model status before navigating\n  });\n});\n\nfunction showLoading(show, message) {\n  const saveButton = document.getElementById('saveAndStart');\n  const loadingContainer = document.getElementById('loadingContainer');\n  const loadingSpinner = document.getElementById('loadingSpinner');\n  const loadingText = document.getElementById('loadingText');\n  \n  isLoading = show;\n  \n  if (show) {\n    saveButton.disabled = true;\n    saveButton.classList.add('opacity-50', 'cursor-not-allowed');\n    saveButton.classList.remove('hover:bg-green-500');\n    loadingContainer.classList.remove('hidden');\n    loadingSpinner.style.display = 'block';\n    if (message) {\n      loadingText.textContent = message;\n    }\n  } else {\n    saveButton.disabled = false;\n    saveButton.classList.remove('opacity-50', 'cursor-not-allowed');\n    saveButton.classList.add('hover:bg-green-500');\n    loadingContainer.classList.add('hidden');\n    loadingSpinner.style.display = 'none';\n  }\n}\n\nfunction getSelectedMic() {\n  return document.getElementById('micSelect').value;\n}\n\nfunction getSelectedOutput() {\n  return document.getElementById('outputSelect').value;\n}\n\nasync function populateAudioDevices() {\n  try {\n    await navigator.mediaDevices.getUserMedia({ audio: true });\n  } catch (err) {\n    console.warn(\"Microphone permission denied or not granted.\");\n    return;\n  }\n\n  const devices = await navigator.mediaDevices.enumerateDevices();\n  const micSelect = document.getElementById('micSelect');\n  const outputSelect = document.getElementById('outputSelect');\n\n  micSelect.innerHTML = '';\n  outputSelect.innerHTML = '';\n\n  devices.forEach(device => {\n    const option = new Option(device.label || `${device.kind}`, device.deviceId);\n    if (device.kind === 'audioinput') micSelect.add(option.cloneNode(true));\n    if (device.kind === 'audiooutput') {\n      outputSelect.add(option.cloneNode(true));\n    }\n  });\n\n  if (micSelect.options.length === 0) {\n    micSelect.add(new Option(\"No mic devices found\", \"\"));\n  }\n  if (outputSelect.options.length === 0) {\n    outputSelect.add(new Option(\"Default Output\", \"default\"));\n  }\n}\n\nfunction visualizeMic(analyser, canvasId) {\n  const canvas = document.getElementById(canvasId);\n  const ctx = canvas.getContext(\"2d\");\n  analyser.fftSize = 256;\n  const bufferLength = analyser.frequencyBinCount;\n  const dataArray = new Uint8Array(bufferLength);\n\n  function draw() {\n    requestAnimationFrame(draw);\n    analyser.getByteFrequencyData(dataArray);\n    ctx.fillStyle = \"#1f2937\";\n    ctx.fillRect(0, 0, canvas.width, canvas.height);\n    const barWidth = canvas.width / bufferLength;\n    for (let i = 0; i < bufferLength; i++) {\n      const barHeight = dataArray[i];\n      ctx.fillStyle = \"#4ade80\";\n      ctx.fillRect(i * barWidth, canvas.height - barHeight / 2, barWidth - 1, barHeight / 2);\n    }\n  }\n  draw();\n}"
  },
  {
    "path": "static/chat.js",
    "content": "let ws;\nlet sessionStartTime = null;\nlet messageCount = 0;\nlet audioLevelsChart = null;\nlet isRecording = false;\nlet isAudioCurrentlyPlaying = false;\nlet configSaved = false;\nlet currentAudioSource = null; \nlet interruptRequested = false; \nlet interruptInProgress = false;\nlet audioContext = null;\nlet lastSeenGenId = 0;\nlet reconnecting = false;\nlet reconnectAttempts = 0;\nlet maxReconnectAttempts = 10;\n\nconst SESSION_ID = \"default\";\nconsole.log(\"chat.js loaded\");\n\nlet micStream;\nlet selectedMicId = null;\nlet selectedOutputId = null;\n\nlet audioPlaybackQueue = [];\nlet audioDataHistory = [];\nlet micAnalyser, micContext;\nlet activeGenId = 0;\n\nfunction createPermanentVoiceCircle() {\n  if (document.getElementById('voice-circle')) return;\n  const style = document.createElement('style');\n  style.textContent = `\n    #voice-circle{\n      position:fixed;top:50%;left:50%;\n      width:180px;height:180px;border-radius:50%;\n      background:rgba(99,102,241,.20);\n      transform:translate(-50%,-50%) scale(var(--dynamic-scale,1));\n      pointer-events:none;z-index:50;\n      transition:background-color .35s ease;\n    }\n    #voice-circle.active{\n      animation:pulse-circle 2s infinite alternate ease-in-out;\n    }\n    @keyframes pulse-circle{\n      0%{background:rgba(99,102,241,.55)}\n      100%{background:rgba(99,102,241,.20)}\n    }`;\n  document.head.appendChild(style);\n\n  const c = document.createElement('div');\n  c.id='voice-circle';\n  document.body.appendChild(c);\n  console.log(\"Created permanent voice circle\");\n}\n\nfunction showVoiceCircle() {\n  const c=document.getElementById('voice-circle')||createPermanentVoiceCircle();\n  c.classList.add('active');\n}\n\nfunction hideVoiceCircle() {\n  const c=document.getElementById('voice-circle');\n  if (c){ c.classList.remove('active'); c.style.setProperty('--dynamic-scale',1); }\n}\n\nfunction showNotification(msg, type='info'){\n  const n=document.createElement('div');\n  n.className=`fixed bottom-4 right-4 px-4 py-3 rounded-lg shadow-lg z-50\n               ${type==='success'?'bg-green-600':\n                 type==='error'  ?'bg-red-600':'bg-indigo-600'}`;\n  n.textContent=msg;\n  document.body.appendChild(n);\n  setTimeout(()=>{n.classList.add('opacity-0');\n                  setTimeout(()=>n.remove(),500)},3000);\n}\n\nfunction addMessageToConversation(sender,text){\n  const pane=document.getElementById('conversationHistory');\n  if(!pane) return;\n  const box=document.createElement('div');\n  box.className=`p-3 mb-3 rounded-lg text-sm ${\n            sender==='user'?'bg-gray-800 ml-2':'bg-indigo-900 mr-2'}`;\n  box.innerHTML=`\n      <div class=\"flex items-start mb-2\">\n        <div class=\"w-6 h-6 rounded-full flex items-center justify-center\n             ${sender==='user'?'bg-gray-300 text-gray-800':'bg-indigo-500 text-white'}\">\n             ${sender==='user'?'U':'AI'}\n        </div>\n        <span class=\"text-xs text-gray-400 ml-2\">${new Date().toLocaleTimeString()}</span>\n      </div>\n      <div class=\"text-white mt-1 text-sm\">${text\n            .replace(/&/g,'&amp;').replace(/</g,'&lt;')\n            .replace(/\\*\\*(.*?)\\*\\*/g,'<strong>$1</strong>')\n            .replace(/\\*(.*?)\\*/g,'<em>$1</em>')\n            .replace(/```([^`]+)```/g,'<pre><code>$1</code></pre>')\n            .replace(/`([^`]+)`/g,'<code>$1</code>')\n            .replace(/\\n/g,'<br>')}</div>`;\n  pane.appendChild(box);\n  pane.scrollTop=pane.scrollHeight;\n}\n\nfunction connectWebSocket() {\n  if (reconnecting && reconnectAttempts >= maxReconnectAttempts) {\n    console.error(\"Maximum reconnect attempts reached. Please refresh the page.\");\n    showNotification(\"Connection lost. Please refresh the page.\", \"error\");\n    return;\n  }\n\n  if (ws && ws.readyState !== WebSocket.CLOSED && ws.readyState !== WebSocket.CLOSING) {\n    try {\n      ws.close();\n    } catch (e) {\n      console.warn(\"Error closing existing WebSocket:\", e);\n    }\n  }\n\n  const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';\n  ws = new WebSocket(`${proto}//${location.host}/ws`);\n  window.ws = ws;\n\n  const connLbl = document.getElementById('connectionStatus');\n  if (connLbl) {\n    connLbl.textContent = reconnecting ? 'Reconnecting…' : 'Connecting…';\n    connLbl.className = 'text-yellow-500';\n  }\n\n  ws.onopen = () => {\n    if (connLbl) {\n      connLbl.textContent = 'Connected';\n      connLbl.className = 'text-green-500';\n    }\n    \n    reconnecting = false;\n    reconnectAttempts = 0;\n    \n    ws.send(JSON.stringify({type: 'request_saved_config'}));\n    \n    if (!reconnecting) {\n      addMessageToConversation('ai', 'WebSocket connected. Ready for voice or text.');\n    } else {\n      showNotification(\"Reconnected successfully\", \"success\");\n    }\n  };\n\n  ws.onclose = (event) => {\n    console.log(\"WebSocket closed with code:\", event.code);\n    if (connLbl) {\n      connLbl.textContent = 'Disconnected';\n      connLbl.className = 'text-red-500';\n    }\n\n    // Clear audio state on disconnection\n    clearAudioPlayback();\n    \n    // Don't auto-reconnect if this was a normal closure\n    if (event.code !== 1000 && event.code !== 1001) {\n      reconnecting = true;\n      reconnectAttempts++;\n      \n      const delay = Math.min(1000 * Math.pow(1.5, reconnectAttempts), 1000);\n      console.log(`Reconnecting in ${delay}ms (attempt ${reconnectAttempts})`);\n      \n      setTimeout(connectWebSocket, delay);\n    }\n  };\n\n  ws.onerror = (error) => {\n    console.error(\"WebSocket error:\", error);\n    if (connLbl) {\n      connLbl.textContent = 'Error';\n      connLbl.className = 'text-red-500';\n    }\n  };\n\n  ws.onmessage = (e) => {\n    try {\n      const data = JSON.parse(e.data);\n      handleWebSocketMessage(data);\n    } catch (err) {\n      console.error(\"Error handling WebSocket message:\", err);\n    }\n  };\n}\n\nfunction sendTextMessage(txt) {\n  if (!txt.trim()) return;\n  \n  if (!ws || ws.readyState !== WebSocket.OPEN) {\n    showNotification(\"Not connected\", \"error\");\n    return;\n  }\n  \n  console.log(\"Force clearing all audio state before sending text message\");\n  \n  // Stop any playing audio\n  if (isAudioCurrentlyPlaying) {\n    if (currentAudioSource) {\n      try {\n        if (currentAudioSource.disconnect) currentAudioSource.disconnect();\n        if (currentAudioSource.stop) currentAudioSource.stop(0);\n      } catch (e) {\n        console.warn(\"Error stopping audio:\", e);\n      }\n      currentAudioSource = null;\n    }\n    isAudioCurrentlyPlaying = false;\n  }\n  \n  // Clear all flags and queues\n  interruptRequested = false;\n  interruptInProgress = false;\n  activeGenId = 0;\n  audioPlaybackQueue = [];\n  \n  // Always force interruption to be absolutely sure\n  if (ws && ws.readyState === WebSocket.OPEN) {\n    try {\n      ws.send(JSON.stringify({type: 'interrupt', immediate: true}));\n    } catch (e) {\n      console.warn(\"Error sending interrupt:\", e);\n    }\n  }\n  \n  // Wait a bit before sending the actual message\n  setTimeout(() => {\n    try {\n      // Show visual feedback\n      showVoiceCircle();\n      \n      // Send the message\n      ws.send(JSON.stringify({\n        type: 'text_message',\n        text: txt,\n        session_id: SESSION_ID\n      }));\n      \n      const cnt = document.getElementById('messageCount');\n      if (cnt) cnt.textContent = ++messageCount;\n      \n      document.getElementById('textInput').value = '';\n      \n      console.log(\"Text message sent successfully\");\n    } catch (error) {\n      console.error(\"Error sending message:\", error);\n      showNotification(\"Error sending message\", \"error\");\n    }\n  }, 300);\n}\n\n// Reset all audio state to ensure clean state for new interactions\nfunction resetAudioState() {\n  console.log(\"Resetting audio state\");\n  \n  // Clear any stale generation information\n  activeGenId = 0;\n  lastSeenGenId = 0;\n  \n  // Clear any remaining flags\n  interruptRequested = false;\n  interruptInProgress = false;\n  \n  // Make sure we don't have any playing audio\n  if (isAudioCurrentlyPlaying) {\n    clearAudioPlayback();\n  }\n  \n  // Clear any queued audio\n  audioPlaybackQueue = [];\n}\n\nfunction clearAudioPlayback() {\n  console.log(\"FORCEFULLY CLEARING AUDIO PLAYBACK\");\n  \n  interruptRequested = true;\n  interruptInProgress = true;\n  \n  try {\n    // Empty the queue first - do this before stopping current source\n    console.log(`Clearing queue with ${audioPlaybackQueue.length} items`);\n    audioPlaybackQueue = [];\n    \n    activeGenId = 0;\n    \n    // Stop any currently playing audio\n    if (currentAudioSource) {\n      console.log(\"Stopping active audio source\");\n      \n      try {\n        if (currentAudioSource.disconnect) {\n          currentAudioSource.disconnect();\n        }\n      } catch (e) {\n        console.warn(\"Error disconnecting audio source:\", e);\n      }\n      \n      try {\n        if (currentAudioSource.stop) {\n          currentAudioSource.stop(0);\n        }\n      } catch (e) {\n        console.warn(\"Error stopping audio source:\", e);\n      }\n      \n      currentAudioSource = null;\n    }\n    \n    try {\n      if (audioContext) {\n        const oldContext = audioContext;\n        audioContext = new (window.AudioContext || window.webkitAudioContext)();\n        window.audioContext = audioContext;\n        \n        try {\n          oldContext.close();\n        } catch (closeError) {\n          console.warn(\"Error closing old audio context:\", closeError);\n        }\n      } else {\n        audioContext = new (window.AudioContext || window.webkitAudioContext)();\n        window.audioContext = audioContext;\n      }\n    } catch (contextError) {\n      console.error(\"Error recreating audio context:\", contextError);\n    }\n  } catch (err) {\n    console.error(\"Error clearing audio:\", err);\n  }\n  \n  // Reset state\n  isAudioCurrentlyPlaying = false;\n  hideVoiceCircle();\n  \n  console.log(\"Audio playback cleared successfully\");\n  \n  // After a short delay, reset the interrupt flags to accept new audio\n  setTimeout(() => {\n    interruptInProgress = false;\n    // Keep interruptRequested true until we get a new generation\n  }, 300);\n}\n\n\n// Handle interruption request from user\nfunction requestInterrupt() {\n  console.log(\"User requested interruption\");\n  \n  if (interruptInProgress) {\n    console.log(\"Interrupt already in progress - force clearing again\");\n    clearAudioPlayback();\n    return false;\n  }\n  \n  // Set the flags immediately\n  interruptRequested = true;\n  interruptInProgress = true;\n  \n  // Show visual feedback\n  showNotification(\"Interrupting...\", \"info\");\n  \n  // Force clear all audio immediately on client side\n  clearAudioPlayback();\n  \n  // Show visual feedback for the button\n  const interruptBtn = document.getElementById('interruptBtn');\n  if (interruptBtn) {\n    interruptBtn.classList.add('bg-red-800');\n    setTimeout(() => {\n      interruptBtn.classList.remove('bg-red-800');\n    }, 300);\n  }\n  \n  // Then notify the server\n  if (ws && ws.readyState === WebSocket.OPEN) {\n    console.log(\"Sending interrupt request to server\");\n    try {\n      ws.send(JSON.stringify({\n        type: 'interrupt',\n        immediate: true\n      }));\n    } catch (error) {\n      console.error(\"Error sending interrupt request:\", error);\n    }\n    \n    // Set a timeout to reset interrupt flags if we don't get server confirmation\n    setTimeout(() => {\n      if (interruptInProgress) {\n        console.log(\"No interrupt confirmation received from server, resetting state\");\n        interruptInProgress = false;\n      }\n    }, 2000);\n    \n    return true;\n  } else {\n    console.warn(\"WebSocket not available for interrupt request\");\n    // Reset flag after brief delay if we couldn't send to server\n    setTimeout(() => {\n      interruptInProgress = false;\n    }, 500);\n    return false;\n  }\n}\n\nfunction handleWebSocketMessage(d) {\n  console.log(\"Received message:\", d.type, d);\n  \n  switch(d.type) {\n    case 'transcription':\n      addMessageToConversation('user', d.text);\n      showVoiceCircle();\n      break;\n      \n    case 'response':\n      addMessageToConversation('ai', d.text);\n      showVoiceCircle();\n      \n      console.log(\"NEW RESPONSE RECEIVED - FORCE RESETTING ALL AUDIO STATE\");\n      \n      if (isAudioCurrentlyPlaying) {\n        if (currentAudioSource) {\n          try {\n            if (currentAudioSource.disconnect) currentAudioSource.disconnect();\n            if (currentAudioSource.stop) currentAudioSource.stop(0);\n          } catch (e) {\n            console.warn(\"Error stopping current audio:\", e);\n          }\n          currentAudioSource = null;\n        }\n        isAudioCurrentlyPlaying = false;\n      }\n      \n      interruptRequested = false;\n      interruptInProgress = false;\n      \n      activeGenId = 0;\n      \n      audioPlaybackQueue = [];\n      \n      try {\n        if (audioContext) {\n          if (audioContext.state === 'suspended') {\n            audioContext.resume().catch(e => console.warn(\"Error resuming audio context:\", e));\n          }\n        } else {\n          audioContext = new (window.AudioContext || window.webkitAudioContext)();\n          window.audioContext = audioContext;\n        }\n      } catch (e) {\n        console.warn(\"Error with audio context:\", e);\n        audioContext = new (window.AudioContext || window.webkitAudioContext)();\n        window.audioContext = audioContext;\n      }\n      \n      console.log(\"Audio state fully reset and ready for new audio\");\n      break;\n      \n    case 'audio_chunk':\n      console.log(\"Audio chunk received, flags:\", \n                 \"interruptRequested:\", interruptRequested, \n                 \"interruptInProgress:\", interruptInProgress,\n                 \"genId:\", d.gen_id,\n                 \"activeGenId:\", activeGenId);\n      \n      if (!isAudioCurrentlyPlaying && activeGenId === 0) {\n        console.log(\"FIRST AUDIO CHUNK - FORCING FLAGS RESET\");\n        interruptRequested = false;\n        interruptInProgress = false;\n      }\n      \n      // Don't queue new audio if an interrupt was requested\n      if (interruptRequested || interruptInProgress) {\n        console.log(\"Interrupt active - ignoring new audio chunk\");\n        return;\n      }\n      \n      // Set active generation ID on first chunk\n      if (activeGenId === 0) {\n        activeGenId = d.gen_id || 1;\n        console.log(\"!!! Setting activeGenId to:\", activeGenId);\n      }\n      \n      // Only accept chunks that match our active generation\n      if ((d.gen_id === activeGenId) || (activeGenId === 0)) {\n        queueAudioForPlayback(d.audio, d.sample_rate, d.gen_id || 0);\n        showVoiceCircle();\n      } else {\n        console.log(`Ignored stale chunk - current gen: ${activeGenId}, received: ${d.gen_id}`);\n      }\n      break;\n      \n    case 'audio_status':\n      console.log(\"Audio status update:\", d.status);\n      \n      if (d.status === 'generating') {\n        console.log(\"GOT GENERATING STATUS - IMMEDIATELY CLEARING ALL INTERRUPT FLAGS\");\n        interruptRequested = false;\n        interruptInProgress = false;\n        \n        // Capture the generation ID for new generations\n        if (d.gen_id) {\n          console.log(`New generation starting with ID: ${d.gen_id}`);\n          activeGenId = d.gen_id;\n        }\n        \n        showVoiceCircle();\n      } \n      else if (d.status === 'complete') {\n        console.log(\"Audio generation complete\");\n        if (!d.gen_id || d.gen_id === activeGenId) {\n          activeGenId = 0; // Reset for next generation\n        }\n        if (!isAudioCurrentlyPlaying) {\n          hideVoiceCircle();\n        }\n      } \n      else if (d.status === 'interrupted' || d.status === 'interrupt_acknowledged') {\n        console.log(\"Server confirmed interrupt - clearing audio\");\n        clearAudioPlayback();\n        \n        setTimeout(() => {\n          console.log(\"Resetting interrupt flags after server confirmation\");\n          interruptRequested = false;\n          interruptInProgress = false;\n        }, 300);\n      }\n      break;\n      \n    case 'status':\n      if (d.message === 'Thinking...') {\n        showVoiceCircle();\n        \n        interruptRequested = false;\n        interruptInProgress = false;\n        activeGenId = 0;\n      }\n      break;\n      \n    case 'error':\n      showNotification(d.message, 'error');\n      hideVoiceCircle();\n      break;\n      \n    case 'vad_status':\n      if (d.status === 'speech_started') {\n        console.log(`[VAD] speech_started | should_interrupt=${d.should_interrupt}`);\n\n        if (d.should_interrupt && isAudioCurrentlyPlaying) {\n          console.log('[VAD] confirmed – sending interrupt');\n          requestInterrupt();\n        } else {\n          console.log('[VAD] ignored (echo / early AI audio)');\n        }\n      }\n      break;\n  }\n}\n\nfunction queueAudioForPlayback(arr, sr, genId = 0) {\n  if (activeGenId !== 0 && genId !== activeGenId) {\n    console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`);\n    return;\n  }\n  \n  // Don't queue if interrupting\n  if (interruptRequested || interruptInProgress) {\n    console.log(\"Interrupt active - skipping audio chunk\");\n    return;\n  }\n  \n  console.log(\"Queueing audio chunk for playback\");\n  audioPlaybackQueue.push({arr, sr, genId});\n  \n  if (!isAudioCurrentlyPlaying) {\n    console.log(\"▶Starting audio playback\");\n    processAudioPlaybackQueue();\n  }\n}\n\nfunction queueAudioForPlayback(arr, sr, genId = 0) {\n  // Extra logging for the first audio chunk\n  if (!isAudioCurrentlyPlaying) {\n    console.log(\"Queueing first audio chunk\", \n               \"interruptRequested:\", interruptRequested, \n               \"interruptInProgress:\", interruptInProgress);\n  }\n  \n  if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length === 0) {\n    console.log(\"First audio chunk - forcing clear of interrupt flags\");\n    interruptRequested = false;\n    interruptInProgress = false;\n  }\n  \n  // Don't queue audio from a different generation than our active one\n  if (activeGenId !== 0 && genId !== activeGenId) {\n    console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`);\n    return;\n  }\n  \n  // Don't queue if interrupting - BUT CHECK AGAIN THAT FLAGS ARE VALID\n  if (interruptRequested || interruptInProgress) {\n    console.log(\"Interrupt active - skipping audio chunk\");\n    return;\n  }\n  \n  console.log(\"Queueing audio chunk for playback\");\n  audioPlaybackQueue.push({arr, sr, genId});\n  \n  if (!isAudioCurrentlyPlaying) {\n    console.log(\"STARTING AUDIO PLAYBACK - FIRST CHUNK\");\n    processAudioPlaybackQueue();\n  }\n}\n\n\n// Modified to ensure first audio actually plays\nfunction processAudioPlaybackQueue() {\n  if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length > 0) {\n    console.log(\"Starting first audio chunk - force clearing interrupt flags\");\n    interruptRequested = false;\n    interruptInProgress = false;\n  }\n  \n  // Double-check interrupt flags AFTER clearling them\n  if (interruptRequested || interruptInProgress) {\n    console.log(\"Interrupt active - not processing audio queue\");\n    isAudioCurrentlyPlaying = false;\n    hideVoiceCircle();\n    return;\n  }\n  \n  // Check if queue is empty\n  if (!audioPlaybackQueue.length) {\n    console.log(\"📭 Audio queue empty, stopping playback\");\n    isAudioCurrentlyPlaying = false;\n    hideVoiceCircle();\n    currentAudioSource = null;\n    return;\n  }\n  \n  // Enable the interrupt button when audio is playing\n  const interruptBtn = document.getElementById('interruptBtn');\n  if (interruptBtn) {\n    interruptBtn.disabled = false;\n    interruptBtn.classList.remove('opacity-50');\n  }\n  \n  console.log(\"Processing next audio chunk\");\n  isAudioCurrentlyPlaying = true;\n  \n  // Get the genId from the chunk\n  const {arr, sr, genId} = audioPlaybackQueue.shift();\n  \n  // Skip if it's a stale chunk\n  if (activeGenId !== 0 && genId !== activeGenId) {\n    console.log(`Skipping stale chunk playback (gen ${genId} vs active ${activeGenId})`);\n    processAudioPlaybackQueue(); // Continue with next chunk\n    return;\n  }\n  \n  playAudioChunk(arr, sr)\n    .then(() => {\n      // Check interrupt status again after playback\n      if (!interruptRequested && !interruptInProgress) {\n        processAudioPlaybackQueue();\n      } else {\n        console.log(\"interrupt active - stopping queue processing\");\n        isAudioCurrentlyPlaying = false;\n        hideVoiceCircle();\n      }\n    })\n    .catch(err => {\n      console.error(\"Error in audio playback:\", err);\n      isAudioCurrentlyPlaying = false;\n      hideVoiceCircle();\n      \n      // Try to continue with next chunk despite errors\n      setTimeout(() => {\n        if (audioPlaybackQueue.length > 0 && !interruptRequested) {\n          processAudioPlaybackQueue();\n        }\n      }, 200);\n    });\n}\n\nasync function playAudioChunk(audioArr, sampleRate) {\n  // Skip playback if interrupt was requested\n  if (interruptRequested || interruptInProgress) {\n    console.log(\"Interrupt active - not playing audio chunk\");\n    return Promise.resolve();\n  }\n  \n  try {\n    // Ensure we have a valid audio context\n    if (!audioContext) {\n      audioContext = new (window.AudioContext || window.webkitAudioContext)();\n      window.audioContext = audioContext;\n    }\n    \n    // Make sure context is resumed\n    if (audioContext.state === 'suspended') {\n      await audioContext.resume();\n    }\n    \n    const buf = audioContext.createBuffer(1, audioArr.length, sampleRate);\n    buf.copyToChannel(new Float32Array(audioArr), 0);\n    \n    const src = audioContext.createBufferSource();\n    src.buffer = buf;\n    \n    // Store reference to current source for potential interruption\n    currentAudioSource = src;\n    \n    const an = audioContext.createAnalyser(); \n    an.fftSize = 256;\n    src.connect(an); \n    an.connect(audioContext.destination); \n    src.start();\n    \n    console.log(\"🎵 Started playing audio chunk\");\n\n    const arr = new Uint8Array(an.frequencyBinCount);\n    const circle = document.getElementById('voice-circle');\n    \n    // Animation function that respects interruption\n    function pump() {\n      // Stop animation if source is no longer current or interrupt requested\n      if (src !== currentAudioSource || interruptRequested || interruptInProgress) {\n        return;\n      }\n      \n      try {\n        an.getByteFrequencyData(arr);\n        const avg = arr.reduce((a,b) => a+b, 0) / arr.length;\n        if (circle) {\n          circle.style.setProperty('--dynamic-scale', (1+avg/255*1.5).toFixed(3));\n        }\n      } catch (e) {\n        console.warn(\"Error in animation pump:\", e);\n        return;\n      }\n      \n      if (src.playbackState !== src.FINISHED_STATE) {\n        requestAnimationFrame(pump);\n      }\n    }\n    pump();\n    \n    return new Promise(resolve => {\n      src.onended = () => {\n        // Only resolve if this is still the current source and no interrupt\n        if (src === currentAudioSource && !interruptRequested && !interruptInProgress) {\n          resolve();\n        } else {\n          resolve(); // Still resolve to maintain chain\n        }\n      };\n    });\n  } catch (error) {\n    console.error(\"Error playing audio chunk:\", error);\n    return Promise.resolve(); // Resolve anyway to keep chain going\n  }\n}\n\nasync function startRecording() {\n  if (isRecording) return;\n  try {\n    const constraints = {\n      audio: selectedMicId ? {deviceId:{exact:selectedMicId}} : true\n    };\n    micStream = await navigator.mediaDevices.getUserMedia(constraints);\n\n    if (!audioContext) audioContext = new (AudioContext||webkitAudioContext)();\n    const src = audioContext.createMediaStreamSource(micStream);\n    const proc = audioContext.createScriptProcessor(4096,1,1);\n    src.connect(proc); proc.connect(audioContext.destination);\n\n    proc.onaudioprocess = e => {\n      const samples = Array.from(e.inputBuffer.getChannelData(0));\n      if (ws && ws.readyState === WebSocket.OPEN) {\n        try {\n          ws.send(JSON.stringify({\n            type:'audio',\n            audio:samples,\n            sample_rate:audioContext.sampleRate,\n            session_id:SESSION_ID\n          }));\n        } catch (error) {\n          console.error(\"Error sending audio data:\", error);\n          stopRecording();\n        }\n      }\n    };\n\n    window._micProcessor = proc;        \n    isRecording = true;\n    document.getElementById('micStatus').textContent = 'Listening…';\n    showVoiceCircle();\n  } catch (err) {\n    console.error(\"Microphone access error:\", err);\n    showNotification('Microphone access denied','error');\n  }\n}\n\nfunction stopRecording() {\n  if (!isRecording) return;\n  try {\n    if (window._micProcessor) {\n      window._micProcessor.disconnect();\n      window._micProcessor = null;\n    }\n    if (micStream) {\n      micStream.getTracks().forEach(t => t.stop());\n      micStream = null;\n    }\n  } catch(e) {\n    console.warn(\"Error stopping recording:\", e);\n  }\n  isRecording = false;\n  \n  const micStatus = document.getElementById('micStatus');\n  if (micStatus) {\n    micStatus.textContent = 'Click to speak';\n  }\n  hideVoiceCircle();\n}\n\nasync function setupChatUI() {\n  document.documentElement.classList.add('bg-gray-950');\n  document.documentElement.style.backgroundColor = '#030712';\n\n  createPermanentVoiceCircle();\n  connectWebSocket();\n\n  initAudioLevelsChart();\n\n  const txt = document.getElementById('textInput');\n  const btn = document.getElementById('sendTextBtn');\n  \n  // Setup enhanced interrupt button\n  const interruptBtn = document.createElement('button');\n  interruptBtn.id = 'interruptBtn';\n  interruptBtn.className = 'px-3 py-2 ml-2 bg-red-600 text-white rounded hover:bg-red-700 flex items-center transition duration-150';\n  interruptBtn.innerHTML = '<svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5 mr-1\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M10 18a8 8 0 100-16 8 8 0 000 16zM8 7a1 1 0 00-1 1v4a1 1 0 001 1h4a1 1 0 001-1V8a1 1 0 00-1-1H8z\" clip-rule=\"evenodd\" /></svg> Stop';\n  interruptBtn.onclick = (e) => {\n    e.preventDefault();\n    try {\n      requestInterrupt();\n      interruptBtn.classList.add('bg-red-800', 'scale-95');\n      setTimeout(() => interruptBtn.classList.remove('bg-red-800', 'scale-95'), 150);\n    } catch (error) {\n      console.error(\"Error in interrupt button handler:\", error);\n    }\n  };\n  interruptBtn.title = \"Stop AI speech (Space or Esc)\";\n  interruptBtn.disabled = true; // Disabled by default\n  interruptBtn.classList.add('opacity-50', 'cursor-not-allowed');\n  \n  if (btn && btn.parentElement) {\n    btn.parentElement.appendChild(interruptBtn);\n  }\n  \n  // Add debug button for easier debugging of interrupt issues\n  const debugBtn = document.createElement('button');\n  debugBtn.innerText = \"Debug Audio\";\n  debugBtn.className = \"px-3 py-2 ml-2 bg-blue-600 text-white rounded text-xs\";\n  debugBtn.onclick = () => {\n    console.log(\"- Debug info:\");\n    console.log(\"- Audio playing:\", isAudioCurrentlyPlaying);\n    console.log(\"- Interrupt requested:\", interruptRequested);\n    console.log(\"- Interrupt in progress:\", interruptInProgress);\n    console.log(\"- Current source:\", currentAudioSource);\n    console.log(\"- Queue length:\", audioPlaybackQueue.length);\n    console.log(\"- Audio context state:\", audioContext?.state);\n    console.log(\"- Active generation ID:\", activeGenId);\n    console.log(\"- Last seen generation ID:\", lastSeenGenId);\n    console.log(\"- WebSocket state:\", ws ? ws.readyState : \"no websocket\");\n    showNotification(\"Debug info in console\", \"info\");\n  };\n  \n  if (btn && btn.parentElement) {\n    btn.parentElement.appendChild(debugBtn);\n  }\n  \n  // Run the update function periodically\n  setInterval(() => {\n    const interruptBtn = document.getElementById('interruptBtn');\n    if (interruptBtn) {\n      if (isAudioCurrentlyPlaying && !interruptRequested && !interruptInProgress) {\n        interruptBtn.disabled = false;\n        interruptBtn.classList.remove('opacity-50', 'cursor-not-allowed');\n      } else {\n        interruptBtn.disabled = true;\n        interruptBtn.classList.add('opacity-50', 'cursor-not-allowed');\n      }\n    }\n  }, 300);\n  \n  if (btn) {\n    btn.onclick = () => {\n      try {\n        sendTextMessage(txt.value);\n      } catch (error) {\n        console.error(\"Error in send button handler:\", error);\n      }\n    };\n  }\n  \n  if (txt) {\n    txt.addEventListener('keydown', e => {\n      if (e.key === 'Enter' && !e.shiftKey) {\n        e.preventDefault();\n        try {\n          sendTextMessage(txt.value);\n        } catch (error) {\n          console.error(\"Error in text input handler:\", error);\n        }\n      }\n    });\n  }\n  \n  const micBtn = document.getElementById('micToggleBtn');\n  if (micBtn) {\n    micBtn.addEventListener('click', () => {\n      try {\n        if (isRecording) stopRecording();\n        else startRecording();\n      } catch (error) {\n        console.error(\"Error in mic button handler:\", error);\n      }\n    });\n  }\n  \n  // Add event listeners to detect keyboard interruptions\n  document.addEventListener('keydown', e => {\n    // Allow space or escape to interrupt\n    if ((e.code === 'Space' || e.code === 'Escape') && isAudioCurrentlyPlaying) {\n      e.preventDefault();\n      try {\n        requestInterrupt();\n        \n        // Add visual feedback\n        const interruptBtn = document.getElementById('interruptBtn');\n        if (interruptBtn) {\n          interruptBtn.classList.add('bg-red-800');\n          setTimeout(() => {\n            interruptBtn.classList.remove('bg-red-800');\n          }, 200);\n        }\n      } catch (error) {\n        console.error(\"Error in keyboard interrupt handler:\", error);\n      }\n    }\n  });\n  \n  // Initialize audio context\n  if (!audioContext) {\n    try {\n      audioContext = new (window.AudioContext || window.webkitAudioContext)();\n      window.audioContext = audioContext;\n    } catch (error) {\n      console.error(\"Error creating audio context:\", error);\n      showNotification(\"Audio initialization failed. Please refresh the page.\", \"error\");\n    }\n  }\n  \n  // Try to unlock audio context on user interaction\n  ['click', 'touchstart', 'keydown'].forEach(ev =>\n    document.addEventListener(ev, function unlock() {\n      if (audioContext && audioContext.state === 'suspended') {\n        try {\n          audioContext.resume();\n        } catch (error) {\n          console.warn(\"Error resuming audio context:\", error);\n        }\n      }\n      document.removeEventListener(ev, unlock);\n    })\n  );\n\n  console.log(\"Chat UI ready with enhanced interruption support\");\n}\n\nif (document.readyState === 'loading') {\n  document.addEventListener('DOMContentLoaded', setupChatUI);\n} else {\n  setupChatUI();\n}\n\nfunction initAudioLevelsChart() {\n  const ctx = document.getElementById('audioLevels');\n  if (!ctx) return;\n  \n  try {\n    if (audioLevelsChart) audioLevelsChart.destroy();\n    \n    const grad = ctx.getContext('2d').createLinearGradient(0, 0, 0, 100);\n    grad.addColorStop(0, 'rgba(79,70,229,.6)');\n    grad.addColorStop(1, 'rgba(79,70,229,.1)');\n    \n    audioLevelsChart = new Chart(ctx, {\n      type: 'line',\n      data: {\n        labels: Array(30).fill(''),\n        datasets: [{\n          data: Array(30).fill(0),\n          backgroundColor: grad,\n          borderColor: 'rgba(99,102,241,1)',\n          borderWidth: 2,\n          tension: .4,\n          fill: true,\n          pointRadius: 0\n        }]\n      },\n      options: {\n        animation: false,\n        responsive: true,\n        scales: {\n          y: {\n            beginAtZero: true,\n            max: 100,\n            ticks: {display: false},\n            grid: {color: 'rgba(255,255,255,.1)'}\n          },\n          x: {display: false, grid: {display: false}}\n        },\n        plugins: {\n          legend: {display: false},\n          tooltip: {enabled: false}\n        },\n        elements: {point: {radius: 0}}\n      }\n    });\n  } catch (error) {\n    console.error(\"Error initializing audio chart:\", error);\n  }\n}"
  },
  {
    "path": "static/crud.js",
    "content": "let allConversations = [];\n\ndocument.addEventListener('DOMContentLoaded', async () => {\n  await loadConversations();\n\n  document.getElementById('searchInput').addEventListener('input', () => {\n    const query = document.getElementById('searchInput').value.toLowerCase();\n    const filtered = allConversations.filter(c =>\n      c.user_message.toLowerCase().includes(query) ||\n      c.ai_message.toLowerCase().includes(query)\n    );\n    renderConversations(filtered);\n  });\n\n  document.getElementById('deleteAllBtn').addEventListener('click', async () => {\n    if (!confirm(\"Are you sure you want to delete ALL conversations?\")) return;\n    await fetch('/api/conversations', { method: 'DELETE' });\n    await loadConversations();\n  });\n});\n\nasync function loadConversations() {\n  const res = await fetch('/api/conversations');\n  allConversations = await res.json();\n  renderConversations(allConversations);\n}\n\nfunction renderConversations(list) {\n  const container = document.getElementById('conversationList');\n  container.innerHTML = '';\n\n  if (list.length === 0) {\n    container.innerHTML = '<p class=\"text-gray-400\">No conversations found.</p>';\n    return;\n  }\n\n  list.forEach(conv => {\n    const div = document.createElement('div');\n    div.className = \"bg-gray-800 p-4 rounded shadow\";\n    div.innerHTML = `\n      <div><strong>User:</strong></div>\n      <textarea data-id=\"${conv.id}\" data-field=\"user\" class=\"w-full p-2 rounded bg-gray-700 mt-1\">${conv.user_message}</textarea>\n      <div class=\"mt-2\"><strong>AI:</strong></div>\n      <textarea data-id=\"${conv.id}\" data-field=\"ai\" class=\"w-full p-2 rounded bg-gray-700 mt-1\">${conv.ai_message}</textarea>\n      <button class=\"saveBtn mt-2 bg-green-600 hover:bg-green-500 px-4 py-1 rounded\">Save</button>\n      <button class=\"deleteBtn mt-2 ml-2 bg-red-600 hover:bg-red-500 px-4 py-1 rounded\">Delete</button>\n    `;\n    container.appendChild(div);\n\n    div.querySelector('.saveBtn').addEventListener('click', async () => {\n      const id = conv.id;\n      const user = div.querySelector('textarea[data-field=\"user\"]').value;\n      const ai = div.querySelector('textarea[data-field=\"ai\"]').value;\n      await fetch(`/api/conversations/${id}`, {\n        method: 'PUT',\n        headers: { 'Content-Type': 'application/json' },\n        body: JSON.stringify({ user_message: user, ai_message: ai })\n      });\n      alert(\"Saved.\");\n    });\n\n    div.querySelector('.deleteBtn').addEventListener('click', async () => {\n      const id = conv.id;\n      if (!confirm(\"Delete this conversation?\")) return;\n      await fetch(`/api/conversations/${id}`, { method: 'DELETE' });\n      await loadConversations();\n    });\n  });\n}\n"
  },
  {
    "path": "templates/chat.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n  <meta charset=\"UTF-8\">\n  <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n  <title>AI Companion - Chat</title>\n  <link href=\"https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css\" rel=\"stylesheet\">\n  <script src=\"https://cdn.jsdelivr.net/npm/chart.js\"></script>\n  <style>\n    body, html {\n      background-color: #030712 !important;\n      color: white !important;\n      overflow-x: hidden;\n      margin: 0;\n      padding: 0;\n    }\n    \n    .pulse {\n      animation: pulse 1.5s infinite;\n    }\n    @keyframes pulse {\n      0% { box-shadow: 0 0 0 0 rgba(99, 102, 241, 0.7); }\n      70% { box-shadow: 0 0 0 10px rgba(99, 102, 241, 0); }\n      100% { box-shadow: 0 0 0 0 rgba(99, 102, 241, 0); }\n    }\n    \n    #voice-circle {\n      position: fixed;\n      top: 50%;\n      left: 50%;\n      transform: translate(-50%, -50%) scale(0.7);\n      width: 180px;\n      height: 180px;\n      border-radius: 50%;\n      background-color: rgba(99, 102, 241, 0.2);\n      z-index: 50;\n      pointer-events: none;\n      display: block !important;\n      transition: transform 0.5s ease, background-color 0.5s ease;\n    }\n    \n    #voice-circle.active {\n      animation: pulse-circle 2s infinite alternate;\n    }\n    \n    @keyframes pulse-circle {\n      0% { transform: translate(-50%, -50%) scale(1); background-color: rgba(99, 102, 241, 0.5); }\n      100% { transform: translate(-50%, -50%) scale(2); background-color: rgba(99, 102, 241, 0.1); }\n    }\n    \n    .header-container {\n      position: fixed;\n      top: 0;\n      left: 0;\n      right: 0;\n      height: 64px;\n      z-index: 40;\n    }\n    \n    .status-container {\n      position: fixed;\n      left: 0;\n      top: 64px;\n      bottom: 0;\n      width: 25%;\n      overflow-y: auto;\n      padding: 10px;\n    }\n    \n    .conversation-container {\n      position: fixed;\n      right: 0;\n      top: 64px;\n      bottom: 0;\n      width: 25%;\n      overflow-y: auto;\n      padding: 10px;\n    }\n    \n    .center-area {\n      position: fixed;\n      left: 25%;\n      right: 25%;\n      top: 64px;\n      bottom: 0;\n    }\n    \n    /* Mobile layout */\n    @media (max-width: 768px) {\n      .status-container {\n        width: 50%;\n        left: 0;\n      }\n      \n      .conversation-container {\n        width: 50%;\n        right: 0;\n      }\n      \n      .center-area {\n        display: none;\n      }\n    }\n  </style>\n</head>\n<body class=\"bg-gray-950 text-white\">\n\n  <div id=\"voice-circle\"></div>\n\n  <header class=\"header-container bg-gray-900 p-4 flex justify-between items-center shadow-md\">\n    <h1 class=\"text-2xl font-bold text-indigo-400\">AI Companion</h1>\n    <button id=\"settingsBtn\" class=\"bg-indigo-700 hover:bg-indigo-800 text-white px-4 py-2 rounded shadow\">\n      Settings\n    </button>\n  </header>\n\n  <div class=\"flex flex-row\">\n    \n    <aside class=\"status-container bg-gray-900 rounded-lg shadow p-4 flex flex-col\">\n      <h2 class=\"text-xl font-semibold text-indigo-400 mb-4\">Status</h2>\n\n      <div class=\"space-y-4\">\n        <div>\n          <h3 class=\"text-lg font-medium mb-2\">System</h3>\n          <div class=\"text-sm space-y-1\">\n            <div class=\"flex justify-between\"><span>Connection:</span><span id=\"connectionStatus\" class=\"text-yellow-400\">Connecting...</span></div>\n            <div class=\"flex justify-between\"><span>Models:</span><span id=\"modelStatus\" class=\"text-yellow-400\">Not Loaded</span></div>\n            <div class=\"flex justify-between\"><span>Audio:</span><span id=\"audioStatus\" class=\"text-yellow-400\">Idle</span></div>\n          </div>\n        </div>\n\n        <div>\n          <h3 class=\"text-lg font-medium mb-2\">Session</h3>\n          <div class=\"text-sm space-y-1\">\n            <div class=\"flex justify-between\"><span>Duration:</span><span id=\"sessionDuration\">00:00:00</span></div>\n            <div class=\"flex justify-between\"><span>Messages:</span><span id=\"messageCount\">0</span></div>\n          </div>\n        </div>\n        \n        <div class=\"mt-4 p-4 bg-gray-800 rounded-lg\">\n          <div class=\"flex items-center justify-between\">\n            <button id=\"micToggleBtn\" class=\"w-16 h-16 bg-indigo-600 hover:bg-indigo-700 text-white rounded-full flex items-center justify-center focus:outline-none\">\n              <svg xmlns=\"http://www.w3.org/2000/svg\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n              class=\"h-8 w-8\">\n              <path stroke-linecap=\"round\" stroke-linejoin=\"round\"\n                    d=\"M12 3a4 4 0 0 1 4 4v6a4 4 0 0 1-8 0V7a4 4 0 0 1 4-4z\"/>\n              <path stroke-linecap=\"round\" stroke-linejoin=\"round\"\n                    d=\"M19 11a7 7 0 0 1-14 0\"/>\n              <path stroke-linecap=\"round\" stroke-linejoin=\"round\"\n                    d=\"M12 21v-4\"/>\n            </svg>\n            </button>\n            <div class=\"text-center\">\n              <span id=\"micStatus\" class=\"text-sm text-gray-300 block\">Click to speak</span>\n            </div>\n          </div>\n        </div>\n      </div>\n    </aside>\n\n    <div class=\"center-area\"></div>\n\n    <section class=\"conversation-container bg-gray-900 rounded-lg shadow p-4 flex flex-col\">\n      <div id=\"conversationHistory\" class=\"flex-grow overflow-y-auto space-y-4 mb-4\" style=\"max-height: calc(100vh - 15rem);\">\n      </div>\n\n      <div class=\"p-4 bg-gray-800 rounded-lg\">\n        <div class=\"flex items-center\">\n          <input \n            type=\"text\" \n            id=\"textInput\" \n            class=\"flex-grow bg-gray-700 text-white p-3 rounded-l border border-gray-600 focus:outline-none focus:ring-2 focus:ring-indigo-500\" \n            placeholder=\"Type your message here...\"\n          >\n          <button \n            id=\"sendTextBtn\" \n            class=\"bg-indigo-600 hover:bg-indigo-700 text-white p-3 rounded-r border border-indigo-700\"\n          >\n            <svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-6 w-6\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\">\n              <path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"2\" d=\"M14 5l7 7m0 0l-7 7m7-7H3\" />\n            </svg>\n          </button>\n        </div>\n      </div>\n    </section>\n  </div>\n\n  <div id=\"settingsModal\" class=\"fixed inset-0 bg-black bg-opacity-80 hidden z-50 flex items-center justify-center\">\n    <div class=\"bg-gray-900 text-white p-6 rounded-lg shadow-lg w-full max-w-3xl max-h-screen overflow-y-auto\">\n      <div class=\"flex justify-between items-center mb-6\">\n        <h2 class=\"text-2xl font-bold text-indigo-400\">Audio Settings</h2>\n        <button id=\"closeSettingsBtn\" class=\"text-gray-400 hover:text-white\">\n          ✕\n        </button>\n      </div>\n\n      <div class=\"grid grid-cols-1 md:grid-cols-2 gap-4\">\n        <div>\n          <label for=\"micSelect\" class=\"block text-sm font-medium\">Microphone</label>\n          <select id=\"micSelect\" class=\"w-full bg-gray-800 text-white p-2 rounded border border-gray-700 mt-1\"></select>\n          <canvas id=\"micCanvas\" class=\"w-full h-16 mt-2 rounded\"></canvas>\n        </div>\n\n        <div>\n          <label for=\"outputSelect\" class=\"block text-sm font-medium\">Speaker</label>\n          <select id=\"outputSelect\" class=\"w-full bg-gray-800 text-white p-2 rounded border border-gray-700 mt-1\"></select>\n          <canvas id=\"outputCanvas\" class=\"w-full h-16 mt-2 rounded\"></canvas>\n        </div>\n\n        <div>\n          <label for=\"vadEnabled\" class=\"flex items-center space-x-2 text-sm font-medium\">\n            <input type=\"checkbox\" id=\"vadEnabled\" class=\"rounded text-indigo-600 focus:ring-indigo-500\" checked />\n            <span>Voice Activity Detection</span>\n          </label>\n        </div>\n\n        <div>\n          <label for=\"vadThreshold\" class=\"block text-sm font-medium\">VAD Sensitivity: <span id=\"vadThresholdValue\">0.5</span></label>\n          <input type=\"range\" id=\"vadThreshold\" min=\"0.1\" max=\"0.9\" step=\"0.05\" value=\"0.5\" class=\"w-full mt-1\" />\n        </div>\n\n        <div>\n          <label for=\"volumeLevel\" class=\"block text-sm font-medium\">Microphone Volume: <span id=\"volumeLevelValue\">1.0</span></label>\n          <input type=\"range\" id=\"volumeLevel\" min=\"0.1\" max=\"2.0\" step=\"0.1\" value=\"1.0\" class=\"w-full mt-1\" />\n        </div>\n\n        <div>\n          <label for=\"speakerVolume\" class=\"block text-sm font-medium\">Speaker Volume: <span id=\"speakerVolumeValue\">1.0</span></label>\n          <input type=\"range\" id=\"speakerVolume\" min=\"0.1\" max=\"2.0\" step=\"0.1\" value=\"1.0\" class=\"w-full mt-1\" />\n        </div>\n      </div>\n\n      <div class=\"mt-6 flex justify-end space-x-4\">\n        <button id=\"testMicBtn\" class=\"bg-indigo-600 px-4 py-2 rounded hover:bg-indigo-700\">Test Mic</button>\n        <button id=\"testAudioBtn\" class=\"bg-indigo-600 px-4 py-2 rounded hover:bg-indigo-700\">Test Audio</button>\n        <button id=\"saveAudioSettingsBtn\" class=\"bg-green-600 px-4 py-2 rounded hover:bg-green-700\">Save Settings</button>\n      </div>\n    </div>\n  </div>\n  \n  <script src=\"/static/chat.js\"></script>\n</body>\n</html>"
  },
  {
    "path": "templates/crud.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n  <meta charset=\"UTF-8\" />\n  <title>Conversation Manager</title>\n  <link href=\"https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css\" rel=\"stylesheet\">\n  <script defer src=\"/static/crud.js\"></script>\n</head>\n<body class=\"bg-gray-900 text-white\">\n  <div class=\"p-6\">\n    <a href=\"/setup\" class=\"absolute top-4 right-4 bg-blue-700 hover:bg-blue-600 text-white px-4 py-2 rounded\">\n      Return to Setup\n    </a>\n    <h1 class=\"text-2xl font-bold text-indigo-400 mb-4\">Memory Manager</h1>\n    <div class=\"flex items-center justify-between mb-4\">\n        <input\n          type=\"text\"\n          id=\"searchInput\"\n          placeholder=\"Search...\"\n          class=\"p-2 w-1/2 bg-gray-700 rounded text-white placeholder-gray-400\"\n        />\n        <button id=\"deleteAllBtn\" class=\"bg-red-700 hover:bg-red-600 px-4 py-2 rounded ml-4\">Delete All</button>\n      </div>\n      \n    <div id=\"conversationList\" class=\"space-y-4\"></div>\n  </div>\n</body>\n</html>\n"
  },
  {
    "path": "templates/index.html",
    "content": "<meta http-equiv=\"refresh\" content=\"0; url=/setup\" />"
  },
  {
    "path": "templates/setup.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\" class=\"dark\">\n<head>\n<meta charset=\"UTF-8\" />\n<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\" />\n<title>AI Companion Setup</title>\n<script defer src=\"/static/app.js\"></script>\n<link href=\"https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css\" rel=\"stylesheet\">\n<style>\ncanvas {\nwidth: 100%;\nmax-width: 300px;\nheight: 50px;\nbackground-color: #374151;\nborder-radius: 0.375rem;\n}\n.spinner {\nborder: 4px solid rgba(255, 255, 255, 0.3);\nborder-radius: 50%;\nborder-top: 4px solid #4ade80;\nwidth: 30px;\nheight: 30px;\nanimation: spin 1s linear infinite;\nmargin: 0 auto;\ndisplay: none;\n}\n@keyframes spin {\n 0% { transform: rotate(0deg); }\n 100% { transform: rotate(360deg); }\n}\n.reference-section {\n  border: 1px solid #4b5563;\n  border-radius: 0.5rem;\n  padding: 1rem;\n  margin-bottom: 1rem;\n}\n.reference-heading {\n  display: flex;\n  justify-content: space-between;\n  align-items: center;\n  margin-bottom: 0.5rem;\n}\n</style>\n</head>\n<body class=\"bg-gray-900 text-white\">\n<div class=\"min-h-screen flex items-center justify-center\">\n<div class=\"bg-gray-800 p-8 rounded-lg shadow-xl w-full max-w-3xl\">\n<h1 class=\"text-2xl font-bold mb-4 text-indigo-400\">AI Companion Setup</h1>\n<a href=\"/crud\" class=\"absolute top-4 right-4 bg-blue-700 hover:bg-blue-600 text-white px-4 py-2 rounded\">\n Open Conversation DB\n</a>\n<div class=\"grid grid-cols-1 gap-4\">\n<label>System Prompt:\n<textarea id=\"systemPrompt\" class=\"w-full p-2 bg-gray-700 rounded mt-1\">You are a friendly AI.</textarea>\n</label>\n<label>Model Path:\n<input type=\"text\" id=\"modelPath\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"./finetuned_model\">\n</label>\n<label>LLM Path (GGUF):\n<input type=\"text\" id=\"llmPath\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"./models/llama3-8b-instruct.gguf\">\n</label>\n\n<!-- Primary Reference (Required) -->\n<div class=\"reference-section\">\n  <div class=\"reference-heading\">\n    <h3 class=\"text-lg font-medium text-indigo-300\">Primary Reference Audio (Required)</h3>\n  </div>\n  <label>Reference Audio:\n  <input type=\"text\" id=\"referenceAudio\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"./reference.wav\">\n  </label>\n  <label>Reference Text:\n  <input type=\"text\" id=\"referenceText\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"Hi, how can I help?\">\n  </label>\n</div>\n\n<!-- Second Reference (Optional) -->\n<div class=\"reference-section\">\n  <div class=\"reference-heading\">\n    <h3 class=\"text-lg font-medium text-indigo-300\">Secondary Reference (Optional)</h3>\n    <span class=\"text-xs text-gray-400\">For better voice quality</span>\n  </div>\n  <label>Reference Audio 2:\n  <input type=\"text\" id=\"referenceAudio2\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"\">\n  </label>\n  <label>Reference Text 2:\n  <input type=\"text\" id=\"referenceText2\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"\">\n  </label>\n</div>\n\n<!-- Third Reference (Optional) -->\n<div class=\"reference-section\">\n  <div class=\"reference-heading\">\n    <h3 class=\"text-lg font-medium text-indigo-300\">Tertiary Reference (Optional)</h3>\n    <span class=\"text-xs text-gray-400\">For even better voice quality</span>\n  </div>\n  <label>Reference Audio 3:\n  <input type=\"text\" id=\"referenceAudio3\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"\">\n  </label>\n  <label>Reference Text 3:\n  <input type=\"text\" id=\"referenceText3\" class=\"w-full p-2 bg-gray-700 rounded mt-1\" value=\"\">\n  </label>\n</div>\n\n<label>Microphone:\n<select id=\"micSelect\" class=\"w-full p-2 bg-gray-700 rounded mt-1\"></select>\n</label>\n<div>\n<canvas id=\"micCanvas\" width=\"300\" height=\"50\"></canvas>\n<button id=\"testMicBtn\" class=\"bg-indigo-600 hover:bg-indigo-500 px-4 py-2 rounded mt-2\">Test Mic</button>\n</div>\n<label>Headset / Output:\n<select id=\"outputSelect\" class=\"w-full p-2 bg-gray-700 rounded mt-1\"></select>\n</label>\n<div>\n<canvas id=\"outputCanvas\" width=\"300\" height=\"50\"></canvas>\n<button id=\"testOutputBtn\" class=\"bg-indigo-600 hover:bg-indigo-500 px-4 py-2 rounded mt-2\">Test Output</button>\n</div>\n<div class=\"flex flex-col items-center justify-center mt-4\">\n<button id=\"saveAndStart\" class=\"bg-green-600 hover:bg-green-500 px-4 py-2 rounded w-full\">Start Companion</button>\n<div id=\"loadingContainer\" class=\"mt-4 text-center hidden\">\n<div class=\"spinner\" id=\"loadingSpinner\"></div>\n<p id=\"loadingText\" class=\"mt-2 text-indigo-300\">Initializing models, please wait...</p>\n</div>\n</div>\n</div>\n</div>\n</div>\n</body>\n</html>"
  },
  {
    "path": "test.py",
    "content": "import time\nfrom generator import Segment, load_csm_1b, generate_streaming_audio\nimport torchaudio\n\nprint(f\"Starting script at: {time.strftime('%H:%M:%S')}\")\nstart_time = time.time()\n\nprint(\"Downloading model...\")\nmodel_start = time.time()\nprint(f\"Model download completed in {time.time() - model_start:.2f} seconds\")\n\nprint(\"Loading model to CUDA...\")\nload_start = time.time()\ngenerator = load_csm_1b(\"cuda\")\nprint(f\"Model loaded in {time.time() - load_start:.2f} seconds\")\n\nspeakers = [0, 1, 0, 0]\ntranscripts = [\n    \"Hey how are you doing.\",\n    \"Pretty good, pretty good.\",\n    \"I'm great.\",\n    \"So happy to be speaking to you.\",\n]\naudio_paths = [\n    \"utterance_0.wav\",\n    \"utterance_1.wav\",\n    \"utterance_2.wav\",\n    \"utterance_3.wav\",\n]\n\ndef load_audio(audio_path):\n    print(f\"Loading reference audio: {audio_path}\")\n    audio_load_start = time.time()\n    audio_tensor, sample_rate = torchaudio.load(audio_path)\n    audio_tensor = torchaudio.functional.resample(\n        audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate\n    )\n    print(f\"Audio loaded and resampled in {time.time() - audio_load_start:.2f} seconds\")\n    return audio_tensor\n\nprint(\"Creating segments with reference audio...\")\nsegments_start = time.time()\nsegments = [\n    Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))\n    for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)\n]\nprint(f\"Segments created in {time.time() - segments_start:.2f} seconds\")\n\n# Option 1: Regular generation with streaming internally enabled\nprint(\"Generating audio (with internal streaming)...\")\ngen_start = time.time()\naudio = generator.generate(\n    text=\"Me too, this is some cool stuff huh?\",\n    speaker=0,\n    context=segments,\n    max_audio_length_ms=10_000,\n    stream=True  # Enable internal streaming\n)\nprint(f\"Audio generation completed in {time.time() - gen_start:.2f} seconds\")\n\nprint(\"Saving audio file...\")\nsave_start = time.time()\ntorchaudio.save(\"audio_regular.wav\", audio.unsqueeze(0).cpu(), generator.sample_rate)\nprint(f\"Audio saved in {time.time() - save_start:.2f} seconds\")\n\n# Option 2: Use the streaming helper function that saves as it goes\nprint(\"Generating audio using streaming API...\")\ngenerate_streaming_audio(\n    generator=generator,\n    text=\"Me too, this is some cool stuff huh?\",\n    speaker=0,\n    context=segments,\n    output_file=\"audio_streamed.wav\",\n    max_audio_length_ms=10_000,\n    play_audio=True  # Set to True to play audio in real-time (requires sounddevice package)\n)\n\ntotal_time = time.time() - start_time\nprint(f\"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)\")\nprint(f\"Script completed at: {time.strftime('%H:%M:%S')}\")"
  },
  {
    "path": "vad.py",
    "content": "import numpy as np\nimport torch\nfrom typing import Callable, Dict, List\nfrom collections import deque\nclass VoiceActivityDetector:\n    def __init__(\n        self,\n        model,\n        utils,\n        sample_rate: int = 16000,\n        threshold: float = 0.3,\n        silence_duration: int = 45\n    ):\n        self.model = model\n        self.sample_rate = sample_rate\n        self.threshold = threshold\n        self.silence_duration = silence_duration\n        \n        # Get functions from utils\n        self.get_speech_timestamps = utils[0]\n        \n        self.is_speaking = False\n        self.silent_frames = 0\n        self.frame_size = 512 if sample_rate == 16000 else 256  # Required by Silero VAD\n        \n        print(f\"VAD initialized with threshold {threshold}, frame size {self.frame_size}, silence duration {silence_duration}\")\n\n    def reset(self) -> None:\n        self.is_speaking   = False\n        self.silent_frames = 0\n\n        if hasattr(self.model, \"reset_states\"):\n            self.model.reset_states()\n        elif hasattr(self.model, \"reset_state\"):\n            self.model.reset_state()\n        else:\n            for buf in (\"h\", \"c\"):\n                if hasattr(self.model, buf):\n                    getattr(self.model, buf).zero_()\n\n    def process_audio_chunk(self, audio_chunk: np.ndarray) -> bool:\n        # Prepare audio chunk\n        if audio_chunk.ndim > 1:\n            audio_chunk = np.mean(audio_chunk, axis=1)\n        if audio_chunk.dtype != np.float32:\n            audio_chunk = audio_chunk.astype(np.float32)\n        \n        # Process in chunks of the correct size\n        speech_detected = False\n        turn_ended = False\n        \n        speech_probs = []\n        \n        # Process audio in correct sized chunks for Silero VAD\n        for i in range(0, len(audio_chunk), self.frame_size):\n            # Get chunk of correct size\n            chunk = audio_chunk[i:i+self.frame_size]\n            \n            # If we don't have enough samples, pad with zeros\n            if len(chunk) < self.frame_size:\n                chunk = np.pad(chunk, (0, self.frame_size - len(chunk)))\n            \n            # Convert to tensor\n            audio_tensor = torch.tensor(chunk).to('cpu')\n            \n            # Get speech probability\n            \n            speech_prob = self.model(audio_tensor, self.sample_rate).item()\n            \n            speech_probs.append(speech_prob)\n            \n            # Update speaking state\n            if speech_prob >= self.threshold:\n                speech_detected = True\n                self.silent_frames = 0\n            else:\n                if self.is_speaking:\n                    self.silent_frames += 1\n        \n        # Print detailed speech detection information\n        # print(f\"Speech probabilities: {speech_probs}\")\n        # print(f\"Speech detected: {speech_detected}, Current state: {self.is_speaking}\")\n        # print(f\"Silent frames: {self.silent_frames}, Threshold: {self.silence_duration}\")\n        \n        # Update speaking state based on all chunks\n        if speech_detected:\n            self.is_speaking = True\n            self.silent_frames = 0\n        elif self.is_speaking and self.silent_frames >= self.silence_duration:\n            # Transition to not speaking if we've had enough silent frames\n            self.is_speaking = False\n            turn_ended = True\n            print(f\"Turn ended after {self.silent_frames} silent frames\")\n            self.silent_frames = 0\n        \n        return turn_ended\n\n\nclass AudioStreamProcessor:\n    def __init__(\n        self,\n        model,\n        utils,\n        sample_rate: int = 16000,\n        chunk_size: int = 512,\n        vad_threshold: float = 0.3,\n        callbacks: Dict[str, Callable] = None,\n        pre_speech_buffer_size: int = 10\n    ):\n        self.sample_rate = sample_rate\n        self.chunk_size = chunk_size\n        self.pre_speech_buffer = deque(maxlen=pre_speech_buffer_size)\n        # Ensure model is on CPU\n        if hasattr(model, 'to'):\n            model = model.to('cpu')\n            \n        self.vad = VoiceActivityDetector(\n            model=model,\n            utils=utils,\n            sample_rate=sample_rate,\n            threshold=vad_threshold,\n            silence_duration=45  # Increased for better end detection\n        )\n        \n        self.audio_buffer = []\n        self.is_collecting = False\n        self.callbacks = callbacks or {}\n        self.silent_chunk_count = 0\n        self.max_silent_chunks = 30  # Force end after this many silent chunks\n        \n        print(f\"AudioStreamProcessor initialized with threshold: {vad_threshold}\")\n    \n    def process_audio(self, audio_chunk: np.ndarray):\n        # Always add to pre-speech buffer\n        self.pre_speech_buffer.append(audio_chunk)\n        \n        if self.is_collecting:\n            self.audio_buffer.append(audio_chunk)\n        \n        # Process with VAD\n        is_turn_end = self.vad.process_audio_chunk(audio_chunk)\n        \n        # Start collecting on speech detection\n        if self.vad.is_speaking and not self.is_collecting:\n            self.is_collecting = True\n            self.silent_chunk_count = 0\n            # Include pre-speech buffer in the audio buffer\n            self.audio_buffer = list(self.pre_speech_buffer)\n            print(f\"Speech started, beginning collection with {len(self.pre_speech_buffer)} pre-speech chunks\")\n            if \"on_speech_start\" in self.callbacks:\n                self.callbacks[\"on_speech_start\"]()\n        \n        # Count silent chunks when collecting but not speaking\n        if self.is_collecting and not self.vad.is_speaking:\n            self.silent_chunk_count += 1\n            print(f\"Silent chunk count: {self.silent_chunk_count}, max: {self.max_silent_chunks}\")\n            # Force end after too many silent chunks\n            if self.silent_chunk_count >= self.max_silent_chunks:\n                is_turn_end = True\n                print(f\"Forcing speech end after {self.silent_chunk_count} silent chunks\")\n        else:\n            self.silent_chunk_count = 0\n                \n        # End collection on turn end\n        if is_turn_end and self.is_collecting:\n            print(\"Turn end detected, processing collected audio\")\n            self.is_collecting = False\n            if self.audio_buffer:\n                print(f\"Audio buffer length: {len(self.audio_buffer)} chunks\")\n                print(\"Speech ended, processing collected audio\")\n                complete_audio = np.concatenate(self.audio_buffer)\n                print(f\"Complete audio length: {len(complete_audio)}\")\n                \n                if \"on_speech_end\" in self.callbacks:\n                    try:\n                        print(\"Calling on_speech_end callback\")\n                        self.callbacks[\"on_speech_end\"](complete_audio, self.sample_rate)\n                        print(\"on_speech_end callback completed successfully\")\n                    except Exception as e:\n                        print(f\"Error in on_speech_end callback: {e}\")\n                \n                # Clear buffer after processing\n                self.audio_buffer = []\n                self.silent_chunk_count = 0\n    \n    def reset(self):\n        self.vad.reset()\n        self.audio_buffer = []\n        self.is_collecting = False\n        self.silent_chunk_count = 0\n        print(\"AudioStreamProcessor reset\")"
  }
]