Repository: kyutai-labs/moshivis Branch: main Commit: 8624f4e01d4b Files: 170 Total size: 751.6 KB Directory structure: gitextract_d_6tdl44/ ├── .dockerignore ├── .gitattributes ├── .github/ │ ├── actions/ │ │ └── rust_build/ │ │ └── action.yml │ ├── requirements_github_actions.txt │ └── workflows/ │ ├── checks.yml │ └── rust-ci.yml ├── .gitignore ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE/ │ ├── bug.yml │ └── question.yml ├── LICENSE-APACHE ├── LICENSE-MIT ├── LICENSE.md ├── PULL_REQUEST_TEMPLATE.md ├── README.md ├── client/ │ ├── .eslinrc.json │ ├── .nvmrc │ ├── .prettierignore │ ├── .prettierrc.json │ ├── Dockerfile │ ├── LICENSE │ ├── README.md │ ├── index.html │ ├── package.json │ ├── postcss.config.js │ ├── public/ │ │ └── assets/ │ │ ├── decoderWorker.min.wasm │ │ └── images/ │ │ └── demo/ │ │ └── attribution.txt │ ├── src/ │ │ ├── app.tsx │ │ ├── audio-processor.ts │ │ ├── components/ │ │ │ ├── Button/ │ │ │ │ └── Button.tsx │ │ │ ├── ImageGallery/ │ │ │ │ └── ImageGallery.tsx │ │ │ └── Input/ │ │ │ └── Input.tsx │ │ ├── decoder/ │ │ │ └── decoderWorker.ts │ │ ├── env.ts │ │ ├── index.css │ │ ├── modules.d.ts │ │ ├── pages/ │ │ │ ├── Conversation/ │ │ │ │ ├── Conversation.tsx │ │ │ │ ├── MediaContext.ts │ │ │ │ ├── SocketContext.ts │ │ │ │ ├── components/ │ │ │ │ │ ├── AudioVisualizer/ │ │ │ │ │ │ ├── AudioVisualizer.tsx │ │ │ │ │ │ ├── ClientVisualizer.tsx │ │ │ │ │ │ └── ServerVisualizer.tsx │ │ │ │ │ ├── Controls/ │ │ │ │ │ │ └── Controls.tsx │ │ │ │ │ ├── ModelParams/ │ │ │ │ │ │ └── ModelParams.tsx │ │ │ │ │ ├── ServerAudio/ │ │ │ │ │ │ ├── ServerAudio.tsx │ │ │ │ │ │ └── ServerAudioStats.tsx │ │ │ │ │ ├── ServerInfo/ │ │ │ │ │ │ └── ServerInfo.tsx │ │ │ │ │ ├── TextDisplay/ │ │ │ │ │ │ ├── TextDisplay.tsx │ │ │ │ │ │ └── TextDisplayStats.tsx │ │ │ │ │ └── UserAudio/ │ │ │ │ │ ├── UserAudio.tsx │ │ │ │ │ └── UserAudioStats.tsx │ │ │ │ ├── getMimeType.ts │ │ │ │ └── hooks/ │ │ │ │ ├── audioUtils.ts │ │ │ │ ├── useModelParams.ts │ │ │ │ ├── useServerAudio.ts │ │ │ │ ├── useServerInfo.ts │ │ │ │ ├── useServerText.ts │ │ │ │ ├── useSocket.ts │ │ │ │ └── useUserAudio.ts │ │ │ └── Queue/ │ │ │ ├── Queue.tsx │ │ │ ├── api/ │ │ │ │ ├── client.ts │ │ │ │ ├── errors/ │ │ │ │ │ ├── api_error.ts │ │ │ │ │ └── response_error.ts │ │ │ │ └── validators.ts │ │ │ └── hooks/ │ │ │ └── useUserEmail.ts │ │ └── protocol/ │ │ ├── encoder.ts │ │ ├── testMessages.ts │ │ └── types.ts │ ├── tailwind.config.js │ ├── tsconfig.json │ └── vite.config.ts ├── docker-bake.hcl ├── kyuteye_mlx/ │ ├── .pylintrc │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── kyuteye_mlx/ │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── local_web.py │ │ ├── mlx_vlm/ │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── models/ │ │ │ ├── __init__.py │ │ │ ├── pixtral/ │ │ │ │ ├── __init__.py │ │ │ │ └── vision.py │ │ │ └── siglip/ │ │ │ └── vision.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── generate.py │ │ │ ├── lm.py │ │ │ ├── pixtral.py │ │ │ └── siglip.py │ │ ├── modules/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── cross_attention.py │ │ │ ├── kv_cache.py │ │ │ └── transformer.py │ │ ├── py.typed │ │ ├── quantize.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── loading.py │ │ ├── profiling.py │ │ └── sampling.py │ ├── pixtral-12b-8bit.config │ ├── pyproject.toml │ ├── siglip448.config │ └── tests/ │ └── test_siglip.py ├── kyuteye_pt/ │ ├── .pylintrc │ ├── LICENSE.md │ ├── README.md │ ├── configs/ │ │ └── moshika-vis.yaml │ ├── kyuteye/ │ │ ├── __init__.py │ │ ├── config/ │ │ │ ├── __init__.py │ │ │ ├── enums.py │ │ │ ├── kyuteye_config.py │ │ │ └── subconfigs.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── docker-bake.hcl │ │ │ ├── helium.py │ │ │ ├── hf_model_configs.py │ │ │ ├── image_projection.py │ │ │ ├── loaders.py │ │ │ └── moshivis.py │ │ ├── modules/ │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── cross_attention.py │ │ │ ├── image_encoder.py │ │ │ ├── image_transforms.py │ │ │ ├── streaming_utils.py │ │ │ ├── transformer.py │ │ │ └── utils.py │ │ ├── server.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── dist_utils.py │ │ ├── logging_utils.py │ │ └── struct_utils.py │ ├── pyproject.toml │ └── tests/ │ └── hello.py ├── kyuteye_rs/ │ ├── Cargo.toml │ ├── configs/ │ │ ├── config-moshika-vis-q8.json │ │ └── config-moshika-vis.json │ ├── moshi-backend/ │ │ ├── Cargo.toml │ │ ├── build.rs │ │ └── src/ │ │ ├── audio.rs │ │ ├── build.rs │ │ ├── image_embedder.rs │ │ ├── main.rs │ │ ├── metrics.rs │ │ ├── standalone.rs │ │ ├── stream_both.rs │ │ └── utils.rs │ └── moshi-core/ │ ├── Cargo.toml │ └── src/ │ ├── conv.rs │ ├── dynamic_logits_processor.rs │ ├── lib.rs │ ├── lm.rs │ ├── lm_generate.rs │ ├── lm_generate_multistream.rs │ ├── mimi.rs │ ├── nn.rs │ ├── quantization.rs │ ├── seanet.rs │ ├── streaming.rs │ └── transformer.rs ├── scripts/ │ ├── convert_ckpt_utils.py │ └── get_static_client.py └── ssvd/ ├── README.md ├── __init__.py ├── generate.py ├── multiturn_instruct.py ├── multiturn_prompting.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ **/target/ **/node_modules/ **/dist ssvd/synthetic_visual_dialogues/ ================================================ FILE: .gitattributes ================================================ *.wav filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .github/actions/rust_build/action.yml ================================================ name: rust_build description: 'Setup rust env' inputs: os: default: ubuntu-latest toolchain: default: stable target: default: check runs: using: "composite" steps: - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ inputs.toolchain }} override: true - name: cargo cache uses: actions/cache@v3 with: path: | ~/.cargo/bin/ ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ kyuteye_rs/target/ key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.toml') }} restore-keys: ${{ inputs.os }}-cargo- - name: install deps shell: bash run: | sudo apt-get update sudo apt-get install libasound2-dev ================================================ FILE: .github/requirements_github_actions.txt ================================================ # Main setup # old version: transformers 4.43.3 and accelerate 0.33.0 # new version (for pixtrla): transformers 4.46.0 and accelerate 1.0.1 accelerate==1.0.1 anls anls-star av<12 auditok<0.3.0 cython datasets deepspeed demucs einops encodec fasttext flashy>=0.0.1 gradio huggingface_hub hydra_colorlog hydra-core>=1.1 ipywidgets jiwer julius jupyterlab librosa maturin num2words numpy onnxruntime opencv-python protobuf pyannote.audio pyannote.metrics pycocoevalcap pycocotools sentencepiece spacy==3.5.2 tensorboard timm torch==2.2.0 torchaudio==2.2.0 torchmetrics torchtyping torchvision==0.17.0 tqdm transformers==4.47.0 # need Encodec there. webdataset==0.2.100 # for sanity evaluate rouge-score xformers==0.0.24 # specific clip commit clip @ https://github.com/openai/CLIP/archive/master.zip#sha256=11c3593912e6e6446fb0bde144c5ea374f7e19eeab9072c3eb00b59dd8afb706 # launcheon + code prettifying stuff fire rich pyyaml black mypy==1.11.2 pylint matplotlib seaborn ================================================ FILE: .github/workflows/checks.yml ================================================ name: Checks on: push: branches: - main pull_request: types: [opened, synchronize, reopened, ready_for_review] workflow_dispatch: jobs: pylint_pytorch: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Static lint analysis with pylint run: | cd kyuteye_pt && uv run --locked pylint --rcfile=.pylintrc --fail-under=8.5 ./kyuteye ruff_mlx: runs-on: macos-14 steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Static lint analysis with pylint run: | cd kyuteye_mlx && uv run ruff format --diff && uv run ruff check --select I sanity_check_pytorch: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Sanity check run: | cd kyuteye_pt && uv run --locked sanity-check sanity_check_mlx: runs-on: macos-14 steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Sanity check run: | cd kyuteye_mlx && uv run --locked sanity-check sanity_check_rust: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Cache Cargo uses: actions/cache@v3 with: path: | ~/.cargo/registry kyuteye_rs/target key: cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: cargo- - run: cd kyuteye_rs && cargo fmt --all -- --check - name: Ubuntu dependencies run: | sudo apt-get update sudo apt-get install -y -qq libasound2-dev libssl-dev libpulse-dev libdbus-1-dev portaudio19-dev protobuf-compiler - name: Clippy run: cd kyuteye_rs && cargo --locked clippy --workspace --tests --examples --locked -- -D warnings build_client: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - run: docker buildx bake client - run: tail client/dist/index.html ================================================ FILE: .github/workflows/rust-ci.yml ================================================ on: push: branches: [ main ] pull_request: branches: [ main, refacto ] name: Rust CI jobs: check: name: Check defaults: run: working-directory: ./kyuteye_rs runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] rust: [stable] steps: - uses: actions/checkout@v2 - uses: ./.github/actions/rust_build - name: check shell: bash run: | cargo check - name: clippy shell: bash run: | cargo clippy -- -D warnings - name: fmt shell: bash run: | cargo fmt --all -- --check test: name: Test defaults: run: working-directory: ./kyuteye_rs runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] rust: [stable] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v5 with: python-version: 3.11 - uses: ./.github/actions/rust_build with: target: test - name: test shell: bash run: | cargo test ================================================ FILE: .gitignore ================================================ ~* __pycache__ *.pt *.pth *.ipynb* *.egg-info *.jsonl nohup.out .idea/* client/node_modules client/dist target/ *.safetensors .DS_Store *.lprof *.prof cert.pem key.pem .mypy_cache Gemfile.lock project_page/_site/* kyuteye_mlx/static/* ssvd/synthetic_visual_dialogues/ ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to MoshiVis ## Pull Requests MoshiVis is the implementation of a research paper. Therefore, we do not plan on accepting many pull requests for new features. However, we certainly welcome them for bug fixes. 1. Fork the repo and create your branch from `main`. 2. If you have changed APIs, update the documentation accordingly. 3. Ensure pre-commit hooks pass properly, in particular the linting and typing. 4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`. 5. Accept the Contributor License Agreement (see after). Note that in general, we will not accept refactoring of the code. ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a Contributor License Agreement. If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle: > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms. The full CLA is provided as follows: > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free, > irrevocable license to use, modify, distribute, and sublicense my Contributions. > I understand and accept that Contributions are limited to modifications, improvements, or changes > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to > review, accept, reject, or request changes to any Contributions I submit, and that submitting > a pull request does not guarantee its inclusion in the project. > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify, > reproduce, distribute, and create derivative works based on my Contributions. > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions, > giving the Kyutai-labs full rights to file for and enforce patents. > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me. > I confirm that my Contributions are original and that I have the legal right to grant this license. > If my Contributions include third-party materials, I will ensure that I have the necessary permissions > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion. > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation. > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties. > By submitting a pull request, I agree to be bound by these terms. ## Issues Please submit issues on our Github repository. ## License By contributing to MoshiVis, you agree that your contributions will be licensed under the LICENSE-* files in the root directory of this source tree. In particular, the rust code is licensed under APACHE, and the python code under MIT. ================================================ FILE: ISSUE_TEMPLATE/bug.yml ================================================ name: Bug Report description: You found a bug. labels: ["bug", "triage"] body: - type: dropdown id: backend attributes: label: Backend impacted description: Which backend is concerned with your bug report? options: - The PyTorch implementation - The MLX implementation - The Rust implementation - Other / All default: 0 validations: required: true - type: dropdown id: os attributes: label: Operating system description: What is your operating system? options: - Linux - Mac OS X - Windows (unsupported) default: 0 validations: required: true - type: dropdown id: hardware attributes: label: Hardware description: What hardware are you using? options: - CPU - GPU with CUDA - Metal with MLX default: 0 validations: required: true - type: textarea id: description attributes: label: Description description: Provide a detailed description of your bug. placeholder: value: validations: required: true - type: textarea id: more_info attributes: label: Extra information description: Please provide any other relevant information, such as log extracts, code etc. placeholder: value: validations: required: true - type: textarea id: env attributes: label: Environment description: Please provide any other relevant information, such as log extracts, code etc. placeholder: value: | Fill in the following information on your system. - Operating system version: If the backend impacted is PyTorch: - Python version: - PyTorch version: - CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`): - GPU model and memory: If the backend is MLX: - Mac model: validations: required: true ================================================ FILE: ISSUE_TEMPLATE/question.yml ================================================ name: Question description: You have a question about Moshi/Mimi, this codebase. labels: ["question", "triage"] body: - type: markdown attributes: value: | Please first check the [FAQ](https://github.com/kyutai-labs/moshi/blob/main/FAQ.md). - type: checkboxes id: terms attributes: label: Due diligence description: Have you searched the existing issues / FAQ / Google / asked ChatGPT? options: - label: I have done my due diligence in trying to find the answer myself. required: true - type: dropdown id: backend attributes: label: Topic description: What is your question about? options: - The paper - The PyTorch implementation - The MLX implementation - The Rust implementation - Other / All default: 0 validations: required: true - type: textarea id: question attributes: label: Question description: What is your question? placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc. value: validations: required: true ================================================ FILE: LICENSE-APACHE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: LICENSE-MIT ================================================ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: LICENSE.md ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: PULL_REQUEST_TEMPLATE.md ================================================ ## Checklist - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this. - [ ] Run pre-commit hook. - [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`. ## PR Description ================================================ FILE: README.md ================================================ # M👁️shiVis: Teaching Speech Models to Converse about Images ![CI checks](https://github.com/kyutai-labs/moshivis/actions/workflows/checks.yml/badge.svg) [[Preprint]][moshi-vision-arxiv] [[Demo]][talk-to-moshivis] [[Models on Hugging Face]](https://huggingface.co/collections/kyutai/) MoshiVis is a **Vision Speech Model** (VSM) directly building on the speech-text foundation model [Moshi][moshi-arxiv] and augmenting it with the ability to freely discuss about an image while maintaining its natural conversation style and low latency. In total, MoshiVis adds $\sim$ 206M adapters parameters on top of the 7B Moshi and a pretrained frozen 400M PaliGemma2 vision encoder. This repository currently contains inference code to run your own MoshiVis server supporting three different backends via a webUI frontend. We are also planning to release training/finetuning code in the future. For more information about our speech codec Mimi and speech model Moshi, please visit the original [Moshi repo][moshi-github]. For more technical details on MoshiVis, see our [blog post][blog] and [preprint][moshi-vision-arxiv]. [Talk to MoshiVis][talk-to-moshivis] now on our live demo !

Schema representing the structure of MoshiVis.

To inject visual inputs in the stream of *speech tokens* from Moshi, we extend the core transformer with a **cross-attention mechanism** to infuse visual information into the speech tokens stream. To maintain Moshi's **low-latency** and reduce memory usage, the cross-attention projection weights are shared **across layers.** Moreover, to ensure that Moshi’s original conversational abilities are not lost in the process, the cross-attention modules feature a gating mechanism that allows the model to modulate the visual input stream at will. For more details on MoshiVis, including our training pipeline, synthetic data generation pipeline, and ablation experiments on the gating mechanism see our [preprint][moshi-vision-arxiv]. ## Model Release We release MoshikaVis, based on the original Moshika (*female voice*) checkpoints from Moshi's open-source release. For the image embedding part, we rely on publicly available off-the-shelf image-text encoders: The checkpoints we release use the frozen weights of a vision encoder from the [PaliGemma2](https://arxiv.org/abs/2412.03555) family, specifically on the weights provided at [huggingface](https://huggingface.co/google/paligemma2-3b-pt-448). Note that for convenience, each MoshiVis checkpoint contains the full model: i.e., the vision adaptation modules weights are bundled together with the weights of Mimi (speech codec), the Helium text tokenizer, image encoder, and base Moshi model. For each model, we release several variants compatible with three different backends and quantization formats. Further instructions for each backend can be found below. | Backend | Moshi**ka** | | ------- | ----------- | | [PyTorch](#pytorch-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-pytorch-bf16) | | [Rust](#rust-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-candle-bf16) [Q8_0](https://huggingface.co/kyutai/moshika-vis-candle-q8) | | [MLX](#mlx-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-mlx-bf16) | All model weights (*excluding the bundled vision encoder*) are released under the CC-BY 4.0 license; The bundled vision encoder (*PaliGemma2's vision encoder*) is released under the [Gemma license](https://ai.google.dev/gemma/terms). ## Organisation of the Repository For the **frontend**, we recommend using the provided web UI as it allows for additional echo cancellation that helps the overall model quality. To obtain the client, you can either **(i)** build it yourself from the sources in [`client`](client/) as [described here](#building-the-frontend) or **(ii)** download the pre-built static version we provide: ```bash # Download prebuilt client sources # option 1: using uv dependency manager uv run scripts/get_static_client.py # OR option 2: with pip pip install fire rich huggingface_hub python scripts/get_static_client.py ``` Most commands below will serve this UI by default using the `https` protocol (see more info [here](#http-vs-https)). To connect via `https`, you will need to generate SSL certificates first, as follows: ```bash # Generate the SSL certificates in the root directory openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem ``` We provide three different **backends** for the MoshiVis inference stack in this repo. While we hope that the present codebase will work on Windows, we do not provide official support for it. - A [PyTorch](#pytorch-backend) version in the [`kyuteye_pt`](kyuteye_pt) directory. - A [Rust](#rust-backend) version (as used in the online demo) is in the [`kyuteye_rs`](kyuteye_rs/) directory. - A [MLX](#mlx-backend) version (tested on a MacBook Pro M3) is in the [`kyuteye_mlx`](kyuteye_mlx/) directory For the PyTorch and MLX backends, we recommend using [uv](https://docs.astral.sh/uv/) to setup and run the code, as it will manage all dependencies for you transparently. `uv` is provided as a lightweight binary and can be installed as: ```bash curl -LsSf https://astral.sh/uv/install.sh | sh ``` ### PyTorch Backend > Note: At the moment, we do not support quantization > for the PyTorch version, so you will need a GPU with a significant amount of memory ($\sim$ 24GB). You can start the MoshiVis PyTorch server with the following command and then access the web UI on [https://localhost:8008](https://localhost:8008) ```bash cd kyuteye_pt uv run server configs/moshika-vis.yaml --port 8088 ``` Note that if your GPU is on a distant machine, you may need to forward the remote 8088 port to your localhost using ssh `-L` flag. Then connects to [https://localhost:8088](https://localhost:8088) as mentionned previously. ### Rust Backend > For the Rust backend, you will need a recent version of the [Rust toolchain](https://rustup.rs/). > To compile GPU support, you will need a valid [CUDA](https://developer.nvidia.com/cuda-toolkit) installation, in particular with `nvcc`. In order to run the Rust inference server, use the following command: ```bash cd kyuteye_rs pip install pkg-config cargo run --features cuda --bin moshi-backend -r -- --config configs/config-moshika-vis.json standalone --vis ``` When using macOS, you can replace `--features cuda` with `--features metal`. Alternatively you can use `config-moshika-vis-q8.json` rather than `config-moshika-vis.json` to use the quantized q8 model. You can also change some of the server options (e.g., starting port) in the json file directly. Once the server has printed 'standalone worker listening', this means the model is ready. By default the Rust server will be accessible at [https://localhost:8088](https://localhost:8088). ### MLX Backend We provide a MLX model checkpoint in `bfloat16` as well as quantized checkpoints using `q4` and `q8`. To start the MoshiVis MLX backend you can then run the following commands: ```bash cd kyuteye_mlx # In bfloat16 - weights will be downloaded from HF uv run server # In q4 uv run server -q 4 # In q8 uv run server -q 8 ``` You can then access the web UI at [http://localhost:8008](http://localhost:8008). Note that unlike other backends, not all settings available in the web UI are propagated to the MLX backend. Instead, you can configure some options directly via the command line e.g. `--text-temperature`. ### Frontends #### WebUI We recommend using the WebUI frontend as explained [here](#organisation-of-the-repository). If you want to build the sources yourself, follow these steps (further installation and build instructions can be found in the `client` directory): **via NPM.** ```bash cd client npm install npm run build ``` **via Docker.** If you have `docker` installed, you can also build the client via ```bash docker buildx bake client ``` After building the sources, the static dir for the web UI can then be found in the `client/dist` directory, and will be used as default for the different backend. #### Rust Command Line Alternatively, we also provide a command line interface for the Rust backend: ```bash cd kyuteye_rs; cargo run --bin moshi-cli -r -- tui --host localhost ``` ## Troubleshooting ### http vs https By default, the web UI server starts with the `https` protocol rather than `http`: Accessing a server that is not localhost via `http` may cause issues with using the microphone in the web UI (in some browsers this is only allowed using https). To use an `https` connection, you will first need to setup SSL certificates: ```bash # Generate the SSL certificates in the root directory # pip install openssl openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem ``` Note that if you want to use a `http` connection instead you can: * For the PyTorch backend, add the flag `--ssl False` * For the MLX backend, `http` is the default and `https` can be used with `--ssl certdir` where `certdir` is the directory that contains the certificates. Note that when using `https` you may get warnings from the browser about the site being unsafe. When using chrome for instance, you can bypass these by selecting "Details" or "Advanced", then "Visit this unsafe site" or "Proceed to localhost (unsafe)". ## License The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend. The web client code is provided under the MIT license. The model weights (*excluding the vision encoder*) for the models are released under the CC-BY 4.0 license; the vision encoder is licensed under Apache 2.0. All images displayed in the web UI are obtained under the free Unsplash license. For the precise list of image urls and authors, please refer to [this file](client/public/assets/images/demo/attribution.txt). ## Datasets We also release two data-related artifacts to accompany MoshiVis: * In the `ssvd` directory, we include code and instructions to reproduce our synthetic visual dialogue datasets described in Section 3.3 and Appendix E of our preprint * For evaluation purposes, we also release [`Babillage`](https://huggingface.co/datasets/kyutai/Babillage) on HuggingFace, which contains spoken versions of three common VLM benchmarks (COCO-Captions 2014, OCR-VQA and VQAv2) for prompting the model's visual understanding in audio form. ## Citation If you use MoshiVis in your research, please cite our work: ``` @article{kyutai2025moshivis, author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez}, year = {2025}, title = {Vision-Speech Models: Teaching Speech Models to Converse about Images}, journal = {ArXiv}, url = {https://arxiv.org/abs/2503.15633} } @techreport{kyutai2024moshi, title={Moshi: a speech-text foundation model for real-time dialogue}, author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour}, year={2024}, eprint={2410.00037}, archivePrefix={arXiv}, primaryClass={eess.AS}, url={https://arxiv.org/abs/2410.00037}, } ``` [blog]: https://kyutai.org/moshivis [moshi-vision-arxiv]: https://arxiv.org/abs/2503.15633 [moshi-arxiv]: https://arxiv.org/abs/2410.00037 [moshi-github]: https://github.com/kyutai-labs/moshi/tree/main?tab=readme-ov-file#models [talk-to-moshivis]: https://vis.moshi.chat ================================================ FILE: client/.eslinrc.json ================================================ { "env": { "browser": true, "es2021": true }, "extends": [ "plugin:react/recommended", "standard-with-typescript", "plugin:import/typescript", "plugin:prettier/recommended" ], "parser": "@typescript-eslint/parser", "overrides": [], "parserOptions": { "ecmaVersion": "latest", "sourceType": "module", "project": "./tsconfig.json" }, "plugins": ["react", "prettier"], "rules": { "@typescript-eslint/triple-slash-reference": "off" } } ================================================ FILE: client/.nvmrc ================================================ v20.12.2 ================================================ FILE: client/.prettierignore ================================================ dist/* ================================================ FILE: client/.prettierrc.json ================================================ { "arrowParens": "avoid", "singleQuote": false, "trailingComma": "all", "tabWidth": 2, "useTabs": false, "semi": true, "printWidth": 80, "plugins": ["prettier-plugin-tailwindcss"] } ================================================ FILE: client/Dockerfile ================================================ FROM node:20 AS builder WORKDIR /app COPY . /app RUN npm install RUN npx vite build FROM scratch AS build_result COPY --from=builder /app/dist / ================================================ FILE: client/LICENSE ================================================ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: client/README.md ================================================ # moshi-client Frontend for the demo. ## Quickstart To start developping, you will need a basic environment with NodeJS, for instance: ```bash cd client micromamba create -n node22 python=3.10 micromamba activate node22 micromamba install nodejs=22.11 # install npm install ``` Alternatively, you can use [NVM](https://github.com/nvm-sh/nvm) to help you manage your node version and make sure you're on the recommended version for this project. If you do so run, `nvm use`. To run the client in dev mode, use: ```bash # typically will start on port 5173 npm run dev ``` When you're satisfied, build the client (in `dist` directory) that will be used as static dir by the different backends: ```bash npm run build ``` If Docker is available, you can skip all the previous steps and just run ``` docker buildx bake ``` from the root of this repository. It will output the static sources for the website in `client/dist`. ### License The present code is provided under the MIT license. ================================================ FILE: client/index.html ================================================ moshi.chat
================================================ FILE: client/package.json ================================================ { "name": "kyutai-client", "private": true, "version": "0.0.0", "type": "module", "scripts": { "dev": "vite", "build": "tsc && vite build", "lint": "eslint", "lint:fix": "eslint --fix", "prettier": "prettier --write .", "preview": "vite preview" }, "devDependencies": { "@eslint/js": "^9.3.0", "@types/react": "^18.3.1", "@types/react-dom": "^18.3.0", "@types/ws": "^8.5.10", "autoprefixer": "^10.4.19", "daisyui": "^4.12.2", "eslint": "^8.57.0", "eslint-config-prettier": "^9.1.0", "eslint-plugin-prettier": "^5.1.3", "eslint-plugin-react": "^7.34.1", "globals": "^15.2.0", "postcss": "^8.4.38", "prettier": "^3.2.5", "prettier-eslint": "^16.3.0", "prettier-plugin-tailwindcss": "^0.5.14", "tailwindcss": "^3.4.3", "typescript": "^5.2.2", "typescript-eslint": "^7.9.0", "vite": "^6.2.1", "vite-plugin-top-level-await": "^1.4.1" }, "dependencies": { "eruda": "^3.0.1", "opus-recorder": "^8.0.5", "react": "^18.3.1", "react-dom": "^18.3.1", "react-router-dom": "^6.23.1", "webm-duration-fix": "^1.0.4", "ws": "^8.16.0", "zod": "^3.23.8" } } ================================================ FILE: client/postcss.config.js ================================================ export default { plugins: { tailwindcss: {}, autoprefixer: {}, }, }; ================================================ FILE: client/public/assets/images/demo/attribution.txt ================================================ image1.jpg https://unsplash.com/photos/seven-brushes-and-water-color-palette-TTwwVG4Isjw Crystal de Passillé-Chabot image2.jpg https://unsplash.com/photos/a-bunch-of-stuffed-animals-that-are-on-a-shelf-uTDvpjnF2nw Hoyoun Lee image3.jpg https://unsplash.com/photos/an-orange-and-white-clownfish-in-an-aquarium-8iHpPG7Vk9Y James Lee image4.jpg https://unsplash.com/photos/red-dragon-action-figure-on-table-X-A-LJVAhzk Clint Bustrillos image5.jpg https://unsplash.com/photos/panda-bear-sitting-on-bamboo-sticks-surrounded-with-trees-NsNRu6dfRds Ying Wu image6.jpg https://unsplash.com/photos/carousel-with-string-lights-OEBeLcrzlaw cmophoto.net image7.jpg https://unsplash.com/photos/flight-of-pigeons-flying-above-grass-field-near-eiffel-tower-in-paris-m45uW4f9YQg Stijn te Strake image8.jpg https://unsplash.com/photos/gray-typewriter-and-macbook-1F4MukO0UNg Glenn Carstens-Peters image9.jpg https://unsplash.com/photos/a-couple-of-sandwiches-sitting-on-top-of-a-cutting-board-smv9xho-dnE Deepthi Clicks image10.jpg https://unsplash.com/photos/man-in-white-hat-and-black-shirt-painting-nkVa4ylaWG0 Federico Scarionati image11.jpg https://unsplash.com/photos/a-statue-of-a-gnome-next-to-a-light-pole-UNT9ExjTgZE Lionel Mermoz image12.jpg https://unsplash.com/photos/brown-concrete-building-on-green-grass-field-during-daytime-mDIMJzdu5D0 Celine Chamiot-Poncet image13.jpg https://unsplash.com/photos/astronaut-in-spacesuit-floating-in-space-Yj1M5riCKk4 NASA image14.jpg https://unsplash.com/photos/marble-toy-lot-near-yellow-drawstring-pouch-1kZzV02D2hM Crissy Jarvis image15.jpg https://unsplash.com/photos/person-holding-black-frying-pan-APDMfLHZiRA Kevin McCutcheon image16.jpg https://unsplash.com/photos/selective-focus-photo-of-four-green-humming-birds-with-red-flowers-5TU1htuOUn4 James wainscoat image17.jpg https://unsplash.com/photos/orange-and-white-tabby-cat-sitting-on-brown-wooden-table-in-kitchen-room-w2DsS-ZAP4U Paul Hanoka image18.jpg https://unsplash.com/photos/lantern-on-the-street-at-nighttime--F3wMFrZ7z0 Denys Nevozhai image19.jpg https://unsplash.com/photos/baked-breads-in-rack-ZnPNZpjzi0M Dan Gold image20.jpg https://unsplash.com/photos/lemonades-on-tray-JB5YCqOXV1o Rod Long ================================================ FILE: client/src/app.tsx ================================================ import ReactDOM from "react-dom/client"; import { createBrowserRouter, RouterProvider, } from "react-router-dom"; import "./index.css"; // @ts-expect-error - Worker is not recognized by the TS compiler import { DecoderWorker } from "./decoder/decoderWorker"; import { Queue } from "./pages/Queue/Queue"; const router = createBrowserRouter([ { path: "/", element: , }, ]); ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render( ); ================================================ FILE: client/src/audio-processor.ts ================================================ // @ts-nocheck function asMs(samples) { return (samples * 1000 / sampleRate).toFixed(1); } function asSamples(mili) { return Math.round(mili * sampleRate / 1000); } class MoshiProcessor extends AudioWorkletProcessor { constructor() { super(); console.log("Moshi processor lives", currentFrame, sampleRate); console.log(currentTime); // Buffer length definitions let frameSize = asSamples(80); // initialBufferSamples: we wait to have at least that many samples before starting to play this.initialBufferSamples = 2 * frameSize; // once we have enough samples, we further wait that long before starting to play. // This allows to have buffer lengths that are not a multiple of frameSize. this.partialBufferSamples = asSamples(80); // If the buffer length goes over that many, we will drop the oldest packets until // we reach back initialBufferSamples + partialBufferSamples. this.maxBufferSamples = asSamples(80); // increments this.partialBufferIncrement = asSamples(40); this.maxPartialWithIncrements = asSamples(240); this.maxBufferSamplesIncrement = asSamples(40); this.maxMaxBufferWithIncrements = asSamples(240); // State and metrics this.initState(); this.port.onmessage = (event) => { if (event.data.type == "reset") { console.log("Reset audio processor state."); this.initState(); return; } let frame = event.data.frame; this.frames.push(frame); if (this.currentSamples() >= this.initialBufferSamples && !this.started) { this.start(); } if (this.pidx < 20) { console.log(this.timestamp(), "Got packet", this.pidx++, asMs(this.currentSamples()), asMs(frame.length)) } if (this.currentSamples() >= this.totalMaxBufferSamples()) { console.log(this.timestamp(), "Dropping packets", asMs(this.currentSamples()), asMs(this.totalMaxBufferSamples())); let target = this.initialBufferSamples + this.partialBufferSamples while (this.currentSamples() > (this.initialBufferSamples + this.partialBufferSamples)) { let first = this.frames[0]; let to_remove = this.currentSamples() - target; to_remove = Math.min(first.length - this.offsetInFirstBuffer, to_remove); this.offsetInFirstBuffer += to_remove; this.timeInStream += to_remove / sampleRate; if (this.offsetInFirstBuffer == first.length) { this.frames.shift(); this.offsetInFirstBuffer = 0; } } console.log(this.timestamp(), "Packet dropped", asMs(this.currentSamples())); this.maxBufferSamples += this.maxBufferSamplesIncrement; this.maxBufferSamples = Math.min(this.maxMaxBufferWithIncrements, this.maxBufferSamples); console.log("Increased maxBuffer to", asMs(this.maxBufferSamples)); } let delay = this.currentSamples() / sampleRate; this.port.postMessage({ totalAudioPlayed: this.totalAudioPlayed, actualAudioPlayed: this.actualAudioPlayed, delay: event.data.micDuration - this.timeInStream, minDelay: this.minDelay, maxDelay: this.maxDelay, }); }; } initState() { this.frames = new Array(); this.offsetInFirstBuffer = 0; this.firstOut = false; this.remainingPartialBufferSamples = 0; this.timeInStream = 0.; this.resetStart(); // Metrics this.totalAudioPlayed = 0.; this.actualAudioPlayed = 0.; this.maxDelay = 0.; this.minDelay = 2000.; // Debug this.pidx = 0; // For now let's reset the buffer params. this.partialBufferSamples = asSamples(80); this.maxBufferSamples = asSamples(80); } totalMaxBufferSamples() { return this.maxBufferSamples + this.partialBufferSamples + this.initialBufferSamples; } timestamp() { return Date.now() % 1000; } currentSamples() { let samples = 0; for (let k = 0; k < this.frames.length; k++) { samples += this.frames[k].length } samples -= this.offsetInFirstBuffer; return samples; } resetStart() { this.started = false; } start() { this.started = true; this.remainingPartialBufferSamples = this.partialBufferSamples; this.firstOut = true; } canPlay() { return this.started && this.frames.length > 0 && this.remainingPartialBufferSamples <= 0; } process(inputs, outputs, parameters) { let delay = this.currentSamples() / sampleRate; if (this.canPlay()) { this.maxDelay = Math.max(this.maxDelay, delay); this.minDelay = Math.min(this.minDelay, delay); } const output = outputs[0][0]; if (!this.canPlay()) { if (this.actualAudioPlayed > 0) { this.totalAudioPlayed += output.length / sampleRate; } this.remainingPartialBufferSamples -= output.length; return true; } if (this.firstOut) { console.log(this.timestamp(), "Audio resumed", asMs(this.currentSamples()), this.remainingPartialBufferSamples); } let first = this.frames[0]; let out_idx = 0; while (out_idx < output.length && this.frames.length) { let first = this.frames[0]; let to_copy = Math.min(first.length - this.offsetInFirstBuffer, output.length - out_idx); output.set(first.subarray(this.offsetInFirstBuffer, this.offsetInFirstBuffer + to_copy), out_idx); this.offsetInFirstBuffer += to_copy; out_idx += to_copy; if (this.offsetInFirstBuffer == first.length) { this.offsetInFirstBuffer = 0; this.frames.shift(); } } if (this.firstOut) { this.firstOut = false; for (let i = 0; i < out_idx; i++) { output[i] *= i / out_idx; } } if (out_idx < output.length) { console.log(this.timestamp(), "Missed some audio", output.length - out_idx); this.partialBufferSamples += this.partialBufferIncrement; this.partialBufferSamples = Math.min(this.partialBufferSamples, this.maxPartialWithIncrements); console.log("Increased partial buffer to", asMs(this.partialBufferSamples)); // We ran out of a buffer, let's revert to the started state to replenish it. this.resetStart(); for (let i = 0; i < out_idx; i++) { output[i] *= (out_idx - i) / out_idx; } } this.totalAudioPlayed += output.length / sampleRate; this.actualAudioPlayed += out_idx / sampleRate; this.timeInStream += out_idx / sampleRate; return true; } } registerProcessor("moshi-processor", MoshiProcessor); ================================================ FILE: client/src/components/Button/Button.tsx ================================================ import { FC } from "react"; type ButtonProps = React.ButtonHTMLAttributes; export const Button: FC = ({ children, className, ...props }) => { return ( ); }; ================================================ FILE: client/src/components/ImageGallery/ImageGallery.tsx ================================================ import { useState, ChangeEvent } from "react"; import { Button } from "../Button/Button"; // Natural images import img1 from "/assets/images/demo/image1.jpg"; import img2 from "/assets/images/demo/image2.jpg"; import img3 from "/assets/images/demo/image3.jpg"; import img4 from "/assets/images/demo/image4.jpg"; import img5 from "/assets/images/demo/image5.jpg"; import img6 from "/assets/images/demo/image6.jpg"; import img7 from "/assets/images/demo/image7.jpg"; import img8 from "/assets/images/demo/image8.jpg"; import img9 from "/assets/images/demo/image9.jpg"; import img10 from "/assets/images/demo/image10.jpg"; import img11 from "/assets/images/demo/image11.jpg"; import img12 from "/assets/images/demo/image12.jpg"; import img13 from "/assets/images/demo/image13.jpg"; import img14 from "/assets/images/demo/image14.jpg"; import img15 from "/assets/images/demo/image15.jpg"; import img16 from "/assets/images/demo/image16.jpg"; import img17 from "/assets/images/demo/image17.jpg"; import img18 from "/assets/images/demo/image18.jpg"; import img19 from "/assets/images/demo/image19.jpg"; import img20 from "/assets/images/demo/image20.jpg"; const images = [ img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16, img17, img18, img19, img20, ] var images_order: number[] = []; for (let i = 0; i < images.length; i++) { images_order.push(i) } type ImageGalleryProps = React.InputHTMLAttributes & { // Properties for the ImageGallery paramsSetter: Function; clickAction: Function; size: number; numImages: number; } type ImageItemProps = React.InputHTMLAttributes & { // Properties for a single item in the ImageGallery // Two actions: // paramsSetter sets the chosen image url into the model params // clickAction then starts the conversation paramsSetter: Function; clickAction: Function; size: number; imageUrl: string; } function ImageSelect(props: ImageItemProps) { // Represents a single image in the gallery const [isHover, setIsHover] = useState(false); const handleMouseEnter = () => { setIsHover(true); }; const handleMouseLeave = () => { setIsHover(false); }; let bordercolor = isHover ? "#f7a319" : "black"; let bgalpha = isHover ? 0.05 : 0.6; let textalpha = isHover ? 1.0 : 0.0 let label = isHover ? "Select" : "X"; let style = { width: props.size, height: props.size, background: `url(${props.imageUrl})`, backgroundSize: "100% 100%", border: `3px solid ${bordercolor}`, margin: "2px", padding: "0px", color: `rgba(255, 255, 255, ${textalpha})`, boxShadow: `inset 0 0 0 1000px rgba(0,0,0,${bgalpha})`, textShadow: `2px 2px 2px rgba(0, 0, 0, ${textalpha})` }; return ( ); } const shuffle = (array: number[]) => { return array.sort(() => Math.random() - 0.5); }; export const ImageGallery = (props: ImageGalleryProps) => { const [ordering, SetOrdering] = useState(images_order); const [preview, setPreview] = useState(sessionStorage.getItem("imageUrl")); const handleFileChange = (e: ChangeEvent, isCapture: boolean) => { if (e.target.files && e.target.files[0]) { const file = e.target.files[0]; const url = URL.createObjectURL(file); setPreview(url); props.paramsSetter(url); // only save the image URL when it's an uploaded file // doesn't really seem to work with one-shot photo otherwise if (!isCapture) { sessionStorage.setItem("imageUrl", url); } } }; const resetFile = () => { setPreview(null); props.paramsSetter(undefined); sessionStorage.removeItem("imageUrl"); }; function handleShuffle() { SetOrdering(shuffle([...ordering])); } // Image Gallery widget (random subset) const steps = []; for (let i = 0; i < props.numImages; i++) { steps.push(); } return (
{preview && Preview}
{preview && } {preview && }
{!preview &&
handleFileChange(e, false)} />
} {!preview &&
handleFileChange(e, true)} />
} {!preview && } {!preview && }
{steps}
) ; }; ================================================ FILE: client/src/components/Input/Input.tsx ================================================ type InputProps = React.InputHTMLAttributes & { error?: string; } export const Input = ({className, error, ...props}:InputProps) => { return (
{error &&

{error}

}
); } ================================================ FILE: client/src/decoder/decoderWorker.ts ================================================ export const DecoderWorker = new Worker( new URL("/assets/decoderWorker.min.js", import.meta.url), ); ================================================ FILE: client/src/env.ts ================================================ type ENV = { VITE_QUEUE_API_PATH: string; VITE_ENV: 'development' | 'production'; }; const parseEnv = (): ENV => { const VITE_QUEUE_API_PATH = import.meta.env.VITE_QUEUE_API_PATH; if (!VITE_QUEUE_API_PATH) { throw new Error("VITE_QUEUE_API_PATH is not defined"); } return { VITE_QUEUE_API_PATH, VITE_ENV: import.meta.env.DEV ? 'development' : 'production', }; }; export const env = parseEnv(); ================================================ FILE: client/src/index.css ================================================ @tailwind base; @tailwind components; @tailwind utilities; @layer utilities { /* Hide scrollbar for Chrome, Safari and Opera */ .no-scrollbar::-webkit-scrollbar { display: none; } /* Hide scrollbar for IE, Edge and Firefox */ .no-scrollbar { -ms-overflow-style: none; /* IE and Edge */ scrollbar-width: none; /* Firefox */ } .scrollbar::-webkit-scrollbar { width: 10px; } .scrollbar::-webkit-scrollbar-track { background: transparent; } .scrollbar::-webkit-scrollbar-thumb { background: white; border: 3px solid #f6f7ed; } } .settingsbutton#changed:before { content: "C"; width: 13px; height: 13px; line-height: 18px; text-align: center; display: block; border-radius: 50%; background: #54e8b3; border: 1px solid #FFF; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4); color: #FFF; position: absolute; top: -7px; right: -7px; } .main-grid { display: grid; grid-template-columns: 1fr; grid-template-rows: min-content 1fr 1fr; gap: 30px; grid-auto-flow: column; grid-template-areas: "controls" "player" "player-text"; @media screen and (min-width: 768px) { grid-template-columns: 2fr 2.5fr; grid-template-rows: min-content min-content min-content 1fr; gap: 30px 30px; grid-auto-flow: column; align-items: center; justify-items: center; grid-template-areas: "controls controls" "player player-stats" "player player-text" "player player-text"; } } .presentation { max-width: 450px; } .presentation>p { padding-top: 10px; } .gallery { max-width: 450px; } .cute-words { color: #54e8b3; } .vis-words { color: #f7a319; } .explain-links { color: #BCFCE5; } .controls { grid-area: controls; } .player { grid-area: player; grid-template-areas: "server-audio" "user-audio" "user-image" "download-links"; display: grid; grid-template-columns: 1fr 1fr; grid-template-rows: 3fr; justify-items: stretch; row-gap: 30px; /* margin:auto; */ } .user-image { grid-area: user-image; grid-column: 1 / -1; grid-row: 1; height: 200px } .user-image img { height: 100%; width: auto; margin: auto } .server-audio { grid-area: server-audio; grid-column: 1; grid-row: 2; } .user-audio { grid-area: user-audio; grid-column: 2; grid-row: 2; } .download-links { grid-area: download-links; grid-column: 1/-1; grid-row: 3; color: #54e8b3; height: 10%; } .player-stats { grid-area: player-stats; width: 100%; height: 100%; } .commands { grid-area: commands; width: 100%; height: 100%; } .player-text { grid-area: player-text; width: 100%; height: 100%; overflow: scroll; } ================================================ FILE: client/src/modules.d.ts ================================================ declare module "opus-recorder"; ================================================ FILE: client/src/pages/Conversation/Conversation.tsx ================================================ import { FC, MutableRefObject, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useSocket } from "./hooks/useSocket"; import { SocketContext } from "./SocketContext"; import { ServerAudio } from "./components/ServerAudio/ServerAudio"; import { UserAudio } from "./components/UserAudio/UserAudio"; import { Button } from "../../components/Button/Button"; import { ServerAudioStats } from "./components/ServerAudio/ServerAudioStats"; import { AudioStats } from "./hooks/useServerAudio"; import { TextDisplay } from "./components/TextDisplay/TextDisplay"; import { MediaContext } from "./MediaContext"; import { ServerInfo } from "./components/ServerInfo/ServerInfo"; import { ModelParamsValues, useModelParams } from "./hooks/useModelParams"; import { ModelParams } from "./components/ModelParams/ModelParams"; import fixWebmDuration from "webm-duration-fix"; import canvasLogo from "./canvas-logo.png"; import { getMimeType, getExtension } from "./getMimeType"; type ConversationProps = { workerAddr: string; workerAuthId?: string; sessionAuthId?: string; sessionId?: number; email?: string; audioContext: MutableRefObject; worklet: MutableRefObject; onConversationEnd?: () => void; isBypass?: boolean; } & Partial; const buildURL = ({ workerAddr, params, workerAuthId, email, textSeed, audioSeed, }: { workerAddr: string; params: ModelParamsValues; workerAuthId?: string; email?: string; textSeed: number; audioSeed: number; }) => { if (workerAddr == "same" || workerAddr == "") { workerAddr = window.location.hostname + ":" + window.location.port; console.log("Overriding workerAddr to", workerAddr); } const wsProtocol = (window.location.protocol === 'https:') ? 'wss' : 'ws'; const url = new URL(`${wsProtocol}://${workerAddr}/api/chat`); if (workerAuthId) { url.searchParams.append("worker_auth_id", workerAuthId); } if (email) { url.searchParams.append("email", email); } url.searchParams.append("text_temperature", params.textTemperature.toString()); url.searchParams.append("text_topk", params.textTopk.toString()); url.searchParams.append("audio_temperature", params.audioTemperature.toString()); url.searchParams.append("audio_topk", params.audioTopk.toString()); url.searchParams.append("pad_mult", params.padMult.toString()); url.searchParams.append("text_seed", textSeed.toString()); url.searchParams.append("audio_seed", audioSeed.toString()); url.searchParams.append("repetition_penalty_context", params.repetitionPenaltyContext.toString()); url.searchParams.append("repetition_penalty", params.repetitionPenalty.toString()); // Add image params if given if (params.imageUrl != undefined) { url.searchParams.append("image_resolution", params.imageResolution.toString()); url.searchParams.append("center_crop", params.centerCrop.toString()); url.searchParams.append("xa_start", params.gateDelay.toString()); url.searchParams.append("text_temperature_gating_influence", params.gateInfluence.toString()); } return url.toString(); }; export const Conversation: FC = ({ workerAddr, workerAuthId, audioContext, worklet, sessionAuthId, sessionId, onConversationEnd, isBypass = false, email, ...params }) => { const getAudioStats = useRef<() => AudioStats>(() => ({ playedAudioDuration: 0, missedAudioDuration: 0, totalAudioMessages: 0, delay: 0, minPlaybackDelay: 0, maxPlaybackDelay: 0, })); const isRecording = useRef(false); const videoChunks = useRef([]); const audioChunks = useRef([]); const audioStreamDestination = useRef(audioContext.current.createMediaStreamDestination()); const mediaRecorder = useRef(null); const audioRecorder = useRef(new MediaRecorder(audioStreamDestination.current.stream, { mimeType: getMimeType("audio"), audioBitsPerSecond: 128000 })); const [videoURL, setVideoURL] = useState(""); const [audioURL, setAudioURL] = useState(""); const [userRating, setUserRating] = useState(0); const [userRatingHovered, setUserRatingHovered] = useState(0); const [baseBlobName, setBaseBlobName] = useState("moshi"); const [isOver, setIsOver] = useState(false); const modelParams = useModelParams(params); const micDuration = useRef(0); const actualAudioPlayed = useRef(0); const textContainerRef = useRef(null); const textSeed = useMemo(() => Math.round(1000000 * Math.random()), []); const audioSeed = useMemo(() => Math.round(1000000 * Math.random()), []); const canvasRef = useRef(null); const logoRef = useRef(null); const [isLogoLoaded, setIsLogoLoaded] = useState(false); const WSURL = buildURL({ workerAddr, params: modelParams, workerAuthId, email: email, textSeed: textSeed, audioSeed: audioSeed, }); const onDisconnect = useCallback(() => { setIsOver(true); console.log("on disconnect!"); stopRecording(); }, [setIsOver]); const { isConnected, sendMessage, socket, start, stop } = useSocket({ // onMessage, uri: WSURL, onDisconnect, imageUrl: params.imageUrl, }); useEffect(() => { audioRecorder.current.ondataavailable = (e) => { audioChunks.current.push(e.data); }; audioRecorder.current.onstop = async () => { let blob: Blob; const mimeType = getMimeType("audio"); if (mimeType.includes("webm")) { blob = await fixWebmDuration(new Blob(audioChunks.current, { type: mimeType })); } else { blob = new Blob(audioChunks.current, { type: mimeType }); } setAudioURL(URL.createObjectURL(blob)); audioChunks.current = []; console.log("Audio Recording and encoding finished"); }; }, [mediaRecorder, audioRecorder, setVideoURL, setAudioURL, videoChunks, audioChunks]); const RatingButton = (props: { rating: number }) => { const [isHover, setIsHover] = useState(false); const handleMouseEnter = () => { setUserRatingHovered(props.rating); setIsHover(true); }; const handleMouseLeave = () => { setUserRatingHovered(0); setIsHover(false); }; let style = { color: (isHover || userRating >= props.rating || userRatingHovered >= props.rating) ? `#f7a319` : '#333333', }; return ( ); }; useEffect(() => { start(); return () => { stop(); }; }, [start, workerAuthId]); useEffect(() => { if (!canvasRef) { console.log("No canvas ref"); return; } if (!logoRef) { console.log("No logo ref"); return; } if (!isLogoLoaded) { console.log("Logo not loaded"); return; } if (!canvasRef.current) { console.log("No canvas"); return; } if (!logoRef.current) { console.log("No logo"); return; } const ctx = canvasRef.current.getContext("2d"); if (ctx) { ctx.drawImage(logoRef.current, 20, 250, 320, 98); ctx.lineWidth = 1; ctx.strokeStyle = "white"; ctx.strokeRect(5, 5, 370, 370); } }, [canvasRef, logoRef, isLogoLoaded]); const startRecording = useCallback(() => { if (isRecording.current) { return; } console.log(Date.now() % 1000, "Starting recording"); console.log("Starting recording"); if (canvasRef.current) { // Note: Attaching a track from this stream to the existing MediaRecorder // rather than creating a new MediaRecorder for the canvas stream // doesn't work on Safari as it just ends the recording immediately. // It works on Chrome though and is much cleaner. console.log("Adding canvas to stream"); const captureStream = canvasRef.current.captureStream(30); captureStream.addTrack(audioStreamDestination.current.stream.getAudioTracks()[0]); mediaRecorder.current = new MediaRecorder(captureStream, { mimeType: getMimeType("video"), videoBitsPerSecond: 1000000 }); mediaRecorder.current.ondataavailable = (e) => { console.log("Video data available"); videoChunks.current.push(e.data); }; mediaRecorder.current.onstop = async () => { let blob: Blob; const mimeType = getMimeType("video"); if (mimeType.includes("webm")) { blob = await fixWebmDuration(new Blob(videoChunks.current, { type: mimeType })); } else { blob = new Blob(videoChunks.current, { type: mimeType }); } setVideoURL(URL.createObjectURL(blob)); videoChunks.current = []; console.log("Video Recording and encoding finished"); }; } worklet.current?.connect(audioStreamDestination.current); // videoStream.current.addTrack(audioStreamDestination.current.stream.getAudioTracks()[0]); setVideoURL(""); setAudioURL(""); mediaRecorder.current?.start(); audioRecorder.current.start(); isRecording.current = true; }, [isRecording, setVideoURL, setVideoURL, worklet, audioStreamDestination, mediaRecorder, audioRecorder, canvasRef]); const stopRecording = useCallback(() => { console.log("Stopping recording"); console.log("isRecording", isRecording) if (!isRecording.current) { return; } worklet.current?.disconnect(audioStreamDestination.current); audioRecorder.current.stop(); mediaRecorder.current?.stop(); isRecording.current = false; }, [isRecording, worklet, audioStreamDestination, mediaRecorder, audioRecorder]); return (
{isOver && !isBypass && ( ) } { (!isOver || isBypass) && ( ) }
AudioStats) => (getAudioStats.current = callback) } />
Feel free to rate the interaction before ending the session:
{audioURL && } {videoURL && } {videoURL && getExtension("video") === "webm" && }
setBaseBlobName(x)} /> {!workerAuthId && }
{ console.log("Logo loaded"); setIsLogoLoaded(true); }} />
); }; ================================================ FILE: client/src/pages/Conversation/MediaContext.ts ================================================ import { MutableRefObject, createContext, useContext } from "react"; type MediaContextType = { startRecording: () => void; stopRecording: () => void; audioContext: MutableRefObject; audioStreamDestination: MutableRefObject; worklet: MutableRefObject; micDuration: MutableRefObject; actualAudioPlayed: MutableRefObject; }; export const MediaContext = createContext(null); export const useMediaContext = () => { const context = useContext(MediaContext); if (!context) { throw new Error( "useMediaContext must be used within a MediaContextProvider", ); } return context; }; ================================================ FILE: client/src/pages/Conversation/SocketContext.ts ================================================ import { createContext, useContext } from "react"; import { WSMessage } from "../../protocol/types"; type SocketContextType = { isConnected: boolean; socket: WebSocket | null; sendMessage: (message: WSMessage) => void; }; export const SocketContext = createContext({ isConnected: false, socket: null, sendMessage: () => {}, }); export const useSocketContext = () => { return useContext(SocketContext); }; ================================================ FILE: client/src/pages/Conversation/components/AudioVisualizer/AudioVisualizer.tsx ================================================ import { FC, useCallback, useEffect, useRef } from "react"; type AudioVisualizerProps = { analyser: AnalyserNode | null; }; export const AudioVisualizer: FC = ({ analyser }) => { const requestRef = useRef(null); const canvasRef = useRef(null); const visualizeData = useCallback(() => { requestRef.current = window.requestAnimationFrame(() => visualizeData()); if (!canvasRef.current) { console.log("Canvas not found"); return; } const audioData = new Uint8Array(140); analyser?.getByteFrequencyData(audioData); const bar_width = 3; let start = 0; const ctx = canvasRef.current.getContext("2d"); if (!ctx) { console.log("Canvas context not found"); return; } ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); for (let i = 0; i < audioData.length; i++) { start = i * 4; let gradient = ctx.createLinearGradient( 0, 0, canvasRef.current.width, canvasRef.current.height, ); gradient.addColorStop(0.2, "#2392f5"); gradient.addColorStop(0.5, "#fe0095"); gradient.addColorStop(1.0, "purple"); ctx.fillStyle = gradient; ctx.fillRect( start, canvasRef.current.height, bar_width, (-audioData[i] * 100) / 255, ); } }, [analyser]); const resetCanvas = useCallback(() => { if (!canvasRef.current) { return; } const ctx = canvasRef.current.getContext("2d"); if (!ctx) { return; } ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); }, []); useEffect(() => { if (!analyser) { return; } visualizeData(); return () => { if (requestRef.current) { console.log("Canceling animation frame"); cancelAnimationFrame(requestRef.current); } }; }, [visualizeData, analyser, resetCanvas]); return ; }; ================================================ FILE: client/src/pages/Conversation/components/AudioVisualizer/ClientVisualizer.tsx ================================================ import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react"; import { clamp } from "../../hooks/audioUtils"; type AudioVisualizerProps = { analyser: AnalyserNode | null; parent: RefObject; copyCanvasRef: RefObject; }; const MAX_INTENSITY = 255; const COLORS = [ "#197556", "#299e77", "#32b89b", "#31d4b8", "#14d9d5", "#41eff2", "#7ff3f5", "#789bf5", "#eb94eb", "#e63280", "#c41862", ]; export const ClientVisualizer: FC = ({ analyser, parent, copyCanvasRef }) => { const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0); const requestRef = useRef(null); const canvasRef = useRef(null); const drawBars = useCallback( ( ctx: CanvasRenderingContext2D, x: number, y: number, volume: number, height: number, width: number, gap: number, ) => { const barHeight = height / 10 - gap; for (let i = 1; i <= 10; i++) { const barY = y + height + gap + Math.min(1, width / 30) - (i * barHeight + i * gap); ctx.fillStyle = COLORS[i - 1]; ctx.strokeStyle = "white"; ctx.lineWidth = Math.min(1, height / 100); if (i <= volume) { ctx.fillRect(x, barY, width, barHeight); } ctx.strokeRect(x, barY, width, barHeight); } }, [], ); const draw = useCallback((ctx: CanvasRenderingContext2D, audioData: Uint8Array, x: number, y: number, width: number, height: number) => { const stereoGap = Math.floor(width / 30); const barGap = Math.floor(height / 30); const padding = Math.floor(width / 30); const maxBarHeight = Math.floor(height - padding * 2); const maxBarWidth = Math.floor( width / 2.5 - stereoGap - padding * 2, ); const centerX = x + width / 2; const averageIntensity = Math.sqrt( audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length, ); const intensity = clamp( averageIntensity * 1.4, averageIntensity, MAX_INTENSITY, ); const volume = Math.floor((intensity * 10) / MAX_INTENSITY); ctx.fillStyle = "rgba(0, 0, 0, 0)"; ctx.fillRect(x, y, width, height); drawBars( ctx, centerX - maxBarWidth - stereoGap / 2, y, volume, maxBarHeight, maxBarWidth, barGap, ); drawBars( ctx, centerX + stereoGap / 2, y, volume, maxBarHeight, maxBarWidth, barGap, ); }, [analyser, drawBars]); const visualizeData = useCallback(() => { const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0 if (width !== canvasWidth) { console.log("Setting canvas width"); setCanvasWidth(width); } requestRef.current = window.requestAnimationFrame(() => visualizeData()); if (!canvasRef.current) { console.log("Canvas not found"); return; } const audioData = new Uint8Array(140); analyser?.getByteFrequencyData(audioData); const ctx = canvasRef.current.getContext("2d"); if (!ctx) { console.log("Canvas context not found"); return; } ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); draw(ctx, audioData, 0, 0, width, width); if (copyCanvasRef?.current) { const copyCtx = copyCanvasRef.current.getContext("2d"); if (copyCtx) { const x = 240; const y = 140; const width = 140 / 1.25; // slightly scaled down version const height = 180 / 1.25; // slightly scaled down version copyCtx.clearRect(x, y, width, height); draw(copyCtx, audioData, x, y, width, height); } } }, [analyser, canvasWidth, drawBars, parent, copyCanvasRef, draw]); useEffect(() => { visualizeData(); return () => { if (requestRef.current) { console.log("Canceling animation frame"); cancelAnimationFrame(requestRef.current); } }; }, [visualizeData, analyser]); return ( ); }; ================================================ FILE: client/src/pages/Conversation/components/AudioVisualizer/ServerVisualizer.tsx ================================================ import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react"; import { clamp } from "../../hooks/audioUtils"; import { useSocketContext } from "../../SocketContext"; type AudioVisualizerProps = { analyser: AnalyserNode | null; parent: RefObject; imageUrl: string | undefined; copyCanvasRef?: RefObject; }; const MAX_INTENSITY = 255; export const ServerVisualizer: FC = ({ analyser, parent, imageUrl, copyCanvasRef }) => { const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0); const requestRef = useRef(null); const canvasRef = useRef(null); const { isConnected } = useSocketContext(); const draw = useCallback((width: number, centerX: number, centerY: number, audioData: Uint8Array, ctx: CanvasRenderingContext2D) => { const maxCircleWidth = Math.floor(width * 0.95); const averageIntensity = Math.sqrt( audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length, ); const intensity = clamp( averageIntensity * 1.4, averageIntensity, MAX_INTENSITY, ); const relIntensity = intensity / MAX_INTENSITY; const radius = ((isConnected ? 0.3 + 0.7 * relIntensity : relIntensity) * maxCircleWidth) / 2; // Draw a circle with radius based on intensity ctx.clearRect(centerX - width / 2, centerY - width / 2, width, width); ctx.fillStyle = 'rgba(0, 0, 0, 0)'; ctx.fillRect(centerX - width / 2, centerY - width / 2, width, width); ctx.beginPath(); //ctx.fillStyle = "#39e3a7"; ctx.fillStyle = 'rgba(57, 227, 167, 0.5)'; ctx.arc(centerX, centerY, radius, 0, 2 * Math.PI); ctx.fill(); ctx.closePath(); // Draw an inner circle if we are connected. if (isConnected) { ctx.beginPath(); ctx.arc(centerX, centerY, maxCircleWidth / 6, 0, 2 * Math.PI); // ctx.fillStyle = "#BCFCE5"; ctx.fillStyle = 'rgba(188, 252, 229, 0.5)'; ctx.fill(); ctx.closePath(); } //Draw a circle with max radius ctx.beginPath(); ctx.arc(centerX, centerY, maxCircleWidth / 2, 0, 2 * Math.PI); ctx.strokeStyle = "white"; ctx.lineWidth = (width / 50 < 3) ? 3 : width / 50; ctx.stroke(); ctx.fillStyle = 'rgba(0, 0, 0, 0.6)'; ctx.fill() ctx.closePath(); }, [isConnected]); const visualizeData = useCallback(() => { const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0; if (width !== canvasWidth) { console.log("Setting canvas width"); setCanvasWidth(width); } requestRef.current = window.requestAnimationFrame(() => visualizeData()); if (!canvasRef.current) { console.log("Canvas not found"); return; } const ctx = canvasRef.current.getContext("2d"); const audioData = new Uint8Array(140); analyser?.getByteFrequencyData(audioData); if (!ctx) { console.log("Canvas context not found"); return; } const centerX = width / 2; const centerY = width / 2; draw(width, centerX, centerY, audioData, ctx); if (copyCanvasRef?.current) { const copyCtx = copyCanvasRef.current.getContext("2d"); if (copyCtx) { draw(100, 295, 70, audioData, copyCtx); if (imageUrl) { const img = new Image() img.src = imageUrl; img.onload = function () { copyCtx.drawImage(img, 25, 25, 200, 200); copyCtx.strokeStyle = 'white'; copyCtx.rect(25, 25, 200, 200); copyCtx.stroke(); }; } } } }, [analyser, isConnected, canvasWidth, parent, copyCanvasRef]); useEffect(() => { if (!analyser) { return; } analyser.smoothingTimeConstant = 0.95; visualizeData(); return () => { if (requestRef.current) { console.log("Canceling animation frame"); cancelAnimationFrame(requestRef.current); } }; }, [visualizeData, analyser]); return ( ); }; ================================================ FILE: client/src/pages/Conversation/components/Controls/Controls.tsx ================================================ import { controlBOSMessage, controlEOSMessage, } from "../../../../protocol/testMessages"; import { useSocketContext } from "../../SocketContext"; import { Button } from "../../../../components/Button/Button"; export const Controls = () => { const { sendMessage } = useSocketContext(); const sendControlBOS = () => { sendMessage(controlBOSMessage); }; const sendControlEOS = () => { sendMessage(controlEOSMessage); }; return (
); }; ================================================ FILE: client/src/pages/Conversation/components/ModelParams/ModelParams.tsx ================================================ import { FC, RefObject } from "react"; import { useModelParams } from "../../hooks/useModelParams"; import { Button } from "../../../../components/Button/Button"; type ModelParamsProps = { isConnected: boolean; modal?: RefObject, } & ReturnType; export const ModelParams: FC = ({ textTemperature, textTopk, audioTemperature, audioTopk, padMult, repetitionPenalty, repetitionPenaltyContext, imageResolution, gateDelay, gateInfluence, displayColor, centerCrop, setTextTemperature, setTextTopk, setAudioTemperature, setAudioTopk, setPadMult, setRepetitionPenalty, setRepetitionPenaltyContext, setImageResolution, setGateDelay, setGateInfluence, setDisplayColor, setCenterCrop, resetParams, isConnected, modal, }) => { return (
{!isConnected && Hover on each element to display a helpful tooltip}
Text temperature: {textTemperature} setTextTemperature(parseFloat(e.target.value))} />
Text topk: {textTopk} setTextTopk(parseInt(e.target.value))} />
Audio temperature: {audioTemperature} setAudioTemperature(parseFloat(e.target.value))} />
Audio topk: {audioTopk} setAudioTopk(parseInt(e.target.value))} />
Padding multiplier: {padMult} setPadMult(parseFloat(e.target.value))} />
Repeat penalty: {repetitionPenalty} setRepetitionPenalty(parseFloat(e.target.value))} />
Repeat penalty last N: {repetitionPenaltyContext} setRepetitionPenaltyContext(parseFloat(e.target.value))} />
Image max-side (px): {imageResolution} setImageResolution(parseFloat(e.target.value))} />
Center Crop: {centerCrop ? '✔️' : '✖️'} setCenterCrop((parseFloat(e.target.value) == 1) ? true : false)} />
Gating Delay: {gateDelay} setGateDelay(parseFloat(e.target.value))} />
Display Gating: {displayColor ? '✔️' : '✖️'} setDisplayColor((parseFloat(e.target.value) == 1) ? true : false)} />
Temperature Gating: {gateInfluence} setGateInfluence(parseFloat(e.target.value))} />
{!isConnected && } {!isConnected && }
) }; ================================================ FILE: client/src/pages/Conversation/components/ServerAudio/ServerAudio.tsx ================================================ import { FC, useRef } from "react"; import { AudioStats, useServerAudio } from "../../hooks/useServerAudio"; import { ServerVisualizer } from "../AudioVisualizer/ServerVisualizer"; type ServerAudioProps = { setGetAudioStats: (getAudioStats: () => AudioStats) => void; imageUrl: string | undefined; copyCanvasRef?: React.RefObject; }; export const ServerAudio: FC = ({ setGetAudioStats, imageUrl, copyCanvasRef }) => { const { analyser, hasCriticalDelay, setHasCriticalDelay } = useServerAudio({ setGetAudioStats, }); const containerRef = useRef(null); return ( <> {hasCriticalDelay && (

A connection issue has been detected, you've been reconnected

)}
); }; ================================================ FILE: client/src/pages/Conversation/components/ServerAudio/ServerAudioStats.tsx ================================================ import { useState, useEffect, useRef } from "react"; type ServerAudioStatsProps = { getAudioStats: React.MutableRefObject< () => { playedAudioDuration: number; missedAudioDuration: number; totalAudioMessages: number; delay: number; minPlaybackDelay: number; maxPlaybackDelay: number; } >; }; export const ServerAudioStats = ({ getAudioStats }: ServerAudioStatsProps) => { const [audioStats, setAudioStats] = useState(getAudioStats.current()); const movingAverageSum = useRef(0.); const movingAverageCount = useRef(0.); const movingBeta = 0.85; let convertMinSecs = (total_secs: number) => { // convert secs to the format mm:ss.cc let mins = (Math.floor(total_secs / 60)).toString(); let secs = (Math.floor(total_secs) % 60).toString(); let cents = (Math.floor(100 * (total_secs - Math.floor(total_secs)))).toString(); if (secs.length < 2) { secs = "0" + secs; } if (cents.length < 2) { cents = "0" + cents; } return mins + ":" + secs + "." + cents; }; useEffect(() => { const interval = setInterval(() => { const newAudioStats = getAudioStats.current(); setAudioStats(newAudioStats); movingAverageCount.current *= movingBeta; movingAverageCount.current += (1 - movingBeta) * 1; movingAverageSum.current *= movingBeta; movingAverageSum.current += (1 - movingBeta) * newAudioStats.delay; }, 141); return () => { clearInterval(interval); }; }, []); return (

Server Audio Stats

Audio played: {convertMinSecs(audioStats.playedAudioDuration)}
Missed audio: {convertMinSecs(audioStats.missedAudioDuration)}
Latency: {(movingAverageSum.current / movingAverageCount.current).toFixed(3)}
Min/Max buffer: {audioStats.minPlaybackDelay.toFixed(3)} / {audioStats.maxPlaybackDelay.toFixed(3)}
); }; ================================================ FILE: client/src/pages/Conversation/components/ServerInfo/ServerInfo.tsx ================================================ import { useServerInfo } from "../../hooks/useServerInfo"; function pretty_format(num: number): number { return Math.round((num + Number.EPSILON) * 100) / 100 } export const ServerInfo = (props: { setFileName: Function }) => { const { serverInfo } = useServerInfo(); if (!serverInfo) { return null; } props.setFileName(serverInfo.base_filename); return (
Our server is running on the following configuration:
Image resolution: {serverInfo.image_resolution} px
Text temperature: {pretty_format(serverInfo.text_temperature)}
Text topk: {serverInfo.text_topk}
Temperature gating: {pretty_format(serverInfo.text_temperature_gating_influence)}
Audio temperature: {pretty_format(serverInfo.audio_temperature)}
Audio topk: {serverInfo.audio_topk}
Pad mult: {serverInfo.pad_mult}
Repeat penalty last N: {serverInfo.repetition_penalty_context}
Repeat penalty: {serverInfo.repetition_penalty}
LM model file: {serverInfo.lm_model_file}
Instance name: {serverInfo.instance_name}
); }; ================================================ FILE: client/src/pages/Conversation/components/TextDisplay/TextDisplay.tsx ================================================ import { FC, useEffect, useRef } from "react"; import { useServerText } from "../../hooks/useServerText"; type TextDisplayProps = { containerRef: React.RefObject; displayColor: boolean | undefined; }; // Palette 2: Purple to Green Moshi // sns.diverging_palette(288, 145, s=90, l=72, n=11).as_hex() // Palette 2: Green to orange Moshi // sns.diverging_palette(145, 40, s=90, l=72, n=11).as_hex() const textDisplayColors = [ '#38c886', '#5bd09a', '#80d9af', '#a4e2c4', '#c8ead9', '#f2f1f1', '#f4e0cb', '#f5cea6', '#f5bd81', '#f6ac5b', '#f79b37'] function clamp_color(v: number) { return v <= 0 ? 0 : v >= textDisplayColors.length ? textDisplayColors.length : v } export const TextDisplay: FC = ({ containerRef, displayColor }) => { const { text, textColor } = useServerText(); const currentIndex = text.length - 1; const prevScrollTop = useRef(0); useEffect(() => { if (containerRef.current) { prevScrollTop.current = containerRef.current.scrollTop; containerRef.current.scroll({ top: containerRef.current.scrollHeight, behavior: "smooth", }); } }, [text]); if (displayColor && (textColor.length == text.length)) { return (
{text.map((t, i) => ( {t} )) }
); } else { return (
{text.map((t, i) => ( {t} ))}
); }; }; ================================================ FILE: client/src/pages/Conversation/components/TextDisplay/TextDisplayStats.tsx ================================================ import { FC } from "react"; type TextDisplayStatsProps = { totalTextMessages: number; }; export const TextDisplayStats: FC = ({ totalTextMessages, }) => { return (

Text Display Stats

Total messages:

{totalTextMessages}

); }; ================================================ FILE: client/src/pages/Conversation/components/UserAudio/UserAudio.tsx ================================================ import { FC, useCallback, useEffect, useRef, useState } from "react"; import { useSocketContext } from "../../SocketContext"; import { useUserAudio } from "../../hooks/useUserAudio"; import { ClientVisualizer } from "../AudioVisualizer/ClientVisualizer"; type UserAudioProps = { copyCanvasRef: React.RefObject; }; export const UserAudio: FC = ({ copyCanvasRef }) => { const [analyser, setAnalyser] = useState(null); const { sendMessage, isConnected } = useSocketContext(); const containerRef = useRef(null); const onRecordingStart = useCallback(() => { console.log("Recording started"); }, []); const onRecordingStop = useCallback(() => { console.log("Recording stopped"); }, []); const onRecordingChunk = useCallback( (chunk: Uint8Array) => { if (!isConnected) { return; } sendMessage({ type: "audio", data: chunk, }); }, [sendMessage, isConnected], ); const { startRecordingUser, stopRecording } = useUserAudio({ constraints: { audio: { echoCancellation: true, noiseSuppression: true, autoGainControl: true, channelCount: 1, }, video: false, }, onDataChunk: onRecordingChunk, onRecordingStart, onRecordingStop, }); useEffect(() => { let res: Awaited>; if (isConnected) { startRecordingUser().then(result => { if (result) { res = result; setAnalyser(result.analyser); } }); } return () => { console.log("Stop recording called from somewhere else."); stopRecording(); res?.source?.disconnect(); }; }, [startRecordingUser, stopRecording, isConnected]); return (
); }; ================================================ FILE: client/src/pages/Conversation/components/UserAudio/UserAudioStats.tsx ================================================ import { FC } from "react"; type UserAudioStatsProps = { sentMessagesCount: number; }; export const UserAudioStats: FC = ({ sentMessagesCount, }) => { return (

User Audio Stats

Total messages:

{sentMessagesCount}

); }; ================================================ FILE: client/src/pages/Conversation/getMimeType.ts ================================================ export const mimeTypeCheck = () => { const types = [ "audio/ogg", "audio/wav", "audio/webm;codecs=opus", "audio/webm;codecs=pcm", "audio/webm;codecs=pcm_s16le", "audio/webm;codecs=pcm_f32le", "audio/mp3", "audio/aac", "audio/mp4", "audio/webm", "audio/mpeg", "video/mp4", "video/webm;codecs=vp9", "video/webm;codecs=vp8", "video/webm", ]; for (const mime of types) { console.log(mime, MediaRecorder.isTypeSupported(mime)); } } const getVideoMimeType = () => { if (!MediaRecorder.isTypeSupported){ return "video/mp4"; } if (MediaRecorder.isTypeSupported("video/webm")) { return "video/webm"; } if (MediaRecorder.isTypeSupported("video/mp4")) { return "video/mp4"; } console.log("No supported video mime type found") return ""; }; const getAudioMimeType = () => { if (!MediaRecorder.isTypeSupported){ return "audio/mp4"; } if (MediaRecorder.isTypeSupported("audio/webm")) { return "audio/webm"; } if (MediaRecorder.isTypeSupported("audio/mpeg")) { return "audio/mpeg"; }`` if (MediaRecorder.isTypeSupported("audio/mp4")) { return "audio/mp4"; } console.log("No supported audio mime type found") return ""; } export const getMimeType = (type: "audio" | "video") => { if(type === "audio") { return getAudioMimeType(); } return getVideoMimeType(); } export const getExtension = (type: "audio" | "video") => { if(getMimeType(type).includes("mp4")) { return "mp4"; } if(getMimeType(type).includes("mpeg")) { return "mp3"; } return "webm"; } ================================================ FILE: client/src/pages/Conversation/hooks/audioUtils.ts ================================================ export const clamp = (value: number, min: number, max: number) => { return Math.min(Math.max(value, min), max); }; ================================================ FILE: client/src/pages/Conversation/hooks/useModelParams.ts ================================================ import { useCallback, useState } from "react"; export const DEFAULT_TEXT_TEMPERATURE = 0.45; export const DEFAULT_TEXT_TOPK = 25; export const DEFAULT_AUDIO_TEMPERATURE = 0.7; export const DEFAULT_AUDIO_TOPK = 250; export const DEFAULT_PAD_MULT = 0; export const DEFAULT_REPETITION_PENALTY_CONTEXT = 64; export const DEFAULT_REPETITION_PENALTY = 1.15; export const DEFAULT_IMAGE_RESOLUTION = 448; export const DEFAULT_IMAGE_URL = undefined; export const DEFAULT_GATE_DELAY = 16; export const DEFAULT_GATE_INFLUENCE = 0.0; export const DEFAULT_DISPLAY_COLOR = true; export const DEFAULT_CENTER_CROP = false; export type ModelParamsValues = { textTemperature: number; textTopk: number; audioTemperature: number; audioTopk: number; padMult: number; repetitionPenaltyContext: number, repetitionPenalty: number, imageResolution: number, imageUrl: string | undefined, gateDelay: number, gateInfluence: number, displayColor: boolean, centerCrop: boolean, }; export function importantSettingsHaveChanged(params: ModelParamsValues): boolean { return (params.textTemperature != DEFAULT_TEXT_TEMPERATURE) || (params.textTopk != DEFAULT_TEXT_TOPK) || (params.audioTemperature != DEFAULT_AUDIO_TEMPERATURE) || (params.audioTopk != DEFAULT_AUDIO_TOPK) || (params.padMult != DEFAULT_PAD_MULT) || (params.repetitionPenalty != DEFAULT_REPETITION_PENALTY) || (params.repetitionPenaltyContext != DEFAULT_REPETITION_PENALTY_CONTEXT) || (params.imageResolution != DEFAULT_IMAGE_RESOLUTION) || (params.gateDelay != DEFAULT_GATE_DELAY) || (params.gateInfluence != DEFAULT_GATE_INFLUENCE) || (params.centerCrop != DEFAULT_CENTER_CROP) } type useModelParamsArgs = Partial; export const useModelParams = (params?: useModelParamsArgs) => { const [textTemperature, setTextTemperatureBase] = useState(params?.textTemperature || DEFAULT_TEXT_TEMPERATURE); const [textTopk, setTextTopkBase] = useState(params?.textTopk || DEFAULT_TEXT_TOPK); const [audioTemperature, setAudioTemperatureBase] = useState(params?.audioTemperature || DEFAULT_AUDIO_TEMPERATURE); const [audioTopk, setAudioTopkBase] = useState(params?.audioTopk || DEFAULT_AUDIO_TOPK); const [padMult, setPadMultBase] = useState(params?.padMult || DEFAULT_PAD_MULT); const [repetitionPenalty, setRepetitionPenaltyBase] = useState(params?.repetitionPenalty || DEFAULT_REPETITION_PENALTY); const [repetitionPenaltyContext, setRepetitionPenaltyContextBase] = useState(params?.repetitionPenaltyContext || DEFAULT_REPETITION_PENALTY_CONTEXT); const [imageResolution, setImageResolutionBase] = useState(params?.imageResolution || DEFAULT_IMAGE_RESOLUTION); const [imageUrl, setImageUrlBase] = useState(params?.imageUrl || DEFAULT_IMAGE_URL); const [gateDelay, setGateDelayBase] = useState(params?.gateDelay || DEFAULT_GATE_DELAY); const [gateInfluence, setGateInfluenceBase] = useState(params?.gateInfluence || DEFAULT_GATE_INFLUENCE); const [displayColor, setDisplayColorBase] = useState(params?.displayColor == undefined ? DEFAULT_DISPLAY_COLOR : params?.displayColor); const [centerCrop, setCenterCropBase] = useState(params?.centerCrop == undefined ? DEFAULT_CENTER_CROP : params?.centerCrop); const resetParams = useCallback(() => { setTextTemperatureBase(DEFAULT_TEXT_TEMPERATURE); setTextTopkBase(DEFAULT_TEXT_TOPK); setAudioTemperatureBase(DEFAULT_AUDIO_TEMPERATURE); setAudioTopkBase(DEFAULT_AUDIO_TOPK); setPadMultBase(DEFAULT_PAD_MULT); setRepetitionPenaltyBase(DEFAULT_REPETITION_PENALTY); setRepetitionPenaltyContextBase(DEFAULT_REPETITION_PENALTY_CONTEXT); setImageResolutionBase(DEFAULT_IMAGE_RESOLUTION); setImageUrlBase(DEFAULT_IMAGE_URL); setGateDelayBase(DEFAULT_GATE_DELAY); setGateInfluenceBase(DEFAULT_GATE_INFLUENCE); setDisplayColorBase(DEFAULT_DISPLAY_COLOR); setCenterCropBase(DEFAULT_CENTER_CROP); }, [ setTextTemperatureBase, setTextTopkBase, setAudioTemperatureBase, setAudioTopkBase, setPadMultBase, setRepetitionPenaltyBase, setRepetitionPenaltyContextBase, setImageResolutionBase, setImageUrlBase, setDisplayColorBase, setCenterCropBase, ]); const setTextTemperature = useCallback((value: number) => { if (value <= 1.2 && value >= 0.2) { setTextTemperatureBase(value); } }, []); const setTextTopk = useCallback((value: number) => { if (value <= 500 && value >= 10) { setTextTopkBase(value); } }, []); const setAudioTemperature = useCallback((value: number) => { if (value <= 1.2 && value >= 0.2) { setAudioTemperatureBase(value); } }, []); const setAudioTopk = useCallback((value: number) => { if (value <= 500 && value >= 10) { setAudioTopkBase(value); } }, []); const setPadMult = useCallback((value: number) => { if (value <= 4 && value >= -4) { setPadMultBase(value); } }, []); const setRepetitionPenalty = useCallback((value: number) => { if (value <= 2.0 && value >= 1.0) { setRepetitionPenaltyBase(value); } }, []); const setRepetitionPenaltyContext = useCallback((value: number) => { if (value <= 200 && value >= 0) { setRepetitionPenaltyContextBase(value); } }, []); const setImageResolution = useCallback((value: number) => { if (value <= 512 && value >= 160) { setImageResolutionBase(value); } }, []); const setImageUrl = useCallback((value: string | undefined) => { setImageUrlBase(value); }, []); const setGateDelay = useCallback((value: number) => { if (value <= 32 && value >= 0) { setGateDelayBase(value); } }, []); const setGateInfluence = useCallback((value: number) => { if (value <= 1.0 && value >= 0.0) { setGateInfluenceBase(value); } }, []); const setDisplayColor = useCallback((value: boolean) => { setDisplayColorBase(value); }, []); const setCenterCrop = useCallback((value: boolean) => { setCenterCropBase(value); }, []); return { textTemperature, textTopk, audioTemperature, audioTopk, padMult, repetitionPenalty, repetitionPenaltyContext, imageResolution, imageUrl, gateDelay, gateInfluence, displayColor, centerCrop, setTextTemperature, setTextTopk, setAudioTemperature, setAudioTopk, setPadMult, setRepetitionPenalty, setRepetitionPenaltyContext, setImageUrl, setImageResolution, setGateDelay, setGateInfluence, setDisplayColor, setCenterCrop, resetParams, } } ================================================ FILE: client/src/pages/Conversation/hooks/useServerAudio.ts ================================================ import { useCallback, useEffect, useRef, useState } from "react"; import { useSocketContext } from "../SocketContext"; import { decodeMessage } from "../../../protocol/encoder"; import { useMediaContext } from "../MediaContext"; import { DecoderWorker } from "../../../decoder/decoderWorker"; export type AudioStats = { playedAudioDuration: number; missedAudioDuration: number; totalAudioMessages: number; delay: number; minPlaybackDelay: number; maxPlaybackDelay: number; }; type useServerAudioArgs = { setGetAudioStats?: (getAudioStats: () => AudioStats) => void; }; type WorkletStats = { totalAudioPlayed: number; actualAudioPlayed: number; delay: number; minDelay: number; maxDelay: number; }; export const useServerAudio = ({setGetAudioStats}: useServerAudioArgs) => { const { socket } = useSocketContext(); const {startRecording, stopRecording, audioContext, worklet, micDuration, actualAudioPlayed } = useMediaContext(); const analyser = useRef(audioContext.current.createAnalyser()); worklet.current.connect(analyser.current); const startTime = useRef(null); const decoderWorker = useRef(DecoderWorker); const [hasCriticalDelay, setHasCriticalDelay] = useState(false); const totalAudioMessages = useRef(0); const receivedDuration = useRef(0); const workletStats = useRef({ totalAudioPlayed: 0, actualAudioPlayed: 0, delay: 0, minDelay: 0, maxDelay: 0,}); const onDecode = useCallback( async (data: Float32Array) => { receivedDuration.current += data.length / audioContext.current.sampleRate; worklet.current.port.postMessage({frame: data, type: "audio", micDuration: micDuration.current}); }, [], ); const onWorkletMessage = useCallback( (event: MessageEvent) => { workletStats.current = event.data; actualAudioPlayed.current = workletStats.current.actualAudioPlayed; }, [], ); worklet.current.port.onmessage = onWorkletMessage; const getAudioStats = useCallback(() => { return { playedAudioDuration: workletStats.current.actualAudioPlayed, delay: workletStats.current.delay, minPlaybackDelay: workletStats.current.minDelay, maxPlaybackDelay: workletStats.current.maxDelay, missedAudioDuration: workletStats.current.totalAudioPlayed - workletStats.current.actualAudioPlayed, totalAudioMessages: totalAudioMessages.current, }; }, []); const onWorkerMessage = useCallback( (e: MessageEvent) => { if (!e.data) { return; } onDecode(e.data[0]); }, [onDecode], ); let midx = 0; const decodeAudio = useCallback((data: Uint8Array) => { if (midx < 5) { console.log(Date.now() % 1000, "Got NETWORK message", micDuration.current - workletStats.current.actualAudioPlayed, midx++); } decoderWorker.current.postMessage( { command: "decode", pages: data, }, [data.buffer], ); }, []); const onSocketMessage = useCallback( (e: MessageEvent) => { const dataArray = new Uint8Array(e.data); const message = decodeMessage(dataArray); if (message.type === "audio") { decodeAudio(message.data); //For stats purposes for now totalAudioMessages.current++; } }, [decodeAudio], ); useEffect(() => { const currentSocket = socket; if (!currentSocket) { return; } worklet.current.port.postMessage({type: "reset"}); console.log(Date.now() % 1000, "Should start in a bit"); startRecording(); currentSocket.addEventListener("message", onSocketMessage); totalAudioMessages.current = 0; return () => { console.log("Stop recording called in unknown function.") stopRecording(); startTime.current = null; currentSocket.removeEventListener("message", onSocketMessage); }; }, [socket]); useEffect(() => { if (setGetAudioStats) { console.log("Setting getAudioStats"); setGetAudioStats(getAudioStats); } }, [setGetAudioStats, getAudioStats]); useEffect(() => { decoderWorker.current.onmessage = onWorkerMessage; // 960 = 24000 / 12.5 / 2 // The /2 is a bit optional, but won't hurt for recording the mic, and for the // the decoding it might help getting some decoded audio out asap. decoderWorker.current.postMessage({ command: "init", bufferLength: 960 * audioContext.current.sampleRate / 24000, decoderSampleRate: 24000, outputBufferSampleRate: audioContext.current.sampleRate, resampleQuality: 0, }); return () => { console.log("Terminating worker"); }; }, [onWorkerMessage]); return { decodeAudio, analyser, getAudioStats, hasCriticalDelay, setHasCriticalDelay, }; }; ================================================ FILE: client/src/pages/Conversation/hooks/useServerInfo.ts ================================================ import { useCallback, useEffect, useState } from "react"; import { useSocketContext } from "../SocketContext"; import { decodeMessage } from "../../../protocol/encoder"; import { z } from "zod"; const ServersInfoSchema = z.object({ text_temperature: z.number(), text_topk: z.number(), text_temperature_gating_influence: z.number(), audio_temperature: z.number(), audio_topk: z.number(), pad_mult: z.number(), repetition_penalty_context: z.number(), repetition_penalty: z.number(), image_resolution: z.number(), lm_model_file: z.string(), instance_name: z.string(), base_filename: z.string(), build_info: z.object({ build_timestamp: z.string(), build_date: z.string(), git_branch: z.string(), git_timestamp: z.string(), git_date: z.string(), git_hash: z.string(), git_describe: z.string(), rustc_host_triple: z.string(), rustc_version: z.string(), cargo_target_triple: z.string(), }), }); const parseInfo = (infos: any) => { const serverInfo = ServersInfoSchema.safeParse(infos); if (!serverInfo.success) { console.error(serverInfo.error); return null; } return serverInfo.data; }; type ServerInfo = { text_temperature: number; text_topk: number; text_temperature_gating_influence: number; audio_temperature: number; audio_topk: number; pad_mult: number; repetition_penalty_context: number; repetition_penalty: number; image_resolution: number; lm_model_file: string; instance_name: string; base_filename: string; build_info: { build_timestamp: string; build_date: string; git_branch: string; git_timestamp: string; git_date: string; git_hash: string; git_describe: string; rustc_host_triple: string; rustc_version: string; cargo_target_triple: string; }; } export const useServerInfo = () => { const [serverInfo, setServerInfo] = useState(null); const { socket } = useSocketContext(); const onSocketMessage = useCallback((e: MessageEvent) => { const dataArray = new Uint8Array(e.data); const message = decodeMessage(dataArray); if (message.type === "metadata") { const infos = parseInfo(message.data); if (infos) { setServerInfo(infos); console.log("received metadata", infos); } } }, [setServerInfo]); useEffect(() => { const currentSocket = socket; if (!currentSocket) { return; } setServerInfo(null); currentSocket.addEventListener("message", onSocketMessage); return () => { currentSocket.removeEventListener("message", onSocketMessage); }; }, [socket]); return { serverInfo }; }; ================================================ FILE: client/src/pages/Conversation/hooks/useServerText.ts ================================================ import { useCallback, useEffect, useState } from "react"; import { useSocketContext } from "../SocketContext"; import { decodeMessage } from "../../../protocol/encoder"; export const useServerText = () => { const [text, setText] = useState([]); const [textColor, setTextColor] = useState([]); const [totalTextMessages, setTotalTextMessages] = useState(0); const { socket } = useSocketContext(); const onSocketMessage = useCallback((e: MessageEvent) => { const dataArray = new Uint8Array(e.data); const message = decodeMessage(dataArray); if (message.type === "text") { setText(text => [...text, message.data]); setTotalTextMessages(count => count + 1); } else if (message.type === "coloredtext") { setText(text => [...text, message.data]); setTextColor(textColor => [...textColor, message.color]); setTotalTextMessages(count => count + 1); } }, []); useEffect(() => { const currentSocket = socket; if (!currentSocket) { return; } setText([]); currentSocket.addEventListener("message", onSocketMessage); return () => { currentSocket.removeEventListener("message", onSocketMessage); }; }, [socket]); return { text, textColor, totalTextMessages }; }; ================================================ FILE: client/src/pages/Conversation/hooks/useSocket.ts ================================================ import { useState, useEffect, useCallback, useRef } from "react"; import { WSMessage } from "../../../protocol/types"; import { decodeMessage, encodeMessage } from "../../../protocol/encoder"; export const useSocket = ({ onMessage, uri, onDisconnect: onDisconnectProp, imageUrl, }: { onMessage?: (message: WSMessage) => void; uri: string; onDisconnect?: () => void; imageUrl?: string; }) => { const lastMessageTime = useRef(null); const [isConnected, setIsConnected] = useState(false); const [imageSent, setImageSent] = useState(false); const [onConnectDone, setOnConnectDone] = useState(false); const [socket, setSocket] = useState(null); const sendMessage = useCallback( (message: WSMessage) => { if (!socket) { console.log("socket not present"); return false; } // audio message with no connection if (message.type == "audio" && !isConnected) { console.log("isConnected false on audio message, please wait for handshake"); return false; } // otherwise send message socket.send(encodeMessage(message)); return true; }, [isConnected, socket], ); useEffect(() => { async function sendImage() { console.log("image send", imageSent); console.log("image url", imageUrl); if (imageUrl && !imageSent) { const imageBytes = await fetchImageBytes(imageUrl); const sent = sendMessage({ type: "image", data: imageBytes, }); if (sent) { console.log("Image sent"); setImageSent(true); } } } sendImage(); }, [socket, onConnectDone, imageUrl, imageSent]); const onConnect = useCallback(() => { console.log("connected, now waiting for handshake."); setOnConnectDone(true); }, [setIsConnected, socket]); const onDisconnect = useCallback(() => { console.log("disconnected"); if (onDisconnectProp) { onDisconnectProp(); } setIsConnected(false); }, [onDisconnectProp]); const onMessageEvent = useCallback( (eventData: MessageEvent) => { lastMessageTime.current = Date.now(); const dataArray = new Uint8Array(eventData.data); const message = decodeMessage(dataArray); if (message.type == "handshake") { console.log("Handshake received, let's rocknroll."); setIsConnected(true); } if (!onMessage) { return; } onMessage(message); }, [onMessage, setIsConnected], ); const start = useCallback(() => { const ws = new WebSocket(uri); ws.binaryType = "arraybuffer"; ws.addEventListener("open", onConnect); ws.addEventListener("close", onDisconnect); ws.addEventListener("message", onMessageEvent); setSocket(ws); console.log("Socket created", ws); lastMessageTime.current = Date.now(); }, [uri, onMessage, onDisconnectProp]); const stop = useCallback(() => { setIsConnected(false); if (onDisconnectProp) { onDisconnectProp(); } socket?.close(); setSocket(null); }, [socket]); useEffect(() => { if (!isConnected) { return; } let intervalId = setInterval(() => { if (lastMessageTime.current && Date.now() - lastMessageTime.current > 10000) { console.log("closing socket due to inactivity", socket); socket?.close(); onDisconnect(); clearInterval(intervalId); } }, 500); return () => { lastMessageTime.current = null; clearInterval(intervalId); }; }, [isConnected, socket]); return { isConnected, socket, sendMessage, start, stop, }; }; async function fetchImageBytes(imageUrl: string) { const response = await fetch(imageUrl); if (!response.ok) { throw new Error(`Failed to fetch image: ${response.statusText}`); } const arrayBuffer = await response.arrayBuffer(); return new Uint8Array(arrayBuffer); } ================================================ FILE: client/src/pages/Conversation/hooks/useUserAudio.ts ================================================ import { useCallback, useRef, useState } from "react"; import Recorder from "opus-recorder"; import encoderPath from "opus-recorder/dist/encoderWorker.min.js?url"; import { useMediaContext } from "../MediaContext"; export enum UserMediaStatuses { IDLE = "IDLE", READY = "READY", WAITING_FOR_PERMISSION = "WAITING_FOR_PERMISSION", ERROR = "ERROR", RECORDING = "RECORDING", STOPPED = "STOPPED", STOPPING = "STOPPING", } type useUserAudioArgs = { constraints: MediaStreamConstraints; onDataChunk?: (chunk: Uint8Array) => void; onRecordingStart?: () => void; onRecordingStop?: () => void; }; export const useUserAudio = ({ constraints, onDataChunk, onRecordingStart = () => {}, onRecordingStop = () => {}, }: useUserAudioArgs) => { const { audioStreamDestination, audioContext, micDuration } = useMediaContext(); const [error, setError] = useState(null); const [status, setStatus] = useState( UserMediaStatuses.IDLE, ); //TODO: Fix any type for recorder const recorder = useRef(null); const getMediaStream = useCallback(async () => { setStatus(UserMediaStatuses.WAITING_FOR_PERMISSION); try { const stream = await window.navigator.mediaDevices.getUserMedia(constraints); setStatus(UserMediaStatuses.IDLE); return stream; } catch (error: any) { console.error(error); setError(error.name); setStatus(UserMediaStatuses.ERROR); return null; } }, [constraints, setStatus]); const startRecordingUser = useCallback(async () => { console.log(Date.now() % 1000, "Starting recording in user audio"); const mediaStream = await getMediaStream(); if (mediaStream) { const analyser = audioContext.current.createAnalyser(); const source = audioContext.current.createMediaStreamSource(mediaStream); source.connect(analyser); source.connect(audioStreamDestination.current); // For buffer length: 960 = 24000 / 12.5 / 2 // The /2 is a bit optional, but won't hurt for recording the mic. // Note that bufferLength actually has 0 impact for mono audio, only // the frameSize and maxFramesPerPage seems to have any. const recorderOptions = { mediaTrackConstraints: constraints, encoderPath, bufferLength: Math.round(960 * audioContext.current.sampleRate / 24000), encoderFrameSize: 20, encoderSampleRate: 24000, maxFramesPerPage: 2, numberOfChannels: 1, recordingGain: 1, resampleQuality: 3, encoderComplexity: 0, encoderApplication: 2049, streamPages: true, }; let chunk_idx = 0; let lastpos = 0; recorder.current = new Recorder(recorderOptions); recorder.current.ondataavailable = (data: Uint8Array) => { // opus actually always works at 48khz, so it seems this is the proper value to use here. micDuration.current = recorder.current.encodedSamplePosition / 48000; if (chunk_idx < 5) { console.log(Date.now() % 1000, "Mic Data chunk", chunk_idx++, (recorder.current.encodedSamplePosition - lastpos) / 48000, micDuration.current); lastpos = recorder.current.encodedSamplePosition; } if (onDataChunk) { onDataChunk(data); } }; recorder.current.onstart = () => { setStatus(UserMediaStatuses.RECORDING); onRecordingStart(); }; recorder.current.onstop = () => { setStatus(UserMediaStatuses.STOPPED); source.disconnect(); onRecordingStop(); recorder.current = null; }; if (recorder.current) { // setTimeout(() => {recorder.current.start(); setStatus(UserMediaStatuses.RECORDING);}, 1500); recorder.current.start(); } return { analyser, mediaStream, source, }; } return { analyser: null, mediaStream: null, source: null, }; }, [setStatus, onDataChunk, onRecordingStart, onRecordingStop]); const stopRecording = useCallback(() => { setStatus(UserMediaStatuses.STOPPING); if (recorder.current) { recorder.current.stop(); } }, [setStatus]); return { status, error, startRecordingUser, stopRecording, }; }; ================================================ FILE: client/src/pages/Queue/Queue.tsx ================================================ import moshiProcessorUrl from "../../audio-processor.ts?worker&url"; import { FC, useEffect, useMemo, useState, useCallback, useRef, MutableRefObject } from "react"; import eruda from "eruda"; import { useSearchParams } from "react-router-dom"; import { Conversation } from "../Conversation/Conversation"; import { Button } from "../../components/Button/Button"; import { ImageGallery } from "../../components/ImageGallery/ImageGallery"; import { useModelParams, importantSettingsHaveChanged } from "../Conversation/hooks/useModelParams"; import { ModelParams } from "../Conversation/components/ModelParams/ModelParams"; import { env } from "../../env"; import { useUserEmail } from "./hooks/useUserEmail"; import { Input } from "../../components/Input/Input"; import { getAPIClient } from "./api/client"; type Status = "connecting" | "in_queue" | "has_credentials" | "error" | "no_queue" | "idle" | "bypass"; function getFloatFromStorage(val: string | null) { return (val == null) ? undefined : parseFloat(val) } function getIntFromStorage(val: string | null) { return (val == null) ? undefined : parseInt(val) } function getBoolFromStage(val: string | null) { return (val == 'true') ? true : ((val == 'false') ? false : undefined) } export const Queue: FC = () => { const [searchParams] = useSearchParams(); let queueId = searchParams.get("queue_id"); if (!queueId) { queueId = 'talktomoshi'; } const [sessionId, setSessionId] = useState(null); const [sessionAuthId, setSessionAuthId] = useState(null); const [workerAddr, setWorkerAddr] = useState(null); const [workerAuthId, setWorkerAuthId] = useState(null); const [currentPosition, setCurrentPosition] = useState(null); const [error, setError] = useState(null); const overrideWorkerAddr = searchParams.get("worker_addr"); const [hasMicrophoneAccess, setHasMicrophoneAccess] = useState(false); const [showMicrophoneAccessMessage, setShowMicrophoneAccessMessage] = useState(false); const [shouldConnect, setShouldConnect] = useState(false); let default_image_url = sessionStorage.getItem("imageUrl"); const modelParams = useModelParams({ textTemperature: getFloatFromStorage(sessionStorage.getItem("textTemperature")), textTopk: getIntFromStorage(sessionStorage.getItem("textTopk")), audioTemperature: getFloatFromStorage(sessionStorage.getItem("audioTemperature")), audioTopk: getIntFromStorage(sessionStorage.getItem("audioTopk")), padMult: getFloatFromStorage(sessionStorage.getItem("padMult")), repetitionPenalty: getFloatFromStorage(sessionStorage.getItem("repetitionPenalty")), repetitionPenaltyContext: getIntFromStorage(sessionStorage.getItem("repetitionPenaltyContext")), imageResolution: getIntFromStorage(sessionStorage.getItem("imageResolution")), gateDelay: getIntFromStorage(sessionStorage.getItem("gateDelay")), gateInfluence: getFloatFromStorage(sessionStorage.getItem("gateInfluence")), displayColor: getBoolFromStage(sessionStorage.getItem("displayColor")), centerCrop: getBoolFromStage(sessionStorage.getItem("centerCrop")), imageUrl: (default_image_url == null) ? undefined : default_image_url }); const modalRef = useRef(null); let def_user_email = sessionStorage.getItem("userEmail"); const { userEmail, setUserEmail, error: emailError, validate } = useUserEmail(!!overrideWorkerAddr, (def_user_email == null) ? '' : def_user_email); const audioContext = useRef(null); const worklet = useRef(null); // enable eruda in development useEffect(() => { if (env.VITE_ENV === "development") { eruda.init(); } () => { if (env.VITE_ENV === "development") { eruda.destroy(); } }; }, []); const getMicrophoneAccess = useCallback(async () => { try { await window.navigator.mediaDevices.getUserMedia({ audio: true }); setHasMicrophoneAccess(true); return true; } catch (e) { console.error(e); setShowMicrophoneAccessMessage(true); setHasMicrophoneAccess(false); } return false; }, [setHasMicrophoneAccess, setShowMicrophoneAccessMessage, setShouldConnect]); const startProcessor = useCallback(async () => { if (!audioContext.current) { audioContext.current = new AudioContext(); } if (worklet.current) { return; } let ctx = audioContext.current; ctx.resume(); try { worklet.current = new AudioWorkletNode(ctx, 'moshi-processor'); } catch (err) { await ctx.audioWorklet.addModule(moshiProcessorUrl); worklet.current = new AudioWorkletNode(ctx, 'moshi-processor'); } worklet.current.connect(ctx.destination); }, [audioContext, worklet]); const onConnect = useCallback(async () => { if (!validate(userEmail)) { return; } await startProcessor(); const hasAccess = await getMicrophoneAccess(); if (hasAccess) { setShouldConnect(true); } }, [setShouldConnect, startProcessor, userEmail, getMicrophoneAccess, validate]); const status: Status = useMemo(() => { if (overrideWorkerAddr) { return "bypass"; } if (!queueId) { return "no_queue"; } if (error) { return "error"; } if (!shouldConnect) { return "idle"; } if (workerAddr && workerAuthId) { return "has_credentials"; } if (!sessionId || !sessionAuthId) { return "connecting"; } return "in_queue"; }, [queueId, sessionId, sessionAuthId, workerAddr, workerAuthId, currentPosition, hasMicrophoneAccess, error, shouldConnect]); const client = useMemo(() => { return getAPIClient(env.VITE_QUEUE_API_PATH) }, [env.VITE_QUEUE_API_PATH]); useEffect(() => { if (!shouldConnect) { return; } if (status !== "connecting" || !queueId) { return; } client.addUser(queueId) .then(({ session_id, session_auth_id }) => { setSessionId(session_id); setSessionAuthId(session_auth_id); console.log("Added user to queue", session_id, session_auth_id); }) .catch((e) => { setError(e.message); console.error(e); }); }, [queueId, client, status, shouldConnect]); useEffect(() => { if (!sessionId || !sessionAuthId) { return; } if (status === "has_credentials") { return; } let isQuerying = false; let intervalId: number | null = null; const checkUser = () => { if (isQuerying) { return; } isQuerying = true; client.checkUser(sessionId, sessionAuthId) .then(({ worker_addr, worker_auth_id, current_position }) => { setCurrentPosition(current_position); if (worker_addr && worker_auth_id) { setWorkerAddr(worker_addr); setWorkerAuthId(worker_auth_id); if (intervalId !== null) { clearInterval(intervalId); } } }) .catch((e) => { if (intervalId !== null) { clearInterval(intervalId); } setError(e.message); console.error(e); }).finally(() => { isQuerying = false; }); } intervalId = setInterval(checkUser, 400); return () => { if (intervalId !== null) { clearInterval(intervalId); } }; }, [sessionId, sessionAuthId, client, setCurrentPosition, setWorkerAddr, setWorkerAuthId, status, setError]); if (status === "bypass" && hasMicrophoneAccess && audioContext.current && worklet.current) { return ( } worklet={worklet as MutableRefObject} {...modelParams} /> ); } if (status === "has_credentials" && workerAddr && audioContext.current && workerAuthId && sessionId && sessionAuthId && worklet?.current) { return ( } worklet={worklet as MutableRefObject} sessionId={sessionId} sessionAuthId={sessionAuthId} onConversationEnd={() => { setWorkerAddr(null); setWorkerAuthId(null); setSessionId(null); setSessionAuthId(null); setShouldConnect(false); }} {...modelParams} /> ); } return (

M👁️shiVis

{/* To add more space to the top add padding to the top of the following div by changing the pt-4 class to pt-8 or pt-12. (see: https://tailwindcss.com/docs/padding) 👁️ If you'd like to move this part to the bottom of the screen, change the class to pb-4 or pb-8 and move the following so it is contained by the last one in the page. Font size can be changed by changing the text-sm class to text-lg or text-xl. (see : https://tailwindcss.com/docs/font-size) As for the links you can use the one below as an example and add more by copying it and changing the href and text. */}

MoshiVis is an experimental multimodal conversational AI. Like Moshi, MoshiVis can listen to you and talk at all time for maximum conversational flow. Now augmented with visual inputs.

For instance, you can now ask Moshi to describe your favorite movie poster, grill it on details about the plot, then go back for more details about the image ask it to do some Pirate role play.

We strive to support all browsers but Chrome works best. Conversations are limited to 5 min.

Head to the Settings to configure the image size and other parameters.

For more information about this project, check out the MoshiVis project page!

Baked with <3 @Kyutai.

Add your email address first, then feel free

to upload your own image or select one below.

Uploaded images should be smaller than 15 MB.

{status == 'error' &&

{error}

} {status == 'no_queue' &&

No queue id provided

} {(status === 'idle' || status === 'bypass') && ( <> {showMicrophoneAccessMessage &&

Please enable your microphone before proceeding

} setUserEmail(e.target.value)} error={emailError ?? ""} onKeyDown={(e) => { if (e.key === "Enter") { if (modelParams.imageUrl == undefined) { modelParams.setImageUrl("/assets/images/demo/image" + Math.floor(1 + Math.random() * 19) + ".jpg") } onConnect(); } }} />
)} {status === "connecting" &&

Connecting to queue...

} {status === "in_queue" && (

You're in the queue !
{currentPosition && Current position: {currentPosition}}

) }
) }; ================================================ FILE: client/src/pages/Queue/api/client.ts ================================================ import { APIError } from "./errors/api_error"; import { ResponseError } from "./errors/response_error"; import { validateAddUser, validateCheckUser } from "./validators"; export const getAPIClient = (url:string) => ({ addUser: async (queueId:string) => { const encodedQueueId = encodeURIComponent(queueId); const response = await fetch(`${url}/add_user?queue_id=${encodedQueueId}`); if (!response.ok) { const errorText = await response.text(); throw new APIError(errorText , response.status); } const json = await response.json(); const result = validateAddUser(json); if(result.success) { return result.data; } console.error(result.error.message); throw new ResponseError("Failed to validate response"); }, checkUser: async (sessionId:number, sessionAuthId:string) => { const encodedSessionAuthId = encodeURIComponent(sessionAuthId); const encodedSessionId = encodeURIComponent(sessionId); const response = await fetch(`${url}/check_user?session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}`); if (!response.ok) { const errorText = await response.text(); throw new APIError(errorText , response.status); } const json = await response.json(); const result = validateCheckUser(json); if(result.success) { return result.data; } console.error(result.error.message); throw new ResponseError("Failed to validate response"); }, addFeedback: async ({ workerAuthId, sessionId, sessionAuthId, feedback, timestamp, email }:{ workerAuthId:string; sessionId:number; sessionAuthId:string; feedback:0|1; timestamp:number; email:string; } ) => { const encodedWorkerAuthId = encodeURIComponent(workerAuthId); const encodedSessionAuthId = encodeURIComponent(sessionAuthId); const encodedSessionId = encodeURIComponent(sessionId); const encodedFeedback = encodeURIComponent(feedback); const encodedTimestamp = encodeURIComponent(timestamp); const encodedEmail = encodeURIComponent(email); const response = await fetch(`${url}/user_feedback?worker_auth_id=${encodedWorkerAuthId}&session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}&feedback=${encodedFeedback}×tamp=${encodedTimestamp}&email=${encodedEmail}`); if (!response.ok) { const errorText = await response.text(); throw new APIError(errorText , response.status); } return response.json(); } }); ================================================ FILE: client/src/pages/Queue/api/errors/api_error.ts ================================================ export class APIError extends Error { status:number; constructor(message:string, status:number) { super(message); this.status = status; this.name = "APIError"; } } ================================================ FILE: client/src/pages/Queue/api/errors/response_error.ts ================================================ export class ResponseError extends Error { constructor(message:string) { super(message); this.name = "ResponseError"; } } ================================================ FILE: client/src/pages/Queue/api/validators.ts ================================================ import { z } from "zod" export const validateAddUser = (response: unknown) => { const AddUser = z.object({ session_id: z.number(), session_auth_id: z.string(), }); return AddUser.safeParse(response); }; export const validateCheckUser = (response: unknown) => { const CheckUser = z.object({ session_id: z.number(), // TODO: add more statuses status: z.enum(['wait', 'ready']), worker_auth_id: z.string().nullable(), worker_addr: z.string().nullable(), current_position: z.string(), }); return CheckUser.safeParse(response); } ================================================ FILE: client/src/pages/Queue/hooks/useUserEmail.ts ================================================ import { useCallback, useState } from "react"; import { z } from "zod"; const validateEmail = z.string().email(); export const useUserEmail = (isBypass: boolean, init_value: string) => { const [userEmail, setUserEmail] = useState(init_value); const [error, setError] = useState(null); const validate = useCallback((email: string) => { if (isBypass) { setError(null); return true; } const result = validateEmail.safeParse(email); if (result.success) { setError(null); sessionStorage.setItem("userEmail", email.toString()); return true; } setError('Invalid email address'); return false; }, [setError]); return { userEmail, setUserEmail, error, validate }; } ================================================ FILE: client/src/protocol/encoder.ts ================================================ import { CONTROL_MESSAGE, CONTROL_MESSAGES_MAP, MODELS_MAP, WSMessage, VERSIONS_MAP, } from "./types"; export const encodeMessage = (message: WSMessage): Uint8Array => { switch (message.type) { case "handshake": return new Uint8Array([ 0x00, VERSIONS_MAP[message.version], MODELS_MAP[message.model], ]); case "audio": return new Uint8Array([0x01, ...message.data]); case "text": // Not used in practice return new Uint8Array([0x02, ...new TextEncoder().encode(message.data)]); case "control": // Not used in practice return new Uint8Array([0x03, CONTROL_MESSAGES_MAP[message.action]]); case "metadata": // Not used in practice return new Uint8Array([ 0x04, ...new TextEncoder().encode(JSON.stringify(message.data)), ]); case "error": // Not used in practice return new Uint8Array([0x05, ...new TextEncoder().encode(message.data)]); case "ping": // Not used in practice return new Uint8Array([0x06]); case "coloredtext": // Not used in practice return new Uint8Array([0x07, 0x05, ...new TextEncoder().encode(message.data)]); case "image": return new Uint8Array([0x08, ...message.data]); case "user_rating": return new Uint8Array([0x0A, message.data]); } }; export const decodeMessage = (data: Uint8Array): WSMessage => { const type = data[0]; const payload = data.slice(1); switch (type) { case 0x00: { return { type: "handshake", version: 0, model: 0, }; } case 0x01: return { type: "audio", data: payload, }; case 0x02: return { type: "text", data: new TextDecoder().decode(payload), }; case 0x03: { const action = Object.keys(CONTROL_MESSAGES_MAP).find( key => CONTROL_MESSAGES_MAP[key as CONTROL_MESSAGE] === payload[0], ) as CONTROL_MESSAGE | undefined; //TODO: log this and don't throw if (!action) { throw new Error("Unknown control message"); } return { type: "control", action, }; } case 0x04: return { type: "metadata", data: JSON.parse(new TextDecoder().decode(payload)), } case 0x05: return { type: "error", data: new TextDecoder().decode(payload), } case 0x06: return { type: "ping", } case 0x07: return { type: "coloredtext", color: payload[0], data: new TextDecoder().decode(payload.slice(1)), }; case 0x08: return { type: "image", data: payload, }; // never used in practice case 0x0A: return { type: "user_rating", data: payload[0], }; default: { console.log(type); throw new Error("Unknown message type"); } } }; ================================================ FILE: client/src/protocol/testMessages.ts ================================================ import { WSMessage } from "./types"; export const handshakeMessage: WSMessage = { type: "handshake", version: 0, model: 0, }; export const audioMessage: WSMessage = { type: "audio", data: new Uint8Array(10), }; export const textMessage: WSMessage = { type: "text", data: "Hello", }; export const controlBOSMessage: WSMessage = { type: "control", action: "start", }; export const controlEOSMessage: WSMessage = { type: "control", action: "endTurn", }; export const metadataMessage: WSMessage = { type: "metadata", data: { key: "value" }, }; ================================================ FILE: client/src/protocol/types.ts ================================================ export type MessageType = | "handshake" | "audio" | "text" | "coloredtext" | "control" | "metadata"; export const VERSIONS_MAP = { 0: 0b00000000, } as const; export const MODELS_MAP = { 0: 0b00000000, } as const; export type VERSION = keyof typeof VERSIONS_MAP; export type MODEL = keyof typeof MODELS_MAP; export type WSMessage = | { type: "handshake"; version: VERSION; model: MODEL; } | { type: "user_rating"; data: number; } | { type: "audio"; data: Uint8Array; } | { type: "text"; data: string; } | { type: "coloredtext"; color: number; data: string; } | { type: "control"; action: CONTROL_MESSAGE; } | { type: "metadata"; data: unknown; } | { type: "error"; data: string; } | { type: "ping"; } | { type: "image"; data: Uint8Array; } export const CONTROL_MESSAGES_MAP = { start: 0b00000000, endTurn: 0b00000001, pause: 0b00000010, restart: 0b00000011, } as const; export type CONTROL_MESSAGE = keyof typeof CONTROL_MESSAGES_MAP; ================================================ FILE: client/tailwind.config.js ================================================ /** @type {import('tailwindcss').Config} */ export default { content: ["./src/**/*.{js,jsx,ts,tsx}", "./index.html"], theme: { extend: {}, }, plugins: [require('daisyui')], }; ================================================ FILE: client/tsconfig.json ================================================ { "compilerOptions": { "target": "ES2020", "useDefineForClassFields": true, "module": "ESNext", "lib": [ "ES2020", "DOM", "DOM.Iterable" ], "skipLibCheck": true, "outDir": "dist", /* Bundler mode */ "moduleResolution": "bundler", "allowImportingTsExtensions": true, "resolveJsonModule": true, "isolatedModules": true, "noEmit": true, "jsx": "react-jsx", /* Linting */ "strict": true, "noUnusedLocals": true, "noUnusedParameters": true, "noFallthroughCasesInSwitch": true, "types": [ "vite/client" ] }, "include": [ "src" ] } ================================================ FILE: client/vite.config.ts ================================================ import { ProxyOptions, defineConfig, loadEnv } from "vite"; import topLevelAwait from "vite-plugin-top-level-await"; export default defineConfig(({ mode }) => { const env = loadEnv(mode, process.cwd()); const proxyConf: Record = env.VITE_QUEUE_API_URL ? { "/api": { target: env.VITE_QUEUE_API_URL, changeOrigin: true, }, } : {}; return { server: { host: "0.0.0.0", https: { cert: "./cert.pem", key: "./key.pem", }, proxy: { ...proxyConf, } }, plugins: [ topLevelAwait({ // The export name of top-level await promise for each chunk module promiseExportName: "__tla", // The function to generate import names of top-level await promise in each chunk module promiseImportName: i => `__tla_${i}`, }), ], }; }); ================================================ FILE: docker-bake.hcl ================================================ group "default" { targets = ["client"] } target "client" { context = "./client" # Specify output type as a local directory output = [ "type=local,dest=./client/dist" ] } ================================================ FILE: kyuteye_mlx/.pylintrc ================================================ [MAIN] # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Clear in-memory caches upon conclusion of linting. Useful if running pylint # in a server-like mode. clear-cache-post-run=no # Load and enable all available extensions. Use --list-extensions to see a list # all available extensions. #enable-all-extensions= # In error mode, messages with a category besides ERROR or FATAL are # suppressed, and no reports are done by default. Error mode is compatible with # disabling specific errors. #errors-only= # Always return a 0 (non-error) status code, even if lint errors are found. # This is primarily useful in continuous integration scripts. #exit-zero= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-allow-list= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. (This is an alternative name to extension-pkg-allow-list # for backward compatibility.) extension-pkg-whitelist= # Return non-zero exit code if any of these messages/categories are detected, # even if score is above --fail-under value. Syntax same as enable. Messages # specified are enabled, while categories only check already-enabled messages. fail-on= # Specify a score threshold under which the program will exit with error. fail-under=10 # Interpret the stdin as a python script, whose filename needs to be passed as # the module_or_package argument. #from-stdin= # Files or directories to be skipped. They should be base names, not paths. ignore=CVS # Add files or directories matching the regular expressions patterns to the # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. ignore-paths= # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores # Emacs file locks ignore-patterns=^\.# # List of module names for which member attributes should not be checked and # will not be imported (useful for modules/projects where namespaces are # manipulated during runtime and thus existing member attributes cannot be # deduced by static analysis). It supports qualified module names, as well as # Unix pattern matching. ignored-modules= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to # avoid hangs. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. py-version=3.10 # Discover python modules and packages in the file system subtree. recursive=no # Add paths to the list of the source roots. Supports globbing patterns. The # source root is an absolute path or a path relative to the current working # directory used to determine a package namespace for modules located under the # source root. source-roots= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # In verbose mode, extra non-checker-related info will be displayed. #verbose= [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. If left empty, argument names will be checked with the set # naming style. #argument-rgx= # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. If left empty, attribute names will be checked with the set naming # style. #attr-rgx= # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Bad variable names regexes, separated by a comma. If names match any regex, # they will always be refused bad-names-rgxs= # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. If left empty, class attribute names will be checked # with the set naming style. #class-attribute-rgx= # Naming style matching correct class constant names. class-const-naming-style=UPPER_CASE # Regular expression matching correct class constant names. Overrides class- # const-naming-style. If left empty, class constant names will be checked with # the set naming style. #class-const-rgx= # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. If left empty, class names will be checked with the set naming style. #class-rgx= # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. If left empty, constant names will be checked with the set naming # style. #const-rgx= # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. If left empty, function names will be checked with the set # naming style. #function-rgx= # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _ # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted good-names-rgxs= # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. If left empty, inline iteration names will be checked # with the set naming style. #inlinevar-rgx= # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. If left empty, method names will be checked with the set naming style. #method-rgx= # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. If left empty, module names will be checked with the set naming style. #module-rgx= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Regular expression matching correct type alias names. If left empty, type # alias names will be checked with the set naming style. #typealias-rgx= # Regular expression matching correct type variable names. If left empty, type # variable names will be checked with the set naming style. #typevar-rgx= # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. If left empty, variable names will be checked with the set # naming style. #variable-rgx= [CLASSES] # Warn about protected attribute access inside special methods check-protected-access-in-special-methods=no # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp, asyncSetUp, __post_init__ # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [DESIGN] # List of regular expressions of class ancestor names to ignore when counting # public methods (see R0903) exclude-too-few-public-methods= # List of qualified class names to ignore when counting class parents (see # R0901) ignored-parents= # Maximum number of arguments for function / method. max-args=5 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=12 # Maximum number of locals for function / method body. max-locals=15 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of return / yield for function / method body. max-returns=6 # Maximum number of statements in function / method body. max-statements=50 # Minimum number of public methods for a class (see R0903). min-public-methods=2 [EXCEPTIONS] # Exceptions that will emit a warning when caught. overgeneral-exceptions=builtins.BaseException,builtins.Exception [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=100 # Maximum number of lines in a module. max-module-lines=1200 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [IMPORTS] # List of modules that can be imported at any level, not just the top level # one. allow-any-import-level= # Allow explicit reexports by alias from a package __init__. allow-reexport-from-package=no # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules= # Output a graph (.gv or any supported image format) of external dependencies # to the given file (report RP0402 must not be disabled). ext-import-graph= # Output a graph (.gv or any supported image format) of all (i.e. internal and # external) dependencies to the given file (report RP0402 must not be # disabled). import-graph= # Output a graph (.gv or any supported image format) of internal dependencies # to the given file (report RP0402 must not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Couples of modules and preferred modules, separated by a comma. preferred-modules= [LOGGING] # The type of string formatting that logging methods do. `old` means using % # formatting, `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, # UNDEFINED. confidence=HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then re-enable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=raw-checker-failed, bad-inline-option, locally-disabled, file-ignored, suppressed-message, useless-suppression, deprecated-pragma, use-symbolic-message-instead, use-implicit-booleaness-not-comparison-to-string, use-implicit-booleaness-not-comparison-to-zero, too-many-locals, unspecified-encoding, too-many-arguments, too-many-instance-attributes, too-many-branches, too-many-statements, too-many-return-statements, too-many-public-methods, too-few-public-methods, use-dict-literal, unnecessary-lambda-assignment, too-many-function-args # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable= [METHOD_ARGS] # List of qualified names (i.e., library.method) which require a timeout # parameter e.g. 'requests.api.get,requests.api.post' timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO # Regular expression of note tags to take in consideration. notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit,argparse.parse_error [REPORTS] # Python expression which should return a score less than or equal to 10. You # have access to the variables 'fatal', 'error', 'warning', 'refactor', # 'convention', and 'info' which contain the number of messages in each # category, as well as 'statement' which is the total number of statements # analyzed. This score is used by the global evaluation report (RP0004). evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. msg-template= # Set the output format. Available formats are: text, parseable, colorized, # json2 (improved json format), json (old json format) and msvs (visual # studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. #output-format= # Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. score=yes [SIMILARITIES] # Comments are removed from the similarity computation ignore-comments=yes # Docstrings are removed from the similarity computation ignore-docstrings=yes # Imports are removed from the similarity computation ignore-imports=yes # Signatures are removed from the similarity computation ignore-signatures=yes # Minimum lines number of a similarity. min-similarity-lines=12 [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. No available dictionaries : You need to install # both the python package and the system dependency for enchant to work. spelling-dict= # List of comma separated words that should be considered directives if they # appear at the beginning of a comment and should not be checked. spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains the private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to the private dictionary (see the # --spelling-private-dict-file option) instead of raising a message. spelling-store-unknown-words=no [STRING] # This flag controls whether inconsistent-quotes generates a warning when the # character used as a quote delimiter is used inconsistently within a module. check-quote-consistency=no # This flag controls whether the implicit-str-concat should generate a warning # on implicit string concatenation in sequences defined over several lines. check-str-concat-over-line-jumps=no [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of symbolic message names to ignore for Mixin members. ignored-checks-for-mixins=no-member, not-async-context-manager, not-context-manager, attribute-defined-outside-init # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 # Regex pattern to define which classes are considered mixins. mixin-class-rgx=.*[Mm]ixin # List of decorators that change the signature of a decorated function. signature-mutators= [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of names allowed to shadow builtins allowed-redefined-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io ================================================ FILE: kyuteye_mlx/LICENSE ================================================ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: kyuteye_mlx/MANIFEST.in ================================================ include LICENSE* include *.md include *.cfg include requirements.txt include moshi_mlx/py.typed ================================================ FILE: kyuteye_mlx/README.md ================================================ # MoshiVis - MLX See the [top-level README.md][main_repo] for more information on MoshiVis. This is the MLX implementation for MoshiVis. ## Usage We have tested the MLX version with MacBook Air M3 (4-bit quantization) and a MacMini M4 Pro (both 4- and 8-bit quantization). You can start the server with: ```bash # In Bfloat16 - weights will be downloaded from HF uv run server # In Q4 uv run server -q 4 # In Q8 uv run server -q 8 ``` This will start the web UI which you can connect to via http, at [localhost:8008](http://localhost:8008). Note that unlike other backends, not all settings available in the web UI are propagated to the MLX backend. Instead, you can configure some options directly via the command line e.g. `--text-temperature`. ## License The present code is provided under the MIT license. Some of this code was taken from mlx-vlm v0.1.9, the code can be found here: https://github.com/Blaizzy/mlx-vlm/tree/a11c034adf6ae4bca5a197990d1ecb77aba83c47 The license of mlx-vlm is MIT. ## Citation If you use either Mimi or Moshi, please cite the following paper, ``` @article{kyutai2025moshivis, author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez}, year = {2025}, title = {Vision-Speech Models: Teaching Speech Models to Converse about Images}, journal = {ArXiv}, url = {https://arxiv.org/abs/2503.15633} } @techreport{kyutai2024moshi, title={Moshi: a speech-text foundation model for real-time dialogue}, author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour}, year={2024}, eprint={2410.00037}, archivePrefix={arXiv}, primaryClass={eess.AS}, url={https://arxiv.org/abs/2410.00037}, } ``` [moshi]: https://kyutai.org/Moshi.pdf [main_repo]: https://github.com/kyutai-labs/MoshiVis ================================================ FILE: kyuteye_mlx/kyuteye_mlx/__init__.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa """ kyuteye_mlx is the MLX inference codebase for Kyutai audio+vision generation models. """ import os import subprocess # TODO: remove this when https://github.com/ml-explore/mlx/issues/1963 is fixed if subprocess.check_output(["sysctl", "hw.model"]).decode().split(":")[1].strip() == "Mac15,12": os.environ["MLX_MAX_OPS_PER_BUFFER"] = "8" os.environ["MLX_MAX_MB_PER_BUFFER"] = "1000000" from jaxtyping import install_import_hook # Set to True to get runtime type-checking and other small checks # but will slow down the steps by ~2% DEBUG_MODE = bool(int(os.environ.get("MOSHI_MLX_DEBUG_MODE", "0"))) if DEBUG_MODE: print("debug mode enabled.") install_import_hook("kyuteye_mlx", "beartype.beartype") else: print("debug mode disabled.") from . import models, modules, utils __version__ = "0.1.0" ================================================ FILE: kyuteye_mlx/kyuteye_mlx/benchmark.py ================================================ import time import mlx.core as mx import numpy as np from .local_web import get_args_for_main, get_model, predict_text_and_audio def main(): args = get_args_for_main() args.quantized = 4 gen = get_model(args, load_weights=False) sum_times = 0 for i in range(100): data = mx.arange(8, dtype=mx.uint32).reshape(8, 1) uploaded_image_embeddings = mx.arange(1024 * 1152, dtype=mx.bfloat16).reshape(1, 1024, 1152) mx.eval((data, uploaded_image_embeddings)) t1 = time.time() predict_text_and_audio(gen, data, uploaded_image_embeddings) t2 = time.time() if i >= 5: sum_times += t2 - t1 print(f"average time per step: {(sum_times / 95) * 1000:1f} ms") ================================================ FILE: kyuteye_mlx/kyuteye_mlx/local_web.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Entrypoint for the web server.""" import argparse import asyncio import io import multiprocessing import multiprocessing.queues import os import queue import sys import tarfile import time import webbrowser from enum import Enum from pathlib import Path from typing import Any import aiohttp import huggingface_hub import mlx.core as mx import mlx.nn as nn import numpy as np import rustymimi import sentencepiece import sphn from aiohttp import web from jaxtyping import UInt32 from PIL import Image from kyuteye_mlx import models, utils from kyuteye_mlx.models.pixtral import PixtralWrapper from kyuteye_mlx.models.siglip import SiglipWrapper from kyuteye_mlx.quantize import quantize from kyuteye_mlx.utils.loading import repeat_shared_weights, split_embedder_weights from kyuteye_mlx.utils.profiling import PROFILING_ENABLED, profile SAMPLE_RATE = 24000 FRAME_SIZE = 1920 CHANNELS = 1 class ModelInput(Enum): AUDIO = 0 IMAGE = 1 class ModelOutput(Enum): AUDIO = 0 TEXT = 1 START = 2 IMAGE_PROCESSED = 3 class ServerMediaInput(Enum): AUDIO = 1 IMAGE = 8 ClientServerQueue = multiprocessing.queues.Queue[tuple[ModelInput, Any]] ServerClientQueue = multiprocessing.queues.Queue[tuple[ModelOutput, Any]] def colorize(text: str, color: str) -> str: code = f"\033[{color}m" restore = "\033[0m" return "".join([code, text, restore]) def log(level: str, msg: str) -> None: if level == "warning": prefix = colorize("[Warn]", "1;31") elif level == "info": prefix = colorize("[Info]", "1;34") elif level == "error": prefix = colorize("[Err ]", "1;31") else: raise ValueError(f"Unknown level {level}") print(prefix + " " + msg) def hf_hub_download(repo: str | None, path: str) -> str: if repo is None or repo == "": raise ValueError(f"the --hf-repo flag is required to retrieve {path}") return huggingface_hub.hf_hub_download(repo, path) def full_warmup( audio_tokenizer: rustymimi.StreamTokenizer, client_to_server: ClientServerQueue, server_to_client: ServerClientQueue, ) -> None: for i in range(4): pcm_data = np.array([0.0] * 1920).astype(np.float32) audio_tokenizer.encode(pcm_data) while True: time.sleep(0.01) data = audio_tokenizer.get_encoded() if data is not None: break client_to_server.put_nowait((ModelInput.AUDIO, data)) if i == 0: continue while True: kind, data = server_to_client.get() if kind == ModelOutput.AUDIO: audio_tokenizer.decode(data) break while True: time.sleep(0.01) data = audio_tokenizer.get_decoded() if data is not None: break def get_model_file(args) -> str: model_file = args.moshi_weights if model_file is None: if args.quantized in (4, 8): model_file = hf_hub_download(args.hf_repo, f"model.q{args.quantized}.safetensors") elif args.quantized is not None: raise ValueError(f"Invalid quantized value: {args.quantized}") else: model_file = hf_hub_download(args.hf_repo, "model.safetensors") return model_file def get_tokenizer(args) -> sentencepiece.SentencePieceProcessor: tokenizer_file = args.tokenizer if tokenizer_file is None: tokenizer_file = hf_hub_download(args.hf_repo, "tokenizer_spm_32k_3.model") log("info", f"[SERVER] loading text tokenizer {tokenizer_file}") return sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore def get_embedder(args) -> SiglipWrapper | PixtralWrapper: model_file = get_model_file(args) if args.encoder == "pixtral": lm_config = models.config_pixtral() elif args.encoder == "siglip": lm_config = models.config_siglip() else: raise ValueError(f"Unknown encoder {args.encoder}") weights = mx.load(model_file) if lm_config.transformer.xa_shared: # for shared cross-attention, we have the weights only once in ckpt weights = repeat_shared_weights(weights, lm_config.transformer.num_layers) _, embedder_weights = split_embedder_weights(weights) if args.encoder == "siglip": embedder = SiglipWrapper() elif args.encoder == "pixtral": embedder = PixtralWrapper() else: raise ValueError(f"Unknown encoder {args.encoder}") if embedder_weights: log("info", "[SERVER] loading embedder weights") embedder_weights["model.embeddings.patch_embedding.weight"] = embedder_weights[ "model.embeddings.patch_embedding.weight" ].transpose(0, 2, 3, 1) embedder.load_weights(list(embedder_weights.items()), strict=True) log("info", "[SERVER] embedder weights loaded") embedder.eval() log("info", "[SERVER] Embedder warmed up") return embedder def get_model(args, load_weights: bool = True) -> models.LmGen: mx.random.seed(299792458) if args.encoder == "pixtral": lm_config = models.config_pixtral() elif args.encoder == "siglip": lm_config = models.config_siglip() else: raise ValueError(f"Unknown encoder {args.encoder}") lm_config.transformer.xa_start = args.xa_start model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) if args.quantized is not None: group_size = 32 if args.quantized == 4 else 64 nn.quantize(model, bits=args.quantized, group_size=group_size) if load_weights: model_file = get_model_file(args) log("info", f"[SERVER] loading weights {model_file}") weights = mx.load(model_file) if lm_config.transformer.xa_shared: # for shared cross-attention, we have the weights only once in ckpt weights = repeat_shared_weights(weights, lm_config.transformer.num_layers) weights, _ = split_embedder_weights(weights) model.load_weights(list(weights.items()), strict=True) log("info", "[SERVER] weights loaded") model.warmup() log("info", "[SERVER] model warmed up") gen = models.LmGen( model=model, max_steps=args.steps + 5, text_sampler=utils.Sampler(temp=args.text_temperature, top_p=args.text_top_p), audio_sampler=utils.Sampler(temp=args.audio_temperature, top_p=args.audio_top_p), check=False, ) return gen def model_server( client_to_server: ClientServerQueue, server_to_client: ServerClientQueue, args: argparse.Namespace, ): gen = get_model(args) embedder = get_embedder(args) text_tokenizer = get_tokenizer(args) server_to_client.put((ModelOutput.START, "start")) log("info", "[SERVER] connected!") try: uploaded_image_embeddings = None gen.reset() for i in range(10000000000000): if i == 150: if PROFILING_ENABLED: profile.print_stats() data_type, data = client_to_server.get() if data_type == ModelInput.AUDIO: handle_audio( data, gen, uploaded_image_embeddings, server_to_client, text_tokenizer, ) elif data_type == ModelInput.IMAGE: img = Image.open(io.BytesIO(data)) # compute longer image size if args.encoder == "pixtral": w, h = img.width, img.height if w > h: new_w = args.img_size new_h = int(h * new_w / w) else: new_h = args.img_size new_w = int(w * new_h / h) img = img.resize((new_w, new_h), resample=3) else: img = img.resize((args.img_size, args.img_size), resample=3) image = np.asarray(img) uploaded_image_embeddings = embedder(mx.array(image)).astype(mx.bfloat16) gen.reset() log("info", f"received image embeddings: {image.shape}") server_to_client.put_nowait((ModelOutput.IMAGE_PROCESSED, None)) except KeyboardInterrupt: pass def handle_audio( data: UInt32[np.ndarray, "tokens 1"], gen: models.LmGen, uploaded_image_embeddings: mx.array | None, server_to_client: ServerClientQueue, text_tokenizer: sentencepiece.SentencePieceProcessor, ): t_start = time.time() text_token, audio_tokens = predict_text_and_audio(gen, mx.array(data), uploaded_image_embeddings) elapsed_eval_seconds = time.time() - t_start elapsed_eval_milliseconds = elapsed_eval_seconds * 1000 print(f"eval in {elapsed_eval_milliseconds} ms") text_token = text_token.item() if text_token not in (0, 3): _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") server_to_client.put_nowait((ModelOutput.TEXT, _text)) if audio_tokens is not None: audio_tokens = np.array(audio_tokens).astype(np.uint32) server_to_client.put_nowait((ModelOutput.AUDIO, audio_tokens)) elapsed_seconds = time.time() - t_start elapsed_milliseconds = elapsed_seconds * 1000 print(f"step in {elapsed_milliseconds} ms") def predict_text_and_audio( gen: models.LmGen, data: UInt32[mx.array, "tokens 1"], uploaded_image_embeddings ) -> tuple[mx.array, mx.array | None]: data = data.transpose(1, 0)[:, :8] if gen.model.xa_cache is not None and not gen.model.xa_cache.is_set: text_token = gen.step(data, uploaded_image_embeddings) else: text_token = gen.step(data, None) text_token = text_token[0] audio_tokens = gen.last_audio_tokens() mx.eval((text_token, audio_tokens)) return text_token, audio_tokens def web_server( client_to_server: ClientServerQueue, server_to_client: ServerClientQueue, args: argparse.Namespace, ): mimi_file = args.mimi_weights if mimi_file is None: mimi_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") input_queue = queue.Queue() output_queue = queue.Queue() text_queue = queue.Queue() audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore kind, start_message = server_to_client.get() if kind != ModelOutput.START: log( "error", f"[CLIENT] recieve {(kind, start_message)} at startup, this is unexpected.", ) log("info", f"[CLIENT] received '{start_message}' from server, starting...") full_warmup(audio_tokenizer, client_to_server, server_to_client) log("info", "warmup done") async def send_loop() -> None: while True: await asyncio.sleep(0.001) try: pcm_data = input_queue.get(block=False) audio_tokenizer.encode(pcm_data) except queue.Empty: continue async def recv_loop() -> None: while True: data = audio_tokenizer.get_decoded() if data is None: await asyncio.sleep(0.001) continue output_queue.put_nowait(data) async def send_loop2() -> None: while True: data = audio_tokenizer.get_encoded() if data is None: await asyncio.sleep(0.001) continue client_to_server.put_nowait((ModelInput.AUDIO, data)) async def recv_loop2() -> None: while True: try: kind, data = server_to_client.get(block=False) if kind == ModelOutput.AUDIO: audio_tokenizer.decode(data) elif kind == ModelOutput.TEXT: text_queue.put_nowait(data) except queue.Empty: await asyncio.sleep(0.001) continue lock = asyncio.Lock() async def handle_chat(request) -> None: ws = web.WebSocketResponse() await ws.prepare(request) async def recv_loop() -> None: nonlocal close try: async for message in ws: if message.type == aiohttp.WSMsgType.ERROR: log("error", f"{ws.exception()}") break elif message.type == aiohttp.WSMsgType.CLOSED: break elif message.type != aiohttp.WSMsgType.BINARY: log("error", f"unexpected message type {message.type}") continue message = message.data if not isinstance(message, bytes): log("error", f"unsupported message type {type(message)}") continue if len(message) == 0: log("warning", "empty message") continue kind = message[0] if kind == ServerMediaInput.AUDIO.value: payload = message[1:] opus_reader.append_bytes(payload) else: log("warning", f"unknown message kind {kind}") finally: close = True log("info", "connection closed") async def opus_loop() -> None: all_pcm_data = None while True: if close: return await asyncio.sleep(0.001) pcm = opus_reader.read_pcm() if pcm.shape[-1] == 0: continue if all_pcm_data is None: all_pcm_data = pcm else: all_pcm_data = np.concatenate((all_pcm_data, pcm)) while all_pcm_data.shape[-1] >= FRAME_SIZE: chunk = all_pcm_data[:FRAME_SIZE] all_pcm_data = all_pcm_data[FRAME_SIZE:] input_queue.put_nowait(chunk) async def send_loop() -> None: while True: if close: return await asyncio.sleep(0.001) msg = opus_writer.read_bytes() if len(msg) > 0: await ws.send_bytes(b"\x01" + msg) try: _text = text_queue.get(block=False) await ws.send_bytes(b"\x02" + bytes(_text, encoding="utf8")) except queue.Empty: continue async def another_loop() -> None: while True: if close: return await asyncio.sleep(0.001) try: pcm_data = output_queue.get(block=False) assert pcm_data.shape == (1920,), pcm_data.shape opus_writer.append_pcm(pcm_data) except queue.Empty: continue log("info", "accepted connection") close = False async with lock: opus_writer = sphn.OpusStreamWriter(SAMPLE_RATE) opus_reader = sphn.OpusStreamReader(SAMPLE_RATE) # Send the handshake. encoded_image = await extract_image(ws) client_to_server.put_nowait((ModelInput.IMAGE, encoded_image)) message_type, _ = server_to_client.get() if message_type != ModelOutput.IMAGE_PROCESSED: log("error", f"We recieved {message_type} instead of IMAGE_PROCESSED.") await ws.send_bytes(b"\x00") await asyncio.gather(opus_loop(), recv_loop(), send_loop(), another_loop()) log("info", "done with connection") return ws async def extract_image(ws: web.WebSocketResponse) -> bytes: """Get the image at the start of the stream. The bytes returned are encoded (png, jpeg, etc...). """ first_message = await ws.receive() first_message = first_message.data kind = first_message[0] if kind != ServerMediaInput.IMAGE.value: log("error", f"First messsage should be an image, got {kind}") return first_message[1:] async def go() -> None: app = web.Application() app.router.add_get("/api/chat", handle_chat) static_path: None | str = None if args.static is None: log("info", "retrieving the static content") dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "vis_dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" if not dist.exists(): with tarfile.open(dist_tgz, "r:gz") as tar: tar.extractall(path=dist_tgz.parent) static_path = str(dist) elif args.static != "none": # When set to the "none" string, we don't serve any static content. static_path = args.static if static_path is not None: async def handle_root(_) -> web.FileResponse: return web.FileResponse(os.path.join(static_path, "index.html")) log("info", f"serving static content from {static_path}") app.router.add_get("/", handle_root) app.router.add_static("/", path=static_path, name="static") runner = web.AppRunner(app) await runner.setup() ssl_context = None protocol = "http" if args.ssl is not None: import ssl ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) cert_file = os.path.join(args.ssl, "cert.pem") key_file = os.path.join(args.ssl, "key.pem") ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) protocol = "https" site = web.TCPSite(runner, args.host, args.port, ssl_context=ssl_context) log("info", f"listening to {protocol}://{args.host}:{args.port}") if not args.no_browser: log("info", f"opening browser at {protocol}://{args.host}:{args.port}") webbrowser.open(f"{protocol}://{args.host}:{args.port}") await asyncio.gather(recv_loop(), send_loop(), recv_loop2(), send_loop2(), site.start()) await runner.cleanup() try: asyncio.run(go()) except KeyboardInterrupt: pass def get_args_for_main() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--tokenizer", type=str) parser.add_argument("--moshi-weights", type=str) parser.add_argument("--mimi-weights", type=str) parser.add_argument("-q", "--quantized", type=int, choices=[4, 8]) parser.add_argument("--steps", default=4000, type=int) parser.add_argument("--hf-repo", type=str, default="kyutai/moshika-vis-mlx") parser.add_argument("--static", type=str) parser.add_argument("--img_size", type=int, default=448) parser.add_argument("--encoder", type=str, default="siglip") parser.add_argument("--host", default="localhost", type=str) parser.add_argument("--port", default=8008, type=int) parser.add_argument( "--ssl", type=str, help=( "use https instead of http, this flag should point to a directory " "that contains valid key.pem and cert.pem files" ), ) parser.add_argument("--no-browser", action="store_true") parser.add_argument("--text-temperature", type=float, default=0.45) parser.add_argument("--audio-temperature", type=float, default=0.7) parser.add_argument("--text-top-p", type=float, default=0.95) parser.add_argument("--audio-top-p", type=float, default=0.95) parser.add_argument("--xa-start", type=int, default=16) args = parser.parse_args() if args.moshi_weights is not None: args.quantized = ( args.quantized or 8 if "q8" in args.moshi_weights else 4 if "q4" in args.moshi_weights else None ) return args def main() -> None: args = get_args_for_main() client_to_server: ClientServerQueue = multiprocessing.Queue() server_to_client: ServerClientQueue = multiprocessing.Queue() # Create two processes subprocess_args = client_to_server, server_to_client, args p1 = multiprocessing.Process(target=web_server, args=subprocess_args) p2 = multiprocessing.Process(target=model_server, args=subprocess_args) # Start the processes p1.start() p2.start() try: while p1.is_alive() and p2.is_alive(): time.sleep(0.001) except KeyboardInterrupt: log("warning", "Interrupting, exiting connection.") p1.terminate() p2.terminate() # Wait for both processes to finish p1.join() p2.join() log("info", "All done!") def sanity_check() -> None: """A small sanity check to make sure all packages can be imported.""" if __name__ == "__main__": main() ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/LICENSE ================================================ MIT License Copyright © 2023 Apple Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/__init__.py ================================================ ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/__init__.py ================================================ ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/pixtral/__init__.py ================================================ ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/pixtral/vision.py ================================================ import inspect from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn @dataclass class VisionConfig: model_type: str num_hidden_layers: int = 24 hidden_size: int = 1024 head_dim: int = 64 intermediate_size: int = 4096 num_attention_heads: int = 16 image_size: int = 336 patch_size: int = 14 projection_dim: int = 768 vocab_size: int = 32000 num_channels: int = 3 rms_norm_eps: float = 1e-5 rope_theta: float = 10000.0 @classmethod def from_dict(cls, params): return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) def position_ids_in_meshgrid(patch_embeds_list: list[mx.array], max_width: int): positions = [] for patch in patch_embeds_list: height, width = patch.shape[1], patch.shape[2] h_grid, v_grid = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij") h_grid = h_grid.reshape(-1, 1) v_grid = v_grid.reshape(-1, 1) ids = h_grid * max_width + v_grid positions.append(ids.flatten()) return mx.concatenate(positions) def generate_block_attention_mask(patch_embeds_list: list[mx.array], tensor: mx.array): seq_len = tensor.shape[1] d_min = -1e9 # Using a large negative value as MLX doesn't have finfo causal_mask = mx.full((seq_len, seq_len), vals=d_min) block_end_idx = mx.cumsum(mx.array(patch_embeds_list)) block_start_idx = mx.concatenate([mx.array([0]), mx.array(patch_embeds_list[:-1])]) block_start_idx = mx.cumsum(block_start_idx) for start, end in zip(block_start_idx, block_end_idx): start, end = int(start), int(end) # Convert to integers for indexing causal_mask[start:end, start:end] = 0 causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len)) return causal_mask def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return mx.concatenate((-x2, x1), axis=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = mx.expand_dims(cos, axis=unsqueeze_dim) sin = mx.expand_dims(sin, axis=unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Attention(nn.Module): def __init__( self, dims: int, num_heads: int, query_input_dims: int | None = None, key_input_dims: int | None = None, value_input_dims: int | None = None, value_dims: int | None = None, value_output_dims: int | None = None, bias: bool = False, ): super().__init__() if (dims % num_heads) != 0: raise ValueError( "The input feature dimensions should be divisible by the " f"number of heads ({dims} % {num_heads}) != 0" ) query_input_dims = query_input_dims or dims key_input_dims = key_input_dims or dims value_input_dims = value_input_dims or key_input_dims value_dims = value_dims or dims value_output_dims = value_output_dims or dims self.embed_dim = dims self.num_heads = num_heads self.head_dim = self.embed_dim // self.num_heads self.scale = self.head_dim**-0.5 self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) def __call__(self, queries, keys, values, position_embeddings, mask=None): queries = self.q_proj(queries) keys = self.k_proj(keys) values = self.v_proj(values) num_heads = self.num_heads B, L, D = queries.shape _, S, _ = keys.shape queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) cos, sin = position_embeddings queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0) attn_weights = mx.matmul(queries, keys.transpose(0, 1, 3, 2)) * self.scale if mask is not None: attn_weights = attn_weights + mask attn_weights = mx.softmax(attn_weights, axis=-1) output = mx.matmul(attn_weights, values) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class MLP(nn.Module): def __init__(self, config: VisionConfig): super().__init__() dim = config.hidden_size hidden_dim = config.intermediate_size self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class EncoderLayer(nn.Module): def __init__(self, config: VisionConfig): super().__init__() self.embed_dim = config.hidden_size self.attention = Attention(config.hidden_size, config.num_attention_heads, bias=True) self.attention_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps) self.feed_forward = MLP(config) self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps) def __call__( self, x: mx.array, position_embeddings: mx.array, mask: mx.array | None = None, ) -> mx.array: y = self.attention_norm(x) y = self.attention(y, y, y, position_embeddings, mask) x = x + y y = self.ffn_norm(x) y = self.feed_forward(y) return x + y class Encoder(nn.Module): def __init__(self, config: VisionConfig): super().__init__() self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] class PixtralRotaryEmbedding: def __init__(self, config): self.dim = config.head_dim self.base = config.rope_theta max_patches_per_side = config.image_size // config.patch_size freqs = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) h = mx.arange(max_patches_per_side) w = mx.arange(max_patches_per_side) freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32) freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32) inv_freq = mx.concatenate( [ mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)), mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)), ], axis=-1, ).reshape(-1, self.dim // 2) self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1) def __call__(self, x, position_ids): freqs = self.inv_freq[position_ids] emb = freqs cos = mx.cos(emb) sin = mx.sin(emb) return cos.astype(x.dtype), sin.astype(x.dtype) class PixtralVisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() self.config = config self.patch_conv = nn.Conv2d( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size, bias=False, ) self.ln_pre = nn.RMSNorm(config.hidden_size) self.transformer = Encoder(config) self.patch_positional_embedding = PixtralRotaryEmbedding(config) def __call__( self, x: list[mx.array], output_hidden_states: bool | None = None, ) -> mx.array: B, H, W, C = x[0].shape patch_embeds_list = [self.patch_conv(img) for img in x] patch_embeds = mx.concatenate([p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1) patch_embeds = self.ln_pre(patch_embeds) position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size, ) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) mask = generate_block_attention_mask( [p.shape[2] * p.shape[1] for p in patch_embeds_list], patch_embeds ) encoder_states = (patch_embeds,) if output_hidden_states else None for layer in self.transformer.layers: patch_embeds = layer(patch_embeds, mask=mask, position_embeddings=position_embedding) if output_hidden_states: encoder_states = encoder_states + (patch_embeds,) return patch_embeds, encoder_states ================================================ FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/siglip/vision.py ================================================ import inspect from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn import numpy as np @dataclass class VisionConfig: model_type: str num_hidden_layers: int hidden_size: int intermediate_size: int num_attention_heads: int patch_size: int projection_dim: int image_size: int = 224 num_channels: int = 3 layer_norm_eps: float = 1e-6 @classmethod def from_dict(cls, params): return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) def check_array_shape(arr): shape = arr.shape # Check if the shape has 4 dimensions if len(shape) != 4: return False out_channels, kH, KW, _ = shape # Check if out_channels is the largest, and kH and KW are the same if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): return True else: return False class Attention(nn.Module): def __init__( self, dims: int, num_heads: int, query_input_dims: int | None = None, key_input_dims: int | None = None, value_input_dims: int | None = None, value_dims: int | None = None, value_output_dims: int | None = None, bias: bool = True, ) -> None: super().__init__() if (dims % num_heads) != 0: raise ValueError( "The input feature dimensions should be divisible by the " f"number of heads ({dims} % {num_heads}) != 0" ) query_input_dims = query_input_dims or dims key_input_dims = key_input_dims or dims value_input_dims = value_input_dims or key_input_dims value_dims = value_dims or dims value_output_dims = value_output_dims or dims self.num_heads = num_heads head_dim = dims // num_heads self.scale = head_dim**-0.5 self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) def __call__(self, x, mask=None): queries = self.q_proj(x) keys = self.k_proj(x) values = self.v_proj(x) num_heads = self.num_heads B, L, D = queries.shape _, S, _ = keys.shape queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) class MLP(nn.Module): def __init__(self, config: VisionConfig) -> None: super().__init__() self.activation_fn = nn.GELU(approx="precise") self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) def __call__(self, x: mx.array) -> mx.array: x = self.fc1(x) x = self.activation_fn(x) x = self.fc2(x) return x class EncoderLayer(nn.Module): def __init__(self, config: VisionConfig) -> None: super().__init__() self.embed_dim = config.hidden_size self.self_attn = Attention(config.hidden_size, config.num_attention_heads, bias=True) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = MLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array: r = self.self_attn(self.layer_norm1(x), mask) h = x + r r = self.mlp(self.layer_norm2(h)) return h + r class Encoder(nn.Module): def __init__(self, config: VisionConfig) -> None: super().__init__() self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] def __call__( self, x: mx.array, output_hidden_states: bool | None = None, mask: mx.array | None = None, ) -> tuple[mx.array, mx.array | None]: encoder_states = (x,) if output_hidden_states else None h = x for layer in self.layers: x = layer(x, mask=mask) if output_hidden_states: encoder_states = encoder_states + (x,) h = x[0] return (h, encoder_states) class VisionEmbeddings(nn.Module): def __init__(self, config: VisionConfig) -> None: super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def __call__(self, x: mx.array) -> mx.array: patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) position_ids = mx.array(np.arange(self.num_positions)[None, :]) embeddings = patch_embeddings embeddings += self.position_embedding(position_ids) return embeddings class SigLipVisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() self.embeddings = VisionEmbeddings(config) self.encoder = Encoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size) def __call__( self, x: mx.array, output_hidden_states: bool | None = None, ) -> tuple[mx.array, mx.array, mx.array | None]: x = self.embeddings(x) encoder_outputs = self.encoder(x=x, output_hidden_states=output_hidden_states, mask=None) pooler_output = self.post_layernorm(encoder_outputs[0]) return pooler_output, x, encoder_outputs[-1] class VisionModel(nn.Module): def __init__(self, config: VisionConfig) -> None: super().__init__() self.model_type = config.model_type if self.model_type != "siglip_vision_model": raise ValueError(f"Unsupported model type: {self.model_type}") self.vision_model = SigLipVisionModel(config) def __call__(self, x: mx.array, output_hidden_states: bool | None = None) -> mx.array: return self.vision_model(x, output_hidden_states) def sanitize(self, weights): sanitized_weights = {} for k, v in weights.items(): if "position_ids" in k: # Remove unused position_ids continue elif "patch_embedding.weight" in k: # PyTorch conv2d weight tensors have shape: # [out_channels, in_channels, kH, KW] # MLX conv2d expects the weight be of shape: # [out_channels, kH, KW, in_channels] if check_array_shape(v): sanitized_weights[k] = v else: sanitized_weights[k] = v.transpose(0, 2, 3, 1) else: sanitized_weights[k] = v return sanitized_weights ================================================ FILE: kyuteye_mlx/kyuteye_mlx/models/__init__.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa """ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. """ from .lm import ( Lm, LmConfig, config_v0_1, config_siglip, config_pixtral, config1b_202412, config1b_202412_16rvq, config_helium_1_preview_2b, ) from .generate import LmGen ================================================ FILE: kyuteye_mlx/kyuteye_mlx/models/generate.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import mlx.core as mx from jaxtyping import BFloat16, Int32, UInt32 from kyuteye_mlx import DEBUG_MODE from ..models import Lm from ..utils import sampling class LmGen: def __init__( self, model: Lm, max_steps: int, text_sampler: sampling.Sampler, audio_sampler: sampling.Sampler, check: bool = False, ): self.model: Lm = model self.text_sampler = text_sampler self.audio_sampler = audio_sampler self.max_steps = max_steps self.check = check self.num_codebooks = 1 + model.cfg.audio_codebooks self.gen_sequence = mx.full( shape=(1, self.num_codebooks, max_steps), vals=self.ungenerated_token, dtype=mx.int32, ) self.step_idx = 0 self.audio_padding_token = self.model.cfg.audio_padding_token self.audio_delays = self.model.cfg.audio_delays self.max_delay = max(self.audio_delays) self.main_codebooks = self.model.cfg.depformer.num_slices @property def zero_token(self) -> int: """Special value in the input tokens, indicating that no sampling should happen for that value, and no input should be given to the model.""" return -1 @property def ungenerated_token(self) -> int: """Special value that can be provided in the prompt to indicate that this specific value should be predicted and sampled. This allows for partial teacher forcing, by generating one modality, with the other one fixed. """ return -2 @property def nb_input_tokens(self) -> int: return self.model.cfg.audio_codebooks - self.main_codebooks # Runs one step of inference and return the generated text token. def step( self, other_audio_tokens: UInt32[mx.array, "1 {self.nb_input_tokens}"], image_embeddings: BFloat16[mx.array, "1 dim1 dim2"] | None, ) -> UInt32[mx.array, "1"]: if self.step_idx >= self.max_steps: raise ValueError(f"reached max-steps {self.max_steps}") if self.step_idx == 0: text_tokens = mx.array([[32000]]) else: text_tokens = self.gen_sequence[:, 0, self.step_idx - 1][None] self.gen_sequence[:, 1 + self.main_codebooks :, self.step_idx] = other_audio_tokens audio_tokens = [] for cb_idx, delay in enumerate(self.audio_delays): gen_idx = self.step_idx - 1 - delay if gen_idx >= 0: audio_token = self.gen_sequence[:, cb_idx + 1, gen_idx][None] else: audio_token = mx.array([[self.audio_padding_token]]) if DEBUG_MODE and (audio_token == self.ungenerated_token).any(): # type: ignore raise ValueError(f"ungenerated value in audio tokens cb: {cb_idx} step: {self.step_idx}") assert audio_token.shape == (1, 1), "invalid audio-tokens shape" audio_tokens.append(audio_token) if DEBUG_MODE and (text_tokens == self.ungenerated_token).any(): # type: ignore raise ValueError(f"ungenerated value in text tokens {self.step_idx}") assert text_tokens.shape == (1, 1), "invalid text-tokens shape" text_tokens, audio_tokens = self.model.sample( text_tokens, audio_tokens, self.step_idx, self.text_sampler, self.audio_sampler, image_embeddings=image_embeddings, ) assert audio_tokens.shape == (8,), "invalid output audio-token shape" self.gen_sequence[:, 0, self.step_idx] = text_tokens for cb_idx, delay in enumerate(self.audio_delays[: self.main_codebooks]): gen_idx = self.step_idx - delay if gen_idx >= 0: self.gen_sequence[:, cb_idx + 1, gen_idx] = audio_tokens[cb_idx] self.step_idx += 1 return text_tokens def last_audio_tokens(self) -> Int32[mx.array, "1 {self.nb_input_tokens}"] | None: gen_idx = self.step_idx - 1 - self.max_delay if gen_idx < 0: return None tokens = self.gen_sequence[:, 1 : 1 + self.main_codebooks, gen_idx] if DEBUG_MODE and (tokens == self.audio_padding_token).any(): # type: ignore return None if DEBUG_MODE and (tokens == self.ungenerated_token).any(): # type: ignore raise ValueError(f"ungenerated value in last-audio tokens {self.step_idx}") return tokens def reset(self) -> None: self.gen_sequence = mx.full( shape=(1, self.num_codebooks, self.max_steps), vals=self.ungenerated_token, dtype=mx.int32, ) self.step_idx = 0 self.model.reset_all_caches() ================================================ FILE: kyuteye_mlx/kyuteye_mlx/models/lm.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from importlib.metadata import version import mlx.core as mx import mlx.nn as nn from packaging.version import parse as parse_version from ..modules.kv_cache import KVCache, RotatingKVCache, XACache from ..modules.transformer import Transformer, TransformerConfig from ..utils import sampling MLX_BELOW_0_22_0 = parse_version(version("mlx")) < parse_version("0.22.0") @dataclass class DepFormerConfig: transformer: TransformerConfig num_slices: int @dataclass class LmConfig: transformer: TransformerConfig depformer: DepFormerConfig text_in_vocab_size: int text_out_vocab_size: int audio_vocab_size: int audio_codebooks: int audio_delays: list[int] @property def audio_eos_token(self) -> int: return self.audio_vocab_size - 2 @property def audio_padding_token(self) -> int: return self.audio_vocab_size - 1 class DepFormerSlice(nn.Module): def __init__( self, in_vocab_size: int, out_vocab_size: int, main_transformer_dim: int, cfg: TransformerConfig, ): super().__init__() dim = cfg.d_model self.emb = nn.Embedding(in_vocab_size, dim) self.linear_in = nn.Linear(main_transformer_dim, dim, bias=False) self.linear_out = nn.Linear(dim, out_vocab_size, bias=False) self.transformer = Transformer(cfg) def __call__(self, _: mx.array) -> mx.array: raise ValueError("not implemented") class DepFormer(nn.Module): def __init__(self, cfg: LmConfig): super().__init__() self.slices: list[DepFormerSlice] = [] for slice_idx in range(cfg.depformer.num_slices): in_vs = cfg.text_in_vocab_size if slice_idx == 0 else cfg.audio_vocab_size slice = DepFormerSlice( in_vs, cfg.audio_vocab_size - 1, main_transformer_dim=cfg.transformer.d_model, cfg=cfg.depformer.transformer, ) self.slices.append(slice) def __call__(self, _: mx.array) -> mx.array: raise ValueError("not implemented") def sample( self, main_transformer_out: mx.array, step_idx: int, sampler: sampling.Sampler, text_token: mx.array, cache: list[KVCache] | list[RotatingKVCache], ) -> mx.array: tokens = [] last_token = text_token # The cache is shared between the depformer slices but not persisted between sample calls. for c in cache: c.reset() for slice_idx, slice in enumerate(self.slices): last_token = last_token if step_idx > 0 or slice_idx in (0, 1, 9) else mx.array(2048) if MLX_BELOW_0_22_0: embedding = slice.emb(last_token) else: last_token_reshaped = mx.expand_dims(last_token, axis=0) embedding = slice.emb(last_token_reshaped).squeeze(0) xs = slice.linear_in(main_transformer_out) + embedding xs = slice.transformer(xs, cache=cache) logits = slice.linear_out(xs) last_token, _ = sampler(logits[0]) tokens.append(last_token) tokens = mx.concatenate(tokens) return tokens class Lm(nn.Module): def __init__(self, cfg: LmConfig): super().__init__() dim = cfg.transformer.d_model self.transformer: Transformer = Transformer(cfg.transformer, with_img_prefix=True) self.depformer: DepFormer = DepFormer(cfg) self.text_emb = nn.Embedding(cfg.text_in_vocab_size, dim) self.cfg: LmConfig = cfg if cfg.transformer.norm == "layer_norm": self.out_norm = nn.LayerNorm(dim, 1e-5) elif cfg.transformer.norm == "rms_norm": self.out_norm = nn.RMSNorm(dim, 1e-8) else: raise ValueError(f"unsupported norm type {cfg.transformer.norm}") self.text_linear = nn.Linear(dim, cfg.text_out_vocab_size, bias=False) self.audio_embs = [nn.Embedding(cfg.audio_vocab_size, dim) for _ in range(cfg.audio_codebooks)] self.transformer_cache: list[RotatingKVCache] = self.transformer.make_rot_cache() if cfg.transformer.cross_attention: self.xa_cache = XACache() if len(self.depformer.slices) > 0: self.depformer_cache: list[KVCache] = self.depformer.slices[0].transformer.make_cache() else: self.depformer_cache = [] def __call__( self, token_ids: mx.array, ) -> mx.array: # Note that this does not apply the depformer. xs = self.text_emb(token_ids) transformer_out = self.transformer(xs, cache=self.transformer_cache) transformer_out = self.out_norm(transformer_out) text_logits = self.text_linear(transformer_out) return text_logits def sample( self, text_token_ids: mx.array, audio_token_ids: list[mx.array], step_idx: int, text_sampler: sampling.Sampler, audio_sampler: sampling.Sampler, image_embeddings: mx.array | None = None, ) -> tuple[mx.array, mx.array]: xs = self.text_emb(text_token_ids) for token_ids, emb in zip(audio_token_ids, self.audio_embs): xs = xs + emb(token_ids) transformer_out = self.transformer( xs, cache=self.transformer_cache, img_embeds=image_embeddings, xa_cache=self.xa_cache, ) transformer_out = self.out_norm(transformer_out) text_logits = self.text_linear(transformer_out) text_token, _ = text_sampler(text_logits[:, 0]) audio_tokens = self.depformer.sample( transformer_out, step_idx, audio_sampler, text_token, self.depformer_cache, ) return text_token, audio_tokens def warmup(self) -> None: text, audio = self.sample( mx.array([[32000]]), [mx.array([[0]])] * 8, 0, text_sampler=sampling.Sampler(0.5, 0.5), audio_sampler=sampling.Sampler(0.5, 0.5), ) if text.sum().item() == 42: raise ValueError(42) if audio.sum().item() == 42: raise ValueError(42) for c in self.transformer_cache: c.reset() def reset_all_caches(self) -> None: for c in self.transformer_cache: c.reset() if hasattr(self, "xa_cache"): self.xa_cache.reset() for c in self.depformer_cache: c.reset() def config1b_202412() -> LmConfig: transformer = TransformerConfig( d_model=2048, num_heads=16, num_layers=16, dim_feedforward=2048 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=750, max_period=100000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="rope", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, ) depformer = DepFormerConfig( transformer=TransformerConfig( d_model=1024, num_heads=16, num_layers=6, dim_feedforward=1024 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=8, max_period=10000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="none", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, ), num_slices=8, ) return LmConfig( transformer=transformer, depformer=depformer, audio_vocab_size=2049, text_in_vocab_size=48001, text_out_vocab_size=48000, audio_codebooks=16, audio_delays=([0] + [1] * 7) * 2, ) def config1b_202412_16rvq() -> LmConfig: transformer = TransformerConfig( d_model=2048, num_heads=16, num_layers=16, dim_feedforward=2048 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=750, max_period=100000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="rope", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, ) depformer = DepFormerConfig( transformer=TransformerConfig( d_model=1024, num_heads=16, num_layers=6, dim_feedforward=1024 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=16, max_period=10000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="none", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, ), num_slices=16, ) return LmConfig( transformer=transformer, depformer=depformer, audio_vocab_size=2049, text_in_vocab_size=48001, text_out_vocab_size=48000, audio_codebooks=32, audio_delays=([0] + [1] * 15) * 2, ) def config_v0_1() -> LmConfig: transformer = TransformerConfig( d_model=4096, num_heads=32, num_layers=32, dim_feedforward=4096 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=3000, max_period=10000, use_conv_block=False, use_conv_bias=True, cross_attention=True, gating=True, norm="rms_norm", positional_embedding="rope", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, xa_gating="sigmoid", xa_shared=True, img_emb_dim=1152, ) depformer = DepFormerConfig( transformer=TransformerConfig( d_model=1024, num_heads=16, num_layers=6, dim_feedforward=1024 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=8, max_period=10000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="none", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, xa_gating="none", xa_shared=True, ), num_slices=8, ) return LmConfig( transformer=transformer, depformer=depformer, audio_vocab_size=2049, text_in_vocab_size=32001, text_out_vocab_size=32000, audio_codebooks=16, audio_delays=([0] + [1] * 7) * 2, ) def config_siglip() -> LmConfig: config = config_v0_1() config.transformer.img_emb_dim = 1152 return config def config_pixtral() -> LmConfig: config = config_v0_1() config.transformer.img_emb_dim = 1024 return config def config_helium_1_preview_2b() -> LmConfig: transformer = TransformerConfig( d_model=2560, num_heads=20, num_layers=24, dim_feedforward=2560 * 4, # dim * hidden_scale causal=True, norm_first=True, bias_ff=False, bias_attn=False, layer_scale=None, context=4096, max_period=100000, use_conv_block=False, use_conv_bias=True, cross_attention=False, gating=True, norm="rms_norm", positional_embedding="rope", conv_layout=False, conv_kernel_size=3, kv_repeat=1, max_seq_len=4096, ) depformer = DepFormerConfig( transformer=transformer, num_slices=0, ) return LmConfig( transformer=transformer, depformer=depformer, audio_vocab_size=2049, text_in_vocab_size=48000, text_out_vocab_size=48000, audio_codebooks=0, audio_delays=[], ) ================================================ FILE: kyuteye_mlx/kyuteye_mlx/models/pixtral.py ================================================ import json import mlx import mlx.core as mx import mlx.nn from ..mlx_vlm.models.pixtral.vision import PixtralVisionModel, VisionConfig class PixtralWrapper(mlx.nn.Module): """Pixtral encoder returning penultimate features""" def __init__(self) -> None: super().__init__() # if not os.path.exists("pixtral-12b-8bit.safetensors"): # self.load_pixtral_weights() with open("pixtral-12b-8bit.config", "r") as f: vision_config = VisionConfig(**json.load(f)) self.model = PixtralVisionModel(vision_config) weights = mx.load("pixtral-12b-8bit.safetensors") mlx.nn.quantize(self.model, bits=8, group_size=64) self.model.load_weights(list(weights.items())) # def load_pixtral_weights(self): # from mlx_vlm.models.pixtral import Model # from dataclasses import asdict # pixtral = Model.from_pretrained("mlx-community/pixtral-12b-8bit") # self.model = pixtral.vision_tower.vision_model # self.model.save_weights("pixtral-12b-8bit.safetensors") # with open("pixtral-12b-8bit.config", "w") as f: # json.dump(asdict(pixtral.config.vision_config), f) def __call__(self, x: mx.array) -> mx.array: """Forward to the last hidden states We expect the input to be an RGB uint8 image (H, W, C) """ means = mx.array([0.48145466, 0.4578275, 0.40821073]) std = mx.array([0.26862954, 0.26130258, 0.27577711]) x = ((x / 255.0) - means[None, None, :]) / std[None, None, :] x = [x[None, :, :, :].astype(mx.float32)] assert isinstance(x, list), "Pixtral expects a list of tensors." return self.model(x, output_hidden_states=False)[0] def warmup(self) -> None: eval(self(mx.zeros((224, 224, 3), dtype=mx.uint8))) ================================================ FILE: kyuteye_mlx/kyuteye_mlx/models/siglip.py ================================================ import json import os import mlx import mlx.core as mx import mlx.nn from ..mlx_vlm.models.siglip.vision import VisionConfig, VisionModel class SiglipWrapper(mlx.nn.Module): """Siglip encoder returning penultimate features""" def __init__(self) -> None: super().__init__() with open("siglip448.config", "r") as f: vision_config = VisionConfig(**json.load(f)) print("loaded!") self.model = VisionModel(vision_config).vision_model def __call__(self, x: mx.array) -> mx.array: """Forward to the last hidden states We expect the input to be an RGB uint8 image (H, W, C) """ means = mx.array([0.5] * 3) std = mx.array([0.5] * 3) x = ((x / 255.0) - means[None, None, :]) / std[None, None, :] x = x[None, :, :, :].astype(mx.float32) out = self.model(x, output_hidden_states=False)[0] return out[None] def warmup(self) -> None: eval(self(mx.zeros((224, 224, 3), dtype=mx.uint8))) ================================================ FILE: kyuteye_mlx/kyuteye_mlx/modules/__init__.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa """Modules used for building the models.""" from .kv_cache import KVCache, RotatingKVCache from .transformer import Transformer, TransformerConfig ================================================ FILE: kyuteye_mlx/kyuteye_mlx/modules/config.py ================================================ from dataclasses import dataclass from typing import Literal @dataclass class TransformerConfig: d_model: int num_heads: int num_layers: int causal: bool norm_first: bool bias_ff: bool bias_attn: bool layer_scale: float | None positional_embedding: str use_conv_block: bool cross_attention: bool xa_shared: bool xa_gating: Literal["none", "sigmoid", "tanh"] conv_kernel_size: int use_conv_bias: bool gating: bool norm: str context: int max_period: int max_seq_len: int kv_repeat: int dim_feedforward: int conv_layout: bool img_emb_dim: int | None = None xa_start: int = 0 @property def head_dim(self) -> int: return self.d_model // self.num_heads ================================================ FILE: kyuteye_mlx/kyuteye_mlx/modules/cross_attention.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Gated cross-attention module""" from typing import Any, Literal import mlx.core as mx import mlx.nn as nn from jaxtyping import BFloat16 from .config import TransformerConfig from .kv_cache import XACache class SharedModuleType(type): """Wrapper to build shared Pytorch modules""" _instances = {} # type: ignore def __call__(cls, *args: Any, **kwargs: Any) -> Any: if cls not in cls._instances: cls._instances[cls] = super(SharedModuleType, cls).__call__(*args, **kwargs) return cls._instances[cls] class CrossAttention(nn.Module): def __init__(self, cfg: TransformerConfig): super().__init__() self.cfg = cfg self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.kv_proj = nn.Linear(cfg.d_model, 2 * cfg.d_model, bias=False) self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias_attn) self.scale = cfg.head_dim ** (-0.5) def __call__( self, xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"], xa_cache: XACache, kv: BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"] | None = None, ) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]: k, v = xa_cache.state if xa_cache is not None else (None, None) assert kv is not None or (k is not None and v is not None), ( "Need to provide embeddings or pre-computed keys and values" ) b, t, hd = xs.shape q = self.q_proj(xs).reshape(b, self.cfg.num_heads, t, self.cfg.head_dim) if k is None and v is None: assert kv is not None, "No image embeds given but also no cache found" kv = self.kv_proj(kv).reshape(b, -1, 2, self.cfg.num_heads, self.cfg.head_dim) k = kv[:, :, 0].transpose(0, 2, 1, 3) v = kv[:, :, 1].transpose(0, 2, 1, 3) xa_cache.set(k, v) xs = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) xs = xs.reshape(b, t, hd) xs = self.out_proj(xs) return xs class SharedCrossAttention(CrossAttention, metaclass=SharedModuleType): """Shared Cross Attention projection across all layers""" pass # pylint: disable=unnecessary-pass class XAGate(nn.Module): def __init__( self, cfg: TransformerConfig, hidden_dims_factor: float = 0.125, activation: Literal["tanh", "sigmoid"] = "sigmoid", conditional_gating: bool = True, ): super().__init__() assert conditional_gating self.dims = cfg.d_model hidden_dims = int(hidden_dims_factor * self.dims) self.alpha = nn.Sequential( nn.Linear(self.dims, hidden_dims, bias=False), nn.ReLU(), nn.Linear(hidden_dims, self.dims, bias=False), ) if activation == "tanh": self.act = nn.Tanh() elif activation == "sigmoid": # shift left to mimic initialization ~ close to 0 self.act = lambda x: mx.sigmoid(x - 4) else: raise NotImplementedError("Unknown activation function", activation) def __call__( self, xs: BFloat16[mx.array, "batch time {self.dims}"] ) -> BFloat16[mx.array, "batch time {self.dims}"]: return xs * self.act(self.alpha(xs)) class GatedCrossAttention(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() self.mha = (SharedCrossAttention if cfg.xa_shared else CrossAttention)(cfg) # Output Gating self.gate: nn.Module | None = None if cfg.xa_gating != "none": self.gate = XAGate(cfg) def __call__( self, xs: BFloat16[mx.array, "batch time features"], xa_cache: XACache, kv: BFloat16[mx.array, "batch kv_size features"] | None = None, ) -> BFloat16[mx.array, "batch time features"]: if kv is None and not xa_cache.is_set: return xs xs = self.mha(xs=xs, xa_cache=xa_cache, kv=kv) if self.gate is not None: xs = self.gate(xs) return xs ================================================ FILE: kyuteye_mlx/kyuteye_mlx/modules/kv_cache.py ================================================ # Most of the code below comes from: # https://github.com/ml-explore/mlx-examples/blob/6c2369e4b97f49fb5906ec46033497b39931b25d/llms/mlx_lm/models/base.py#L1 # Copyright © 2023-2024 Apple Inc. import inspect from dataclasses import dataclass from typing import Any, Self import mlx.core as mx import numpy as np class XACache: def __init__(self) -> None: self.keys: mx.array | None = None self.values: mx.array | None = None self.is_set: bool = False def set(self, k: mx.array, v: mx.array) -> None: # See https://github.com/ml-explore/mlx/issues/1918 for an # explanation of this hack. self.keys = mx.array(np.array(k.astype(mx.float32))).astype(mx.bfloat16) self.values = mx.array(np.array(v.astype(mx.float32))).astype(mx.bfloat16) self.is_set = True def reset(self) -> None: self.keys = None self.values = None self.is_set = False @property def state(self) -> tuple[mx.array | None, mx.array | None]: return self.keys, self.values class KVCache: def __init__(self, head_dim: int | tuple[int, int], n_kv_heads: int) -> None: self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): self.k_head_dim = self.v_head_dim = head_dim elif isinstance(head_dim, tuple) and len(head_dim) == 2: self.k_head_dim, self.v_head_dim = head_dim else: raise ValueError("head_dim must be an int or a tuple of two ints") self.keys: mx.array | None = None self.values: mx.array | None = None self.offset = 0 self.step = 256 def update_and_fetch(self, keys: mx.array, values: mx.array) -> tuple[mx.array, mx.array]: prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: B = keys.shape[0] n_steps = (self.step + keys.shape[2] - 1) // self.step k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: assert self.values is not None if prev % self.step != 0: self.keys = self.keys[..., :prev, :] self.values = self.values[..., :prev, :] self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: self.keys, self.values = new_k, new_v self.offset += keys.shape[2] self.keys[..., prev : self.offset, :] = keys assert self.values is not None self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] def reset(self) -> None: self.offset = 0 @property def state(self) -> tuple[mx.array | None, mx.array | None]: return self.keys, self.values class RotatingKVCache: def __init__( self, head_dim: int | tuple[int, int], n_kv_heads: int, max_size: int, keep: int = 0, step: int = 256, ) -> None: self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): self.k_head_dim = self.v_head_dim = head_dim elif isinstance(head_dim, tuple) and len(head_dim) == 2: self.k_head_dim, self.v_head_dim = head_dim else: raise ValueError("head_dim must be an int or a tuple of two ints") self.keep = keep self.keys: mx.array | None = None self.values: mx.array | None = None self.offset = 0 self.max_size = max_size self.step = step self._idx = 0 def _trim(self, trim_size: int, v: mx.array, append: mx.array | None = None) -> mx.array: to_cat = [] if trim_size > 0: to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] else: to_cat = [v] if append is not None: to_cat.append(append) return mx.concatenate(to_cat, axis=2) def update_and_fetch(self, keys: mx.array, values: mx.array) -> tuple[mx.array, mx.array]: prev = self.offset B, _, S = keys.shape[:3] # Prefill mode if S > 1: if self.keys is None: self.keys = keys self.values = values else: # The largest size is self.max_size + S - 1 to ensure # every token gets at least self.max_size context trim_size = self.keys.shape[2] - self.max_size + 1 self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += S self._idx = self.keys.shape[2] return self.keys, self.values # Generation mode # May not have hit the max size yet, so potentially # keep growing the cache if self.keys is None or (prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size): new_size = min(self.step, self.max_size - prev) k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: assert self.values is not None self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: self.keys, self.values = new_k, new_v self._idx = prev # Trim if needed trim_size = self.keys.shape[2] - self.max_size if trim_size > 0: self.keys = self._trim(trim_size, self.keys) self.values = self._trim(trim_size, self.values) self._idx = self.max_size # Rotate if self._idx == self.max_size: self._idx = self.keep # Assign self.keys[..., self._idx : self._idx + 1, :] = keys assert self.values is not None self.values[..., self._idx : self._idx + 1, :] = values self.offset += 1 self._idx += 1 # If the buffer is not full, slice off the end if self.offset < self.max_size: return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys, self.values def reset(self) -> None: self.offset = 0 self._idx = 0 @property def state(self) -> tuple[mx.array | None, mx.array | None]: return self.keys, self.values @dataclass class BaseModelArgs: @classmethod def from_dict(cls, params: dict[str, Any]): return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) ================================================ FILE: kyuteye_mlx/kyuteye_mlx/modules/transformer.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import mlx.core as mx import mlx.nn as nn from jaxtyping import BFloat16 from .config import TransformerConfig from .cross_attention import GatedCrossAttention from .kv_cache import KVCache, RotatingKVCache, XACache class Attention(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() num_kv = cfg.num_heads // cfg.kv_repeat out_dim = cfg.d_model + 2 * num_kv * cfg.d_model // cfg.num_heads self.cfg = cfg self.in_proj = nn.Linear(cfg.d_model, out_dim, bias=cfg.bias_attn) self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias_attn) self.scale = cfg.head_dim ** (-0.5) self.rope = None if cfg.positional_embedding == "rope": self.rope = nn.RoPE(cfg.head_dim, traditional=True, base=cfg.max_period) def __call__( self, xs: BFloat16[mx.array, "batch time features"], cache: KVCache | RotatingKVCache, mask: BFloat16[mx.array, "batch kv_size features"] | None = None, ) -> BFloat16[mx.array, "batch time features"]: assert self.cfg.kv_repeat == 1, "only kv_repeat==1 is supported" b, t, hd = xs.shape qkv = self.in_proj(xs).reshape(b, t, 3, self.cfg.num_heads, self.cfg.head_dim) q = qkv[:, :, 0].transpose(0, 2, 1, 3) k = qkv[:, :, 1].transpose(0, 2, 1, 3) v = qkv[:, :, 2].transpose(0, 2, 1, 3) if self.rope is not None: q = self.rope(q, offset=cache.offset) k = self.rope(k, offset=cache.offset) k, v = cache.update_and_fetch(k, v) k_len = k.shape[2] k_target_len = t + min(self.cfg.context, k_len - t) if k_target_len < k_len: k = k[:, :, k_len - k_target_len :] v = v[:, :, k_len - k_target_len :] xs = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) xs = xs.transpose(0, 2, 1, 3).reshape(b, t, hd) xs = self.out_proj(xs) return xs class MlpGating(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() hidden = 2 * cfg.dim_feedforward // 3 if cfg.dim_feedforward == 4 * cfg.d_model: hidden = 11 * cfg.d_model // 4 self.linear_in = nn.Linear(cfg.d_model, 2 * hidden, bias=cfg.bias_ff) self.linear_out = nn.Linear(hidden, cfg.d_model, bias=cfg.bias_ff) def __call__( self, xs: BFloat16[mx.array, "batch time features"] ) -> BFloat16[mx.array, "batch time features"]: xs = self.linear_in(xs) b, t, _ = xs.shape xs = xs.reshape(b, t, 2, -1) return self.linear_out(nn.silu(xs[:, :, 0]) * xs[:, :, 1]) class MlpNoGating(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward, bias=cfg.bias_ff) self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model, bias=cfg.bias_ff) def __call__(self, xs: mx.array) -> mx.array: return self.linear2(nn.gelu_approx(self.linear1(xs))) class TransformerLayer(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() self.cfg = cfg assert not cfg.use_conv_block, "conv-block is not supported" if cfg.gating: self.gating = MlpGating(cfg) else: self.gating = MlpNoGating(cfg) if cfg.norm == "layer_norm": self.norm1 = nn.LayerNorm(cfg.d_model, 1e-5) self.norm2 = nn.LayerNorm(cfg.d_model, 1e-5) elif cfg.norm == "rms_norm": self.norm1 = nn.RMSNorm(cfg.d_model, 1e-8) self.norm2 = nn.RMSNorm(cfg.d_model, 1e-8) else: raise ValueError(f"unsupported norm type {cfg.norm}") self.self_attn = Attention(cfg) if cfg.cross_attention: self.cross_attention = GatedCrossAttention(cfg) if cfg.norm == "rms_norm": self.norm_cross = nn.RMSNorm(cfg.d_model, 1e-8) elif cfg.norm == "layer_norm": self.norm_cross = nn.LayerNorm(cfg.d_model, 1e-5) else: raise ValueError(f"unsupported norm type {cfg.norm}") else: self.cross_attention = None self.xa_start = cfg.xa_start def __call__( self, xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"], cache: KVCache | RotatingKVCache, img_embeds: (BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"] | None) = None, xa_cache: XACache | None = None, ) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]: xs = xs + self.self_attn(self.norm1(xs), cache=cache) if self.cross_attention is not None: if xa_cache is None: raise ValueError("xa_cache should never be None when using cross attention.") if cache.offset >= self.xa_start: xs = xs + self.cross_attention(self.norm_cross(xs), xa_cache=xa_cache, kv=img_embeds) xs = xs + self.gating(self.norm2(xs)) return xs class ImagePrefix(nn.Module): def __init__(self, cfg: TransformerConfig) -> None: super().__init__() self.cfg = cfg self.norm_xa = nn.RMSNorm(cfg.d_model, 1e-8) self.proj_xa = nn.Linear(cfg.img_emb_dim, cfg.d_model, bias=True) def __call__( self, xa: BFloat16[mx.array, "batch kv_size {self.cfg.img_emb_dim}"] ) -> BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"]: xa = self.proj_xa(xa) xa = self.norm_xa(xa) return xa class Transformer(nn.Module): def __init__(self, cfg: TransformerConfig, with_img_prefix: bool = False) -> None: super().__init__() self.cfg = cfg self.layers = [TransformerLayer(cfg=cfg) for _ in range(cfg.num_layers)] if with_img_prefix: self.image_prefix = ImagePrefix(cfg) def __call__( self, xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"], cache: list[KVCache] | list[RotatingKVCache], img_embeds: (BFloat16[mx.array, "batch kv_size {self.cfg.img_emb_dim}"] | None) = None, xa_cache: XACache | None = None, ) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]: if img_embeds is not None and xa_cache is not None and not xa_cache.is_set: img_embeds = self.image_prefix(img_embeds) else: img_embeds = None for layer, c in zip(self.layers, cache): xs = layer(xs, cache=c, xa_cache=xa_cache, img_embeds=img_embeds) return xs def make_cache(self) -> list[KVCache]: num_kv_heads = self.cfg.num_heads // self.cfg.kv_repeat return [KVCache(head_dim=self.cfg.head_dim, n_kv_heads=num_kv_heads) for _ in self.layers] def make_rot_cache(self) -> list[RotatingKVCache]: num_kv_heads = self.cfg.num_heads // self.cfg.kv_repeat return [ RotatingKVCache( head_dim=self.cfg.head_dim, n_kv_heads=num_kv_heads, max_size=self.cfg.max_seq_len, ) for _ in self.layers ] ================================================ FILE: kyuteye_mlx/kyuteye_mlx/py.typed ================================================ ================================================ FILE: kyuteye_mlx/kyuteye_mlx/quantize.py ================================================ # /// script # requires-python = ">=3.10" # dependencies = [ # "fire", # "mlx==0.18.1", # "safetensors >= 0.4.0, < 0.5", # ] # /// # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Literal, Optional import fire import mlx.core as mx from mlx import nn from mlx.utils import tree_flatten from kyuteye_mlx import models from kyuteye_mlx.utils.loading import ( remove_shared_weights, repeat_shared_weights, split_embedder_weights, ) def quantize( model_file: str, out_file: Optional[str] = None, img_embed: Literal["siglip", "pixtral"] = "siglip", bits: int = 8, group_size: int = 64, quantize_embedder: bool = False, ) -> None: if out_file is None: out_file = model_file.replace(".safetensors", f".q{bits}.safetensors") weights = mx.load(model_file) if img_embed == "siglip": lm_config = models.config_siglip() if quantize_embedder: from kyuteye_mlx.kyuteye_mlx.models.siglip import SiglipWrapper embedder = SiglipWrapper() else: embedder = None else: lm_config = models.config_pixtral() if quantize_embedder: from kyuteye_mlx.kyuteye_mlx.models.pixtral import PixtralWrapper embedder = PixtralWrapper() else: embedder = None model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) weights = repeat_shared_weights(weights, lm_config.transformer.num_layers) weights, embed_weights = split_embedder_weights(weights) model.load_weights(list(weights.items()), strict=True) print("weights loaded") nn.quantize(model, bits=bits, group_size=group_size) if quantize_embedder: assert embedder is not None nn.quantize(embedder, bits=bits, group_size=group_size) embed_weights = dict(tree_flatten(embedder.parameters())) print(f"saving the quantized q{bits} weights in {out_file}") new_weights = dict(tree_flatten(model.parameters())) new_weights = remove_shared_weights(new_weights, lm_config.transformer.num_layers) # Re-adding prefix for consistency with other scripts. new_weights.update({"img_embedder." + k: v for k, v in embed_weights.items()}) mx.save_safetensors(out_file, new_weights) def main(): fire.Fire(quantize) if __name__ == "__name__": main() ================================================ FILE: kyuteye_mlx/kyuteye_mlx/utils/__init__.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa """Utilities.""" from .sampling import Sampler ================================================ FILE: kyuteye_mlx/kyuteye_mlx/utils/loading.py ================================================ import mlx.core as mx def repeat_shared_weights(weights: dict[str, mx.array], num_layers: int) -> dict[str, mx.array]: for layer_idx in range(1, num_layers): for srckey in [ "transformer.layers.0.cross_attention.mha.kv_proj", "transformer.layers.0.cross_attention.mha.q_proj", "transformer.layers.0.cross_attention.mha.out_proj", ]: for subkey in ["weight", "scales", "biases"]: k = srckey + "." + subkey if k in weights: weights[k.replace(".0.", f".{layer_idx}.")] = weights[k] return weights def remove_shared_weights(weights: dict[str, mx.array], num_layers: int) -> dict[str, mx.array]: for layer_idx in range(1, num_layers): k1 = "transformer.layers.0.cross_attention.mha.kv_proj.weight" weights.pop(k1.replace(".0.", f".{layer_idx}.")) return weights def split_embedder_weights( weights: dict[str, mx.array], ) -> tuple[dict[str, mx.array], dict[str, mx.array]]: embedder_weights = {} model_weights = {} for k, v in weights.items(): if k.startswith("img_embedder."): embedder_weights[k[len("img_embedder.") :]] = v else: model_weights[k] = v return model_weights, embedder_weights ================================================ FILE: kyuteye_mlx/kyuteye_mlx/utils/profiling.py ================================================ from typing import Callable import line_profiler PROFILING_ENABLED = False profile: line_profiler.LineProfiler | Callable if PROFILING_ENABLED: profile = line_profiler.LineProfiler() else: def profile(x: Callable) -> Callable: return x ================================================ FILE: kyuteye_mlx/kyuteye_mlx/utils/sampling.py ================================================ # Taken from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass from functools import partial import mlx.core as mx from jaxtyping import BFloat16, UInt32 @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def top_p_sampling( logits: BFloat16[mx.array, "batch vocab"], top_p: float, temperature: float ) -> UInt32[mx.array, "batch"]: """ Apply top-p (nucleus) sampling to logits. Args: logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. Returns: token selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 # noqa probs = mx.softmax(logits * (1 / temperature), axis=-1) # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) sorted_probs = probs[..., sorted_indices.squeeze(0)] cumulative_probs = mx.cumsum(sorted_probs, axis=-1) # select tokens with cumulative probs below threshold top_probs = mx.where( cumulative_probs > 1 - top_p, sorted_probs, 0, ) sorted_token = mx.random.categorical(mx.log(top_probs)) token = sorted_indices.squeeze(0)[sorted_token] return token @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits: BFloat16[mx.array, "batch vocab"], temp: float) -> UInt32[mx.array, "batch"]: return mx.random.categorical(logits * (1 / temp)) @dataclass class Sampler: temp: float top_p: float def __call__( self, logits: BFloat16[mx.array, "batch vocab"] ) -> tuple[UInt32[mx.array, "batch"], BFloat16[mx.array, "batch vocab"]]: logit_bias: dict[int, float] | None = None if logit_bias: indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) logits[:, indices] += values if self.temp == 0: token = mx.argmax(logits, axis=-1) else: if self.top_p > 0 and self.top_p < 1.0: token = top_p_sampling(logits, self.top_p, self.temp) else: token = categorical_sampling(logits, self.temp) logprobs = logits - mx.logsumexp(logits) return token, logprobs ================================================ FILE: kyuteye_mlx/pixtral-12b-8bit.config ================================================ {"model_type": "pixtral", "num_hidden_layers": 24, "hidden_size": 1024, "head_dim": 64, "intermediate_size": 4096, "num_attention_heads": 16, "image_size": 1024, "patch_size": 16, "projection_dim": 768, "vocab_size": 32000, "num_channels": 3, "rms_norm_eps": 1e-05, "rope_theta": 10000.0} ================================================ FILE: kyuteye_mlx/pyproject.toml ================================================ [project] name = "kyuteye_mlx" requires-python = ">= 3.10,<3.13" description = "Kyutai with an 'eye', but running on macOS" dependencies = [ "numpy >= 2.1.0, < 2.2", "safetensors >= 0.4.0, < 0.5", "huggingface-hub >= 0.24, < 0.25", "rustymimi == 0.2.2", "sentencepiece == 0.2", "sounddevice == 0.5", "sphn >= 0.1.4", # Do not change this version of mlx. All the others up # to 0.22.1 are slower for this codebase. "mlx==0.23.1", "aiohttp>=3.10.5, <3.11", "pillow", "line-profiler>=4.2.0", "rich>=13.9.4", "packaging>=24.2", "jaxtyping==0.3.0", "beartype>=0.19.0", "fire>=0.7.0", ] authors = [ { name = "Gabriel de Marmiesse", email = "gabriel@kyutai.org" }, { name = "Moritz Boehle", email = "moritz@kyutai.org" } ] license = {text = "MIT"} dynamic = ["version"] readme = "README.md" [project.scripts] server = "kyuteye_mlx.local_web:main" sanity-check = "kyuteye_mlx.local_web:sanity_check" quantize = "kyuteye_mlx.quantize:main" benchmark = "kyuteye_mlx.benchmark:main" [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.setuptools.dynamic] version = {attr = "kyuteye_mlx.__version__"} [tool.ruff] line-length = 110 [tool.mypy] ignore_missing_imports = true disallow_untyped_defs = true [[tool.mypy.overrides]] module = "kyuteye_mlx.mlx_vlm.*" disallow_untyped_defs = false [dependency-groups] dev = [ "mypy>=1.11.2", "pylint>=3.3.4", "pytest>=8.3.4", "torch>=2.3.0", "moshi==0.1.0", "ruff>=0.9.7", "monkeytype>=23.3.0", ] ================================================ FILE: kyuteye_mlx/siglip448.config ================================================ {"model_type": "siglip_vision_model", "num_hidden_layers": 27, "hidden_size": 1152, "intermediate_size": 4304, "num_attention_heads": 16, "patch_size": 14, "projection_dim": 2304, "image_size": 448, "num_channels": 3, "layer_norm_eps": 1e-06} ================================================ FILE: kyuteye_mlx/tests/test_siglip.py ================================================ import mlx.core as mx import numpy as np import torch from transformers import AutoModelForImageTextToText, AutoProcessor from kyuteye_mlx.models.siglip import SiglipWrapper def convert_weights_for_mlx(weights: dict[str, torch.Tensor]) -> dict[str, mx.array]: new_weights = {} for k, v in weights.items(): new_key = k.removeprefix("vision_model.") new_key = "model." + new_key new_weights[new_key] = mx.array(v) new_weights["model.embeddings.patch_embedding.weight"] = new_weights[ "model.embeddings.patch_embedding.weight" ].transpose(0, 2, 3, 1) return new_weights @torch.no_grad() def test_siglip_weights_conversion() -> None: model_id = "google/paligemma2-3b-pt-448" processor = AutoProcessor.from_pretrained(model_id) image_processor = processor.image_processor model = AutoModelForImageTextToText.from_pretrained(model_id) np.random.seed(99) img = np.random.randint(0, 255, size=(448, 448, 3), dtype=np.uint8) inp = image_processor(images=img, return_tensors="pt") out_pytorch = model.vision_tower(pixel_values=inp["pixel_values"]) out_pytorch_as_np = out_pytorch.last_hidden_state.detach().numpy() as_mlx_weights = convert_weights_for_mlx(model.vision_tower.state_dict()) siglip_wrapper = SiglipWrapper() siglip_wrapper.load_weights(list(as_mlx_weights.items()), strict=True) output = siglip_wrapper(mx.array(img)) out_mlx_as_np = np.array(output, copy=False) diff_average = np.mean(np.abs(out_pytorch_as_np - out_mlx_as_np)) assert diff_average < 2e-5, f"Average difference between the two embeddings is {diff_average}" ================================================ FILE: kyuteye_pt/.pylintrc ================================================ [MAIN] # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Clear in-memory caches upon conclusion of linting. Useful if running pylint # in a server-like mode. clear-cache-post-run=no # Load and enable all available extensions. Use --list-extensions to see a list # all available extensions. #enable-all-extensions= # In error mode, messages with a category besides ERROR or FATAL are # suppressed, and no reports are done by default. Error mode is compatible with # disabling specific errors. #errors-only= # Always return a 0 (non-error) status code, even if lint errors are found. # This is primarily useful in continuous integration scripts. #exit-zero= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-allow-list= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. (This is an alternative name to extension-pkg-allow-list # for backward compatibility.) extension-pkg-whitelist= # Return non-zero exit code if any of these messages/categories are detected, # even if score is above --fail-under value. Syntax same as enable. Messages # specified are enabled, while categories only check already-enabled messages. fail-on= # Specify a score threshold under which the program will exit with error. fail-under=10 # Interpret the stdin as a python script, whose filename needs to be passed as # the module_or_package argument. #from-stdin= # Files or directories to be skipped. They should be base names, not paths. ignore=CVS # Add files or directories matching the regular expressions patterns to the # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. ignore-paths= # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores # Emacs file locks ignore-patterns=^\.# # List of module names for which member attributes should not be checked and # will not be imported (useful for modules/projects where namespaces are # manipulated during runtime and thus existing member attributes cannot be # deduced by static analysis). It supports qualified module names, as well as # Unix pattern matching. ignored-modules= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to # avoid hangs. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. py-version=3.10 # Discover python modules and packages in the file system subtree. recursive=no # Add paths to the list of the source roots. Supports globbing patterns. The # source root is an absolute path or a path relative to the current working # directory used to determine a package namespace for modules located under the # source root. source-roots= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # In verbose mode, extra non-checker-related info will be displayed. #verbose= [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. If left empty, argument names will be checked with the set # naming style. #argument-rgx= # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. If left empty, attribute names will be checked with the set naming # style. #attr-rgx= # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Bad variable names regexes, separated by a comma. If names match any regex, # they will always be refused bad-names-rgxs= # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. If left empty, class attribute names will be checked # with the set naming style. #class-attribute-rgx= # Naming style matching correct class constant names. class-const-naming-style=UPPER_CASE # Regular expression matching correct class constant names. Overrides class- # const-naming-style. If left empty, class constant names will be checked with # the set naming style. #class-const-rgx= # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. If left empty, class names will be checked with the set naming style. #class-rgx= # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. If left empty, constant names will be checked with the set naming # style. #const-rgx= # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. If left empty, function names will be checked with the set # naming style. #function-rgx= # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _ # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted good-names-rgxs= # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. If left empty, inline iteration names will be checked # with the set naming style. #inlinevar-rgx= # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. If left empty, method names will be checked with the set naming style. #method-rgx= # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. If left empty, module names will be checked with the set naming style. #module-rgx= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Regular expression matching correct type alias names. If left empty, type # alias names will be checked with the set naming style. #typealias-rgx= # Regular expression matching correct type variable names. If left empty, type # variable names will be checked with the set naming style. #typevar-rgx= # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. If left empty, variable names will be checked with the set # naming style. #variable-rgx= [CLASSES] # Warn about protected attribute access inside special methods check-protected-access-in-special-methods=no # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp, asyncSetUp, __post_init__ # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [DESIGN] # List of regular expressions of class ancestor names to ignore when counting # public methods (see R0903) exclude-too-few-public-methods= # List of qualified class names to ignore when counting class parents (see # R0901) ignored-parents= # Maximum number of arguments for function / method. max-args=5 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=12 # Maximum number of locals for function / method body. max-locals=15 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of return / yield for function / method body. max-returns=6 # Maximum number of statements in function / method body. max-statements=50 # Minimum number of public methods for a class (see R0903). min-public-methods=2 [EXCEPTIONS] # Exceptions that will emit a warning when caught. overgeneral-exceptions=builtins.BaseException,builtins.Exception [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=100 # Maximum number of lines in a module. max-module-lines=1200 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [IMPORTS] # List of modules that can be imported at any level, not just the top level # one. allow-any-import-level= # Allow explicit reexports by alias from a package __init__. allow-reexport-from-package=no # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules= # Output a graph (.gv or any supported image format) of external dependencies # to the given file (report RP0402 must not be disabled). ext-import-graph= # Output a graph (.gv or any supported image format) of all (i.e. internal and # external) dependencies to the given file (report RP0402 must not be # disabled). import-graph= # Output a graph (.gv or any supported image format) of internal dependencies # to the given file (report RP0402 must not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Couples of modules and preferred modules, separated by a comma. preferred-modules= [LOGGING] # The type of string formatting that logging methods do. `old` means using % # formatting, `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, # UNDEFINED. confidence=HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then re-enable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=raw-checker-failed, bad-inline-option, locally-disabled, file-ignored, suppressed-message, useless-suppression, deprecated-pragma, use-symbolic-message-instead, use-implicit-booleaness-not-comparison-to-string, use-implicit-booleaness-not-comparison-to-zero, too-many-locals, unspecified-encoding, too-many-arguments, too-many-instance-attributes, too-many-branches, too-many-statements, too-many-return-statements, too-many-public-methods, too-few-public-methods, use-dict-literal, unnecessary-lambda-assignment, too-many-function-args # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable= [METHOD_ARGS] # List of qualified names (i.e., library.method) which require a timeout # parameter e.g. 'requests.api.get,requests.api.post' timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO # Regular expression of note tags to take in consideration. notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit,argparse.parse_error [REPORTS] # Python expression which should return a score less than or equal to 10. You # have access to the variables 'fatal', 'error', 'warning', 'refactor', # 'convention', and 'info' which contain the number of messages in each # category, as well as 'statement' which is the total number of statements # analyzed. This score is used by the global evaluation report (RP0004). evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. msg-template= # Set the output format. Available formats are: text, parseable, colorized, # json2 (improved json format), json (old json format) and msvs (visual # studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. #output-format= # Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. score=yes [SIMILARITIES] # Comments are removed from the similarity computation ignore-comments=yes # Docstrings are removed from the similarity computation ignore-docstrings=yes # Imports are removed from the similarity computation ignore-imports=yes # Signatures are removed from the similarity computation ignore-signatures=yes # Minimum lines number of a similarity. min-similarity-lines=12 [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. No available dictionaries : You need to install # both the python package and the system dependency for enchant to work. spelling-dict= # List of comma separated words that should be considered directives if they # appear at the beginning of a comment and should not be checked. spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains the private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to the private dictionary (see the # --spelling-private-dict-file option) instead of raising a message. spelling-store-unknown-words=no [STRING] # This flag controls whether inconsistent-quotes generates a warning when the # character used as a quote delimiter is used inconsistently within a module. check-quote-consistency=no # This flag controls whether the implicit-str-concat should generate a warning # on implicit string concatenation in sequences defined over several lines. check-str-concat-over-line-jumps=no [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of symbolic message names to ignore for Mixin members. ignored-checks-for-mixins=no-member, not-async-context-manager, not-context-manager, attribute-defined-outside-init # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 # Regex pattern to define which classes are considered mixins. mixin-class-rgx=.*[Mm]ixin # List of decorators that change the signature of a decorated function. signature-mutators= [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of names allowed to shadow builtins allowed-redefined-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io ================================================ FILE: kyuteye_pt/LICENSE.md ================================================ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: kyuteye_pt/README.md ================================================ # MoshiVis - PyTorch See the [top-level README.md][main_repo] for more information on MoshiVis. This is the PyTorch implementation for MoshiVis. ## License The present code is provided under the MIT license. ## Citation If you use MoshiVis, please cite this repository and the Moshi paper. ``` @article{kyutai2025moshivis, author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez}, year = {2025}, title = {Vision-Speech Models: Teaching Speech Models to Converse about Images}, journal = {ArXiv}, url = {https://arxiv.org/abs/2503.15633} } @techreport{kyutai2024moshi, title={Moshi: a speech-text foundation model for real-time dialogue}, author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour}, year={2024}, eprint={2410.00037}, archivePrefix={arXiv}, primaryClass={eess.AS}, url={https://arxiv.org/abs/2410.00037}, } ``` [main_repo]: https://github.com/kyutai-labs/moshivis ================================================ FILE: kyuteye_pt/configs/moshika-vis.yaml ================================================ add_boi_eoi: false align_img_and_speech_tokens_dim: true encoder_name: siglip_gemma2_448 hf_repo: 'kyutai/moshika-vis-pytorch-bf16' image_size: 512 interpolation: bicubic mimi_codec: tokenizer-e351c8d8-checkpoint125.safetensors model: model.safetensors norm_extra: null norm_xa: rms_norm num_crossattended_tokens: -1 num_extra_tokens: 0 text_tokenizer: tokenizer_spm_32k_3.model xa_conditional_gating: true xa_delay: 0 xa_dim: null xa_start: start xa_end: end xa_gate_shared: false xa_gating: sigmoid xa_layers: [] xa_shared: true xa_step: 1 ================================================ FILE: kyuteye_pt/kyuteye/__init__.py ================================================ ================================================ FILE: kyuteye_pt/kyuteye/config/__init__.py ================================================ ================================================ FILE: kyuteye_pt/kyuteye/config/enums.py ================================================ """Knowledge base of useful fixed values used across the codebase""" from enum import Enum, unique from typing import List, Optional, Tuple @unique class ImageEncoder(Enum): """Encapsulate every image encoder""" CLIP_VIT = "clip_vit" CLIP_VIT_LARGE = "clip_vit_large" MOBILECLIP_S1 = "mobileclip_s1" MOBILECLIP_S2 = "mobileclip_s2" # original siglip SIGLIP = "siglip" # Siglip Pretrained: Phase 2 of PaliGemma 1 training SIGLIP_GEMMA1_224 = "siglip_gemma1_224" # Same but for PaliGemma 2 SIGLIP_GEMMA2_224 = "siglip_gemma2_224" SIGLIP_GEMMA2_448 = "siglip_gemma2_448" SIGLIP_GEMMA2_896 = "siglip_gemma2_896" PIXTRAL = "pixtral" @property def out_dims(self) -> int: """Return the number of dimensions output by the given image encoder""" if self in { ImageEncoder.PIXTRAL, ImageEncoder.CLIP_VIT_LARGE, ImageEncoder.MOBILECLIP_S1, }: return 1024 if self == ImageEncoder.CLIP_VIT: return 512 if self == ImageEncoder.MOBILECLIP_S2: return 1280 if self in { ImageEncoder.SIGLIP, ImageEncoder.SIGLIP_GEMMA1_224, ImageEncoder.SIGLIP_GEMMA2_224, ImageEncoder.SIGLIP_GEMMA2_448, ImageEncoder.SIGLIP_GEMMA2_896, }: return 1152 raise NotImplementedError("Unknown image encoder", self.name) def to_rust(self) -> str: """Return the corresponding `ImageEncoder` enum name in the rust codebase""" if self == ImageEncoder.PIXTRAL: return "Pixtral" if self in { ImageEncoder.SIGLIP, ImageEncoder.SIGLIP_GEMMA1_224, ImageEncoder.SIGLIP_GEMMA2_224, }: return "Siglip224" if self == ImageEncoder.SIGLIP_GEMMA2_448: return "Siglip448" if self == ImageEncoder.SIGLIP_GEMMA2_896: return "Siglip896" if self == ImageEncoder.MOBILECLIP_S1: return "MobileclipS1" if self == ImageEncoder.MOBILECLIP_S2: return "MobileclipS2" raise ValueError( f"Image encoder {self.name} is not implemented in the Rust codebase" ) ================================================ FILE: kyuteye_pt/kyuteye/config/kyuteye_config.py ================================================ """Main configuration object used to configure the model and training pipeline""" import os from copy import deepcopy from dataclasses import asdict, fields from pathlib import Path from typing import Any, Dict, Optional, Sequence import torch import yaml from kyuteye.config.enums import ImageEncoder from kyuteye.config.subconfigs import ( FusionConfig, ImageEncoderConfig, LMConfig, MoshiConfig, ) from kyuteye.utils.dist_utils import print_main from kyuteye.utils.logging_utils import flatten_nested_dict class KyuteyeConfig: """Base class encapsulating all options for training and evaluating a multimodal model""" fuse: FusionConfig image: ImageEncoderConfig lm: LMConfig def __init__(self, **kwargs: Any): self._fields_to_sub: Dict[str, str] = {} # Define all modular subconfigs defined in subconfigs.py backup_kwargs = deepcopy(kwargs) self._subnames = [] for name, constructor in [ ("fuse", FusionConfig), ("image", ImageEncoderConfig), ("lm", LMConfig), ("moshi", MoshiConfig), ]: keys = {f.name for f in fields(constructor)} setattr( self, name, constructor( **{k: backup_kwargs.pop(k) for k in keys if k in backup_kwargs} ), ) self._subnames.append(name) # Extra arguments comming from the CLI argparser will be passed but not used # we still mark them to remember them if len(backup_kwargs): print_main( "\n[bold yellow]WARN:[/bold yellow] Found superfluous arguments " "passed to [cyan]KyuteyeConfig[/cyan]:\n" + "\n".join(f" - {k} = {v}" for k, v in backup_kwargs.items()) + "\n", rich=True, flush=True, ) # Map field names to the correct subconfig it belongs to for name in self._subnames: for f in fields(getattr(self, name)): if f.name in self._fields_to_sub: raise AssertionError( f"Subconfig {name} uses field {f.name} which is already in use" ) if f.name in self._subnames: raise AssertionError( f"Subconfig {name} uses field {f.name} which" " is already defined in KyuteyeConfig" ) self._fields_to_sub[f.name] = name # Postinit if self.fuse.num_crossattended_tokens == 0: self.image.norm_xa = None if self.fuse.num_extra_tokens == 0: self.image.norm_extra = None if self.fuse.xa_dim is None: if not self.fuse.align_img_and_speech_tokens_dim: self.fuse.xa_dim = ImageEncoder(self.image.encoder_name).out_dims def __getattribute__(self, name: str) -> Any: """Getattr with direct shortcut to all subconfigs' subfields""" if name not in ["_fields_to_sub", "_subnames"] and name in self._fields_to_sub: return getattr(getattr(self, self._fields_to_sub[name]), name) return super().__getattribute__(name) def __setattr__(self, name: str, value: Any) -> None: """Setattr with direct shortcut to all subconfigs' fields""" if hasattr(self, "_fields_to_sub") and name in self._fields_to_sub: setattr(getattr(self, self._fields_to_sub[name]), name, value) else: super().__setattr__(name, value) @property def moshi_constructor_kwargs(self) -> Dict[str, Any]: """Return constructor for constructing a MoshiVis model""" return dict( **asdict(self.moshi), xa_dim=self.fuse.xa_dim, **self.fuse.crossattention_kwargs, ) @classmethod def from_yml(cls, path: Path | str) -> "KyuteyeConfig": """Initialize current config from a yaml file""" return cls(**__load_yaml__(str(path))) def to_yml(self, path: Optional[Path | str] = None) -> None: """Save current config to a yaml file""" if path is None: path = self.output_dir path = str(path) __save_yaml__( { k: ( v.name if hasattr(v, "name") else ( str(v).replace("torch.", "") if isinstance(v, torch.dtype) else ( tuple(x.name for x in v) if k in {"train_dataset", "eval_dataset", "blind_dataset"} else v ) ) ) for k, v in self.to_dict().items() }, path, ) def print(self, flat: bool = False, only: Optional[Sequence[str]] = None) -> None: """Pretty print current config""" print_main("-" * 100, "[bold green]Config[/bold green]", rich=True) if flat: print_main( "\n".join( f"\t{k} = {v}" for k, v in self.to_dict(flat=True).items() if (only is None or k in only) ), rich=True, ) else: for name, subd in self.to_dict(flat=False).items(): if only is not None and name not in only: continue print_main("\n\t", "-" * 50, f"[cyan]{name}[/cyan]", rich=True) print_main( "\n".join(f"\t\t{k} = {v}" for k, v in subd.items()), rich=True, ) print_main("-" * 100) def to_dict(self, flat: bool = True) -> Dict[str, Any]: """Returns current config as a (flat) dictionary""" d = { subconfig: asdict(getattr(self, subconfig)) for subconfig in self._subnames } if flat: return flatten_nested_dict(d) return d def __load_yaml__(path: Path | str) -> Dict: """Load a dictionary from a YAML file :param path: Path to load the config from :return: the dictionary of kwargs that will be fed to `KyuteyeConfig` """ path = str(path) assert os.path.exists(path), f"Could not load config {path}: File does not exist" with open(path, "r") as stream: config = yaml.safe_load(stream) # KyuteyeConfig works with flattened dict as inputs config = flatten_nested_dict(config) # Yaml parse sequences sa list -> tuples config = {k: tuple(v) if isinstance(v, list) else v for k, v in config.items()} return config def __save_yaml__(config: Dict, path: Path | str) -> None: """Save a config dictionary to a YAML :param config: Kyuteye config converted to a dict :param path: Path to save the config to """ path = str(path) path = path + (".yml" if not path.endswith(".yml") else "") base_dir = os.path.abspath(os.path.dirname(path)) os.makedirs(base_dir, exist_ok=True) with open(path, "w") as stream: config = yaml.safe_dump(config, stream) ================================================ FILE: kyuteye_pt/kyuteye/config/subconfigs.py ================================================ """Modular configs for configuring a Kyuteye model training run""" from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Dict, Literal, Optional, Tuple from kyuteye.config.enums import ImageEncoder from kyuteye.utils.dist_utils import print_main def __is_nonstring_iterable__(arg: Any) -> bool: return isinstance(arg, Iterable) and not isinstance(arg, str) @dataclass(frozen=True) class LMConfig: """Configure the model weights""" # if repo is None, then model, mimi_codec and tokenizer are # expected to point to local files hf_repo: Optional[str] = "kyutai/moshika-vis-pytorch-bf16" # MoshiVis model; Note that if the model doesn't contain weights # for the frozen image encoder, those will be loaded from audiocraft # directly model: str = "model.safetensors" # Mimi codec mimi_codec: str = "tokenizer-e351c8d8-checkpoint125.safetensors" # Tokenizer for Helium text_tokenizer: str = "tokenizer_spm_32k_3.model" @staticmethod def help(field_name: str) -> str: """Optional; returns the argparse's help message for each field name in this subconfig""" if field_name == "model_path": return ( "Path to .safetensors containing model weights (vision encoder + LLM)" ) if field_name == "mimi_codec": return "Path to .safetensors containing Mimi weights" if field_name == "text_tokenizer": return "Path to .safetensors containing text tokenizer" return "" @dataclass(frozen=False) class ImageEncoderConfig: """Configure the image encoder for MoshiVis""" # Main image backbone to load encoder_name: str = "pixtral" interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = ( "bicubic" ) # Whether to add dropout after the learned image linear projection image_size: int = 256 # normalization used after projecting the extra tokens norm_extra: Optional[Literal["layer_norm", "rms_norm"]] = "rms_norm" # normalization used after projecting the xa tokens norm_xa: Optional[Literal["layer_norm", "rms_norm"]] = "rms_norm" def __post_init__(self) -> None: assert self.image_size > 0 self.encoder_name = self.encoder_name.lower() try: ImageEncoder(self.encoder_name) except ValueError as e: raise ValueError(f"Unknown image encoder {self.encoder_name}") from e self.aug_strategy = "Pixtral" if self.encoder_name == "pixtral" else "None" @staticmethod def help(field_name: str) -> str: """Optional; returns the argparse's help message for each field name in this subconfig""" if field_name == "encoder_name": return "Name of the pretrained image encoder to use" if field_name == "interpolation": return "Interpolation algorithm for image resizing" if field_name == "image_size": return "Input image size to the encoder" return "" @dataclass(frozen=True) class MoshiConfig: """Configure the backbone Moshi""" dim: int = 4096 text_card: int = 32000 padding_token_id: int = 3 n_q: int = 16 dep_q: int = 8 audio_card: int = 2048 num_heads: int = 32 num_layers: int = 32 hidden_scale: float = 4.125 causal: bool = True context: int = 3000 max_period: int = 10000 gating: bool = True activation: str = "silu" norm: str = "rms_norm_f32" positional_embedding: str = "rope" depformer: bool = True depformer_dim: int = 1024 depformer_dim_feedforward: int = int(4.125 * 1024) depformer_num_heads: int = 16 depformer_num_layers: int = 6 depformer_multi_linear: bool = True depformer_context: int = 8 depformer_gating: bool = True depformer_activation: str = "silu" depformer_pos_emb: str = "none" depformer_weights_per_step: bool = True delays: Tuple[int, ...] = (0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1) @staticmethod def help(field_name: str) -> str: """Optional; returns the argparse's help message for each field name in this subconfig""" if field_name == "model_path": return ( "Path to .safetensors containing model weights (vision encoder + LLM)" ) if field_name == "mimi_codec": return "Path to .safetensors containing Mimi weights" if field_name == "text_tokenizer": return "Path to .safetensors containing text tokenizer" return "" @dataclass(frozen=False) class FusionConfig: """Configures how we integrate the image information in MoshiVis""" num_extra_tokens: int = -1 add_boi_eoi: bool = False token_insertion: Optional[Literal["prefix"]] = "prefix" num_crossattended_tokens: int = 0 align_img_and_speech_tokens_dim: bool = True xa_dim: Optional[int] = None xa_start: Optional[Literal["start", "boi", "eoi"]] = None xa_end: Optional[Literal["end", "eoi"] | int] = None xa_step: int = 1 xa_delay: int = 0 xa_shared: bool = True xa_gate_shared: bool = False xa_gating: Literal["tanh", "sigmoid", "none"] = "tanh" xa_conditional_gating: bool = False xa_layers: Tuple[int, ...] = () xa_extended_layer_dims: Optional[int] = None xa_extended_layer_embed_dims: int = 1024 xa_extended_layer_p_norm: int = 2 @staticmethod def help(field_name: str) -> str: """Optional; returns the argparse's help message for each field name in this subconfig""" if field_name == "num_extra_tokens": return ( "Number of extra tokens (containing image information) to insert into" "the text+audio tokens. If -1, uses all the image tokens." ) if field_name == "token_insertion": return ( "An option used later in `fuse_utils` to determine where" " to insert the extra tokens" ) if field_name == "num_crossattended_tokens": return ( "Number of extra tokens (containing image information) to use as keys/values" " source in the cross-attention mechanism. If -1, uses all the image tokens" ) if field_name == "xa_shared": return ( "If True, the linear layers of the cross-attention mechanism" "are shared across layers" ) if field_name == "xa_gate_shared": return ( "If True, the gating modules of the cross-attention mechanism" "are shared across layers" ) if field_name == "xa_gating": return "Type of multiplicative gating at the output of each cross-attention" if field_name == "xa_conditional_gating": return "Whether to make the gating input dependent" if field_name == "xa_layers": return ( "If given, specifies the subset of layers to apply cross-attention to" ) if field_name == "xa_start": return "At which token to start applying the cross-attention" if field_name == "xa_end": return "Until which token to apply the cross-attention" if field_name == "xa_step": return "Applies cross-attention every `xa_step` token between xa_start and xa_end" if field_name == "xa_delay": return "Applies cross-attention to token t with the embedding from token t - xa_delay" if field_name == "xa_extended_layer_dims": return "By how many dimensions to extend the layers in xa_layers." return "" def __post_init__(self) -> None: """Check that the subconfig is valid""" if self.num_extra_tokens == 0 and self.num_crossattended_tokens == 0: print_main("[bold yellow]WARN:[/bold yellow]: No fusion mechanisms active") if not self.align_img_and_speech_tokens_dim and self.xa_dim is not None: print_main( "[bold yellow]WARN:[/bold yellow]: xa_dim is set hence it " "takes precedence over the value fed to align_img_and_speech_tokens" ) # Convert if isinstance(self.xa_layers, list): self.xa_layers = tuple(self.xa_layers) if self.xa_end is not None and self.xa_end not in {"end", "eoi"}: self.xa_end = int(self.xa_end) if self.num_extra_tokens == 0: self.add_boi_eoi = False if self.num_crossattended_tokens == 0: self.xa_start = None self.xa_end = None else: assert ( self.xa_start is not None ), "xa_start should be given to use cross-attention" assert ( self.xa_end is not None ), "xa_end should be given to use cross-attention" if self.num_extra_tokens == 0: assert ( self.xa_start not in { "boi", "eoi", } and self.xa_end != "eoi" ), "BoI and EoI tokens are not inserted if `num_extra_tokens` is 0" elif self.xa_start != "start": print_main( "[yellow]WARN:[/yellow]Making cross_attention start before BoI is weird", rich=True, ) if self.xa_step > 1: raise NotImplementedError("xa_step > 1 is not implemented yet") if self.xa_extended_layer_dims is not None: assert ( self.xa_extended_layer_dims > 0 ), "xa_extended_layer_dims should be positive" return @property def crossattention_kwargs(self) -> Dict[str, Any]: """Return crossattention related kwargs used for model initialization. These are passed to modules/cross_attention.py:GatedCrossAttention down the line""" return { # passed to Moshi/Helium constructor "cross_attention": self.num_crossattended_tokens != 0, "xa_layers": self.xa_layers, # Futher passed to GatedCrossAttention construction "xa_gating": self.xa_gating, "xa_conditional_gating": self.xa_conditional_gating, "xa_shared": self.xa_shared, "xa_gate_shared": self.xa_gate_shared, "xa_start": self.xa_start, "xa_end": self.xa_end, "xa_step": self.xa_step, "xa_delay": self.xa_delay, } ================================================ FILE: kyuteye_pt/kyuteye/models/__init__.py ================================================ ================================================ FILE: kyuteye_pt/kyuteye/models/docker-bake.hcl ================================================ ================================================ FILE: kyuteye_pt/kyuteye/models/helium.py ================================================ # pylint: disable=redefined-outer-name, pointless-string-statement """Port of Helium from Jax to Pytorch and then HF. The architecture is also made to be easily converted to the mimi/audiocraft codebase""" from typing import Any, Literal, Optional, Tuple import torch from kyuteye.modules.transformer import Transformer from kyuteye.modules.utils import ClampedEmbedding, NormalizationLayer class Helium(torch.nn.Module): """Jax -> Pytorch port of Helium (text LLM) :param dim: Inner dimension of the tokens :param num_head: Number of heads :param num_layers: Number of layers :param hidden_scale: Scale for the inner dimension of the FFN layes (wrt. `dim`) :param context: Context size for the model. This is only used to determine the automatic causal mask and also set the KV cache accordingly at inference :param cross_attention: Whether to add cross-attention layers :param card: Cardinality of the vocabulary :param output_card: (Optional) can be used to specify a different number of output tokens than card (which is used for the embedding layers). This is useful for audio models that add an extra *initial token* which is never predicted :param norm: Type of normalization layers to use :param positional_embedding: Type of positional embedding to use :param max_period: Maximum period / theta for RoPE embeddings :param causal: Whether to automatically force a causal mask in MHSA :param gating: If True, will use gated FFN (Swi-GLU like) :param activation: Activation to use for the FFN gating, if `gating` is True :param padding_token_id: Padding token ID of the tokenizer :param freeze_padding_embedding: If True, the embedding of the padding token ID will not receive gradients/updates during training :param device: Device to load the model on :param dtype: Dtype to define the model parameters as """ def __init__( self, # Architecture dim: int, num_heads: int, num_layers: int, hidden_scale: float = 4.125, context: int = 2048, cross_attention: bool = False, card: int = 32000, output_card: Optional[int] = None, norm: Literal[ "rms_norm", "rms_norm_f32", "real_rms_norm", "real_rms_norm_f32", ] = "real_rms_norm_f32", positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "rope", max_period: float = 10000, causal: bool = True, gating: bool = True, activation: str = "silu", # Padding token padding_token_id: int = 3, freeze_padding_embedding: bool = False, zero_token_id: int = -1, # Other kwargs device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs: Any, ): super().__init__() # Initial tokens projection self.dim = dim self.text_emb = ClampedEmbedding( card, dim, padding_idx=padding_token_id if freeze_padding_embedding else None, zero_idx=zero_token_id, dtype=dtype, device=device, ) # Output RMS Norm and linear layer self.out_norm = getattr(NormalizationLayer, norm.upper()).create_norm_fn( dim, device=device, dtype=dtype ) self.text_linear = torch.nn.Linear( dim, card if output_card is None else output_card, bias=False ) self.cross_attention = cross_attention # Main transformer self.transformer = Transformer( # Architecture d_model=dim, num_heads=num_heads, num_layers=num_layers, dim_feedforward=int(hidden_scale * dim), causal=causal, context=context, cross_attention=cross_attention, positional_embedding=positional_embedding, max_period=max_period, # Transformer Layer kwargs gating=gating, activation=activation, norm=norm, # Others device=device, dtype=dtype, **kwargs, ) def forward( self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_src: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, return_features: bool = False, ) -> Tuple[torch.Tensor, float]: """Forward function :param input_ids: Input tokens IDs with size `(batch_size, seq_length)` :param inputs_embeds: If given, this will skip input_ids and the initial embedding phase. Instead, `inputs_embeds` are directly fed to the transformer :param attention_mask: 0/1 Attention mask with shape `(batch_size, seq_length)`. Indicates which tokens to mask in the attention, e.g. padding tokens :param cross_attention_src: Cross-attention source with shape `(batch_size, seq_length, dim)` :param cross_attention_mask: Cross-attention mask with shape `(batch_size, seq_length)` :param return_features: If True, will skip the last classification layer and only output features with dimensions `(batch_size, seq_length, dim)` """ if inputs_embeds is None: assert input_ids is not None inputs_embeds = self.text_emb(input_ids) x, gate_weight = self.transformer( inputs_embeds, attention_mask=attention_mask, cross_attention_src=cross_attention_src, cross_attention_mask=cross_attention_mask, ) x = self.out_norm(x) if return_features: return x, gate_weight return self.text_linear(x), gate_weight ================================================ FILE: kyuteye_pt/kyuteye/models/hf_model_configs.py ================================================ # pylint: disable=protected-access """Configuration for HF-compliant models""" from typing import Any, Literal, Optional, Sequence from transformers import PretrainedConfig class HeliumConfig(PretrainedConfig): """Config class for Helium language models (LLM part)""" model_type = "Helium_v2" def __init__( self, dim: int = 1024, num_heads: int = 12, num_layers: int = 24, hidden_scale: int = 4, context: int = 2048, cross_attention: bool = False, card: int = 32000, output_card: Optional[int] = None, norm: Literal[ "rms_norm", "rms_norm_f32", "real_rms_norm", "real_rms_norm_f32", ] = "real_rms_norm_f32", positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "rope", freeze_padding_embedding: bool = False, max_period: float = 10000, causal: bool = True, gating: bool = True, activation: Literal[ "none", "identity", "sigmoid", "tanh", "relu", "leaky_relu", "elu", "gelu", "silu", "mish", "softsign", ] = "silu", bos_token_id: int = 1, eos_token_id: int = 2, pad_token_id: int = 3, **kwargs: Any, ): super().__init__( vocab_size=32000, hidden_size=dim, num_attention_heads=num_heads, num_hidden_layers=num_layers, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) self.dim = dim self.num_heads = num_heads self.num_layers = num_layers self.hidden_scale = hidden_scale self.context = context self.cross_attention = cross_attention self.positional_embedding = positional_embedding self.card = card self.output_card = output_card self.norm = norm self.max_period = max_period self.causal = causal self.gating = gating self.activation = activation self.freeze_padding_embedding = freeze_padding_embedding class MoshiVisConfig(HeliumConfig): """Config for Moshi-Vis""" model_type = "Moshi_v1" def __init__( self, text_card: int = 32000, text_context: int = 3000, n_q: int = 8, n_q_per_source: Optional[int] = None, audio_card: int = 1024, depformer: bool = False, depformer_multi_linear: bool = False, depformer_pos_emb: Optional[Literal["none", "sin", "rope", "sin_rope"]] = None, depformer_dim: Optional[int] = None, depformer_dim_feedforward: Optional[int] = None, depformer_num_layers: Optional[int] = None, depformer_num_heads: Optional[int] = None, depformer_weights_per_step: bool = False, depformer_input_cumsum: bool = False, depformer_skip_self_attn: bool = False, delays: Optional[Sequence[int]] = None, same_initial: bool = False, text_loss_weight: float = 1.0, audio_loss_weight: float = 1.0, audio_other_channel_loss_weight: float = 1.0, audio_semantic_loss_weight: float = 100.0, audio_acoustic_loss_weight: float = 1.0, padding_loss_weight: float = 1.0, textonly_padding_loss_weight: float = 1.0, audio_padding_loss_weight: float = 1.0, sparsity_loss_weight: float = 0, mask_audio_codebooks_from: int = -1, add_pad_embed_text_only_other_tokens: bool = False, **kwargs: Any, ) -> None: super().__init__( **kwargs, ) # rename some args from HeliumConfig to explicitly # distinguish audio from text del self.card self.text_card = text_card del self.context self.text_context = text_context # add Audio config self.n_q = n_q self.n_q_per_source = n_q_per_source or n_q self.audio_card = audio_card self.depformer = depformer self.depformer_multi_linear = depformer_multi_linear self.depformer_pos_emb = depformer_pos_emb self.depformer_dim = depformer_dim self.depformer_dim_feedforward = depformer_dim_feedforward self.depformer_num_layers = depformer_num_layers self.depformer_num_heads = depformer_num_heads self.depformer_weights_per_step = depformer_weights_per_step self.depformer_input_cumsum = depformer_input_cumsum self.depformer_skip_self_attn = depformer_skip_self_attn self.delays = list(delays) if delays is not None else None self.same_initial = same_initial # loss weights for the different codebbooks self.text_loss_weight = text_loss_weight self.audio_loss_weight = audio_loss_weight self._audio_other_channel_loss_weight = audio_other_channel_loss_weight self._audio_semantic_loss_weight = audio_semantic_loss_weight self._audio_acoustic_loss_weight = audio_acoustic_loss_weight self.padding_loss_weight = padding_loss_weight self.textonly_padding_loss_weight = textonly_padding_loss_weight self.audio_padding_loss_weight = audio_padding_loss_weight self._sparsity_loss_weight = sparsity_loss_weight self.mask_audio_codebooks_from = mask_audio_codebooks_from self.add_pad_embed_text_only_other_tokens = add_pad_embed_text_only_other_tokens if self.mask_audio_codebooks_from >= 0: self.num_audio_tokens_in_loss_main = self.mask_audio_codebooks_from self.num_audio_tokens_in_loss_other = self.mask_audio_codebooks_from else: self.num_audio_tokens_in_loss_main = self.n_q_per_source - 1 self.num_audio_tokens_in_loss_other = (self.n_q - self.n_q_per_source) - 1 @property def total_audio_loss_weight(self) -> float: """Total weight used to normalize the losses on audio tokens""" return ( # weight for Moshi self._audio_semantic_loss_weight + self._audio_acoustic_loss_weight * self.num_audio_tokens_in_loss_main # weight for Other + self._audio_other_channel_loss_weight * ( self._audio_semantic_loss_weight + self._audio_acoustic_loss_weight * self.num_audio_tokens_in_loss_other ) ) @property def audio_semantic_loss_weight(self) -> float: """Loss weight set on the semantic token""" return ( self.audio_loss_weight * self._audio_semantic_loss_weight / (self.total_audio_loss_weight + 1e-6) ) @property def audio_acoustic_loss_weight(self) -> float: """Loss weight set on the semantic token""" return ( self.audio_loss_weight * self._audio_acoustic_loss_weight / (self.total_audio_loss_weight + 1e-6) ) @property def audio_other_semantic_loss_weight(self) -> float: """Loss weight set on audio codebooks for OTHER channel""" return self._audio_other_channel_loss_weight * self.audio_semantic_loss_weight @property def audio_other_acoustic_loss_weight(self) -> float: """Loss weight set on audio codebooks for OTHER channel""" return self._audio_other_channel_loss_weight * self.audio_acoustic_loss_weight @property def sparsity_loss_weight(self) -> float: """Loss weight set on the sparsity loss for the extended transformer.""" return self._sparsity_loss_weight ================================================ FILE: kyuteye_pt/kyuteye/models/image_projection.py ================================================ """Image encoders (CLIP, SigLIP)""" from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch from einops import rearrange from kyuteye.config.enums import ImageEncoder from kyuteye.modules.image_encoder import ( PixtralOutput, get_img_normalize, load_image_encoder, ) from kyuteye.modules.utils import NormalizationLayer if TYPE_CHECKING: from kyuteye.config.kyuteye_config import KyuteyeConfig class ImageProjection(torch.nn.Module): """ Takes in a batch of images and returns a batch of embeddings of the same dimensions, which are then fed to the LM, either by inserting the tokens in the stream, or by using them as source for cross-attention moduels (or both). :param config: KyuteyeConfig object :param lm_model_dim: Output dimension (number of channels) for this module """ def __init__( self, kyuteye_config: "KyuteyeConfig", lm_model_dim: Optional[int], load_pretrained_encoder: bool = True, ): super().__init__() self.kyuteye_config = kyuteye_config try: self.encoder_type = getattr( ImageEncoder, self.kyuteye_config.image.encoder_name.upper() ) except AttributeError as e: raise NotImplementedError( f"Unknown image encoder {self.encoder_type}" ) from e # Number of output dimensions of the entire module (i.e. including # potential projection after the encoder) if self.kyuteye_config.xa_dim is not None: self.out_dim = self.kyuteye_config.xa_dim else: assert lm_model_dim is not None self.out_dim = lm_model_dim # Load the image encoder self.enc = load_image_encoder( self.encoder_type, pretrained=load_pretrained_encoder ) # Projection layer; there are two possible projection targets # A. for the extra tokens self.proj_extra = self.init_proj_module( self.kyuteye_config.fuse.num_extra_tokens ) # B. for the cross attention self.proj_xa = self.init_proj_module( self.kyuteye_config.fuse.num_crossattended_tokens ) # Output normalizations self.norm_extra = self.init_norm_module(self.kyuteye_config.image.norm_extra) self.norm_xa = self.init_norm_module(self.kyuteye_config.image.norm_xa) @classmethod def from_config( cls, kyuteye_config: "KyuteyeConfig", lm_model_dim: Optional[int] = None, moshi_weight: Optional[Dict[str, Any]] = None, device: str | torch.device = "cpu", ) -> "ImageProjection": """Init image projection from config""" load_pretrained_encoder = moshi_weight is None or not any( x.startswith("enc.") for x in moshi_weight ) image_projection = cls( kyuteye_config, lm_model_dim, load_pretrained_encoder=load_pretrained_encoder, ) if moshi_weight is not None: missing_keys, _ = image_projection.load_state_dict( moshi_weight, strict=False ) encoder_keys: List[str] = [] proj_keys: List[str] = [] for key in missing_keys: (encoder_keys if key.startswith("enc.") else proj_keys).append(key) print(proj_keys) assert len(proj_keys) == 0, "Failed to load image to speech projections" print(encoder_keys) assert len(encoder_keys) == 0, "Failed to load frozen image encoder" return image_projection.to(device) def init_proj_module(self, num_tokens: int) -> Optional[torch.nn.Module]: """Init the project module for the inserted and/or cross-attended iamge tokens""" if num_tokens == 0: return None if num_tokens == -1: if self.encoder_out_dim != self.out_dim: return torch.nn.Linear(self.encoder_out_dim, self.out_dim) return torch.nn.Identity() raise ValueError(f"Found negative number of tokens for projection {num_tokens}") @property def encoder_out_dim(self) -> int: """Number of dimension output by the encoder""" return self.encoder_type.out_dims @property def to_tensor_and_normalize(self) -> Callable: """Image normalization function""" return get_img_normalize(self.encoder_type)() def init_norm_module(self, norm_type: Optional[str]) -> Optional[torch.nn.Module]: """Init normalization module""" if norm_type is None: return None return getattr(NormalizationLayer, norm_type.upper()).create_norm_fn( self.out_dim ) def forward(self, x: torch.Tensor | List[torch.Tensor]) -> Dict[str, torch.Tensor]: """Image embedding mapping""" # Apply image encoder encoded, mask = self.encode(x) # Apply different projection for extra vs cross attended tokens out = {} # The mask will be handled by the QP mapper and this module will output the same # number of tokens for every sample in the batch, i.e., no padding is needed anymore. # => We will not forward the mask in this case. if mask is not None: out["cross_attention_mask"] = mask if self.proj_extra is not None: assert mask is None, "proj_extra is not implemented yet for pixtral." out["image_embeds"] = self.project_extra(encoded) if self.proj_xa is not None: out["cross_attention_src"] = self.project_xa(encoded) return out def encode( self, x: torch.Tensor | List[torch.Tensor] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Pass through the image encoder backbone and reshape to (batch, seq, dim) Tensors""" logits = self.enc(x) if self.encoder_type == ImageEncoder.PIXTRAL: assert isinstance(logits, PixtralOutput) return logits.out, logits.mask if logits.ndim == 4: logits = rearrange(logits, "b d h w -> b (h w) d") if logits.ndim != 3: raise ValueError( "The image encoder should output a tensor of shape" " (B, Seq, D) (ViT) or (B, D, H, W) (CNN)" ) return logits, None def project_extra(self, logits: torch.Tensor) -> torch.Tensor: """Projection 1: Used for inserted extra tokens""" assert self.proj_extra is not None logits = self.proj_extra(logits) if self.norm_extra is not None: return self.norm_extra(logits) return logits def project_xa(self, logits: torch.Tensor) -> torch.Tensor: """Projection 2: Used for cross-attended tokens""" assert self.proj_xa is not None logits = self.proj_xa(logits) if self.norm_xa is not None: return self.norm_xa(logits) return logits ================================================ FILE: kyuteye_pt/kyuteye/models/loaders.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Load moshi-vis neccessary components.""" from typing import Any, Dict, Optional, Tuple import torch from kyuteye.config.kyuteye_config import KyuteyeConfig from kyuteye.models.image_projection import ImageProjection from kyuteye.models.moshivis import MoshiVisGen def get_moshi_vis( kyuteye_config: KyuteyeConfig, moshi_weight: Optional[str] = None, device: str | torch.device = "cpu", dtype: torch.dtype = torch.bfloat16, gen_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[MoshiVisGen, ImageProjection]: """Return main Moshi model""" image_proj_state: Dict[str, torch.Tensor] = {} model_state: Dict[str, torch.Tensor] = {} if moshi_weight is not None: from safetensors.torch import load_file for key, v in load_file(moshi_weight, device=device).items(): # type: ignore if key.startswith("image_prefix."): image_proj_state[key[13:]] = v else: model_state[key] = v moshi_vis = MoshiVisGen.from_config( kyuteye_config, model_state, device, dtype, **(gen_kwargs or {}) ) image_embedder = ImageProjection.from_config( kyuteye_config, moshi_vis.model_dim, image_proj_state, device ) return moshi_vis.to(dtype), image_embedder.to(dtype) ================================================ FILE: kyuteye_pt/kyuteye/models/moshivis.py ================================================ """Moshi the little AI""" from functools import partial from typing import Any, Dict, List, Literal, Optional, Tuple import torch from kyuteye.config.kyuteye_config import KyuteyeConfig from kyuteye.models.helium import Helium from kyuteye.modules.streaming_utils import StreamingModule from kyuteye.modules.transformer import Transformer from kyuteye.modules.utils import ClampedEmbedding from moshi.utils.sampling import sample_token class MoshiVis(StreamingModule): """Moshi model derived from Audiocraft with extra stuff for vision conditioninign""" # Class attributes; extra special tokens end_of_text_padding_id = 0 zero_token_id = -1 ungenerated_token_id = -2 def __init__( self, hidden_scale: float = 4.125, norm: str = "real_rms_norm_f32", gating: bool = True, activation: str = "silu", n_q: int = 8, dep_q: Optional[int] = None, audio_card: int = 1024, audio_context: Optional[int] = None, depformer: bool = False, depformer_multi_linear: bool = False, depformer_pos_emb: Optional[Literal["none", "sin", "rope", "sin_rope"]] = None, depformer_dim: Optional[int] = None, depformer_dim_feedforward: Optional[int] = None, depformer_num_layers: Optional[int] = None, depformer_num_heads: Optional[int] = None, depformer_weights_per_step: bool = False, depformer_context: Optional[int] = 8, depformer_gating: Optional[bool] = None, depformer_activation: Optional[str] = None, delays: Optional[List[int]] = None, text_card: int = 32000, text_context: Optional[int] = None, padding_token_id: int = 3, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs: Any, ) -> None: """Initialize a MoshiVis model""" super().__init__() # Set parameter for generation/preprocessing self.text_card = text_card self.audio_card = audio_card self.text_context = text_context self.text_padding_token_id = padding_token_id self.audio_context = audio_context self.n_q = n_q self.dep_q = dep_q or self.n_q assert delays is not None and len(delays) > 0, "Delays must be non empty" assert len(delays) <= self.num_codebooks, "Too many delays" if len(delays) < self.num_codebooks: delays = delays + [delays[-1]] * (self.num_codebooks - len(delays)) self.delays = delays embeddings_factory = partial( ClampedEmbedding, device=device, dtype=dtype, zero_idx=self.zero_token_id ) # LLM backbone (includes text embedding + text linear projection) self.llm = Helium( hidden_scale=hidden_scale, card=text_card + 1, # Add an initial token in the embedding but not in the text linear output_card=text_card + int(padding_token_id is None), padding_token_id=padding_token_id, device=device, dtype=dtype, zero_token_id=self.zero_token_id, **kwargs, ) # Audio input embeddings self.audio_emb = torch.nn.ModuleList( [ embeddings_factory(audio_card + 1, self.llm.dim) for _ in range(self.num_audio_codebooks_in) ] ) # Depformer self.depformer: Optional[torch.nn.Module] = None self.depformer_multi_linear = depformer_multi_linear if depformer: assert depformer_dim is not None assert depformer_num_heads is not None assert depformer_num_layers is not None assert depformer_pos_emb is not None if depformer_dim_feedforward is None: depformer_dim_feedforward = int(hidden_scale * depformer_dim) assert depformer_dim_feedforward is not None self.depformer_in = torch.nn.ModuleList( [ torch.nn.Linear(self.llm.dim, depformer_dim, bias=False) for _ in range( self.num_audio_codebooks_out if depformer_multi_linear else 1 ) ] ) # Text and audio input embeddings for the depformer self.depformer_emb = torch.nn.ModuleList( [ embeddings_factory(audio_card + 1, depformer_dim) for _ in range(self.num_audio_codebooks_out - 1) ] ) self.depformer_text_emb = embeddings_factory(text_card + 1, depformer_dim) self.depformer = Transformer( d_model=depformer_dim, dim_feedforward=depformer_dim_feedforward, positional_embedding=depformer_pos_emb, num_heads=depformer_num_heads, num_layers=depformer_num_layers, norm=norm, device=device, dtype=dtype, causal=True, cross_attention=False, context=depformer_context, gating=depformer_gating or gating, activation=depformer_activation or activation, weights_per_step=dep_q if depformer_weights_per_step else None, ) # Output projection self.audio_linears = torch.nn.ModuleList( [ torch.nn.Linear(depformer_dim, audio_card, bias=False) for _ in range(self.num_audio_codebooks_out) ] ) @property def cross_attention(self) -> bool: """Shortcut for checking whether cross_attention i sused""" return self.llm.cross_attention @property def num_audio_codebooks_in(self) -> int: """Number of audio codebooks to model as input""" return self.n_q @property def num_audio_codebooks_out(self) -> int: """Number of audio codebooks to model in the depformer""" return self.dep_q @property def num_codebooks(self) -> int: """Number codebooks including text""" return self.num_audio_codebooks_in + 1 @property def initial_audio_token_id(self) -> int: """Initial token for the audio codebooks""" return self.audio_card @property def initial_text_token_id(self) -> int: """Initial token for the text; takes into account the "fake/proxy" tokens for beginning and end of image if they have been set""" return self.text_card @property def audio_offset(self) -> int: """Offset in the audio codebook. Returns 1 because we always generate with text""" return 1 def forward_text( self, input_ids: torch.Tensor, cross_attention_src: Optional[ Tuple[torch.Tensor, torch.Tensor] | torch.Tensor ] = None, cross_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, float]: """Forward pass for Moshi :param input_ids: Text + audio tokens of shape (batch, codebooks, seq length) :param cross_attention_src: Conditioning (image) tokens that can be cross-attended to through the cross attention module :param cross_attention_mask: Additional mask for the cross_attention_src. This is necessary mainly for Pixtral models, as the cross-attended images might be of different sizes and therefore padded. :param attention_mask: Optional attention mask on input_ids (e.g. used at generation for batched inference with left padding) :return: A tuple containing the * text logits (None if `text_or_audio` is audio) * audio logits (None if `text_or_audio` is text) """ # Embed tokens inputs_embeds = torch.zeros((), device=input_ids.device) if self.audio_offset > 0: inputs_embeds = self.llm.text_emb(input_ids[:, 0, :]) for cb_index in range(self.num_audio_codebooks_in): update = self.audio_emb[cb_index]( input_ids[:, cb_index + self.audio_offset, :] ) inputs_embeds += update # Pass through Helium transformer_out, gate_weight = self.llm( inputs_embeds=inputs_embeds, cross_attention_src=cross_attention_src, cross_attention_mask=cross_attention_mask, attention_mask=attention_mask, return_features=True, ) # Output proj text_logits = self.llm.text_linear(transformer_out)[:, None] return transformer_out, text_logits, gate_weight def forward_depformer( self, depformer_cb_index: int, input_ids: torch.Tensor, depformer_input: torch.Tensor, ) -> torch.Tensor: """Forward one depformer step""" _, num_codes, seq_len = input_ids.shape assert self.depformer is not None assert ( num_codes == 1 ), f"Codebooks for Depformer streaming should be passed 1 by 1, got {num_codes}." assert ( seq_len == 1 ), f"Steps for Depformer streaming should be passed 1 by 1, got {seq_len}." assert ( depformer_input.shape[1] == 1 ), "Transformer output should be a for a single step." # project transformer out depformer_input = self.depformer_in[ depformer_cb_index if self.depformer_multi_linear else 0 ](depformer_input) # project input ids if depformer_cb_index == 0: depformer_input += self.depformer_text_emb(input_ids[:, 0]) else: depformer_input += self.depformer_emb[depformer_cb_index - 1]( input_ids[:, 0] ) # depformer_input is [B, 1, depformer_dim]. # The streaming state of the depformer ensures that the proper layer is run. dep_output, _ = self.depformer(depformer_input) logits = self.audio_linears[depformer_cb_index](dep_output) logits = logits[:, None] assert logits.dim() == 4, logits.shape # [B, Ka, S, card] return logits @property def device(self) -> torch.device: """Torch device""" return next(iter(self.parameters())).device def get_initial_token(self) -> torch.Tensor: """Returns the initial token that will be fed to the model to predict the very first timestep. This is akin to a beginning of sentence tokens but to handle potentially delayed codebooks :param text_or_audio: Whether we are predicting for text, audio, or both :return: A Tensor fo shape (B, K, 1) """ zero = torch.full( [1, 1, 1], MoshiVis.zero_token_id, device=self.device, dtype=torch.long ) audio_token = torch.full_like( zero, self.initial_audio_token_id or MoshiVis.zero_token_id ) text_token = torch.full_like( zero, self.initial_text_token_id or MoshiVis.zero_token_id ) return torch.cat( [text_token, audio_token.expand(-1, self.num_audio_codebooks_in, -1)], dim=1 ) class MoshiVisGen(StreamingModule): """MoshiVis for autoregressive generation at inference""" def __init__( self, moshi_vis: MoshiVis, use_sampling: bool = True, temp: float = 0.8, temp_text: float = 0.7, top_k: int = 250, top_k_text: int = 25, check: bool = False, ): assert not moshi_vis.training, "generation shouldn't be used in training mode." super().__init__() self.lm_model = moshi_vis self.use_sampling = use_sampling self.temp = temp self.temp_text = temp_text self.top_k = top_k self.top_k_text = top_k_text self.check = check self.max_delay = max( moshi_vis.delays ) # with delays, we need to generate a few more time steps. self.delays_cuda = torch.tensor( moshi_vis.delays, device=self.lm_model.device, dtype=torch.long ) self.initial_token = self.lm_model.get_initial_token() def update_gen_kwargs( self, temp: Optional[float] = None, temp_text: Optional[float] = None, top_k: Optional[int] = None, top_k_text: Optional[int] = None, ) -> None: """update params for sampling during generation""" self.temp = temp or self.temp self.temp_text = temp_text or self.temp_text self.top_k = top_k or self.top_k self.top_k_text = top_k_text or self.top_k_text @property def model_dim(self) -> int: """Return dimension of the tokens in the model""" return self.lm_model.llm.dim @property def num_audio_codebooks_out(self) -> int: """Number of audio codebooks generated by the model""" return self.lm_model.num_audio_codebooks_out @classmethod def from_config( cls, kyuteye_config: KyuteyeConfig, moshi_weight: Optional[Dict[str, Any]] = None, device: str | torch.device = "cpu", dtype: torch.dtype = torch.bfloat16, **gen_kwargs: Any, ) -> "MoshiVisGen": """Instantiate model from a config :param base config: :param moshi_weight """ moshivis = MoshiVis(**kyuteye_config.moshi_constructor_kwargs, dtype=dtype) if moshi_weight is not None: missing_keys, _ = moshivis.load_state_dict(moshi_weight, strict=False) # cross-attention MHSA is shared across layers missing_keys = [ k for k in missing_keys if ("cross_attention.mha" not in k or "layers.0" in k) ] assert len(missing_keys) == 0 return MoshiVisGen(moshi_vis=moshivis.eval().to(device), **gen_kwargs) @torch.no_grad() def precompte_ca_kv( self, embeddings: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Precompte kv proj for cross-attention""" ca_layer = self.lm_model.llm.transformer.layers[0].cross_attention.mha if hasattr(ca_layer, "in_proj_weight_kv"): splits = torch.chunk(ca_layer.in_proj_weight_kv, 2) else: splits = torch.chunk(ca_layer.in_proj_weight, 3) k = torch.nn.functional.linear( # pylint: disable=not-callable embeddings, splits[-2] ) v = torch.nn.functional.linear( # pylint: disable=not-callable embeddings, splits[-1] # type: ignore ) return k, v @torch.no_grad() def step( self, input_tokens: torch.Tensor, ca_src: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None, ) -> Tuple[torch.Tensor | None, float]: """One step of generation""" state = self._streaming_state if state is None: raise RuntimeError( "You should wrap those calls with a `with lm_gen.streaming(): ...`." ) lm_model = self.lm_model assert input_tokens.dim() == 3, "Shape should be [B, K, T]." batch_size, num_codes, seq_len = input_tokens.shape assert seq_len == 1, "Only support being given steps one by one." needed_tokens = lm_model.num_codebooks - lm_model.num_audio_codebooks_out - 1 assert ( num_codes == needed_tokens ), f"We expect {needed_tokens} tokens from the user stream, got {num_codes}." current_input_cache = self.get_streaming_attribute( "cache", torch.full( (batch_size, self.lm_model.num_codebooks, self.max_delay + 2), self.lm_model.ungenerated_token_id, device=self.lm_model.device, dtype=torch.long, ), ) current_offset = self.get_streaming_attribute("offset", 0) dcache_len = current_input_cache.shape[2] # write input_tokens (sent from Mimi) in OTHER codebooks for q_other in range(input_tokens.shape[1]): k = lm_model.num_audio_codebooks_out + lm_model.audio_offset + q_other write_position = (current_offset + lm_model.delays[k]) % dcache_len current_input_cache[:, k, write_position : write_position + 1] = ( input_tokens[:, q_other] ) # Only for the very beginning, we extend the initial token for the acoustic # token that are delayed, and thus have no good value to take. position = current_offset % dcache_len for k, delay in enumerate(lm_model.delays): if current_offset <= delay: current_input_cache[:, k, position] = self.initial_token[:, k, 0] # Transformer forward input_ = current_input_cache[:, :, position : position + 1] if self.check: # Check that we are not feeding in any value that is not generated yet. assert not (input_ == lm_model.ungenerated_token_id).any(), ( current_offset, input_, ) assert ( input_[:, lm_model.audio_offset :] <= lm_model.audio_card ).all(), input_ assert (input_[:, :1] <= lm_model.text_card).all() transformer_out, text_logits, gate_weight = self.lm_model.forward_text( input_, cross_attention_src=ca_src ) # Sample text tokens # Shape of text_logits should be [B, K_text=1, T=1, Card_text] text_token = sample_token( text_logits.float(), self.use_sampling, self.temp_text, self.top_k_text, ) assert text_token.dim() == 3, text_token.shape assert text_token.shape[2] == 1 assert text_token.shape[1] == 1, "Only one text stream supported." text_token = text_token[:, 0, 0] # shape is [B] # Generate and sample audio tokens audio_tokens = self.depformer_step(text_token, transformer_out) # Write generated tokens current_offset += 1 position = current_offset % dcache_len current_input_cache[:, 0, position] = text_token current_input_cache[ :, lm_model.audio_offset : lm_model.num_audio_codebooks_out + lm_model.audio_offset, position, ] = audio_tokens # if <= max_delay, we continue partial-generation # until removing all ungenerated tokens if current_offset <= self.max_delay: self.add_streaming_attribute("cache", current_input_cache) self.add_streaming_attribute("offset", current_offset) return None, 0.0 # otherwise, retrieve tokens with the correct delay gen_delays_cuda = self.delays_cuda[ : lm_model.num_audio_codebooks_out + lm_model.audio_offset ] index = ( ((current_offset - self.max_delay + gen_delays_cuda) % dcache_len) .view(1, -1, 1) .expand(current_input_cache.shape[0], -1, 1) ) out = current_input_cache.gather(dim=2, index=index) self.add_streaming_attribute("offset", current_offset) self.add_streaming_attribute("cache", current_input_cache) return out, gate_weight def depformer_step( self, text_token: torch.Tensor, transformer_out: torch.Tensor, ) -> torch.Tensor: """A step of the depformer""" batch_size = text_token.shape[0] depformer_tokens: list[torch.Tensor] = [] assert self.lm_model.depformer is not None with self.lm_model.depformer.streaming(): next_token = text_token[:, None, None] for cb_index in range(self.lm_model.num_audio_codebooks_out): logits = self.lm_model.forward_depformer( cb_index, next_token, transformer_out ) next_token = sample_token( logits.float(), self.use_sampling, self.temp, self.top_k, ) assert next_token.shape == (batch_size, 1, 1) depformer_tokens.append(next_token[:, 0, 0]) out = torch.stack(depformer_tokens, dim=1) return out ================================================ FILE: kyuteye_pt/kyuteye/modules/__init__.py ================================================ ================================================ FILE: kyuteye_pt/kyuteye/modules/attention.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """MultiHead Self-Attention module with optional KV Caching""" import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple import torch from einops import rearrange from kyuteye.modules.streaming_utils import StreamingModule from kyuteye.modules.utils import RotaryEmbedding, multi_linear @dataclass class KVCache: """Efficient streaming KVCache to avoid allocating new memory too many times. :param batch_size: Batch size. :param num_heads: Number of heads in the attention. :param dim_per_head: Dimension per head. :param context: Context size for the attention, if None, will grow exponentially, otherwise will use a fixed allocation with a bit of overhead. :param growth: Growth factor for the exponential growth, fraction of overhead when context is not None. :param initial_size: Initial size of the cache, used only when context is None. :param device: Device on which to initialize the cache. :param dtype: dtype to use for the cache. :param cache: Initial cache, if provided. :param current_end: Current end of the cache, used only when cache is provided. """ def __init__( self, batch_size: int, num_heads: int, dim_per_head: int, context: Optional[int] = None, growth: float = 1.2, initial_size: int = 100, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.bfloat16, cache: Optional[torch.Tensor] = None, current_end: int = 0, ) -> None: if cache is None: assert current_end == 0 assert growth > 1 self.growth = growth if context is not None: initial_size = 1 + int(growth * context) self.capacity = initial_size self.context = context self.current_end = current_end if cache is None: self._cache = torch.full( (2, batch_size, initial_size, num_heads, dim_per_head), float("NaN"), device=device, dtype=dtype, ) else: self._cache = cache def clone(self) -> "KVCache": """Return a separate memory copy of the KV cache""" return KVCache( self._cache.shape[1], self._cache.shape[3], self._cache.shape[4], self.context, self.growth, self.capacity, self._cache.device, self._cache.dtype, self._cache.clone(), self.current_end, ) @property def current_start(self) -> int: """Current start of the KV cache (0 if no context size)""" return 0 if self.context is None else max(self.current_end - self.context, 0) def __maybe_increase_capacity__(self, required_capacity: int) -> None: """If needed, increase capacity to the `required_capacity` using exponential growth strategy""" if required_capacity > self.capacity: if self.context is None: # We take an exponential growth approach. new_capacity = self.capacity while required_capacity > new_capacity: new_capacity = int(math.ceil(new_capacity * self.growth)) new_shape = list(self._cache.shape) new_shape[2] = new_capacity new_cache = torch.full( tuple(new_shape), float("NaN"), device=self._cache.device, dtype=self._cache.dtype, ) new_cache[:, :, : self.current_end] = self._cache[ :, :, : self.current_end ] self._cache = new_cache self.capacity = new_capacity else: # With context, we just have to roll the predict to the left and # use the new space on the right. assert self.current_start > 0 self._cache[:] = self._cache.roll(-self.current_start, dims=2) self.current_end -= self.current_start def complete( self, k: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Add keys `k` and values `v` to the current cache and returns cache up to the context size""" assert k.shape[1] == v.shape[1] self.__maybe_increase_capacity__(self.current_end + k.shape[1]) assert self.current_end + k.shape[1] <= self.capacity, ( self.current_end, k.shape[1], self.capacity, ) self._cache[0, :, self.current_end : self.current_end + k.shape[1]] = k self._cache[1, :, self.current_end : self.current_end + v.shape[1]] = v self.current_end += k.shape[1] valid = self._cache[:, :, self.current_start : self.current_end] return valid[0], valid[1] class MultiheadAttention(StreamingModule): """Similar to `nn.MultiheadAttention` but with support for causal evaluation. Args: :param embed_dim: Dimension to project to. :param num_heads: Number of heads. :param causal: If true, applies causal mask automatically. :param context: Number of time steps the attention can access to. When causal, can access `context` time steps into the past, and when non causal, can access `context // 2` steps in the past, and the same in the future. :param rope: Rope embedding to use. If None, no rope embedding is applied :param cross_attention: Should be true when used as a cross attention. Cannot be used with `causal` or `rope` (as it wouldn't make sens to interpret the time steps in the keys relative to those in the queries). :param use_kv_cache: If True, enables a KV cache with context size `context`. :param weights_per_step: use different weights per depformer step. If non zero, should correspond to the number of possible time steps. :param device: Device on which to initialize the module. :param dtype: dtype to use. """ def __init__( self, embed_dim: int, num_heads: int, causal: bool = False, context: Optional[int] = None, rope: Optional[RotaryEmbedding] = None, cross_attention: bool = False, use_kv_cache: bool = False, weights_per_step: int = 0, xa_dim: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.embed_dim = embed_dim self.causal = causal self.context = context self.rope = rope self.cross_attention = cross_attention self.num_heads = num_heads self.use_kv_cache = use_kv_cache self.weights_per_step = weights_per_step mult = max(1, weights_per_step) if cross_attention: assert not causal, "Cannot set causal mask when `cross attention` is True." assert ( not context ), "Cannot set context size when `cross attention` is True." # if cross-attention source have != num_dims than the speech tokens, # we need to separate the KV and Q embeddings if cross_attention and xa_dim is not None and xa_dim != embed_dim: in_proj_q = torch.nn.Linear( embed_dim, mult * embed_dim, bias=False, **factory_kwargs ) in_proj_kv = torch.nn.Linear( xa_dim, mult * 2 * embed_dim, bias=False, **factory_kwargs ) self.in_proj_weight_q = in_proj_q.weight self.in_proj_bias_q = in_proj_q.bias self.in_proj_weight_kv = in_proj_kv.weight self.in_proj_bias_kv = in_proj_kv.bias self.in_proj_weight = None self.in_proj_bias = None else: in_proj = torch.nn.Linear( embed_dim, mult * 3 * embed_dim, bias=False, **factory_kwargs ) self.in_proj_weight = in_proj.weight self.in_proj_bias = in_proj.bias self.out_proj = torch.nn.Linear( embed_dim, mult * embed_dim, bias=False, **factory_kwargs ) def _complete_kv( self, k: torch.Tensor, v: torch.Tensor, initial_kv_cache_size: int = 256 ) -> Tuple[torch.Tensor, torch.Tensor]: """Add key/values to the KV cache""" # With cross attention we assume all keys and values # are already available, and streaming is with respect # to the queries only. if self._is_streaming and not self.cross_attention: if "kv_cache" not in self._streaming_state: self._streaming_state["kv_cache"] = KVCache( # type: ignore k.shape[0], k.shape[2], k.shape[3], self.context, initial_size=self.weights_per_step or initial_kv_cache_size, device=k.device, dtype=k.dtype, ) self.streaming_offset = torch.zeros(1) # type: ignore kv_cache: KVCache = self._streaming_state["kv_cache"] # type: ignore self.streaming_offset += k.shape[1] return kv_cache.complete(k, v) return k, v def forward( self, query: torch.Tensor, key: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None, value: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """If self.cross attention is False, we only expects the first input. Otherwise, when using cross attention, we need to explicitly give the source for the respective query/key/value embeddings""" # Get current streaming offset before it gets potentially modified by the KV cache update current_streaming_offset = self.streaming_offset if self.cross_attention: assert key is not None, "Missing inputs in cross attention" if isinstance(key, torch.Tensor): value = value or key assert value is not None, "Missing inputs in cross attention" # Case 1: Inputs x and ca_src have the same number of dimension # We have a single big weight for the QKV projections if self.in_proj_weight is not None: q = torch.nn.functional.linear( # pylint: disable=not-callable query, self.in_proj_weight[: self.embed_dim] ) if isinstance(key, torch.Tensor): k = torch.nn.functional.linear( # pylint: disable=not-callable key, self.in_proj_weight[self.embed_dim : 2 * self.embed_dim] ) v = torch.nn.functional.linear( # pylint: disable=not-callable value, self.in_proj_weight[2 * self.embed_dim :] # type: ignore ) else: k, v = key # Case 2: Inputs x and ca_src have different number of dimension # We have to separate the Q and KV proj else: q = torch.nn.functional.linear( # pylint: disable=not-callable query, self.in_proj_weight_q[: self.embed_dim] ) if isinstance(key, torch.Tensor): k = torch.nn.functional.linear( # pylint: disable=not-callable key, self.in_proj_weight_kv[: self.embed_dim] ) v = torch.nn.functional.linear( # pylint: disable=not-callable value, self.in_proj_weight_kv[self.embed_dim :] # type: ignore ) else: k, v = key q, k, v = [ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v] ] else: assert self.in_proj_weight is not None if self.weights_per_step > 0: projected = multi_linear( self.weights_per_step, self.in_proj_weight, query, offset=current_streaming_offset, ) else: projected = torch.nn.functional.linear( # pylint: disable=not-callable query, self.in_proj_weight ) packed = rearrange( projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads ) q, k, v = torch.unbind(packed, dim=2) if self.rope: q, k = self.rope(q, k, offset=current_streaming_offset) k, v = self._complete_kv(k, v) # Attention q, k, v = [x.transpose(1, 2) for x in [q, k, v]] x = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable q, k, v, is_causal=False, attn_mask=attention_mask ) x = x.transpose(1, 2) # output projection x = rearrange(x, "b t h d -> b t (h d)") if self.weights_per_step > 0: x = multi_linear( self.weights_per_step, self.out_proj.weight, x, offset=current_streaming_offset, ) else: x = self.out_proj(x) return x ================================================ FILE: kyuteye_pt/kyuteye/modules/cross_attention.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Gated cross-attention module""" from typing import Any, Callable, Literal, Optional, Tuple, Union import torch from kyuteye.modules.attention import MultiheadAttention from kyuteye.modules.streaming_utils import StreamingModule class SharedModuleType(type): """Wrapper to build shared Pytorch modules""" _instances = {} # type: ignore def __call__(cls, *args: Any, **kwargs: Any) -> Any: if cls not in cls._instances: cls._instances[cls] = super(SharedModuleType, cls).__call__(*args, **kwargs) return cls._instances[cls] class XAGate(torch.nn.Module): """Learned multiplicative gating per layer""" def __init__( self, device: Optional[Union[str, torch.device]] = None, activation: Literal["tanh", "sigmoid"] = "tanh", conditional_gating: bool = False, dims: Optional[int] = None, hidden_dims_factor: float = 0.125, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.alpha: torch.nn.Parameter | torch.nn.Module self.conditional_gating = conditional_gating if self.conditional_gating: assert dims is not None hidden_dims = int(hidden_dims_factor * dims) self.alpha = torch.nn.Sequential( torch.nn.Linear(dims, hidden_dims, bias=False), torch.nn.ReLU(), torch.nn.Linear(hidden_dims, dims, bias=False), ) else: self.alpha = torch.nn.Parameter( torch.full((1, 1, 1), 0.0, device=device, dtype=dtype) ) self.act: Callable if activation == "tanh": self.act = torch.tanh elif activation == "sigmoid": # shift left to mimic initialization ~ close to 0 self.act = lambda x: torch.sigmoid(x - 4) else: raise NotImplementedError("Unknown activation function", activation) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Gating (constant scaling or input-dependent)""" if isinstance(self.alpha, torch.nn.Parameter): gate_weight = self.act(self.alpha) else: gate_weight = self.act(self.alpha(x)) return x * gate_weight, gate_weight class SharedXaGate(XAGate, metaclass=SharedModuleType): """Shared XaGate""" pass # pylint: disable=unnecessary-pass class CrossAttention(MultiheadAttention): """Cross attention module""" def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["cross_attention"] = True super().__init__(*args, **kwargs) class SharedCrossAttention(CrossAttention, metaclass=SharedModuleType): """Shared Cross Attention projection across all layers""" pass # pylint: disable=unnecessary-pass class GatedCrossAttention(StreamingModule): """Gated Cross Attention with or without sharing parameters across layers""" def __init__( self, embed_dim: int, xa_gating: Literal["tanh", "sigmoid", "none"] = "tanh", xa_conditional_gating: bool = False, xa_shared: bool = True, xa_gate_shared: bool = False, xa_delay: int = 0, xa_start: Literal["start", "boi", "eoi"] = "start", xa_end: Literal["end", "eoi"] | int = "end", xa_step: int = 1, xa_dim: Optional[int] = None, **attn_kwargs: Any, ) -> None: """ Initializes a Gated CrossAttention module. :param xa_gating: Whether (and which type of) to add multiplicative gating at the output of the cross-attention layer :param xa_shared: Whether to share the projection parameters of this corss-attention module across all layers tanh_gate (bool, optional): Whether to apply tanh activation to the gate. Defaults to True. share_tanh_gate (bool, optional): Whether to share the tanh gate parameters across different attention heads. Defaults to True. shared_parameters (bool, optional): Whether to share the attention parameters across different attention heads. Defaults to True. shift (int, optional): The shift value for the cross attention, i.e., whether the update is to be based on past queries. Defaults to 0, i.e., no shifting. xa_scope (Literal["images", "all_after_first_boi"], optional): The scope of the attention. Can be "images" or "all_after_first_boi". Defaults to "images". **attn_kwargs: Additional keyword arguments to be passed to the CrossAttention or SharedCrossAttention class. """ super().__init__() device = attn_kwargs.get("device", None) dtype = attn_kwargs.get("dtype", None) # Attention module self.mha = (SharedCrossAttention if xa_shared else CrossAttention)( embed_dim=embed_dim, xa_dim=xa_dim, **attn_kwargs ) # Output Gating self.gate: Optional[torch.nn.Module] = None if xa_gating != "none": self.gate = (SharedXaGate if xa_gate_shared else XAGate)( activation=xa_gating, device=device, dtype=dtype, dims=embed_dim, conditional_gating=xa_conditional_gating, ) # If the XA module AND gates are shared, we add a per-layer # coefficient to still have some modularity self.per_layer_alpha: Optional[torch.nn.Parameter] = None if xa_shared and xa_gate_shared: self.per_layer_alpha = torch.nn.Parameter( torch.full((1, 1, 1), 1.0, device=device, dtype=dtype) ) # Determine the xa scope self.xa_start = xa_start self.xa_end = xa_end self.xa_step = xa_step self.xa_delay = xa_delay self._active = True def get_xa_scope( self, x: torch.Tensor, image_tokens_mask: Optional[torch.Tensor] ) -> torch.Tensor: """Build the mask of which tokens should receive contribution from the cross-attention image tokens""" if self.is_streaming and isinstance(self.xa_start, int): return int(self.streaming_offset >= self.xa_start) if not ( self.xa_start in {"start", "boi", "eoi"} or isinstance(self.xa_start, int) ) or not (self.xa_end in {"end", "eoi"} or isinstance(self.xa_end, int)): raise NotImplementedError( f"Unsupported XA scope : {self.xa_start}, {self.xa_end}" ) if self.xa_start == "start": # scope = 'all' if self.xa_end == "end": mask = torch.ones_like(x[:, :, :1]) # scope = 'start + relative'abs elif isinstance(self.xa_end, int): mask = torch.zeros_like(x[:, :, :1]) mask[:, : self.xa_end, :] = 1 # everything else is kinda weird (0, BoI) ? (0, EoI) ? else: raise NotImplementedError( f"Unsupported XA scope : {self.xa_start}, {self.xa_end}" ) elif isinstance(self.xa_start, int): if isinstance(self.xa_end, int): mask = torch.zeros_like(x[:, :, :1]) # self.xa_end is relative to self.xa_start mask[:, self.xa_start : self.xa_start + self.xa_end, :] = 1 elif self.xa_end == "end": # If for some reason the start is further than the end, we do not attend mask = torch.zeros_like(x[:, :, :1]) if not self.xa_start > mask.shape[1]: mask[:, self.xa_start :, :] = 1 else: raise NotImplementedError( f"Unsupported XA scope : {self.xa_start}, {self.xa_end}" ) else: assert image_tokens_mask is not None # another easy case is attention only to the image tokens if self.xa_start == "boi" and self.xa_end == "eoi": mask = image_tokens_mask # Otherwise, we build the mask manually # first, determine the start, either EoI or BoI elif self.xa_start == "boi": # everything that comes after boi # e.g. (0 0 0 1 1 1 1 0 0 0 ) becomes # (0 0 0 1 2 3 4 4 4 4 ) mask = torch.cumsum(image_tokens_mask, dim=1) elif self.xa_start == "eoi": # everything that comes after eoi # e.g. (0 0 0 1 1 1 1 0 0 0 ) becomes # (0 0 0 0 0 0 0 4 4 4 ) mask = (1 - image_tokens_mask) * torch.cumsum(image_tokens_mask, dim=1) else: raise NotImplementedError( f"Unsupported XA scope starting at {self.xa_start}" ) # then determine the end of the mask mask = torch.greater(mask, 0).to(x.dtype) if isinstance(self.xa_end, int): mask = torch.cumsum(mask, dim=1) mask = torch.lt(mask, self.xa_end + 1).to(x.dtype) # then apply xa_step if self.xa_step > 1: if self.xa_start not in {"start", 0}: raise NotImplementedError("xa_step") step_mask = torch.eq( torch.remainder(torch.arange(x.shape[1]), self.xa_step), 0 ).to(mask) mask *= step_mask[None, :, None].float() return mask def is_active(self, image_tokens_mask: Optional[torch.Tensor] = None) -> bool: """Whether this model is active during the forward pass""" if self.is_streaming: # case 1: never stop if self.xa_end == "end": return self.xa_start == "start" or ( isinstance(self.xa_start, int) and self.streaming_offset >= self.xa_start ) # case 2: XA applies to the image; we only apply cross-attention # in the step where the image is inserted if self.xa_end == "eoi" and image_tokens_mask is None: return False # Case 3: XA end is relative to XA start if isinstance(self.xa_end, int): if self.xa_start == "start": offset = 0 elif isinstance(self.xa_start, int): offset = self.xa_start elif self.xa_start == "boi": if self.has_streaming_attribute("image_insert_start"): offset = self.get_streaming_info_as_int("image_insert_start") else: return False elif self.xa_start == "eoi": if self.has_streaming_attribute("image_insert_end"): offset = self.get_streaming_info_as_int("image_insert_end") else: return False else: raise ValueError("Unsupported xa_start option", self.xa_start) # if xa_step is active if self.xa_step > 1: return ( offset <= self.streaming_offset < offset + self.xa_end ) and ((self.streaming_offset - offset) % self.xa_step == 0) # base case return offset <= self.streaming_offset < offset + self.xa_end # In training, we are always active and just build the xa scope mask return True def forward( self, x: torch.Tensor, key: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None, value: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, image_tokens_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Gated Cross attention :param x: Input query tensor :param key: Source tokens for the keys :param value: Source tokens for the values (typically, equal to keys) :param cross_attention_mask: Mask to apply to the cross-attention :param image_tokens_mask: Mask indicating where the image tokens are in the stream. This is used to determine which token should cross-attend to the image """ gate_weight = None if self.is_streaming: if not self.has_streaming_attribute("offset"): self.streaming_offset = 0 # Mark the last inserted image's position in streaming mode if image_tokens_mask is not None: self.add_streaming_attribute( "image_insert_start", self.streaming_offset ) image_lengths = torch.sum(image_tokens_mask[:, :, 0], dim=1) assert torch.all( torch.eq(image_lengths, int(image_lengths[0].item())) ), "All inserted images must have the same number of tokens" self.add_streaming_attribute( "image_insert_end", self.streaming_offset + int(image_lengths[0].item()), ) if not self.is_active(image_tokens_mask=image_tokens_mask): x = torch.zeros_like(x) else: x = self.mha( query=x, key=key, value=value, attention_mask=cross_attention_mask ) if self.gate is not None: x, gate_weight = self.gate(x) if self.per_layer_alpha is not None: x *= self.per_layer_alpha # Mask out tokens that should not receive signal from the image x = x * self.get_xa_scope(x, image_tokens_mask=image_tokens_mask) # Update streaming offset if self.is_streaming: self.streaming_offset += x.shape[1] return x, gate_weight ================================================ FILE: kyuteye_pt/kyuteye/modules/image_encoder.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Pretrained image encoders from timm and/or transformers""" from dataclasses import dataclass from typing import Callable, List, Optional, Tuple import torch from kyuteye.config.enums import ImageEncoder from kyuteye.modules.image_transforms import ( Normalize, PixtralNormalize, SigLIPNormalize, ) from transformers import ( AutoConfig, AutoModelForImageTextToText, LlavaForConditionalGeneration, SiglipVisionConfig, SiglipVisionModel, ) class TrimmedFlexiViTWrapper(torch.nn.Module): """ViT module without the classification tower""" def __init__( self, model: torch.nn.Module, interpolate_pos_encoding: bool = False ) -> None: super().__init__() self.interpolate_pos_encoding = interpolate_pos_encoding self.model = model def forward(self, x: torch.Tensor) -> torch.Tensor: """Get last hidden states""" return self.model( x, interpolate_pos_encoding=self.interpolate_pos_encoding ).last_hidden_state def load_paligemma_vision_encoder( name: str, device: torch.device | str = "cpu", pretrained: bool = False ) -> torch.nn.Module: """Load Paligemma encoder from the shared HuggingFace cache""" if pretrained: model = AutoModelForImageTextToText.from_pretrained( name ).vision_tower.vision_model else: image_size = int(name.rsplit("-", 1)[-1]) model = SiglipVisionModel( SiglipVisionConfig( **{ "hidden_size": 1152, "image_size": image_size, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "num_image_tokens": (image_size // 14) ** 2, "num_positions": 256, "patch_size": 14, "projection_dim": 2304, "torch_dtype": "bfloat16", "vision_use_head": False, } ) ).vision_model return TrimmedFlexiViTWrapper(model.to(device), interpolate_pos_encoding=True) @dataclass class PixtralOutput: """Pixtral Output""" out: torch.Tensor mask: torch.Tensor class PixtralWrapper(torch.nn.Module): """Pixtral encoder returning penultimate features""" def __init__( self, device: Optional[torch.device | str] = None, pretrained: bool = False ) -> None: super().__init__() if pretrained: self.model = AutoModelForImageTextToText.from_pretrained( "mistral-community/pixtral-12b" ).vision_tower.to(device) else: config = AutoConfig.from_pretrained("mistral-community/pixtral-12b") self.model = LlavaForConditionalGeneration(config).vision_tower.to(device) self.patch_size = torch.prod( torch.tensor(self.model.patch_conv.weight.shape[-2:]) # type: ignore ) def __get_num_output_tokens__(self, x: List[torch.Tensor]) -> List[int]: """Get number of tokens for each image in the list""" return [ (torch.prod(torch.tensor(img[1].shape[-2:])) // self.patch_size).item() for img in x ] @staticmethod def split_and_pad_output( x: torch.Tensor, split_points: List[int] ) -> Tuple[torch.Tensor, torch.Tensor]: """Split the output of the model into a list of tensors corresponding to the input images and create a mask tensor""" splits = list(torch.split(x, split_points, dim=1)) assert sum(split.shape[1] for split in splits) == x.shape[1] # Pad the splits to have the same size along the sequence dimension max_len = max(split.shape[1] for split in splits) padded_splits = [] mask = torch.zeros((len(splits), max_len), dtype=torch.bool, device=x.device) for i, split in enumerate(splits): pad_len = max_len - split.shape[1] # Right padding of the second to last (i.e., the sequence) dimension padded_split = torch.nn.functional.pad(split, (0, 0, 0, pad_len)) padded_splits.append(padded_split) mask[i, : split.shape[1]] = 1 return torch.cat(padded_splits, dim=0), mask def forward(self, x: List[torch.Tensor] | torch.Tensor) -> PixtralOutput: """Forward to the last hidden states""" if isinstance(x, torch.Tensor): x = list(x) assert isinstance(x, list), "Pixtral expects a list of tensors." split_points = self.__get_num_output_tokens__(x) # Pixtral expects list of images model_out = self.model(x).last_hidden_state split_output, mask = self.split_and_pad_output(model_out, split_points) return PixtralOutput(out=split_output, mask=mask) def get_img_normalize( img_encoder: ImageEncoder, ) -> Callable[..., Normalize | PixtralNormalize]: """Return input normalization function""" if img_encoder == ImageEncoder.PIXTRAL: return PixtralNormalize if img_encoder in { ImageEncoder.SIGLIP_GEMMA2_224, ImageEncoder.SIGLIP_GEMMA2_448, ImageEncoder.SIGLIP_GEMMA2_896, }: return SigLIPNormalize raise NotImplementedError("Unknown image encoder", img_encoder.name) def load_image_encoder( img_encoder: ImageEncoder, device: torch.device | str = "cpu", pretrained: bool = False, ) -> torch.nn.Module: """Load Image encoder as a torch module""" if img_encoder == ImageEncoder.PIXTRAL: return PixtralWrapper(device=device, pretrained=pretrained) if img_encoder in { ImageEncoder.SIGLIP_GEMMA2_224, ImageEncoder.SIGLIP_GEMMA2_448, ImageEncoder.SIGLIP_GEMMA2_896, }: size = int(img_encoder.name.rsplit("_", 1)[-1]) return load_paligemma_vision_encoder( name=f"google/paligemma2-3b-pt-{size}", device=device, pretrained=pretrained, ) raise NotImplementedError(f"image encoder {img_encoder.name} not recognized") ================================================ FILE: kyuteye_pt/kyuteye/modules/image_transforms.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Image transforms""" from typing import List, Literal, Optional, Sequence, Tuple, Union import torch import torchvision.transforms.v2 as T from PIL.Image import Image try: from transformers.models.pixtral.image_processing_pixtral import ( PixtralImageProcessor, ) except ImportError: print("Cannot find Pixtral encoder, you need to upgrade to transformers >= 0.46") def get_minimal_transforms( img_size: Union[Tuple[int], int] = 224, interpolation: Literal[ "bicubic", "bilinear", "nearest", "nearest_exact" ] = "bicubic", to_tensor: bool = False, keep_aspect_ratio: bool = False, max_img_size: Optional[int] = None, ) -> T.Transform: """Minimal transform from converting a PIL image to a Tensor. This is used as default in most of our datasets when the img_transforms is not provided. If keep_aspect_ratio is False, it resizes the image to (img_size, img_size), without respecting aspect ratio. Otherwise, it only resizes the smaller side to img_size, :param img_size: Target image (square) size :param interpolation: Resizing interpolation :param to_tensor: Whether to also convert to a Tensor type (or leave it to a later transform) :param keep_aspect_ratio: Whether to keep the aspect ratio :param max_img_size: Maximum size an image can be along the longer side """ if not isinstance(img_size, tuple): img_size = img_size if keep_aspect_ratio else (img_size, img_size) # type: ignore return T.Compose( [ T.Resize( img_size, interpolation=getattr(T.InterpolationMode, interpolation.upper()), max_size=max_img_size if keep_aspect_ratio else None, ), T.PILToTensor() if to_tensor else T.Identity(), ] ) class Normalize: """Normalization types for the different image encoders. These will be set in image_projection.py""" def __init__(self, mean: Sequence[float], std: Sequence[float]) -> None: super().__init__() self.std = std self.mean = mean self.transform = T.Compose( [ T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)]), T.Normalize( mean, std, ), ] ) def __call__( self, image: Union[Image, List[Image]] ) -> Union[torch.Tensor, List[torch.Tensor]]: if isinstance(image, list): return [self.transform(img) for img in image] return self.transform(image) def to_pil_transform(self, mode: str = "RGB") -> T.Transform: """Returns the function that inverts this normalization""" return T.Compose( [ T.Normalize([0 for _ in self.mean], [1 / (x + 1e-6) for x in self.std]), T.Normalize([-x for x in self.mean], [1 for _ in self.std]), T.ToPILImage(mode=mode), ] ) class UnitNormalize(Normalize): """Normalization for SigLIP encoder""" def __init__(self) -> None: super().__init__( mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), ) class CLIPNormalize(Normalize): """Normalization for SigLIP encoder""" def __init__(self) -> None: super().__init__( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ) class SigLIPNormalize(Normalize): """Normalization for SigLIP encoder""" def __init__(self) -> None: super().__init__(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) class PixtralNormalize: """Image preprocessing for Pixtral https://github.com/huggingface/transformers/blob/fc1ae7f30f1d16c7652c28dd8d91c5d8a8ed2f15/src/transformers/models/pixtral/image_processing_pixtral.py#L369 """ def __init__(self) -> None: self.preprocess = PixtralImageProcessor.from_pretrained( "mistral-community/pixtral-12b" ) def __call__( self, image: Union[Image, List[Image], torch.Tensor] ) -> Union[torch.Tensor, List[torch.Tensor]]: # Image input to pixtral can be an individual image or a list of # images or a list of lists of images (multiple images in a single text sequence) # Case 1: Single image if isinstance(image, Image) or isinstance(image, torch.Tensor): # Pixtral converts a single image into [[image]] — a list of lists of images return self.preprocess(image, return_tensors="pt")["pixel_values"][0][0] # Case2: List of images if isinstance(image, list) and isinstance(image[0], Image): return [ self.preprocess(subimg, return_tensors="pt")["pixel_values"][0][0] for subimg in image ] # Case 3: List of lists of images # This is not currently supported by us raise ValueError( "PixtralNormalize does not support list of lists of images currently." ) ================================================ FILE: kyuteye_pt/kyuteye/modules/streaming_utils.py ================================================ # pylint: disable=protected-access # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Common API for streaming modules during inference""" from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, Optional import torch State = Dict[str, float | int | torch.Tensor] class StreamingModule(torch.nn.Module): """Common API for streaming components.""" def __init__(self) -> None: super().__init__() self._streaming_state: State = {} self._is_streaming = False @property def empty_streaming_state(self) -> bool: """whether streaming state is empty""" return len(self._streaming_state) == 0 def has_streaming_attribute(self, key: str) -> bool: """Whether `key` exists in the current streaming state""" return self._is_streaming and key in self._streaming_state def add_streaming_attribute( self, key: str, value: float | int | torch.Tensor ) -> None: """Add `value` into streaming state's `key`""" self._streaming_state[key] = value def get_streaming_attribute(self, key: str, default: Any = None) -> Any: """Add `value` into streaming state's `key`""" return self._streaming_state.get(key, default) @property def is_streaming(self) -> bool: """in streaming mode""" return self._is_streaming def get_streaming_info_as_int(self, attr_name: str, default: int = 0) -> int: """Tries to get attr_name as an integer""" if self._is_streaming and attr_name in self._streaming_state: if isinstance(self._streaming_state[attr_name], int): return self._streaming_state[attr_name] # type: ignore if isinstance(self._streaming_state[attr_name], torch.Tensor): return int(self._streaming_state[attr_name].item()) # type: ignore raise ValueError( f"Unexpected type {type(self._streaming_state[attr_name])} in streaming state" ) return default @property def streaming_offset(self) -> int: """Shortcut to get the current temporal offset in streaming mode""" return self.get_streaming_info_as_int("offset", default=0) @streaming_offset.setter def streaming_offset(self, value: int | torch.Tensor) -> None: if not self._is_streaming: raise NotImplementedError( "Updating streaming offset of a non-streaming module" ) self._streaming_state["offset"] = value # type: ignore def _apply_named_streaming(self, fn: Callable) -> None: for name, module in self.named_modules(): if isinstance(module, StreamingModule): fn(name, module) def _set_streaming(self, streaming: bool) -> None: def _set_streaming(_: str, module: StreamingModule) -> None: module._is_streaming = streaming self._apply_named_streaming(_set_streaming) @contextmanager def streaming(self) -> Iterator: """Context manager to enter streaming mode. Reset streaming state on exit.""" self._set_streaming(True) try: yield finally: self._set_streaming(False) self.reset_streaming() def streaming_forever(self, batch_size: Optional[int] = None) -> None: """Set in permanent streaming state""" del batch_size self._set_streaming(True) def reset_streaming(self) -> None: """Reset the streaming state.""" def _reset(_: str, module: StreamingModule) -> None: module._streaming_state.clear() self._apply_named_streaming(_reset) def get_streaming_state(self) -> State: """Return the streaming state, including that of sub-modules.""" state: State = {} def _add(name: str, module: StreamingModule) -> None: if name: name += "." for key, value in module._streaming_state.items(): state[name + key] = value self._apply_named_streaming(_add) return state def set_streaming_state(self, state: State) -> None: """Set the streaming state, including that of sub-modules.""" state = dict(state) def _set(name: str, module: StreamingModule) -> None: if name: name += "." module._streaming_state.clear() for key, value in list(state.items()): # complexity is not ideal here, but probably fine. if key.startswith(name): local_key = key[len(name) :] if "." not in local_key: module._streaming_state[local_key] = value del state[key] self._apply_named_streaming(_set) assert len(state) == 0, list(state.keys()) def flush(self, x: Optional[torch.Tensor] = None) -> Optional["StreamingModule"]: """Flush any remaining outputs that were waiting for completion. Typically, for convolutions, this will add the final padding and process the last buffer. This should take an optional argument `x`, which will be provided if a module before this one in the streaming pipeline has already spitted out a flushed out buffer. """ if x is None: return None return self(x) ================================================ FILE: kyuteye_pt/kyuteye/modules/transformer.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Transformer base""" from typing import Any, List, Literal, Optional, Tuple import torch from kyuteye.modules.attention import MultiheadAttention from kyuteye.modules.cross_attention import GatedCrossAttention from kyuteye.modules.streaming_utils import StreamingModule from kyuteye.modules.utils import ( NormalizationLayer, RotaryEmbedding, create_sin_embedding, get_activation, make_ffn, ) class TransformerLayer(StreamingModule): """Base TransformerLayer with causal support. This also integrates cross_attention, when passing `cross_attention=True`, rather than having two separate classes like in PyTorch. :param d_model: Dimension of the data. :param num_heads: Number of heads. :param dim_feedforward: Intermediate dimension of FF module. param causal: Causal mask applied automatically. :param context: Receptive field for the causal mask, infinite if None. :param custom: Use custom MHA implementation, for testing / benchmarking. :param cross_attention: If True, expect to get secondary input for cross-attention. Cross attention will use the default MHA, as it typically won't require special treatment. :param rope: Optional Rope embedding to use. :param norm: Normalization to use. Currently, only 'layer_norm' is supported. :param layer_scale: If not None, LayerScale will be used with the given value as initial scale. :param weights_per_step: use different weights per depformer step. If non zero, should correspond to the number of possible time steps. :param gating: if True, uses SwiGLU like gating in the FFN :param activation: Activation function to use in the FFN layer :param device: Device on which to initialize the module. :param dtype: Dtype to use. """ def __init__( self, d_model: int, num_heads: int, dim_feedforward: int | List[int] = 2048, causal: bool = True, context: Optional[int] = None, cross_attention: bool = False, rope: Optional[RotaryEmbedding] = None, norm: Literal[ "layer_norm", "layer_norm_f32", "rms_norm", "rms_norm_f32", "real_rms_norm", "real_rms_norm_f32", ] = "layer_norm", weights_per_step: int = 0, gating: bool = True, activation: Literal[ "none", "identity", "sigmoid", "tanh", "relu", "leaky_relu", "elu", "gelu", "silu", "mish", "softsign", ] = "silu", # Cross attention to image tokens parameters xa_gating: Literal["sigmoid", "tanh", "none"] = "tanh", xa_conditional_gating: bool = False, xa_shared: bool = True, xa_gate_shared: bool = False, xa_delay: int = 0, xa_start: Optional[Literal["start", "boi", "eoi"]] = None, xa_end: Optional[Literal["end", "eoi"] | int] = None, xa_step: int = 1, xa_dim: Optional[int] = None, # Factory kwargs device: Optional[str | torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() assert norm.upper() in [x.name for x in NormalizationLayer] factory_kwargs = {"device": device, "dtype": dtype} self.causal = causal self.self_attn: MultiheadAttention = MultiheadAttention( causal=causal, context=context, rope=rope, weights_per_step=weights_per_step, embed_dim=d_model, num_heads=num_heads, **factory_kwargs, # type: ignore ) # type: ignore self.norm1 = getattr(NormalizationLayer, norm.upper()).create_norm_fn( d_model, **factory_kwargs ) # Cross attention (optional) self.cross_attention: Optional[torch.nn.Module] = None if cross_attention: assert xa_start is not None and xa_end is not None self.cross_attention = GatedCrossAttention( xa_gating=xa_gating, xa_conditional_gating=xa_conditional_gating, xa_shared=xa_shared, xa_gate_shared=xa_gate_shared, xa_delay=xa_delay, xa_start=xa_start, xa_end=xa_end, xa_step=xa_step, embed_dim=d_model, xa_dim=xa_dim, num_heads=num_heads, **factory_kwargs, # type: ignore ) self.norm_cross = getattr(NormalizationLayer, norm.upper()).create_norm_fn( d_model, **factory_kwargs ) # gating = FFN/MLP self.activation = get_activation(activation) self.gating = make_ffn( d_model, dim_feedforward, self.activation, gating=gating, weights_per_step=weights_per_step, **factory_kwargs, ) self.norm2 = getattr(NormalizationLayer, norm.upper()).create_norm_fn( d_model, **factory_kwargs ) def _ff_block(self, x: torch.Tensor) -> torch.Tensor: """Feed forward block""" x_orig = x x = self.norm2(x) if isinstance(self.gating, torch.nn.ModuleList): # Inference ys: List[torch.Tensor] = [] for t in range(x.shape[1]): y = self.gating[self.streaming_offset + t](x[:, len(ys) : len(ys) + 1]) ys.append(y) update = torch.cat(ys, dim=1) else: # Training: Apply all levels in parallel update = self.gating(x) return x_orig + update def _maybe_cross_attend( self, x: torch.Tensor, cross_attention_src: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor], cross_attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Cross attention""" if self.cross_attention is not None and cross_attention_src is not None: x_orig = x update, gate_weight = self.cross_attention( self.norm_cross(x), cross_attention_src, None, cross_attention_mask, ) return x_orig + update, gate_weight return x, None def _self_attend( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Self-attention""" x_orig = x update = self.self_attn( self.norm1(x), attention_mask=attention_mask, ) return x_orig + update def forward( self, x: torch.Tensor, cross_attention_src: Optional[ Tuple[torch.Tensor, torch.Tensor] | torch.Tensor ] = None, cross_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """(Optional) MHSA => (Optional) cross-attention => FFN :param x: Input query tensor :param cross_attention_src: Optional tensor used as source for the keys and values in the cross-attention :param cross_attention_mask: Mask for the cross-attention :param attention_mask: Attention mask for the self-attention. Can be used to mask out tokens that shouldn't be involved in the self-attention computation :param image_tokens_mask: Mask indicating where the image tokens are in the stream with shape (B, seq, 1). This is used in two places: * (i) in the cross-attention, to determine which token should cross-attend to the image * (ii) in the self-attention, to allow the attention mask to be non causal inside the image tokens """ if self.is_streaming and not self.has_streaming_attribute("offset"): self.streaming_offset = 0 x = self._self_attend(x, attention_mask=attention_mask) x, gate_weight = self._maybe_cross_attend( x, cross_attention_src=cross_attention_src, cross_attention_mask=cross_attention_mask, ) x = self._ff_block(x) # Update streaming offset for the multi linear FFNs in the depformer if self.is_streaming: self.streaming_offset += x.shape[1] return x, gate_weight class Transformer(StreamingModule): """Transformer with Streaming / Causal support. :param d_model: Dimension of the data. :param num_heads: Number of heads. :param num_layers: Number of transformer layers :param dim_feedforward: Intermediate dimension of FF module. :param causal: If True, automatically applies a causal mask. :param context: Size of the receptive field for the causal mask. If None, assumes infinite context. :param cross_attention: If True, `forward` will expect to get secondary input for cross-attention. :param xa_layers: If a non-empty tuple, specified which layers to add cross-attention layers to. If None or an empty tuple, and cross_attention is True, will apply cross attention in every layer :param positional_embedding: Positional embedding strategy (sin, rope, sin_rope, or none). :param max_period: Maximum period for the sin/cos in RoPE embedding. :param positional_scale: Scale of positional embedding, set to 0 to deactivate. :parma device: Device on which to initialize the model. :param dtype: Device type to use. **kwargs: Extra arguments fed to the `TransformerLayer` constructor (e.g. layer_scale) """ def __init__( self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int | list[int] = 2048, causal: bool = False, context: Optional[int] = None, cross_attention: bool = False, xa_layers: Optional[Tuple[int, ...]] = None, positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "sin", max_period: float = 10000, positional_scale: float = 1.0, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs: Any, ) -> None: super().__init__() assert d_model % num_heads == 0 self.positional_embedding = positional_embedding self.max_period = max_period self.positional_scale = positional_scale assert positional_embedding in {"sin", "rope", "sin_rope", "none"} self.rope: Optional[RotaryEmbedding] = None if self.positional_embedding in {"rope", "sin_rope"}: self.rope = RotaryEmbedding(max_period=max_period) self.layers = torch.nn.ModuleList() for layer_idx in range(num_layers): cross_attend_layer = ( xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers ) self.layers.append( TransformerLayer( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, causal=causal, context=context, cross_attention=cross_attention and cross_attend_layer, rope=self.rope, device=device, dtype=dtype, **kwargs, ) ) def set_context(self, context: Optional[int] = None) -> None: """Update context size in all MHSA layers""" for module in self.modules(): if isinstance(module, MultiheadAttention): module.context = context def forward( self, x: torch.Tensor, *args: Any, **kwargs: Any ) -> Tuple[torch.Tensor, float]: """Forward pass""" _, seq_len, channels = x.shape if self.positional_embedding in {"sin", "sin_rope"}: positions = torch.arange(seq_len, device=x.device).view(1, -1, 1) pos_emb = create_sin_embedding( positions, channels, max_period=self.max_period, dtype=x.dtype ) x = x + self.positional_scale * pos_emb alpha = 0.0 for layer_idx, layer in enumerate(self.layers): x, gate_weight = layer(x, *args, **kwargs) if gate_weight is not None and layer_idx >= len(self.layers) - 10: alpha += torch.mean(gate_weight).cpu().item() return x, alpha / min(10, len(self.layers)) ================================================ FILE: kyuteye_pt/kyuteye/modules/utils.py ================================================ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Smaller building blocks: * FFN layers with Swi-GLU like gating * Normalization layers * Input embedding layer * Positional embeddings """ from enum import Enum, unique from typing import Any, Callable, List, Literal, Optional, Tuple import torch def multi_linear( num_linear: int, weight: torch.Tensor, x: torch.Tensor, offset: int = 0 ) -> torch.Tensor: """Utility to apply a multi linear layer to the given input. A multi linear layer applies a different set of weight for each time step. Args: num_linear (int): Number of possible time steps and so number of linears. weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`. x (torch.Tensor): Input tensor, with shape `[B, T, C]`. offset (int): offset for the current time step, in particular for decoding, with time steps provided one by one. """ ys = [] # when calling the depformer, x.shape[1] is always 1, and the offset contains the # codebook index we care about for t in range(x.shape[1]): y = torch.nn.functional.linear( # pylint: disable=not-callable x[:, t], weight.chunk(num_linear)[offset + t] ) ys.append(y) out = torch.stack(ys, 1) return out def get_activation( name: Literal[ "sigmoid", "tanh", "relu", "leaky_relu", "elu", "gelu", "silu", "mish", "softsign", "identity", "none", ], ) -> Callable[[torch.Tensor], torch.Tensor]: """Return correct activation function from the given name""" if name in {"sigmoid", "tanh", "relu"}: return getattr(torch, name) if name in {"leaky_relu", "elu", "gelu", "silu", "mish", "softsign"}: return getattr(torch.nn.functional, name) if name in {"identity", "none"}: return torch.nn.Identity() raise NotImplementedError(f"Unknown activation {name}") def gating_forward_kernel( weight_in: torch.Tensor, weight_out: torch.Tensor, activation: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, ) -> torch.Tensor: """Simple multiplicative gating strategy (SwiGLU like)""" x = torch.nn.functional.linear(x, weight_in) # pylint: disable=not-callable batch_size, seq_len, _ = x.shape x = x.view(batch_size, seq_len, 2, -1) x = activation(x[..., 0, :]) * x[..., 1, :] x = torch.nn.functional.linear(x, weight_out) # pylint: disable=not-callable return x class ActivationGating(torch.nn.Module): """ FFN layer with multiplicative gating using the given activation. :param dim: Dimensions of the tokens. :param activation: Activation function to use. :param factory_kwargs: Other kwargs passed to the linear layer, in particular device and dtype. """ def __init__( self, dim: int, dim_feedforward: int, activation: Callable[[torch.Tensor], torch.Tensor], **factory_kwargs: Any, ): super().__init__() # We should have 8 d^2 param, instead we will have # 2 * h * d + h * d = 3 h * d = 8 d^2 # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx. if dim_feedforward == 4 * dim: hidden = (21 * dim) // 8 else: hidden = (2 * dim_feedforward) // 3 self.linear_in = torch.nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) self.linear_out = torch.nn.Linear(hidden, dim, bias=False, **factory_kwargs) self.activation = activation max_params = 2 * dim * dim_feedforward params = sum(p.numel() for p in self.parameters()) assert params <= max_params, f"Gating has {params} params, max is {max_params}" def forward(self, x: torch.Tensor) -> torch.Tensor: """FFN with Swi-GLU like gating but customizable activation""" return gating_forward_kernel( self.linear_in.weight, self.linear_out.weight, self.activation, x ) class NoGating(torch.nn.Module): """ Simple 2 layer MLP FFN layer """ def __init__( self, dim: int, dim_feedforward: int, activation: Callable[[torch.Tensor], torch.Tensor], **factory_kwargs: Any, ): super().__init__() self.linear1 = torch.nn.Linear( dim, dim_feedforward, bias=False, **factory_kwargs ) self.linear2 = torch.nn.Linear( dim_feedforward, dim, bias=False, **factory_kwargs ) self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: """Simple two layers MLP FFN""" return self.linear2(self.activation(self.linear1(x))) def make_ffn( dim: int, dim_feedforward: int | List[int], activation_fn: Callable[[torch.Tensor], torch.Tensor], gating: bool = False, weights_per_step: int = 0, **factory_kwargs: Any, ) -> torch.nn.Module: """Create a FNN module :param dim: Number of input dimensions :param dim_feedforward: Nubmer of inner dimensions :param activation_fn: Activation function :param gating: If True, uses FFN with multiplicative gating :param factory_kwargs: Any extra argument fed to the Linear layer constructors (e.g. device, dtype) """ ffn: torch.nn.Module if gating: if weights_per_step > 0: if isinstance(dim_feedforward, int): dim_feedforward = [dim_feedforward] * weights_per_step assert isinstance(dim_feedforward, list), dim_feedforward ffn = torch.nn.ModuleList( [ ActivationGating(dim, dim_out, activation_fn, **factory_kwargs) for dim_out in dim_feedforward ] ) else: assert isinstance(dim_feedforward, int) ffn = ActivationGating( dim, dim_feedforward, activation_fn, **factory_kwargs ) else: assert isinstance(dim_feedforward, int) assert ( weights_per_step == 0 ), f"weights per step {weights_per_step} > 0 is not supported without gated FFN" ffn = NoGating(dim, dim_feedforward, activation_fn, **factory_kwargs) return ffn class LayerNormF32(torch.nn.LayerNorm): """Layer norm executed in Float32 for maximal precision""" def forward( self, input: torch.Tensor # pylint: disable=redefined-builtin ) -> torch.Tensor: """Applies the layer norm""" x_f32 = input.float() out_f32 = super().forward(x_f32) return out_f32.to(input.dtype) def _rms_norm( x: torch.Tensor, alpha: torch.Tensor, dtype: Optional[torch.dtype], eps: float, use_var: bool, return_factor: bool = False, ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}" in_dtype = x.dtype if dtype is not None: x = x.to(dtype) if use_var: var = eps + x.var(dim=2, keepdim=True) else: var = eps + torch.mean(x**2, dim=2, keepdim=True) if return_factor: factor = alpha.to(var) * torch.rsqrt(var) return (x * factor).to(in_dtype), factor.to(in_dtype) return (x * (alpha.to(var) * torch.rsqrt(var))).to(in_dtype) class RMSNorm(torch.nn.Module): """RMSNorm layer :param dim: Input channels dimension :param eps: Epsilon """ def __init__( self, dim: int, eps: float = 1e-5, use_var: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): super().__init__() self.eps = eps self.dtype = dtype self.use_var = use_var self.alpha = torch.nn.Parameter( torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """RMS Norm call""" out = _rms_norm(x, self.alpha, self.dtype, self.eps, self.use_var) return out # type: ignore[return-value] @unique class NormalizationLayer(Enum): """Select one of several normalization layers""" LAYER_NORM = 0 LAYER_NORM_F32 = 1 RMS_NORM = 2 RMS_NORM_F32 = 3 REAL_RMS_NORM = 4 REAL_RMS_NORM_F32 = 5 def create_norm_fn(self, dim: int, **kwargs: Any) -> torch.nn.Module: """Return the proper normalization layer initializer""" # Layer Norm if self == NormalizationLayer.LAYER_NORM: return torch.nn.LayerNorm(dim, eps=1e-5, **kwargs) if self == NormalizationLayer.LAYER_NORM_F32: return LayerNormF32( dim, eps=1e-8, **{k: v for k, v in kwargs.items() if k != "dtype"} ) # Real RMS Norm using |x**2| normalization if self == NormalizationLayer.REAL_RMS_NORM: return RMSNorm(dim, eps=1e-5, use_var=False, **kwargs) if self == NormalizationLayer.REAL_RMS_NORM_F32: return RMSNorm( dim, eps=1e-8, dtype=torch.float32, use_var=False, **{k: v for k, v in kwargs.items() if k != "dtype"}, ) # RMS Norm using variance of the data if self == NormalizationLayer.RMS_NORM: return RMSNorm(dim, eps=1e-5, **kwargs) if self == NormalizationLayer.RMS_NORM_F32: return RMSNorm( dim, eps=1e-8, dtype=torch.float32, **{k: v for k, v in kwargs.items() if k != "dtype"}, ) raise NotImplementedError(f"Unknown norm type: {self.name}") class ClampedEmbedding(torch.nn.Embedding): """An embedding layer such that all input IDs of the ID `zero_idx < 0` are mapped to zero at the output of the module Args: lr (float or None): Learning rate for the embedding layer if provided. norm (bool): if True, uses a layer norm after the embedding. zero_idx (int): special value indicating that the output should be exactly 0. """ def __init__( self, *args: Any, norm: bool = False, zero_idx: int = -1, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self.norm = None if norm: self.norm = NormalizationLayer.LAYER_NORM.create_norm_fn(self.embedding_dim) assert zero_idx < 0, "Please use negative values for the zero_idx." self.zero_idx = zero_idx def forward( # pylint: disable=arguments-renamed self, inputs: torch.Tensor ) -> torch.Tensor: """Embed the input IDs""" is_zero = inputs == self.zero_idx zero = torch.zeros(1, dtype=inputs.dtype, device=inputs.device) y = super().forward(inputs.clamp(min=0)) if self.norm is not None: y = self.norm(y) y = torch.where(is_zero[..., None], zero, y) return y def create_sin_embedding( positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Create fixed sinusoidal positional embedding, with shape `[B, T, C]`. Args: positions (torch.Tensor): LongTensor of positions. dim (int): Dimension of the embedding. max_period (float): Maximum period of the cosine/sine functions. dtype (torch.dtype or str): dtype to use to generate the embedding. Returns: torch.Tensor: Sinusoidal positional embedding. """ # Assumes BTC format assert dim % 2 == 0 half_dim = dim // 2 positions = positions.to(dtype) adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) max_period_tensor = torch.full( [], max_period, device=positions.device, dtype=dtype ) # avoid sync point phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) def apply_rope( q: torch.Tensor, k: torch.Tensor, max_period: float = 10_000, offset: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply RoPE embedding to he input queries/keys :param q: queries, shape `[B, T, H, D]`. :param k: keys, shape `[B, T, H, D]`. :param max_period: maximum period for the cos and sin (aka. `theta_rope`). """ batch, seq_length, num_heads, dim = q.shape assert k.shape == q.shape ds = torch.arange(dim // 2, device=q.device, dtype=torch.float32) max_period_t = torch.full([1], max_period, device=q.device, dtype=torch.float32) freqs = 1.0 / (max_period_t ** (2 * ds / dim)) ts = torch.arange( offset, seq_length + offset, device=q.device, dtype=torch.float32 ).view(-1, 1, 1) q = q.view(batch, seq_length, num_heads, dim // 2, 2) k = k.view(batch, seq_length, num_heads, dim // 2, 2) # convention is `r` suffix is real part, `i` is imaginary. qr = q[..., 0].float() qi = q[..., 1].float() kr = k[..., 0].float() ki = k[..., 1].float() rotr = torch.cos(freqs * ts) roti = torch.sin(freqs * ts) qor = qr * rotr - qi * roti qoi = qr * roti + qi * rotr kor = kr * rotr - ki * roti koi = kr * roti + ki * rotr dtype = q.dtype qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) return qo.view(batch, seq_length, num_heads, dim), ko.view( batch, seq_length, num_heads, dim ) class RotaryEmbedding(torch.nn.Module): """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). :param max_period: Maximum period of the rotation frequencies (aka `theta_rope`). """ def __init__(self, max_period: float = 10000.0) -> None: super().__init__() self.max_period = max_period def forward( self, q: torch.Tensor, k: torch.Tensor, offset: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rope rotation to query or key tensor with an `offset` on temporal positions.""" return apply_rope(q, k, max_period=self.max_period, offset=offset) ================================================ FILE: kyuteye_pt/kyuteye/server.py ================================================ # pylint: disable=protected-access,no-member # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Start Pytorch backend""" import asyncio import os import random import time from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple import aiohttp import fire import numpy as np import sentencepiece import sphn import torch from aiohttp import web from huggingface_hub import hf_hub_download from kyuteye.config.enums import ImageEncoder from kyuteye.config.kyuteye_config import KyuteyeConfig from kyuteye.models.loaders import get_moshi_vis from kyuteye.modules.image_transforms import get_minimal_transforms from moshi.models.loaders import get_mimi from torchvision.io import ImageReadMode, decode_image if TYPE_CHECKING: from kyuteye.models.image_projection import ImageProjection from kyuteye.models.moshivis import MoshiVisGen from moshi.models import MimiModel def colorize(text: str, color: str) -> str: """Add colors to log""" code = f"\033[{color}m" restore = "\033[0m" return "".join([code, text, restore]) def make_log(level: str, msg: str) -> str: """Create log""" if level == "warning": prefix = colorize("[Warn]", "1;31") elif level == "info": prefix = colorize("[Info]", "1;34") elif level == "error": prefix = colorize("[Err ]", "1;31") else: raise ValueError(f"Unknown level {level}") return prefix + " " + msg def log(level: str, msg: str) -> None: """Log with colors""" print(make_log(level, msg)) def seed_all(seed: int) -> None: """Seed""" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for multi-GPU setups random.seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False @dataclass class ServerState: """Main server state""" mimi: "MimiModel" text_tokenizer: sentencepiece.SentencePieceProcessor moshi_vis: "MoshiVisGen" image_encoder_model: "ImageProjection" image_size: int xa_start: int lock: asyncio.Lock dtype: torch.dtype display_gating: bool def __init__( self, mimi: "MimiModel", text_tokenizer: sentencepiece.SentencePieceProcessor, moshi_vis: "MoshiVisGen", image_encoder_model: "ImageProjection", device: str | torch.device, dtype: torch.dtype = torch.bfloat16, max_msg_size: int = 0, image_size: int = 448, xa_start: int = 0, ): self.mimi = mimi self.text_tokenizer = text_tokenizer self.moshi_vis = moshi_vis self.image_encoder_model = image_encoder_model self.embeddings: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | None = None self.max_msg_size = max_msg_size self.image_size = image_size self.xa_start = xa_start self.display_gating = True self.device = device self.dtype = dtype self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.lock = asyncio.Lock() self.mimi.streaming_forever(1) self.moshi_vis.streaming_forever(1) def warmup(self) -> None: """Warmup the models""" for _ in range(4): chunk = torch.zeros( 1, 1, self.frame_size, dtype=torch.float32, device=self.device ) codes = self.mimi.encode(chunk) ca_src = self.image_encoder_model( torch.zeros(1, 3, 224, 224, device=self.device) )["cross_attention_src"] for c in range(codes.shape[-1]): tokens, _ = self.moshi_vis.step(codes[:, :, c : c + 1], ca_src=ca_src) if tokens is None: continue _ = self.mimi.decode(tokens[:, 1:]) torch.cuda.synchronize() self.mimi.reset_streaming() self.moshi_vis.reset_streaming() async def handle_chat(self, request: Any) -> Any: """start conversation""" ws = web.WebSocketResponse(max_msg_size=self.max_msg_size) await ws.prepare(request) close = False async def recv_loop() -> None: nonlocal close try: async for message in ws: if message.type == aiohttp.WSMsgType.ERROR: log("error", f"{ws.exception()}") break if message.type == aiohttp.WSMsgType.CLOSED: log("info", "closed received") break if message.type != aiohttp.WSMsgType.BINARY: log("error", f"unexpected message type {message.type}") continue message = message.data if not isinstance(message, bytes): log("error", f"unsupported message type {type(message)}") continue if len(message) == 0: log("warning", "empty message") continue kind = message[0] if kind == 1: # audio payload = message[1:] opus_reader.append_bytes(payload) elif kind == 10: log("info", f"received user rating {message[1]}") else: log("warning", f"unknown message kind {kind}") except Exception as e: print("Exception raised:", e) finally: close = True log("info", "connection closed") async def opus_loop() -> None: all_pcm_data = None while True: if close: return await asyncio.sleep(0.001) pcm = opus_reader.read_pcm() if pcm.shape[-1] == 0: continue if all_pcm_data is None: all_pcm_data = pcm else: all_pcm_data = np.concatenate((all_pcm_data, pcm)) while all_pcm_data.shape[-1] >= self.frame_size: be = time.time() chunk = all_pcm_data[: self.frame_size] all_pcm_data = all_pcm_data[self.frame_size :] chunk = torch.from_numpy(chunk) chunk = chunk.to(device=self.device)[None, None] codes = self.mimi.encode(chunk) for c in range(codes.shape[-1]): tokens, gate_weight = self.moshi_vis.step( codes[:, :, c : c + 1], ca_src=( self.embeddings if self.moshi_vis.get_streaming_attribute("offset", 0) >= self.xa_start else None ), ) if tokens is None: continue assert ( tokens.shape[1] == self.moshi_vis.num_audio_codebooks_out + 1 ) main_pcm = self.mimi.decode(tokens[:, 1:]) main_pcm = main_pcm.cpu() opus_writer.append_pcm(main_pcm[0, 0].numpy()) text_token = tokens[0, 0, 0].item() if text_token not in (0, 3): _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") text_color = round( max(min((gate_weight - 0.005) / 0.016, 1.0), 0.0) * 10 ) msg = ( b"\x07" + text_color.to_bytes(1, "big") + bytes(_text, encoding="utf8") ) log("info", f"text token '{_text}'") await ws.send_bytes(msg) log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms") async def send_loop() -> None: while True: if close: return await asyncio.sleep(0.001) msg = opus_writer.read_bytes() if len(msg) > 0: await ws.send_bytes(b"\x01" + msg) log("info", "accepted connection") close = False query_parameters = request.rel_url.query self.moshi_vis.update_gen_kwargs( temp_text=( float(aux) if (aux := query_parameters.get("text_temperature", None)) is not None else None ), temp=( float(aux) if (aux := query_parameters.get("audio_temperature", None)) is not None else None ), top_k_text=( int(aux) if (aux := query_parameters.get("text_topk", None)) is not None else None ), top_k=( int(aux) if (aux := query_parameters.get("audio_topk", None)) is not None else None ), ) self.image_size = ( int(aux) if (aux := query_parameters.get("image_resolution", None)) is not None else self.image_size ) if (aux := query_parameters.get("xa_start", None)) is not None: self.xa_start = int(aux) async with self.lock: opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) # type: ignore opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) # type: ignore self.mimi.reset_streaming() self.moshi_vis.reset_streaming() await self.extract_image(ws) # Send the handshake. await ws.send_bytes(b"\x00") await asyncio.gather(opus_loop(), recv_loop(), send_loop()) log("info", "done with connection") return ws async def extract_image(self, ws: web.WebSocketResponse) -> None: """Embed imageat the beginning of the stream""" first_message = await ws.receive() first_message = first_message.data try: kind = first_message[0] except Exception as e: raise RuntimeError(f"Error in message: {first_message}") from e if kind != 8: # image raise RuntimeError(f"unknown message kind {kind}") payload = first_message[1:] image_tensor = decode_image( torch.frombuffer(payload, dtype=torch.uint8), mode=ImageReadMode.RGB ) image_tensor = get_minimal_transforms(self.image_size)(image_tensor) image_tensor = self.image_encoder_model.to_tensor_and_normalize(image_tensor) log("info", f"Loaded image tensor with shape {image_tensor.shape}") if self.image_encoder_model.encoder_type == ImageEncoder.PIXTRAL: image_tensor = [image_tensor.to(self.device)] else: image_tensor = image_tensor[None, ...].to(self.device) k, v = self.moshi_vis.precompte_ca_kv( self.image_encoder_model(image_tensor)["cross_attention_src"] ) self.embeddings = (k.to(self.dtype), v.to(self.dtype)) def start_server( kyuteye_config_path: str, host: str = "localhost", port: int = 8998, static: Optional[str] = None, device: str = "cuda", dtype: Literal["float32", "bfloat16"] = "bfloat16", ssl: bool = True, ssl_cert_dir: Optional[str] = None, ) -> None: """Start server :param kyuteye_config: Config of the model to load :param host: Host to start the server on :param port: Port to start the server on :param static: Path to the built client source. If None, defaults to the local client :param device: Device of execution :param dtype: Dtype of execution :param ssl: Whether to launch on https or http protocol :param max_img_size: Max image size (in MB) that can be sent via aiohttp; If 0, no limit is set. Note that input images are resized in any case before being sent to the encoder """ assert kyuteye_config_path is not None root_dir = Path(__file__).parents[2] static_path: None | str = None if static is None: static_path = str(root_dir / "client" / "dist") else: static_path = static static_path = os.path.abspath(static_path) assert static_path is not None and os.path.exists(static_path) seed_all(42) setup_tunnel = None tunnel_token = "" kyuteye_config = KyuteyeConfig.from_yml(kyuteye_config_path) # Load main model components log("info", "loading mimi") if kyuteye_config.hf_repo is None: assert os.path.exists(kyuteye_config.mimi_codec) mimi_weight = kyuteye_config.mimi_codec else: mimi_weight = hf_hub_download(kyuteye_config.hf_repo, kyuteye_config.mimi_codec) mimi = get_mimi(mimi_weight, device) log("info", "mimi loaded") if kyuteye_config.hf_repo is None: assert os.path.exists(kyuteye_config.text_tokenizer) text_tokenizer_path = kyuteye_config.text_tokenizer else: text_tokenizer_path = hf_hub_download( kyuteye_config.hf_repo, kyuteye_config.text_tokenizer ) text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer_path) # type: ignore log("info", "loading moshi + vision") if kyuteye_config.hf_repo is None: assert os.path.exists(kyuteye_config.model) moshi_weight = kyuteye_config.model if not moshi_weight.endswith("_pt.safetensors"): assert moshi_weight.endswith(".safetensors") moshi_weight = moshi_weight.replace(".safetensors", "_pt.safetensors") print(f"Will load from {moshi_weight}") else: moshi_weight = hf_hub_download(kyuteye_config.hf_repo, kyuteye_config.model) torch_dtype = getattr(torch, dtype) moshi_vis, image_embedder = get_moshi_vis( kyuteye_config, moshi_weight, device, torch_dtype ) log("info", "moshi + vision loaded") state = ServerState( mimi=mimi, text_tokenizer=text_tokenizer, moshi_vis=moshi_vis, image_encoder_model=image_embedder, device=device, dtype=torch_dtype, xa_start=kyuteye_config.xa_start, ) log("info", "warming up the model") state.warmup() app = web.Application() app.router.add_get("/api/chat", state.handle_chat) async def handle_root(_): # type: ignore return web.FileResponse(os.path.join(static_path, "index.html")) log("info", f"serving static content from {static_path}") app.router.add_get("/", handle_root) app.router.add_static("/", path=static_path, follow_symlinks=True, name="static") protocol = "http" ssl_context = None if ssl: import ssl as ssl_module ssl_context = ssl_module.create_default_context(ssl_module.Purpose.CLIENT_AUTH) ssl_cert_dir = ssl_cert_dir or str(root_dir) ssl_context.load_cert_chain( certfile=os.path.join(ssl_cert_dir, "cert.pem"), keyfile=os.path.join(ssl_cert_dir, "key.pem"), ) protocol = "https" log("info", f"Access the Web UI directly at {protocol}://{host}:{port}") if setup_tunnel is not None: tunnel = setup_tunnel("localhost", port, tunnel_token, None) log( "info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.", ) log( "info", "Note that this tunnel goes through the US and you" " might experience high latency in Europe.", ) with torch.no_grad(): web.run_app(app, port=port, ssl_context=ssl_context) def sanity_check() -> None: pass def main() -> None: """main""" fire.Fire(start_server) ================================================ FILE: kyuteye_pt/kyuteye/utils/__init__.py ================================================ ================================================ FILE: kyuteye_pt/kyuteye/utils/dist_utils.py ================================================ """Some utils for distributed training""" import os from typing import Any import torch.distributed as dist from rich import print as rich_print def is_main() -> bool: """Returns True iff the current process is the main one""" # torch distributed if dist.is_initialized(): return dist.get_rank() == 0 # procid if "LOCAL_RANK" in os.environ: return int(os.environ.get("LOCAL_RANK", 0)) == 0 return int(os.environ.get("SLURM_PROCID", 0)) == 0 def print_main(*args: Any, rich: bool = False, **kwargs: Any) -> None: """Print function that only activate for the main process""" if is_main(): if rich: rich_print(*args, **kwargs) else: print(*args, **kwargs) ================================================ FILE: kyuteye_pt/kyuteye/utils/logging_utils.py ================================================ """Some utils for experiment tracking and logging""" import json import subprocess from typing import Dict, Tuple import rich def flatten_nested_dict(d: Dict) -> Dict: """Flatten a nested config dictionary""" flattened_dict = {} for k, v in d.items(): if isinstance(v, dict): flattened_dict.update(v) else: flattened_dict[k] = v return flattened_dict def get_git_revision_hash(verbose: bool = True) -> Tuple[str, str]: """Return current git branch and commit""" git_branch = ( subprocess.check_output(["git", "branch", "--show-current"]) .decode("ascii") .strip() ) commit_hash = ( subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() ) if verbose: rich.print( f"[magenta]Git:[/] Commit [bold]{commit_hash}[/] on branch {git_branch}" ) return git_branch, commit_hash def pretty_json(config_dict: dict) -> str: """Pretty print the given dict as json""" config_dict = { k: ( v.name if hasattr(v, "name") else v if isinstance(v, (str, int, float)) else str(v) ) for k, v in config_dict.items() } json_config_dict = json.dumps(config_dict, indent=4) return "".join("\t" + line for line in json_config_dict.splitlines(True)) ================================================ FILE: kyuteye_pt/kyuteye/utils/struct_utils.py ================================================ """Useful structure and simple class definition FrozenEnum are used to hold global configs shared across multiple files""" from enum import Enum, EnumMeta from typing import Any class FrozenEnumMeta(EnumMeta): "Enum metaclass that freezes an enum entirely" def __new__(mcs, name: str, bases: Any, classdict: Any) -> type: classdict["__frozenenummeta_creating_class__"] = True enum = super().__new__(mcs, name, bases, classdict) del enum.__frozenenummeta_creating_class__ # type: ignore[attr-defined] return enum def __setattr__(cls, name: str, value: Any) -> None: members = cls.__dict__.get("_member_map_", {}) if hasattr(cls, "__frozenenummeta_creating_class__") or name in members: return super().__setattr__(name, value) if hasattr(cls, name): msg = "{!r} object attribute {!r} is read-only" else: msg = "{!r} object has no attribute {!r}" raise AttributeError(msg.format(cls.__name__, name)) def __delattr__(cls, name: str) -> None: members = cls.__dict__.get("_member_map_", {}) if hasattr(cls, "__frozenenummeta_creating_class__") or name in members: return super().__delattr__(name) if hasattr(cls, name): msg = "{!r} object attribute {!r} is read-only" else: msg = "{!r} object has no attribute {!r}" raise AttributeError(msg.format(cls.__name__, name)) class FrozenEnum(Enum, metaclass=FrozenEnumMeta): """Frozen Enum type used for immutable configurations""" pass ================================================ FILE: kyuteye_pt/pyproject.toml ================================================ [build-system] requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" [project] name = "kyuteye" version = "0.0.0" description = "Kyutai with an 'eye'" authors = [ { name = "Amelie Royer", email = "amelie@kyutai.org" }, { name = "Moritz Boehle", email = "moritz@kyutai.org" } ] maintainers = [{ name = "Amelie Royer", email = "amelie@kyutai.org" }] keywords = [] readme = "README.md" license = { text = "MIT" } requires-python = ">=3.10" dependencies = [ "einops", "huggingface_hub", "moshi==0.1.0", "numpy<2", "sentencepiece", "torch==2.2.0", "torchvision==0.17.0", "tqdm", "transformers==4.47.0", "triton", "fire", "rich", "pyyaml", "black", "setuptools", "sphn >= 0.1.4", ] [project.scripts] server = "kyuteye.server:main" sanity-check = "kyuteye.server:sanity_check" [tool.mypy] python_version = "3.10" [[tool.mypy.overrides]] module = ["fire", "yaml", "setuptools", "transformers.*", "torchvision.*", "timm.*", "sentencepiece", "moshi.*", "huggingface_hub", "sphn"] ignore_missing_imports = true [tool.setuptools.packages.find] where = ["."] [dependency-groups] dev = [ "mypy==1.11.2", "pylint>=3.3.4", ] ================================================ FILE: kyuteye_pt/tests/hello.py ================================================ from transformers import AutoProcessor, AutoModelForImageTextToText import numpy as np import torch import torch from moshi.models import loaders, LMGen import random from pathlib import Path from huggingface_hub import hf_hub_download def write_weights_for_analysis(model: torch.nn.Module): file_content = "" if isinstance(model, torch.nn.Module): framework_name = "torch" params = model.state_dict().items() else: print("Unsupported model type") for i, (key, value) in enumerate(params): file_content += f"{i} {key} {value.shape}\n" dest = Path(f"/tmp/weights_{random.randint(0, 2**16)}_{framework_name}.txt") dest.write_text(file_content) print("Wrote layers description into " + dest) @torch.no_grad() def test_weights_conversion_moshi(): moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME) moshi = loaders.get_moshi_lm(moshi_weight) lm_gen_torch = LMGen(moshi, temp=0.8, temp_text=0.7) write_weights_for_analysis(lm_gen_torch) if __name__ == "__main__": test_weights_conversion_moshi() ================================================ FILE: kyuteye_rs/Cargo.toml ================================================ [workspace] members = [ "moshi-core", "moshi-backend", ] resolver = "2" [workspace.dependencies] anyhow = "1" axum = { version = "0.8.1", features = ["ws"] } axum-server = { version = "0.7.1", features = ["tls-rustls"] } base64 = "0.21.7" bincode = "1.3.3" byteorder = "1.5.0" candle = { version = "0.8.3", package = "candle-core" } candle-flash-attn = "0.8.3" candle-nn = "0.8.3" candle-transformers = "0.8.3" clap = { version = "4.4.12", features = ["derive"] } color-eyre = "0.6.2" console_error_panic_hook = "0.1.7" cpal = "0.15.3" crossterm = { version = "0.27.0", features = ["event-stream"] } cudarc = { version = "=0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } futures = "0.3.28" futures-util = "0.3.30" hf-hub = { version = "0.3.2", features = ["tokio"] } http = "1.1.0" hyper = "1.3.1" image = "0.25.2" js-sys = "0.3.66" lazy_static = "1.5.0" log = "0.4.20" moshi = { path = "./moshi-core" } native-tls = "0.2.11" numpy = "0.23.0" ogg = { version = "0.9.1", features = ["async"] } opus = "0.3.0" prometheus = "0.13.4" prost = "0.12" pyo3 = "0.23.0" rand = { version = "0.8.5", features = ["getrandom"] } rand_chacha = "0.3.1" ratatui = "=0.26.0" rayon = "1.8.1" rcgen = "0.13.1" regex = "1.10.3" reqwest = { version = "0.12", features = ["stream", "json"] } rubato = "0.15.0" rustls = { version = "0.23.20", features = ["ring"] } sentencepiece = "0.11.2" serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.115" symphonia = { version = "0.5.3", features = ["all"] } timens = "0.1.9" tokio = { version = "1.35.1", features = ["full"] } tokio-rustls = "0.24.1" tokio-stream = "0.1" tokio-tungstenite = { version = "0.21.0", features = ["rustls", "native-tls"] } tonic = "0.11" tonic-build = "0.11" tower = "0.4.13" tower-http = { version = "0.5", features = ["full"] } tracing = "0.1.40" tracing-appender = "0.2.3" tracing-subscriber = "0.3.18" tui-logger = "=0.11.1" vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] } webrtc = "0.10.1" ================================================ FILE: kyuteye_rs/configs/config-moshika-vis-q8.json ================================================ { "instance_name": "foo", "hf_repo": "kyutai/moshika-vis-candle-q8", "lm_model_file": "$HOME/tmp/model.q8_0.gguf", "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", "log_dir": "$HOME/tmp/moshi-logs", "mimi_num_codebooks": 8, "static_dir": "../client/dist", "addr": "0.0.0.0", "port": 8008, "cert_dir": "../", "lm_config": { "acoustic_delay": 1, "generated_audio_codebooks": 8, "input_audio_codebooks": 8, "audio_vocab_size": 2049, "text_eop_token": 0, "text_pad_token": 3, "text_start_token": 32000 }, "image_prefix_backbone": "Siglip448", "image_prefix_use_rms_norm": true, "cross_attention_gating": "ConditionalGatedSigmoid", "cross_attention_in_dims": null } ================================================ FILE: kyuteye_rs/configs/config-moshika-vis.json ================================================ { "instance_name": "foo", "hf_repo": "kyutai/moshika-vis-candle-bf16", "lm_model_file": "$HOME/tmp/model.safetensors", "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", "log_dir": "$HOME/tmp/moshi-logs", "mimi_num_codebooks": 8, "static_dir": "../client/dist", "addr": "0.0.0.0", "port": 8008, "cert_dir": "../", "lm_config": { "acoustic_delay": 1, "generated_audio_codebooks": 8, "input_audio_codebooks": 8, "audio_vocab_size": 2049, "text_eop_token": 0, "text_pad_token": 3, "text_start_token": 32000 }, "image_prefix_backbone": "Siglip448", "image_prefix_use_rms_norm": true, "cross_attention_gating": "ConditionalGatedSigmoid", "cross_attention_in_dims": null } ================================================ FILE: kyuteye_rs/moshi-backend/Cargo.toml ================================================ [package] name = "moshi-backend" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] anyhow = { workspace = true } axum = { workspace = true } axum-server = { workspace = true } bincode = { workspace = true } byteorder = { workspace = true } candle = { workspace = true } candle-nn = { workspace = true } candle-transformers = { workspace = true } clap = { workspace = true } futures-util = { workspace = true } hf-hub = { workspace = true } http = { workspace = true } image = { workspace = true } lazy_static = { workspace = true } log = { workspace = true } moshi = { workspace = true } ogg = { workspace = true } opus = { workspace = true } prometheus = { workspace = true } rand = { workspace = true } rand_chacha = { workspace = true } rcgen = { workspace = true } regex = { workspace = true } reqwest = { workspace = true } rubato = { workspace = true } rustls = { version = "0.23.20", features = ["ring"] } sentencepiece = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } symphonia = { workspace = true } tokio = { workspace = true } tokio-rustls = { workspace = true } tower = { workspace = true } tower-http = { workspace = true } tracing = { workspace = true } tracing-appender = { workspace = true } tracing-subscriber = { workspace = true } [build-dependencies] anyhow = { workspace = true } vergen = { workspace = true } [features] default = [] cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"] [profile.release] debug = true [profile.release-no-debug] inherits = "release" debug = false ================================================ FILE: kyuteye_rs/moshi-backend/build.rs ================================================ use anyhow::Result; use vergen::EmitBuilder; pub fn main() -> Result<()> { // NOTE: This will output everything, and requires all features enabled. // NOTE: See the EmitBuilder documentation for configuration options. EmitBuilder::builder() .all_build() .all_cargo() .all_git() .all_rustc() .all_sysinfo() .emit()?; Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/audio.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. #![allow(unused)] use std::io::prelude::*; pub trait Sample { fn to_i16(&self) -> i16; } impl Sample for f32 { fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } } impl Sample for f64 { fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } } impl Sample for i16 { fn to_i16(&self) -> i16 { *self } } pub fn write_pcm_as_wav( w: &mut W, samples: &[S], sample_rate: u32, ) -> std::io::Result<()> { let len = 12u32; // header let len = len + 24u32; // fmt let len = len + samples.len() as u32 * 2 + 8; // data let n_channels = 1u16; let bytes_per_second = sample_rate * 2 * n_channels as u32; w.write_all(b"RIFF")?; w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes w.write_all(b"WAVE")?; // Format block w.write_all(b"fmt ")?; w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes w.write_all(&1u16.to_le_bytes())?; // PCM w.write_all(&n_channels.to_le_bytes())?; // one channel w.write_all(&sample_rate.to_le_bytes())?; w.write_all(&bytes_per_second.to_le_bytes())?; w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample w.write_all(&16u16.to_le_bytes())?; // bits per sample // Data block w.write_all(b"data")?; w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; for sample in samples.iter() { w.write_all(&sample.to_i16().to_le_bytes())? } Ok(()) } fn conv(samples: &mut Vec, data: std::borrow::Cow>) where T: symphonia::core::sample::Sample, f32: symphonia::core::conv::FromSample, { use symphonia::core::audio::Signal; use symphonia::core::conv::FromSample; samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) } pub(crate) fn pcm_decode>(path: P) -> anyhow::Result<(Vec, u32)> { use symphonia::core::audio::{AudioBufferRef, Signal}; let src = std::fs::File::open(path)?; let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); let hint = symphonia::core::probe::Hint::new(); let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; let mut format = probed.format; let track = format .tracks() .iter() .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) .expect("no supported audio tracks"); let mut decoder = symphonia::default::get_codecs() .make(&track.codec_params, &Default::default()) .expect("unsupported codec"); let track_id = track.id; let sample_rate = track.codec_params.sample_rate.unwrap_or(0); let mut pcm_data = Vec::new(); while let Ok(packet) = format.next_packet() { while !format.metadata().is_latest() { format.metadata().pop(); } if packet.track_id() != track_id { continue; } match decoder.decode(&packet)? { AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), AudioBufferRef::U8(data) => conv(&mut pcm_data, data), AudioBufferRef::U16(data) => conv(&mut pcm_data, data), AudioBufferRef::U24(data) => conv(&mut pcm_data, data), AudioBufferRef::U32(data) => conv(&mut pcm_data, data), AudioBufferRef::S8(data) => conv(&mut pcm_data, data), AudioBufferRef::S16(data) => conv(&mut pcm_data, data), AudioBufferRef::S24(data) => conv(&mut pcm_data, data), AudioBufferRef::S32(data) => conv(&mut pcm_data, data), AudioBufferRef::F64(data) => conv(&mut pcm_data, data), } } Ok((pcm_data, sample_rate)) } pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> anyhow::Result> { use rubato::Resampler; let mut pcm_out = Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024); let mut resampler = rubato::FftFixedInOut::::new(sr_in, sr_out, 1024, 1)?; let mut output_buffer = resampler.output_buffer_allocate(true); let mut pos_in = 0; while pos_in + resampler.input_frames_next() < pcm_in.len() { let (in_len, out_len) = resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?; pos_in += in_len; pcm_out.extend_from_slice(&output_buffer[0][..out_len]); } if pos_in < pcm_in.len() { let (_in_len, out_len) = resampler.process_partial_into_buffer( Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None, )?; pcm_out.extend_from_slice(&output_buffer[0][..out_len]); } Ok(pcm_out) } pub(crate) fn write_opus_header(w: &mut W) -> std::io::Result<()> { use byteorder::WriteBytesExt; // https://wiki.xiph.org/OggOpus#ID_Header w.write_all(b"OpusHead")?; w.write_u8(1)?; // version w.write_u8(1)?; // channel count w.write_u16::(3840)?; // pre-skip w.write_u32::(48000)?; // sample-rate in Hz w.write_i16::(0)?; // output gain Q7.8 in dB w.write_u8(0)?; // channel map Ok(()) } pub(crate) fn write_opus_tags(w: &mut W) -> std::io::Result<()> { use byteorder::WriteBytesExt; // https://wiki.xiph.org/OggOpus#Comment_Header let vendor = "KyutaiMoshi"; w.write_all(b"OpusTags")?; w.write_u32::(vendor.len() as u32)?; // vendor string length w.write_all(vendor.as_bytes())?; // vendor string, UTF8 encoded w.write_u32::(0u32)?; // number of tags Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/build.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use anyhow::Result; use vergen::EmitBuilder; pub fn main() -> Result<()> { // NOTE: This will output everything, and requires all features enabled. // NOTE: See the EmitBuilder documentation for configuration options. EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?; Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/image_embedder.rs ================================================ use anyhow::Result; use candle::{Device, Tensor}; use candle_nn::{linear, Linear, VarBuilder}; use candle_transformers::models::fastvit as MobileClip; use candle_transformers::models::pixtral::vision_model as Pixtral; use candle_transformers::models::siglip as Siglip; use moshi::transformer::{CaSrc, Norm}; use moshi::NormType; fn load_image( bytes: &[u8], max_size: usize, mean: &[f32; 3], std: &[f32; 3], center_crop: bool, preserve_aspect_ratio: bool, ) -> candle::Result { // Load RGB image and resize such that longest side is equal to `max_size` // if crop_to_square is True, the resize also crops the image to square let mut img = image::ImageReader::new(std::io::Cursor::new(bytes)) .with_guessed_format()? .decode() .map_err(candle::Error::wrap)?; // if not center crop, we just resize to the max size let (img, width, height) = if !center_crop { if preserve_aspect_ratio { let (width, height) = (img.width(), img.height()); let (new_width, new_height) = if width < height { (((width * max_size as u32) / height) as usize, max_size) } else { (max_size, ((height * max_size as u32) / width) as usize) }; let img = img.resize_exact(width, height, image::imageops::FilterType::CatmullRom); (img, new_width, new_height) } else { ( img.resize_exact( max_size as u32, max_size as u32, image::imageops::FilterType::CatmullRom, ), max_size, max_size, ) } } // otherwise, we first center crop to a square of (max_size, maz_size) then resize else { let (width, height) = (img.width(), img.height()); let min_dim = if width > height { height } else { width }; let x = (width - min_dim) / 2; let y = (height - min_dim) / 2; //print!("center crop: {} {} {} {} {}\n", x, y, min_dim, width, height); let img = img.crop(x, y, min_dim, min_dim); let img = img.resize_exact( max_size as u32, max_size as u32, image::imageops::FilterType::CatmullRom, ); (img, max_size, max_size) }; let img = img.to_rgb8(); let data = img.into_raw(); let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; tracing::info!(data = ?data.shape(), "image loaded"); let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?; let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? .broadcast_sub(&mean)? .broadcast_div(&std) } // Image Encoder for vision-conditioned models #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum ImageEncoder { Siglip224, Siglip448, Siglip896, MobileclipS1, MobileclipS2, Pixtral, } #[derive(Debug, Clone)] pub enum ImageEncoderModel { Siglip(Siglip::VisionModel), Mobileclip(candle_nn::Func<'static>), Pixtral(Pixtral::Model), } fn init_output_proj(in_dims: usize, out_dims: usize, vb: VarBuilder) -> Result> { let proj = if vb.contains_tensor("proj_xa.weight") { Some(linear(in_dims, out_dims, vb.pp("proj_xa"))?) } else { None }; Ok(proj) } #[derive(Debug, Clone)] pub struct ImageEmbedder { model: ImageEncoderModel, // optional output proj proj: Option, // output norm norm: Norm, //image loading param mean: [f32; 3], std: [f32; 3], // max image size // corresponds to maximum allowed image dimension for the model // For Pixtral, this is the maximum side // For Siglip and Mobileclip, images are fixed to a square size for now max_image_size: usize, patch_size: usize, // whether the lm model will be quantized (need F32 inputs) quantized_lm_model: bool, } impl ImageEmbedder { pub fn new( model_file: &str, image_prefix_backbone: Option, image_prefix_rmsnorm: bool, cross_attention_in_dims: Option, dev: &Device, ) -> Result { let out_dims = cross_attention_in_dims.unwrap_or(4096); let quantized_lm_model = model_file.ends_with(".gguf"); let vb = if quantized_lm_model { // model_file should have format XXX.quant_format.gguf // and the unquantized vision tower weights should be in XXX_vision_tower_unquant.safetensors let base_path = model_file.rsplitn(3, '.').nth(2); let out_path = match base_path { None => anyhow::bail!(".gguf file does not have a corresponding vision tower"), Some(p) => format!("{}_vision_tower_unquant.safetensors", p), }; tracing::info!(?out_path, "Loading unquantized vision encoder from"); unsafe { VarBuilder::from_mmaped_safetensors(&[out_path.as_str()], candle::DType::F32, dev)? } } else { // if safetensors, we assume the file contains *all* model's tensors unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, dev)? } }; let vb = vb.pp("image_prefix"); let norm_typ = match image_prefix_rmsnorm { true => NormType::RmsNorm, false => NormType::LayerNorm, }; // output norm let norm = Norm::new_shortcut( out_dims, norm_typ, moshi::nn::MaybeQuantizedVarBuilder::Real(vb.pp("norm_xa")), )?; let max_image_size = match image_prefix_backbone { Some( ImageEncoder::Siglip224 | ImageEncoder::MobileclipS1 | ImageEncoder::MobileclipS2, ) => 224, Some(ImageEncoder::Siglip448) => 448, Some(ImageEncoder::Siglip896) => 896, Some(ImageEncoder::Pixtral) => 1024, None => anyhow::bail!("Image encoder type not specified in config"), }; // Main backbone match image_prefix_backbone { Some(ImageEncoder::Siglip224 | ImageEncoder::Siglip448 | ImageEncoder::Siglip896) => { // Siglip // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json let siglip_cfg = match image_prefix_backbone { Some(ImageEncoder::Siglip224) => Siglip::VisionConfig::paligemma_3b_224(), Some(ImageEncoder::Siglip448) => Siglip::VisionConfig::paligemma_3b_448(), Some(ImageEncoder::Siglip896) => Siglip::VisionConfig::paligemma_3b_896(), _ => anyhow::bail!("Impossible match arm. Congrats on reaching there"), }; let model = Siglip::VisionModel::new(&siglip_cfg, false, vb.pp("enc.model"))?; let proj = init_output_proj(siglip_cfg.hidden_size, out_dims, vb)?; Ok(Self { model: ImageEncoderModel::Siglip(model), proj, norm, mean: [0.5, 0.5, 0.5], std: [0.5, 0.5, 0.5], max_image_size, patch_size: siglip_cfg.patch_size, quantized_lm_model, }) } Some(ImageEncoder::MobileclipS1) | Some(ImageEncoder::MobileclipS2) => { // mobileclip (mci1/2 from candle) let vit_cfg = match image_prefix_backbone { Some(ImageEncoder::MobileclipS1) => MobileClip::Config::mci1(), Some(ImageEncoder::MobileclipS2) => MobileClip::Config::mci2(), _ => anyhow::bail!("No image backbone specified for cross-attention layers"), }; let model = MobileClip::fastvit_no_final_layer(&vit_cfg, vb.pp("enc.model"))?; let proj = init_output_proj(vit_cfg.in_channels * 16, out_dims, vb)?; Ok(Self { model: ImageEncoderModel::Mobileclip(model), proj, norm, mean: [0., 0., 0.], std: [1., 1., 1.], max_image_size, patch_size: 1, quantized_lm_model, }) } Some(ImageEncoder::Pixtral) => { let pixtral_cfg = Pixtral::Config::pixtral_12b_2409(); let model = Pixtral::Model::new(&pixtral_cfg, vb.pp("enc.model"))?; let proj = init_output_proj(pixtral_cfg.hidden_size, out_dims, vb)?; Ok(Self { model: ImageEncoderModel::Pixtral(model), proj, norm, mean: [0.481_454_66, 0.457_827_5, 0.408_210_73], std: [0.268_629_54, 0.261_302_6, 0.275_777_1], max_image_size, patch_size: pixtral_cfg.patch_size, quantized_lm_model, }) } None => { anyhow::bail!("No image backbone specified for cross-attention layers") } } } pub fn output_proj(&self, img_features: Tensor, dev: &Device) -> Result { // Output linear + normalization let img_features = match &self.proj { None => img_features.apply(&self.norm)?, Some(module) => img_features.apply(module)?.apply(&self.norm)?, }; tracing::info!(feat = ?img_features.shape(), "image features generated"); // Type conversion let dtype = if dev.is_cuda() && !self.quantized_lm_model { candle::DType::BF16 } else { candle::DType::F32 }; Ok(img_features.to_dtype(dtype)?) } pub fn embed( &self, img_bytes: &[u8], image_size: usize, center_crop: bool, dev: &Device, ) -> Result { // load Uint image as tensor then embed // to avoid any issue with the position embeddings interpolations, resize // images to the closest multiple of the model's patch size let image_size = if image_size % self.patch_size > self.patch_size / 2 { image_size - image_size % self.patch_size + self.patch_size } else { image_size - image_size % self.patch_size }; // too small image sizes are very out of distributions so we clamp // the input image size to something reasonable let min_image_size = self.patch_size * 10; let image_size = if !(min_image_size..=self.max_image_size).contains(&image_size) { tracing::info!( "Limiting image size up to be in [{}, {}]", min_image_size, self.max_image_size ); image_size.clamp(min_image_size, self.max_image_size) } else { image_size }; let img_features = match &self.model { // Pixtral handles dynamic image size + non-square ratios by nature ImageEncoderModel::Pixtral(m) => load_image( img_bytes, image_size, &self.mean, &self.std, center_crop, true, )? .to_device(dev)? .unsqueeze(0)? .apply(m)?, // Siglip now also handles dynamic size/ratios with positional embedding interpolation ImageEncoderModel::Siglip(m) => load_image( img_bytes, image_size, &self.mean, &self.std, center_crop, false, )? .to_device(dev)? .unsqueeze(0)? .apply(m)?, // But MobileClip always resize to its own fixed (and square) image size ImageEncoderModel::Mobileclip(m) => load_image( img_bytes, self.max_image_size, &self.mean, &self.std, center_crop, false, )? .to_device(dev)? .unsqueeze(0)? .apply(m)? .flatten_from(2)? .transpose(1, 2)?, }; Ok(CaSrc::Tokens(self.output_proj(img_features, dev)?)) } pub fn embed_from_tensor(&self, img: Tensor, dev: &Device) -> Result { // embed image preloaded as safetensors let img_features = match &self.model { // Currently siglip and MobileClip only supports fixed size (from config) ImageEncoderModel::Siglip(m) => img.apply(m)?, ImageEncoderModel::Mobileclip(m) => img.apply(m)?.flatten_from(2)?.transpose(1, 2)?, // Pixtral supports any size, but we may need to update the positional embeddings ImageEncoderModel::Pixtral(m) => img.apply(m)?, }; Ok(CaSrc::Tokens(self.output_proj(img_features, dev)?)) } } ================================================ FILE: kyuteye_rs/moshi-backend/src/main.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use anyhow::Result; use clap::Parser; use std::str::FromStr; mod audio; mod image_embedder; mod metrics; mod standalone; mod stream_both; mod utils; #[derive(Parser, Debug)] #[clap(name = "server", about = "moshi web server")] struct Args { #[clap(short = 'l', long = "log", default_value = "info")] log_level: String, #[clap(long, default_value = "configs/config-moshika-vis.json")] config: String, #[clap(long)] silent: bool, #[command(subcommand)] command: Command, } #[derive(Parser, Debug)] struct StandaloneArgs { #[clap(long)] cpu: bool, #[clap(short = 's', long = "sig", default_value = None)] sig: Option, #[clap(short = 'e', long = "epoch", default_value = None)] epoch: Option, #[clap(short = 'u', long = "user", default_value = None)] user: Option, #[clap(long)] img: Option, #[clap(long, default_value_t = 224)] img_size: usize, #[clap(long, conflicts_with = "img")] vis: bool, #[clap(long)] s2s: bool, #[clap(long)] asr: bool, } #[derive(Debug, clap::Subcommand)] enum Command { Standalone(StandaloneArgs), } /// A TLS acceptor that sets `TCP_NODELAY` on accepted streams. #[derive(Clone, Debug)] pub struct NoDelayAcceptor; impl axum_server::accept::Accept for NoDelayAcceptor { type Stream = tokio::net::TcpStream; type Service = S; type Future = futures_util::future::BoxFuture<'static, std::io::Result<(Self::Stream, Self::Service)>>; fn accept(&self, stream: tokio::net::TcpStream, service: S) -> Self::Future { Box::pin(async move { // Disable Nagle's algorithm. stream.set_nodelay(true)?; Ok::<_, std::io::Error>((stream, service)) }) } } fn tracing_init( log_dir: &str, instance_name: &str, log_level: &str, silent: bool, ) -> Result { use tracing_subscriber::prelude::*; let build_info = utils::BuildInfo::new(); let file_appender = tracing_appender::rolling::daily(log_dir, format!("log.{}", instance_name)); let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); let filter = tracing_subscriber::filter::LevelFilter::from_str(log_level)?; let mut layers = vec![tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_filter(filter) .boxed()]; if !silent { layers.push(Box::new( tracing_subscriber::fmt::layer() .with_writer(std::io::stdout) .with_filter(filter), )) }; tracing_subscriber::registry().with(layers).init(); tracing::info!(?build_info); Ok(guard) } #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<()> { rustls::crypto::ring::default_provider() .install_default() .map_err(|_| anyhow::Error::msg("unable to install default crypto"))?; let args = Args::parse(); match args.command { Command::Standalone(standalone_args) => { let mut config = standalone::Config::load(&args.config)?; let _guard = tracing_init( &config.stream.log_dir, &config.stream.instance_name, &args.log_level, args.silent, )?; tracing::info!("starting process with pid {}", std::process::id()); if config.stream.requires_model_download() { standalone::download_from_hub(&mut config.stream).await?; } if !std::path::PathBuf::from(&config.static_dir).exists() { use hf_hub::api::tokio::Api; let api = Api::new()?; let repo = api.model("kyutai/moshi-artifacts".to_string()); let dist_tgz = repo.get("vis_dist.tgz").await?; if let Some(parent) = dist_tgz.parent() { let dist = parent.join("dist"); if !dist.exists() { let output = std::process::Command::new("tar") .arg("-xzf") .arg(&dist_tgz) .arg("-C") .arg(parent) .output()?; if !output.status.success() { anyhow::bail!( "error extract {dist_tgz:?}: {}", String::from_utf8_lossy(&output.stderr) ); } } config.static_dir = dist.to_string_lossy().to_string() } } standalone::run(&standalone_args, &config).await?; } } Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/metrics.rs ================================================ use lazy_static::lazy_static; use prometheus::Histogram; use prometheus::{histogram_opts, register_histogram}; pub mod worker { use super::*; lazy_static! { pub static ref MODEL_STEP_DURATION: Histogram = register_histogram!(histogram_opts!( "worker_model_step_duration", "Model step duration distribution.", vec![40e-3, 50e-3, 60e-3, 75e-3, 80e-3, 0.1], )) .unwrap(); } } ================================================ FILE: kyuteye_rs/moshi-backend/src/standalone.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use anyhow::{Context, Result}; use axum::extract::ws; use std::process; use std::sync::Arc; use std::{path::Path, str::FromStr}; use crate::{image_embedder, stream_both, utils, StandaloneArgs}; #[derive(serde::Deserialize, Debug, Clone)] pub struct Config { cert_dir: String, #[serde(default = "utils::default_static_dir")] pub static_dir: String, addr: String, port: u16, #[serde(flatten)] pub stream: stream_both::Config, } impl Config { pub fn load>(p: P) -> Result { let config = std::fs::read_to_string(p)?; let mut config: Self = serde_json::from_str(&config)?; config.static_dir = crate::utils::replace_env_vars(&config.static_dir); config.cert_dir = crate::utils::replace_env_vars(&config.cert_dir); // location of the images in the static dir config.stream.images_dir = Some(std::path::PathBuf::from(&config.static_dir).join("assets/images/demo")); config.stream.log_dir = crate::utils::replace_env_vars(&config.stream.log_dir); config.stream.text_tokenizer_file = crate::utils::replace_env_vars(&config.stream.text_tokenizer_file); config.stream.mimi_model_file = crate::utils::replace_env_vars(&config.stream.mimi_model_file); config.stream.lm_model_file = crate::utils::replace_env_vars(&config.stream.lm_model_file); Ok(config) } pub fn cert_file(&self, name: &str) -> Result { let cert_dir = std::path::PathBuf::from(&self.cert_dir); let cert_file = cert_dir.join(name); if !cert_file.is_file() { anyhow::bail!("missing file {cert_file:?}"); } Ok(cert_file) } } pub(crate) fn device(cpu: bool) -> Result { use candle::Device; if cpu { Ok(Device::Cpu) } else if candle::utils::cuda_is_available() { Ok(Device::new_cuda(0)?) } else if candle::utils::metal_is_available() { Ok(Device::new_metal(0)?) } else { Ok(Device::Cpu) } } impl stream_both::AppStateInner { pub fn new(args: &StandaloneArgs, config: &stream_both::Config) -> Result { let device = device(args.cpu)?; let mut config = config.clone(); if let Some(sig) = args.sig.as_ref() { tracing::info!(sig, "Loading checkpoint from sig"); let mut cmd = process::Command::new("python"); cmd.arg("-m").arg("scripts.mimi_import").arg(sig).arg("-s"); if let Some(epoch) = args.epoch { tracing::info!(epoch, "using epoch"); cmd.arg("-e").arg(epoch.to_string()); } if let Some(user) = args.user.as_ref() { tracing::info!(user, "taking checkpoint from user"); cmd.arg("-u").arg(user); } let output = cmd.output()?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr[..]); tracing::error!("Error while trying to get checkpoint: {stderr}"); anyhow::bail!("Couldn't convert the checkpoint."); } let lm_model_file = String::from_utf8_lossy(&output.stdout[..]) .trim_end_matches('\n') .to_string(); tracing::info!(lm_model_file, "model path was overriden"); config.lm_model_file = lm_model_file; } let dtype = if device.is_cuda() { candle::DType::BF16 } else { candle::DType::F32 }; let image_embedder = match args.vis { // Standard Moshi: no vision features false => None, // --vis: Load the image encoder and the image will be embedded // on-the-fly in stream_handler true => Some(image_embedder::ImageEmbedder::new( &config.lm_model_file, config.image_prefix_backbone.clone(), config.image_prefix_use_rms_norm, config.cross_attention_in_dims, &device, )?), }; // update cross-attention gating based on user-provided JSON config let lm_model = if args.vis || args.img.is_some() { moshi::lm::load_vision( &config.lm_model_file, config.cross_attention_gating, config.cross_attention_in_dims, dtype, &device, )? } else { moshi::lm::load_streaming(&config.lm_model_file, dtype, &device)? }; let mimi_device = if config.use_cpu_for_mimi { &candle::Device::Cpu } else { &device }; let mimi_model = moshi::mimi::load( &config.mimi_model_file, Some(config.mimi_num_codebooks), mimi_device, )?; let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&config.text_tokenizer_file)?; // Warm-up code. { tracing::info!(?dtype, ?device, "warming up the model"); // Warmup the image encoder if any let fake_image = candle::Tensor::zeros((1, 3, 224, 224), candle::DType::F32, &device)?; let ca_src = match &image_embedder { None => None, Some(m) => Some(m.embed_from_tensor(fake_image, &device)?), }; // Warm up the LM model w/ cross-attention as needed let mut lm_model = lm_model.clone(); let (_v, ys) = match &ca_src { None => lm_model.forward(None, vec![None; config.mimi_num_codebooks])?, Some(x) => lm_model.forward_ca(None, vec![None; config.mimi_num_codebooks], x)?, }; let mut lp = candle_transformers::generation::LogitsProcessor::new(123, None, None); let _ = lm_model.depformer_sample(&ys, None, &[], &mut lp)?; let mut mimi_model = mimi_model.clone(); let config = mimi_model.config(); let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize; let fake_pcm = candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?; let codes = mimi_model.encode_step(&fake_pcm.into())?; let ys = mimi_model.decode_step(&codes)?; if ys.as_option().is_none() { anyhow::bail!("Expected Mimi to output some stuff, but nothing came out."); } device.synchronize()?; tracing::info!("model is ready to roll!"); } Ok(Self { lm_model, mimi_model, device, config: config.clone(), text_tokenizer, image_embedder, }) } } async fn handle_socket(socket: ws::WebSocket, sm: stream_both::StreamingModel) { if let Err(err) = stream_both::handle_socket(socket, sm, None).await { tracing::error!(err = err.to_string(), "handle_socket") } } pub async fn stream_handler( ws: ws::WebSocketUpgrade, axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo, state: axum::extract::State, req: axum::extract::Query, ) -> crate::utils::AxumResult { tracing::info!(?addr, "received connection"); let sm = stream_both::StreamingModel::new(&state.0, req.0); Ok(ws.on_upgrade(move |v| handle_socket(v, sm))) } pub async fn download_from_hub(config: &mut stream_both::Config) -> Result<()> { use hf_hub::api::tokio::Api; let api = Api::new()?; let repo = api.model(config.hf_repo.clone()); let extract_filename = |path: &str| -> Result { Path::new(path) .file_name() .and_then(|f| f.to_str()) .map(String::from) .ok_or_else(|| anyhow::anyhow!("'{path}' has no file name")) }; for file_path in [ &mut config.lm_model_file, &mut config.mimi_model_file, &mut config.text_tokenizer_file, ] .iter_mut() { let filename = extract_filename(file_path) .with_context(|| format!("Failed to extract filename for '{file_path}'"))?; let downloaded_path = repo .get(&filename) .await .with_context(|| format!("Failed to download '{file_path}' file"))?; **file_path = downloaded_path .into_os_string() .into_string() .map_err(|_| anyhow::anyhow!("'{file_path}' path is not a valid string"))?; } // Download vision tower unquantized if config.lm_model_file.ends_with(".gguf") { // model_file should have format XXX.quant_format.gguf // and the unquantized vision tower weights should be in XXX_vision_tower_unquant.safetensors let base_path = config.lm_model_file.rsplitn(3, '.').nth(2); let vision_tower_path = match base_path { None => anyhow::bail!(".gguf file does not have a corresponding vision tower"), Some(p) => format!("{p}_vision_tower_unquant.safetensors"), }; let filename = extract_filename(&vision_tower_path) .with_context(|| format!("Failed to extract filename for '{vision_tower_path}'"))?; repo.get(&filename) .await .with_context(|| format!("Failed to download '{vision_tower_path}' file"))?; }; Ok(()) } pub async fn run(args: &StandaloneArgs, config: &Config) -> Result<()> { let cert_pem = config.cert_file("cert.pem")?; let key_pem = config.cert_file("key.pem")?; if !cert_pem.exists() || !key_pem.exists() { let rcgen::CertifiedKey { cert, key_pair } = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?; std::fs::write(&cert_pem, cert.pem())?; std::fs::write(&key_pem, key_pair.serialize_pem())?; } let tls_config = axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_pem, key_pem).await?; let sock_addr = std::net::SocketAddr::from(( std::net::IpAddr::from_str(config.addr.as_str()) .unwrap_or(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)), config.port, )); let state = Arc::new(stream_both::AppStateInner::new(args, &config.stream)?); tracing::info!("serving static dir {}", config.static_dir); let app = axum::Router::new() .route("/api/chat", axum::routing::get(stream_handler)) .fallback_service( tower_http::services::ServeDir::new(&config.static_dir) .append_index_html_on_directories(true), ) .layer(tower::ServiceBuilder::new().layer(tower_http::trace::TraceLayer::new_for_http())) .with_state(state); tracing::info!( "standalone worker listening on https://{}?worker_addr={}", sock_addr, sock_addr ); axum_server::bind_rustls(sock_addr, tls_config) .serve(app.into_make_service_with_connect_info::()) .await?; Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/stream_both.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::image_embedder::{ImageEmbedder, ImageEncoder}; use anyhow::Result; use axum::extract::ws; use futures_util::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; use moshi::{dynamic_logits_processor::GateInfluencedLogitsProcessor, transformer::CaSrc}; use std::sync::Arc; #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] pub struct ForceSessionConfig { pub text_temperature: f64, pub text_topk: usize, pub audio_temperature: f64, pub audio_topk: usize, pub pad_mult: Option, pub repetition_penalty: Option<(usize, f32)>, pub xa_start: Option, pub text_temperature_gating_influence: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Config { pub instance_name: String, #[serde(default)] pub hf_repo: String, pub lm_model_file: String, pub log_dir: String, pub text_tokenizer_file: String, pub mimi_model_file: String, pub mimi_num_codebooks: usize, pub lm_config: Option, #[serde(default = "default_false")] pub use_cpu_for_mimi: bool, pub asr_delay_in_tokens: Option, // optional config options for image conditioning pub image_prefix_backbone: Option, #[serde(default = "default_false")] pub image_prefix_use_rms_norm: bool, pub cross_attention_gating: Option, pub cross_attention_in_dims: Option, pub images_dir: Option, pub force_session_config: Option, } fn default_false() -> bool { false } impl Config { /// Check if all modelling files are available on machine. pub fn requires_model_download(&self) -> bool { [ &self.lm_model_file, &self.mimi_model_file, &self.text_tokenizer_file, ] .iter() .any(|file| !std::path::Path::new(file).exists()) } } pub type AppState = Arc; pub struct AppStateInner { pub lm_model: moshi::lm::LmModel, pub mimi_model: moshi::mimi::Mimi, pub text_tokenizer: sentencepiece::SentencePieceProcessor, pub device: candle::Device, pub config: Config, pub image_embedder: Option, } impl AppStateInner { fn text( &self, prev_text_token: u32, text_token: u32, config: &moshi::lm_generate_multistream::Config, ) -> Option { if text_token != config.text_start_token && text_token != config.text_pad_token && text_token != config.text_eop_token { if prev_text_token == config.text_start_token { self.text_tokenizer.decode_piece_ids(&[text_token]).ok() } else { let prev_ids = self .text_tokenizer .decode_piece_ids(&[prev_text_token]) .ok(); let ids = self .text_tokenizer .decode_piece_ids(&[prev_text_token, text_token]) .ok(); prev_ids.and_then(|prev_ids| { ids.map(|ids| { if ids.len() > prev_ids.len() { ids[prev_ids.len()..].to_string() } else { String::new() } }) }) } } else { None } } } #[derive(serde::Deserialize, Debug, Clone)] pub struct SessionConfigReq { pub text_temperature: Option, pub text_topk: Option, pub audio_temperature: Option, pub audio_topk: Option, pub max_steps: Option, pub audio_seed: Option, pub text_seed: Option, pub email: Option, pub pad_mult: Option, pub repetition_penalty_context: Option, pub repetition_penalty: Option, pub image_resolution: Option, pub center_crop: Option, pub xa_start: Option, pub text_temperature_gating_influence: Option, } #[derive(serde::Serialize, Debug, Clone)] pub struct SessionConfig { pub text_temperature: f64, pub text_topk: usize, pub audio_temperature: f64, pub audio_topk: usize, pub max_steps: usize, pub audio_seed: u64, pub text_seed: u64, pub pad_mult: Option, pub repetition_penalty: Option<(usize, f32)>, pub image_resolution: Option, pub center_crop: Option, pub xa_start: usize, pub text_temperature_gating_influence: f32, pub email: Option, pub user_feedback: Option, } #[derive(serde::Serialize, Debug, Clone)] struct SessionSummary<'a> { #[serde(flatten)] session_config: &'a SessionConfig, last_step_idx: usize, transcript: String, addr: Option, lm_model_file: &'a str, mimi_model_file: &'a str, #[serde(flatten)] lm_config: &'a Option, } impl SessionConfigReq { fn into_session_config(self, force_cfg: Option<&ForceSessionConfig>) -> SessionConfig { use rand::Rng; let text_temperature = match force_cfg { None => self.text_temperature.unwrap_or(0.8), Some(v) => v.text_temperature, }; let audio_temperature = match force_cfg { None => self.audio_temperature.unwrap_or(0.8), Some(v) => v.audio_temperature, }; let text_topk = match force_cfg { None => self.text_topk.unwrap_or(250), Some(v) => v.text_topk, }; let audio_topk = match force_cfg { None => self.audio_topk.unwrap_or(250), Some(v) => v.audio_topk, }; let pad_mult = match force_cfg { None => self.pad_mult, Some(v) => v.pad_mult, }; let repetition_penalty = match force_cfg { None => self.repetition_penalty_context.zip(self.repetition_penalty), Some(v) => v.repetition_penalty, }; let xa_start = match force_cfg { None => self.xa_start.unwrap_or(0), Some(v) => v.xa_start.unwrap_or(0), }; let text_temperature_gating_influence = match force_cfg { None => self.text_temperature_gating_influence.unwrap_or(0.0), Some(v) => v.text_temperature_gating_influence.unwrap_or(0.0), }; SessionConfig { text_temperature, text_topk, text_seed: self.text_seed.unwrap_or_else(|| rand::thread_rng().gen()), audio_temperature, audio_topk, audio_seed: self.audio_seed.unwrap_or_else(|| rand::thread_rng().gen()), email: self.email, user_feedback: None, max_steps: self.max_steps.unwrap_or(4500).min(4500), pad_mult, repetition_penalty, image_resolution: self.image_resolution, center_crop: self.center_crop, xa_start, text_temperature_gating_influence, } } } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] pub struct MetaData { text_temperature: f64, text_topk: usize, text_temperature_gating_influence: f64, audio_temperature: f64, audio_topk: usize, pad_mult: f32, repetition_penalty_context: usize, repetition_penalty: f32, image_resolution: usize, lm_model_file: String, mimi_model_file: String, build_info: crate::utils::BuildInfo, instance_name: String, base_filename: String, } #[derive(Debug, Clone)] pub enum StreamOut { Ready, MetaData { metadata: Box }, Pcm { pcm: Vec }, ColoredText { text: String, intensity: f32 }, } // This must be an allowed value among 120, 240, 480, 960, 1920, and 2880. // Using a different value would result in a BadArg "invalid argument" error when calling encode. // https://opus-codec.org/docs/opus_api-1.2/group__opus__encoder.html#ga4ae9905859cd241ef4bb5c59cd5e5309 const OPUS_ENCODER_FRAME_SIZE: usize = 960; #[derive(Debug, Clone, Copy)] pub enum MsgType { Handshake, Audio, Text, Control, Metadata, Error, Ping, ColoredText, Image, UserRating, } impl MsgType { pub fn from_u8(v: u8) -> Result { let s = match v { 0 => MsgType::Handshake, 1 => MsgType::Audio, 2 => MsgType::Text, 3 => MsgType::Control, 4 => MsgType::Metadata, 5 => MsgType::Error, 6 => MsgType::Ping, 7 => MsgType::ColoredText, 8 => MsgType::Image, 10 => MsgType::UserRating, _ => anyhow::bail!("unexpected msg type {v}"), }; Ok(s) } pub fn to_u8(self) -> u8 { match self { MsgType::Handshake => 0, MsgType::Audio => 1, MsgType::Text => 2, MsgType::Control => 3, MsgType::Metadata => 4, MsgType::Error => 5, MsgType::Ping => 6, MsgType::ColoredText => 7, MsgType::Image => 8, MsgType::UserRating => 10, } } } pub enum ModelInput { Image(Vec), Audio(Vec), Rating(u32), } pub struct MsgSender { pw: ogg::PacketWriter<'static, Vec>, encoder: opus::Encoder, out_pcm: std::collections::VecDeque, out_pcm_buf: Vec, total_data: usize, sender: SplitSink, last_input_pcm: Option, } impl MsgSender { fn new(sender: SplitSink) -> Result { let encoder = opus::Encoder::new(24000, opus::Channels::Mono, opus::Application::Voip)?; // Not sure what the appropriate buffer size would be here. let out_pcm_buf = vec![0u8; 50_000]; let out_pcm = std::collections::VecDeque::with_capacity(2 * OPUS_ENCODER_FRAME_SIZE); let all_data = Vec::new(); let mut pw = ogg::PacketWriter::new(all_data); let mut head = Vec::new(); crate::audio::write_opus_header(&mut head)?; pw.write_packet(head, 42, ogg::PacketWriteEndInfo::EndPage, 0)?; let mut tags = Vec::new(); crate::audio::write_opus_tags(&mut tags)?; pw.write_packet(tags, 42, ogg::PacketWriteEndInfo::EndPage, 0)?; Ok(Self { pw, encoder, out_pcm, out_pcm_buf, total_data: 0, sender, last_input_pcm: None, }) } async fn send_colored_text(&mut self, text: String, intensity: f32) -> Result<()> { let int_intensity = (((intensity - 0.005) / 0.016).clamp(0., 1.) * 10.).round() as u8; let msg: Vec = [ &[MsgType::ColoredText.to_u8()], &[int_intensity], text.as_bytes(), ] .concat(); let msg = ws::Message::Binary(msg.into()); self.sender.send(msg).await?; Ok(()) } async fn send_ready(&mut self) -> Result<()> { // The payload is made of two fields. // 1. Protocol version (`u32`) - always 0 for now. // 2. Model version (`u32`). let msg: Vec = [&[MsgType::Handshake.to_u8()], [0u8; 8].as_slice()].concat(); let msg = ws::Message::Binary(msg.into()); self.sender.send(msg).await?; Ok(()) } async fn send_metadata(&mut self, md: Box) -> Result<()> { let bytes = serde_json::to_vec(&md)?; let msg: Vec = [&[MsgType::Metadata.to_u8()], bytes.as_slice()].concat(); let msg = ws::Message::Binary(msg.into()); self.sender.send(msg).await?; Ok(()) } async fn send_pcm(&mut self, pcm: Vec) -> Result<()> { self.out_pcm.extend(pcm.iter()); self.total_data += pcm.len(); let nchunks = self.out_pcm.len() / OPUS_ENCODER_FRAME_SIZE; for _chunk_id in 0..nchunks { let mut chunk = Vec::with_capacity(OPUS_ENCODER_FRAME_SIZE); for _i in 0..OPUS_ENCODER_FRAME_SIZE { let v = match self.out_pcm.pop_front() { None => anyhow::bail!("unexpected err popping from pcms"), Some(v) => v, }; chunk.push(v) } let size = self.encoder.encode_float(&chunk, &mut self.out_pcm_buf)?; if size > 0 { let msg = self.out_pcm_buf[..size].to_vec(); self.pw.write_packet( msg, 42, ogg::PacketWriteEndInfo::EndPage, self.total_data as u64, )? } else { tracing::error!("OPUS SIZE 0") } let data = self.pw.inner_mut(); if !data.is_empty() { let msg: Vec = [&[MsgType::Audio.to_u8()], data.as_slice()].concat(); let msg = ws::Message::Binary(msg.into()); self.sender.send(msg).await?; self.sender.flush().await?; data.clear(); } else { tracing::error!("OGG SIZE 0") } } Ok(()) } } pub struct StreamingModel { state: AppState, device: candle::Device, config: moshi::lm_generate_multistream::Config, session_config: SessionConfig, } impl StreamingModel { fn run_with_state( &self, state: &mut moshi::lm_generate_multistream::State, receiver: std::sync::mpsc::Receiver, sender: tokio::sync::mpsc::UnboundedSender, ) -> Result<()> { use candle::IndexOp; let app_state = &self.state; let mut mimi = app_state.mimi_model.clone(); let config = state.config().clone(); mimi.reset_state(); tracing::info!("processing loop"); let mut prev_text_token = config.text_start_token; let mut tensor_tokens = vec![]; let mimi_device = if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device }; let ca_src = match &app_state.image_embedder { Some(image_embedder) => { // The first message is the image to use as input for the model. let image_bytes = match receiver.recv() { Ok(ModelInput::Image(image)) => image, _ => anyhow::bail!("Expected image as first message, we are in vision mode."), }; tracing::info!("image received"); let resolution = self.session_config.image_resolution.unwrap_or(224); let center_crop = self.session_config.center_crop.unwrap_or(false); let ca_src = image_embedder.embed(&image_bytes, resolution, center_crop, mimi_device)?; let ca_src = self.state.lm_model.maybe_precompute_ca_kv(Some(ca_src))?; match &ca_src { Some(ca_src) => match ca_src { CaSrc::Tokens(tensor) => { tracing::info!(shape=?tensor.shape(), "image processed") } CaSrc::KeysValues((keys, values)) => { tracing::info!(keys_shape=?keys.shape(), values_shape=?values.shape(), "image processed") } }, None => anyhow::bail!("ca_src should never be None here."), }; ca_src } None => None, }; mimi_device.synchronize()?; sender.send(StreamOut::Ready)?; let mut counter = 0; while let Ok(in_pcm) = receiver.recv() { match in_pcm { ModelInput::Audio(in_pcm) => { if in_pcm.is_empty() { continue; } let pcm_len = in_pcm.len(); let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?; let audio_tokens = mimi.encode_step(&pcms.into())?; let audio_tokens = match audio_tokens.as_option() { None => continue, Some(audio_tokens) => audio_tokens, }; let (_one, _codebooks, steps) = audio_tokens.dims3()?; for step in 0..steps { let codes = audio_tokens.i((0, .., step))?.to_vec1::()?; //let text_token = state.step(prev_text_token, &codes, None, ca_src.as_ref())?; let (text_token, gate_weight) = state.step_with_gate_weight( prev_text_token, &codes, None, if counter >= self.session_config.xa_start { ca_src.as_ref() } else { None }, )?; if let Some(audio_tokens) = state.last_audio_tokens() { let audio_tokens = { let cb = app_state.config.mimi_num_codebooks; candle::Tensor::from_slice( &audio_tokens[..cb], (1, cb, 1), mimi_device, )? }; tensor_tokens.push(audio_tokens.clone()); let pcm = mimi.decode_step(&audio_tokens.into())?; if let Some(pcm) = pcm.as_option() { let pcm = pcm.i((0, 0))?.to_vec1::()?; sender.send(StreamOut::Pcm { pcm })?; } } if let Some(text) = app_state.text(prev_text_token, text_token, &config) { //sender.send(StreamOut::Text { text })?; sender.send(StreamOut::ColoredText { text, intensity: gate_weight, })?; } prev_text_token = text_token; } counter += 1; } ModelInput::Image(_) => { // See PR#105 https://github.com/0x53504852/moshi-rs/pull/105 // Due to the front-end async mode, we may have duplicate identical // image messages which we should ignore while waiting for the first // audio message to arrive tracing::error!("Cannot handle new images during a conversation yet; skipping"); } ModelInput::Rating(grade) => state.set_user_rating(grade), } } tracing::info!("finished the processing loop"); Ok(()) } pub fn new(state: &AppState, session_config: SessionConfigReq) -> Self { let config = match state.config.lm_config.as_ref() { None => moshi::lm_generate_multistream::Config::v0_1(), Some(config) => config.clone(), }; let session_config = session_config.into_session_config(state.config.force_session_config.as_ref()); Self { state: state.clone(), device: state.device.clone(), config, session_config, } } pub fn run( &self, receiver: std::sync::mpsc::Receiver, sender: tokio::sync::mpsc::UnboundedSender, addr: Option, ) -> Result<()> { let app_state = &self.state; let (repetition_penalty_context, repetition_penalty) = self.session_config.repetition_penalty.unwrap_or((32, 1.)); // base path for logging let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?; let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros()); let base_filename = format!("{}-{secs}-{us}", app_state.config.instance_name); // metadata that will be stored over let metadata = MetaData { text_temperature: self.session_config.text_temperature, text_topk: self.session_config.text_topk, text_temperature_gating_influence: self.session_config.text_temperature_gating_influence as f64, audio_temperature: self.session_config.audio_temperature, audio_topk: self.session_config.audio_topk, pad_mult: self.session_config.pad_mult.unwrap_or(0.), repetition_penalty, image_resolution: self.session_config.image_resolution.unwrap_or(0), repetition_penalty_context, lm_model_file: self.state.config.lm_model_file.to_string(), mimi_model_file: self.state.config.mimi_model_file.to_string(), build_info: crate::utils::BuildInfo::new(), instance_name: self.state.config.instance_name.to_string(), base_filename: base_filename.clone(), }; sender.send(StreamOut::MetaData { metadata: Box::new(metadata), })?; let lm_model = app_state.lm_model.clone(); // Load ca_src either from the --img command line (static image), or // from the vision-UI's websocket url (contained in the StreamingModel) let audio_lp = candle_transformers::generation::LogitsProcessor::from_sampling( self.session_config.audio_seed, candle_transformers::generation::Sampling::TopK { k: self.session_config.audio_topk, temperature: self.session_config.audio_temperature, }, ); let text_lp = GateInfluencedLogitsProcessor::from_sampling_with_scale( self.session_config.text_seed, candle_transformers::generation::Sampling::TopK { k: self.session_config.text_topk, temperature: self.session_config.text_temperature, }, self.session_config.text_temperature_gating_influence as f64, ); let mut state = moshi::lm_generate_multistream::State::new( lm_model, self.session_config.max_steps, audio_lp, text_lp, self.session_config.pad_mult, self.session_config.repetition_penalty, self.config.clone(), ); // We want to log the output even if the run function returns an error. let run_result = self.run_with_state(&mut state, receiver, sender); let log_result = (|| { let text_tokens = state.text_tokens(false); let transcript = { let text_tokens = text_tokens .iter() .filter_map(|v| { let v = *v; if v != moshi::lm_generate_multistream::UNGENERATED && v != self.config.text_pad_token && v != self.config.text_eop_token && v != self.config.text_start_token { Some(v) } else { None } }) .collect::>(); self.state .text_tokenizer .decode_piece_ids(&text_tokens) .unwrap_or_else(|_| String::new()) }; let gate_weights = state.gate_weights(false); let audio_tokens = state.audio_tokens(false); let audio_tokens = audio_tokens .iter() .map(|v| { v.iter() .map(|v| { if *v == moshi::lm_generate_multistream::UNGENERATED { -1 } else { *v as i64 } }) .collect::>() }) .collect::>(); let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)? .to_dtype(candle::DType::I64)?; let gate_weights = candle::Tensor::new(gate_weights, &candle::Device::Cpu)? .to_dtype(candle::DType::F32)?; let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?; let user_rating = candle::Tensor::full(state.user_rating(), (1,), &candle::Device::Cpu)?; let log_dir = &app_state.config.log_dir; let base_path = format!("{}/{}", log_dir, base_filename); let json_filename = format!("{base_path}.json"); let json_content = serde_json::to_string_pretty(&SessionSummary { session_config: &self.session_config, last_step_idx: state.step_idx(), transcript, addr, mimi_model_file: &self.state.config.mimi_model_file, lm_model_file: &self.state.config.lm_model_file, lm_config: &self.state.config.lm_config, })?; std::fs::write(json_filename, json_content)?; let st_filename = format!("{base_path}.safetensors"); let st_content = std::collections::HashMap::from([ ("text", text_tokens), ("gate_weights", gate_weights), ("audio", audio_tokens), ("user_rating", user_rating), ]); candle::safetensors::save(&st_content, st_filename)?; Ok(()) })(); run_result.and(log_result) } } type Handle = tokio::task::JoinHandle>; fn spawn_recv_loops( mut receiver: SplitStream, sender: std::sync::mpsc::Sender, ) -> Result<(Handle, Handle)> { use tokio::io::AsyncWriteExt; let (mut tx, rx) = tokio::io::duplex(100_000); let mut pr = ogg::reading::async_api::PacketReader::new(rx); let mut decoder = opus::Decoder::new(24000, opus::Channels::Mono)?; let handle1_sender = sender.clone(); let handle1 = tokio::spawn({ async move { loop { match receiver.next().await { None => { // The close logic is that if this loop exits, then tx gets dropped so pr // gets closed and the second thread gets dropped resulting in sender // getting dropped. break; } Some(v) => { let v = v?.into_data(); if v.is_empty() { continue; } let msg_type = MsgType::from_u8(v[0])?; match msg_type { MsgType::Metadata => {} MsgType::Handshake => {} MsgType::Control => {} MsgType::Text => {} MsgType::Error => {} MsgType::Ping => {} MsgType::Audio => tx.write_all(&v[1..]).await?, MsgType::ColoredText => {} MsgType::Image => { handle1_sender.send(ModelInput::Image(v[1..].to_vec()))? } MsgType::UserRating => { handle1_sender.send(ModelInput::Rating(v[1] as u32))? } } } } } tracing::info!("socket closed"); Ok::<_, anyhow::Error>(()) } }); let handle2 = tokio::spawn(async move { // TODO: dynamic sizing? let mut pcm_buf = vec![0f32; 24_000 * 10]; let mut size_in_buf = 0; loop { match pr.next().await { None => { break; } Some(packet) => { let packet = packet?; if packet.data.starts_with(b"OpusHead") || packet.data.starts_with(b"OpusTags") { continue; } let read_size = decoder.decode_float( &packet.data, &mut pcm_buf[size_in_buf..], /* Forward Error Correction */ false, )?; size_in_buf += read_size; // flush the data every half timestep if size_in_buf >= 24_000 / 25 { if sender .send(ModelInput::Audio(pcm_buf[..size_in_buf].to_vec())) .is_err() { break; } size_in_buf = 0; } } } } tracing::info!("decoder closed"); Ok::<_, anyhow::Error>(()) }); Ok((handle1, handle2)) } async fn sender_loop( mut stream_out_rx: tokio::sync::mpsc::UnboundedReceiver, mut sender: MsgSender, ) -> Result<()> { // It is important for the recv here to be an async enabled one. Otherwise this could lead // to some weird deadlocks. while let Some(v) = stream_out_rx.recv().await { match v { StreamOut::Pcm { pcm } => { if let Some(last_input_pcm) = sender.last_input_pcm { let model_step_duration = last_input_pcm.elapsed().as_secs_f64(); crate::metrics::worker::MODEL_STEP_DURATION.observe(model_step_duration); sender.last_input_pcm = None; } sender.send_pcm(pcm).await? } StreamOut::Ready => sender.send_ready().await?, StreamOut::MetaData { metadata } => sender.send_metadata(metadata).await?, StreamOut::ColoredText { text, intensity } => { sender.send_colored_text(text, intensity).await? } } } Ok::<_, anyhow::Error>(()) } pub async fn handle_socket( socket: ws::WebSocket, sm: StreamingModel, addr: Option, ) -> Result<()> { tracing::info!("accepted websocket connection"); let (sender, receiver) = socket.split(); let sender = MsgSender::new(sender)?; tracing::info!("starting streaming"); let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel(); let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel(); let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?; std::thread::spawn(move || { if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) { tracing::error!("{err}") } }); let sender_loop = tokio::spawn(async move { match sender_loop(stream_out_rx, sender).await { Ok(()) => tracing::info!("sender closed"), Err(err) => { // Using the Display trait rather than the Debug one so as not to include the backtrace. let err = format!("{err}"); tracing::info!(err, "sender err") } } }); let sleep = tokio::time::sleep(std::time::Duration::from_secs(360)); tokio::pin!(sleep); // select should ensure that all the threads get aborted on timeout. tokio::select! { _ = &mut sleep => { tracing::error!("reached timeout"); } r = loop1 => { tracing::error!(?r, "loop1 ended") } r = loop2 => { tracing::error!(?r, "loop2 ended") } r = sender_loop => { tracing::error!(?r, "sender loop ended") } } Ok(()) } ================================================ FILE: kyuteye_rs/moshi-backend/src/utils.rs ================================================ #[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)] pub struct BuildInfo { build_timestamp: String, build_date: String, git_branch: String, git_timestamp: String, git_date: String, git_hash: String, git_describe: String, rustc_host_triple: String, rustc_version: String, cargo_target_triple: String, } impl BuildInfo { pub fn new() -> BuildInfo { BuildInfo { build_timestamp: String::from(env!("VERGEN_BUILD_TIMESTAMP")), build_date: String::from(env!("VERGEN_BUILD_DATE")), git_branch: String::from(env!("VERGEN_GIT_BRANCH")), git_timestamp: String::from(env!("VERGEN_GIT_COMMIT_TIMESTAMP")), git_date: String::from(env!("VERGEN_GIT_COMMIT_DATE")), git_hash: String::from(env!("VERGEN_GIT_SHA")), git_describe: String::from(env!("VERGEN_GIT_DESCRIBE")), rustc_host_triple: String::from(env!("VERGEN_RUSTC_HOST_TRIPLE")), rustc_version: String::from(env!("VERGEN_RUSTC_SEMVER")), cargo_target_triple: String::from(env!("VERGEN_CARGO_TARGET_TRIPLE")), } } } pub struct WrapJson(pub anyhow::Result); impl axum::response::IntoResponse for WrapJson { fn into_response(self) -> axum::response::Response { match self.0 { Ok(v) => axum::Json(v).into_response(), Err(err) => { tracing::error!(?err, "returning internal server error 500"); ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}"), ) .into_response() } } } } pub fn replace_env_vars(input: &str) -> String { let re = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").unwrap(); re.replace_all(input, |caps: ®ex::Captures| { let var_name = &caps[1]; std::env::var(var_name).unwrap_or_else(|_| "".to_string()) }) .to_string() } pub struct WrapBincode(pub anyhow::Result); impl axum::response::IntoResponse for WrapBincode { fn into_response(self) -> axum::response::Response { match self.0.and_then(|v| Ok(bincode::serialize(&v)?)) { Ok(v) => (axum::http::StatusCode::OK, v).into_response(), Err(err) => { tracing::error!(?err, "returning internal server error 500"); ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}"), ) .into_response() } } } } pub fn default_static_dir() -> String { "./client/dist".to_string() } pub struct AxumError(anyhow::Error); impl axum::response::IntoResponse for AxumError { fn into_response(self) -> axum::response::Response { let err = self.0; tracing::error!(?err); ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}"), ) .into_response() } } impl> From for AxumError { fn from(value: E) -> Self { Self(value.into()) } } pub type AxumResult = std::result::Result; ================================================ FILE: kyuteye_rs/moshi-core/Cargo.toml ================================================ [package] name = "moshi" version = "0.1.0" edition = "2021" [dependencies] candle = { workspace = true } candle-nn = { workspace = true } candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true } cudarc = { version = "=0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true } rand = { workspace = true } rayon = "1.8.1" serde = { version = "1.0", features = ["derive"] } tracing = "0.1.40" [features] default = [] cuda = ["candle/cuda", "candle-nn/cuda", "cudarc"] metal = ["candle/metal", "candle-nn/metal"] flash-attn = ["cuda", "dep:candle-flash-attn"] [profile.release] debug = true [profile.release-no-debug] inherits = "release" debug = false ================================================ FILE: kyuteye_rs/moshi-core/src/conv.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::streaming::{StreamTensor, StreamingModule}; use candle::{Module, Result, Tensor, D}; use candle_nn::{Conv1d, VarBuilder}; #[allow(clippy::enum_variant_names)] #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Norm { WeightNorm, SpectralNorm, TimeGroupNorm, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum PadMode { Constant, Reflect, Replicate, } // Applies weight norm for inference by recomputing the weight tensor. This // does not apply to training. // https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html fn conv1d_weight_norm( in_c: usize, out_c: usize, kernel_size: usize, bias: bool, config: candle_nn::Conv1dConfig, vb: VarBuilder, ) -> Result { let weight = if vb.contains_tensor("weight") { vb.get((out_c, in_c, kernel_size), "weight")? } else { let weight_g = vb.get((out_c, 1, 1), "weight_g")?; let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)? }; let bias = if bias { Some(vb.get(out_c, "bias")?) } else { None }; Ok(Conv1d::new(weight, bias, config)) } #[derive(Debug, Clone)] pub struct NormConv1d { conv: Conv1d, norm: Option, span: tracing::Span, } impl NormConv1d { #[allow(clippy::too_many_arguments)] pub fn new( in_c: usize, out_c: usize, k_size: usize, causal: bool, norm: Option, bias: bool, cfg: candle_nn::Conv1dConfig, vb: VarBuilder, ) -> Result { let conv = match norm { None | Some(Norm::TimeGroupNorm) => { if bias { candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))? } else { candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))? } } Some(Norm::WeightNorm) => { conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))? } Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."), }; let norm = match norm { None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None, Some(Norm::TimeGroupNorm) => { if causal { candle::bail!("GroupNorm doesn't support causal evaluation.") } let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; Some(norm) } }; Ok(Self { conv, norm, span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"), }) } } impl Module for NormConv1d { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let xs = xs.apply(&self.conv)?; match self.norm.as_ref() { None => Ok(xs), Some(norm) => xs.apply(norm), } } } #[derive(Debug, Clone)] pub struct NormConvTranspose1d { ws: Tensor, bs: Option, k_size: usize, stride: usize, groups: usize, norm: Option, span: tracing::Span, } impl NormConvTranspose1d { #[allow(clippy::too_many_arguments)] pub fn new( in_c: usize, out_c: usize, k_size: usize, causal: bool, norm: Option, bias: bool, stride: usize, groups: usize, vb: VarBuilder, ) -> Result { let vb = vb.pp("convtr"); let bs = if bias { Some(vb.get(out_c, "bias")?) } else { None }; let ws = match norm { None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?, Some(Norm::WeightNorm) => { if vb.contains_tensor("weight") { vb.get((in_c, out_c, k_size), "weight")? } else { let weight_g = vb.get((in_c, 1, 1), "weight_g")?; let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?; let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)? } } Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."), }; let (ws, groups) = if groups == out_c && in_c == out_c { let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?; let ws = ws .repeat((1, out_c, 1))? .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?; (ws, 1) } else { (ws, groups) }; let norm = match norm { None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None, Some(Norm::TimeGroupNorm) => { if causal { candle::bail!("GroupNorm doesn't support causal evaluation.") } let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; Some(norm) } }; Ok(Self { ws, bs, k_size, stride, groups, norm, span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"), }) } } impl Module for NormConvTranspose1d { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); // conv-transpose1d seems to be broken on metal after enough iterations. Causing // the following error: // _status < MTLCommandBufferStatusCommitted > // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:] // This is now fixed in candle. let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?; let xs = match &self.bs { None => xs, Some(bias) => { let b = bias.dims1()?; let bias = bias.reshape((1, b, 1))?; xs.broadcast_add(&bias)? } }; match self.norm.as_ref() { None => Ok(xs), Some(norm) => xs.apply(norm), } } } fn get_extra_padding_for_conv1d( xs: &Tensor, k_size: usize, stride: usize, padding_total: usize, ) -> Result { let len = xs.dim(D::Minus1)?; let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0; let ideal_len = ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total); Ok(ideal_len.saturating_sub(len)) } fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result { match mode { PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r), PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"), PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r), } } fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result { let len = xs.dim(D::Minus1)?; if len < unpad_l + unpad_r { candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}") } xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r)) } #[derive(Debug, Clone)] pub struct StreamableConv1d { conv: NormConv1d, causal: bool, pad_mode: PadMode, state_prev_xs: StreamTensor, left_pad_applied: bool, kernel_size: usize, span: tracing::Span, } impl StreamableConv1d { #[allow(clippy::too_many_arguments)] pub fn new( in_c: usize, out_c: usize, k_size: usize, stride: usize, dilation: usize, groups: usize, bias: bool, causal: bool, norm: Option, pad_mode: PadMode, vb: VarBuilder, ) -> Result { let cfg = candle_nn::Conv1dConfig { padding: 0, stride, dilation, groups, }; let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb.pp("conv"))?; if k_size < stride { candle::bail!("kernel-size {k_size} is smaller than stride {stride}") } Ok(Self { conv, causal, pad_mode, state_prev_xs: StreamTensor::empty(), left_pad_applied: false, kernel_size: k_size, span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"), }) } } impl Module for StreamableConv1d { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let (_b, _t, _c) = xs.dims3()?; let k_size = self.conv.conv.weight().dim(D::Minus1)?; let conv_cfg = self.conv.conv.config(); // Effective kernel size with dilations. let k_size = (k_size - 1) * conv_cfg.dilation + 1; let padding_total = k_size - conv_cfg.stride; let extra_padding = get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?; let xs = if self.causal { pad1d(xs, padding_total, extra_padding, self.pad_mode)? } else { let padding_right = padding_total / 2; let padding_left = padding_total - padding_right; pad1d( xs, padding_left, padding_right + extra_padding, self.pad_mode, )? }; xs.apply(&self.conv) } } impl StreamingModule for StreamableConv1d { fn reset_state(&mut self) { self.state_prev_xs.reset(); self.left_pad_applied = false; } fn step(&mut self, xs: &StreamTensor) -> Result { let _enter = self.span.enter(); let xs = match xs.as_option() { None => return Ok(().into()), Some(xs) => xs.clone(), }; let xs = if self.left_pad_applied { xs } else { self.left_pad_applied = true; let k_size = self.conv.conv.weight().dim(D::Minus1)?; let conv_cfg = self.conv.conv.config(); let k_size = (k_size - 1) * conv_cfg.dilation + 1; let padding_total = k_size - conv_cfg.stride; pad1d(&xs, padding_total, 0, self.pad_mode)? }; let cfg = self.conv.conv.config(); let stride = cfg.stride; let dilation = cfg.dilation; let kernel = (self.kernel_size - 1) * dilation + 1; let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?; let seq_len = xs.seq_len(D::Minus1)?; let num_frames = (seq_len + stride).saturating_sub(kernel) / stride; if num_frames > 0 { let offset = num_frames * stride; self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?; let in_l = (num_frames - 1) * stride + kernel; let xs = xs.narrow(D::Minus1, 0, in_l)?; // We apply the underlying convtr directly rather than through forward so as // not to apply any padding here. xs.apply(&self.conv.conv) } else { self.state_prev_xs = xs; Ok(StreamTensor::empty()) } } } #[derive(Debug, Clone)] pub struct StreamableConvTranspose1d { convtr: NormConvTranspose1d, causal: bool, state_prev_ys: StreamTensor, kernel_size: usize, span: tracing::Span, } impl StreamableConvTranspose1d { #[allow(clippy::too_many_arguments)] pub fn new( in_c: usize, out_c: usize, k_size: usize, stride: usize, groups: usize, bias: bool, causal: bool, norm: Option, vb: VarBuilder, ) -> Result { let convtr = NormConvTranspose1d::new( in_c, out_c, k_size, causal, norm, bias, stride, groups, vb.pp("convtr"), )?; Ok(Self { convtr, causal, kernel_size: k_size, state_prev_ys: StreamTensor::empty(), span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"), }) } } impl Module for StreamableConvTranspose1d { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let k_size = self.convtr.k_size; let stride = self.convtr.stride; let padding_total = k_size.saturating_sub(stride); let xs = xs.apply(&self.convtr)?; if self.causal { // This corresponds to trim_right_ratio = 1. unpad1d(&xs, 0, padding_total) } else { let padding_right = padding_total / 2; let padding_left = padding_total - padding_right; unpad1d(&xs, padding_left, padding_right) } } } impl StreamingModule for StreamableConvTranspose1d { fn reset_state(&mut self) { self.state_prev_ys.reset() } fn step(&mut self, xs: &StreamTensor) -> Result { let _enter = self.span.enter(); let xs = match xs.as_option() { Some(xs) => xs, None => return Ok(StreamTensor::empty()), }; let stride = self.convtr.stride; // We apply the underlying convtr directly rather than through forward so as // not to apply any padding here. let ys = self.convtr.forward(xs)?; let ot = ys.dim(D::Minus1)?; let ys = match self.state_prev_ys.as_option() { None => ys, Some(prev_ys) => { let pt = prev_ys.dim(D::Minus1)?; // Remove the bias as it will be applied multiple times. let prev_ys = match &self.convtr.bs { None => prev_ys.clone(), Some(bias) => { let bias = bias.reshape((1, (), 1))?; prev_ys.broadcast_sub(&bias)? } }; let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?; let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?; Tensor::cat(&[ys1, ys2], D::Minus1)? } }; let invalid_steps = self.kernel_size - stride; let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?; self.state_prev_ys = prev_ys; Ok(ys) } } #[derive(Debug, Clone)] pub struct ConvDownsample1d { conv: StreamableConv1d, } impl ConvDownsample1d { pub fn new( stride: usize, dim: usize, causal: bool, learnt: bool, vb: VarBuilder, ) -> Result { if !learnt { candle::bail!("only learnt=true is supported") } let conv = StreamableConv1d::new( /* in_c */ dim, /* out_c */ dim, /* k_size_c */ 2 * stride, /* stride */ stride, /* dilation */ 1, /* groups */ 1, // channel_wise = false /* bias */ false, /* causal */ causal, /* norm */ None, /* pad_mode */ PadMode::Replicate, vb.pp("conv"), )?; Ok(Self { conv }) } } impl Module for ConvDownsample1d { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.conv) } } impl StreamingModule for ConvDownsample1d { fn reset_state(&mut self) { self.conv.reset_state() } fn step(&mut self, xs: &StreamTensor) -> Result { self.conv.step(xs) } } #[derive(Debug, Clone)] pub struct ConvTrUpsample1d { convtr: StreamableConvTranspose1d, } impl ConvTrUpsample1d { pub fn new( stride: usize, dim: usize, causal: bool, learnt: bool, vb: VarBuilder, ) -> Result { if !learnt { candle::bail!("only learnt=true is supported") } let convtr = StreamableConvTranspose1d::new( dim, dim, /* k_size */ 2 * stride, /* stride */ stride, /* groups */ dim, /* bias */ false, /* causal */ causal, /* norm */ None, vb.pp("convtr"), )?; Ok(Self { convtr }) } } impl Module for ConvTrUpsample1d { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.convtr) } } impl StreamingModule for ConvTrUpsample1d { fn reset_state(&mut self) { self.convtr.reset_state() } fn step(&mut self, xs: &StreamTensor) -> Result { self.convtr.step(xs) } } #[cfg(test)] mod tests { use super::*; use candle::IndexOp; fn run_conv1d( k_size: usize, stride: usize, dilation: usize, step_size: usize, len: usize, bias: bool, ) -> Result<()> { // TODO: We should ensure for the seed to be constant when running these tests. let dev = &candle::Device::Cpu; let vm = candle_nn::VarMap::new(); let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev); let conv1d = StreamableConv1d::new( /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size, /* stride */ stride, /* dilation */ dilation, /* groups */ 1, /* bias */ bias, /* causal */ true, /* norm */ None, /* pad_mode */ PadMode::Constant, vb, )?; let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?; let ys = conv1d.forward(&xs)?; let mut conv1d = conv1d; let mut ys_steps = vec![]; for idx in 0..len { let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?; let ys = conv1d.step(&xs.into())?; if let Some(ys) = ys.as_option() { ys_steps.push(ys.clone()) } } let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?; let diff = (&ys - &ys_steps)? .abs()? .flatten_all()? .max(0)? .to_vec0::()?; if diff > 1e-5 { println!("{xs}"); println!("{ys}"); println!("{ys_steps}"); candle::bail!("larger diff than expected {diff}") } Ok(()) } fn run_conv_tr1d( k_size: usize, stride: usize, step_size: usize, len: usize, bias: bool, ) -> Result<()> { // TODO: We should ensure for the seed to be constant when running these tests. let dev = &candle::Device::Cpu; let vm = candle_nn::VarMap::new(); let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev); let conv1d = StreamableConvTranspose1d::new( /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size, /* stride */ stride, /* groups */ 1, /* bias */ bias, /* causal */ true, /* norm */ None, vb, )?; let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?; let ys = conv1d.forward(&xs)?; let mut conv1d = conv1d; let mut ys_steps = vec![]; for idx in 0..len { let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?; let ys = conv1d.step(&xs.into())?; if let Some(ys) = ys.as_option() { ys_steps.push(ys.clone()) } } let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?; let diff = (&ys - &ys_steps)? .abs()? .flatten_all()? .max(0)? .to_vec0::()?; if diff > 1e-5 { println!("{xs}"); println!("{ys}"); println!("{ys_steps}"); candle::bail!("larger diff than expected {diff}") } Ok(()) } #[test] fn conv1d() -> Result<()> { for step_size in [1, 2, 3] { for bias in [false, true] { run_conv1d(1, 1, 1, step_size, 5, bias)?; run_conv1d(2, 1, 1, step_size, 5, bias)?; run_conv1d(2, 2, 1, step_size, 6, bias)?; run_conv1d(3, 2, 1, step_size, 8, bias)?; run_conv1d(3, 2, 2, step_size, 8, bias)?; } } Ok(()) } #[test] fn conv_tr1d() -> Result<()> { for step_size in [1, 2, 3] { for bias in [false, true] { run_conv_tr1d(1, 1, step_size, 5, bias)?; run_conv_tr1d(2, 1, step_size, 5, bias)?; run_conv_tr1d(3, 1, step_size, 5, bias)?; run_conv_tr1d(3, 2, step_size, 5, bias)?; } } Ok(()) } } ================================================ FILE: kyuteye_rs/moshi-core/src/dynamic_logits_processor.rs ================================================ use candle::{Context, DType, Error, Result, Tensor}; use candle_transformers::generation::Sampling; use rand::{distributions::Distribution, SeedableRng}; pub struct GateInfluencedLogitsProcessor { rng: rand::rngs::StdRng, sampling: Sampling, text_temperature_gating_influence: f64, } impl GateInfluencedLogitsProcessor { pub fn from_sampling(seed: u64, sampling: Sampling) -> Self { let rng = rand::rngs::StdRng::seed_from_u64(seed); Self { rng, sampling, text_temperature_gating_influence: 0.0, } } pub fn from_sampling_with_scale( seed: u64, sampling: Sampling, text_temperature_gating_influence: f64, ) -> Self { let rng = rand::rngs::StdRng::seed_from_u64(seed); Self { rng, sampling, text_temperature_gating_influence, } } pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) }); let sampling = match temperature { None => Sampling::ArgMax, Some(temperature) => match top_p { None => Sampling::All { temperature }, Some(p) => Sampling::TopP { p, temperature }, }, }; Self::from_sampling(seed, sampling) } fn sample_argmax(&mut self, logits: Tensor) -> Result { let logits_v: Vec = logits.to_vec1()?; let next_token = logits_v .iter() .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i as u32) .context("empty logits")?; Ok(next_token) } fn sample_multinomial(&mut self, prs: &Vec) -> Result { let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; let next_token = distr.sample(&mut self.rng) as u32; Ok(next_token) } /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed /// probability top_p. This way we never sample tokens that have very low probabilities and are /// less likely to go "off the rails". fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result { let mut argsort_indices = (0..prs.len()).collect::>(); // Sort by descending probability. argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i])); // Clamp smaller probabilities to zero. let mut cumsum = 0.; for index in &argsort_indices { if cumsum >= top_p { prs[*index] = 0.0; } else { cumsum += prs[*index]; } } // Sample with clamped probabilities. self.sample_multinomial(prs) } // top-k sampling samples from the k tokens with the largest probabilities. fn sample_topk(&mut self, prs: &mut Vec, top_k: usize) -> Result { if top_k >= prs.len() { self.sample_multinomial(prs) } else { let mut argsort_indices = (0..prs.len()).collect::>(); let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let prs = indices.iter().map(|&i| prs[i]).collect::>(); let index = self.sample_multinomial(&prs)?; Ok(indices[index as usize] as u32) } } // top-k sampling samples from the k tokens with the largest probabilities. // then top-p sampling. fn sample_topk_topp(&mut self, prs: &mut Vec, top_k: usize, top_p: f32) -> Result { if top_k >= prs.len() { self.sample_topp(prs, top_p) } else { let mut argsort_indices = (0..prs.len()).collect::>(); let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let mut prs = indices.iter().map(|&i| prs[i]).collect::>(); let sum_p = prs.iter().sum::(); let index = if top_p <= 0.0 || top_p >= sum_p { self.sample_multinomial(&prs)? } else { self.sample_topp(&mut prs, top_p)? }; Ok(indices[index as usize] as u32) } } pub fn sample(&mut self, logits: &Tensor, gate_weight: f64) -> Result { self.sample_f(logits, |_| {}, gate_weight) } pub fn sample_f( &mut self, logits: &Tensor, f: impl FnOnce(&mut [f32]), gate_weight: f64, ) -> Result { let logits = logits.to_dtype(DType::F32)?; let prs = |temperature: f64| -> Result> { let logits = (&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(&logits)?; let mut prs = prs.to_vec1()?; f(&mut prs); Ok(prs) }; // same normalization used in colored text let temp_factor = 1.0 - self.text_temperature_gating_influence * (gate_weight / 0.016).clamp(0., 0.99); //print!("temp_factor {} (alpha: {}, gate: {})\n", temp_factor, self.text_temperature_gating_influence, gate_weight); let next_token = match &self.sampling { Sampling::ArgMax => self.sample_argmax(logits)?, Sampling::All { temperature } => { let prs = prs(*temperature * temp_factor)?; self.sample_multinomial(&prs)? } Sampling::TopP { p, temperature } => { let mut prs = prs(*temperature * temp_factor)?; if *p <= 0.0 || *p >= 1.0 { // simply sample from the predicted probability distribution self.sample_multinomial(&prs)? } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero self.sample_topp(&mut prs, *p as f32)? } } Sampling::TopK { k, temperature } => { let mut prs = prs(*temperature * temp_factor)?; self.sample_topk(&mut prs, *k)? } Sampling::TopKThenTopP { k, p, temperature } => { let mut prs = prs(*temperature * temp_factor)?; self.sample_topk_topp(&mut prs, *k, *p as f32)? } }; Ok(next_token) } } ================================================ FILE: kyuteye_rs/moshi-core/src/lib.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. pub use candle; pub use candle_nn; pub mod conv; pub mod dynamic_logits_processor; pub mod lm; pub mod lm_generate; pub mod lm_generate_multistream; pub mod mimi; pub mod nn; pub mod quantization; pub mod seanet; pub mod streaming; pub mod transformer; #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum NormType { RmsNorm, LayerNorm, } ================================================ FILE: kyuteye_rs/moshi-core/src/lm.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::nn::{linear, MaybeQuantizedEmbedding, MaybeQuantizedLinear, MaybeQuantizedVarBuilder}; use crate::{ transformer::{self, CaSrc}, NormType, }; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; thread_local! { pub static VERBOSE: bool = { match std::env::var("MIMI_VERBOSE") { Ok(s) => { !s.is_empty() && s != "0" }, Err(_) => false, } } } #[derive(Debug, Clone, serde::Deserialize)] pub struct DepFormerConfig { pub transformer: transformer::Config, pub num_slices: usize, } #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { pub transformer: transformer::Config, pub depformer: Option, pub text_in_vocab_size: usize, pub text_out_vocab_size: usize, pub audio_vocab_size: usize, pub audio_codebooks: usize, } impl Config { // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/af78657c/outputs/hyperparams.json // Update 2024-03-19: Sin embeddings -> None, RmsNorm fix, scale factor 4.125 // Update 2024-05-02: split text_vocab_size into text_in_vocab_size and text_out_vocab_size. // embeddings. pub fn v0_1() -> Self { let lm_cfg = transformer::Config { d_model: 4096, num_heads: 32, num_layers: 32, dim_feedforward: 4096 * 4, // dim * hidden_scale causal: true, norm_first: true, bias_ff: false, bias_attn: false, layer_scale: None, context: 3000, max_period: 10000, use_conv_block: false, use_conv_bias: true, cross_attention: None, gating: Some(candle_nn::Activation::Silu), norm: NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::Rope, conv_layout: false, conv_kernel_size: 3, kv_repeat: 1, max_seq_len: 4096, }; let depformer_cfg = transformer::Config { d_model: 1024, num_heads: 16, num_layers: 6, dim_feedforward: 1024 * 4, // dim * hidden_scale causal: true, norm_first: true, bias_ff: false, bias_attn: false, layer_scale: None, context: 8, max_period: 10000, use_conv_block: false, use_conv_bias: true, cross_attention: None, gating: Some(candle_nn::Activation::Silu), norm: NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::None, conv_layout: false, conv_kernel_size: 3, kv_repeat: 1, max_seq_len: 4096, }; let depformer_cfg = DepFormerConfig { num_slices: 8, transformer: depformer_cfg, }; Self { transformer: lm_cfg, depformer: Some(depformer_cfg), audio_vocab_size: 2049, text_in_vocab_size: 32001, text_out_vocab_size: 32000, audio_codebooks: 8, } } pub fn v0_1_vision() -> Self { let lm_cfg = transformer::Config { d_model: 4096, num_heads: 32, num_layers: 32, dim_feedforward: 4096 * 4, // dim * hidden_scale causal: true, norm_first: true, bias_ff: false, bias_attn: false, layer_scale: None, context: 3000, max_period: 10000, use_conv_block: false, use_conv_bias: true, cross_attention: Some(( transformer::CrossAttentionGating::ConditionalGatedSigmoid, NormType::RmsNorm, None, )), gating: Some(candle_nn::Activation::Silu), norm: NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::Rope, conv_layout: false, conv_kernel_size: 3, kv_repeat: 1, max_seq_len: 4096, }; let depformer_cfg = transformer::Config { d_model: 1024, num_heads: 16, num_layers: 6, dim_feedforward: 1024 * 4, // dim * hidden_scale causal: true, norm_first: true, bias_ff: false, bias_attn: false, layer_scale: None, context: 8, max_period: 10000, use_conv_block: false, use_conv_bias: true, cross_attention: None, gating: Some(candle_nn::Activation::Silu), norm: crate::NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::None, conv_layout: false, conv_kernel_size: 3, kv_repeat: 1, max_seq_len: 4096, }; let depformer_cfg = DepFormerConfig { num_slices: 8, transformer: depformer_cfg, }; Self { transformer: lm_cfg, depformer: Some(depformer_cfg), audio_vocab_size: 2049, text_in_vocab_size: 32001, text_out_vocab_size: 32000, audio_codebooks: 8, } } pub fn v0_1_vision_streaming(num_slices: usize) -> Self { let mut s = Self::v0_1_vision(); s.audio_codebooks = 16; if let Some(depformer) = s.depformer.as_mut() { depformer.num_slices = num_slices; depformer.transformer.context = num_slices; } s } pub fn v0_1_streaming(num_slices: usize) -> Self { let mut s = Self::v0_1(); s.audio_codebooks = 16; if let Some(depformer) = s.depformer.as_mut() { depformer.num_slices = num_slices; depformer.transformer.context = num_slices; } s } } #[derive(Debug, Clone)] struct DepFormerSlice { transformer: transformer::StreamingTransformer, // Note that the embedding for the first slice does not have the same dimension as the // embedding for the other slices as it takes a text token as input rather than an audio token. emb: MaybeQuantizedEmbedding, linear_in: MaybeQuantizedLinear, // depformer_in.{idx} linear_out: MaybeQuantizedLinear, // linears.{idx} } impl DepFormerSlice { fn new( in_vocab_size: usize, out_vocab_size: usize, main_transformer_dim: usize, cfg: &transformer::Config, vb: MaybeQuantizedVarBuilder, ) -> Result { let dim = cfg.d_model; let transformer = transformer::StreamingTransformer::new(cfg, vb.pp("transformer"))?; let emb = MaybeQuantizedEmbedding::new(in_vocab_size, dim, vb.pp("emb"))?; let linear_in = linear(main_transformer_dim, dim, false, vb.pp("linear_in"))?; let linear_out = linear(dim, out_vocab_size, false, vb.pp("linear_out"))?; Ok(Self { transformer, emb, linear_in, linear_out, }) } } #[derive(Debug, Clone)] pub struct DepFormer { slices: Vec, } impl DepFormer { pub fn new( text_vocab_size: usize, audio_vocab_size: usize, main_transformer_dim: usize, cfg: &DepFormerConfig, vb: MaybeQuantizedVarBuilder, ) -> Result { let mut slices = Vec::with_capacity(cfg.num_slices); for slice_idx in 0..cfg.num_slices { let in_vs = if slice_idx == 0 { text_vocab_size } else { audio_vocab_size }; // The depformer cannot predict the audio padding token. let slice = DepFormerSlice::new( in_vs, audio_vocab_size - 1, // The depformer cannot emit an audio padding token. main_transformer_dim, &cfg.transformer, vb.pp(slice_idx), )?; slices.push(slice) } Ok(Self { slices }) } /// Run a transformer sampling step, getting a token id per codebook. /// - `xs` is the previous layer hidden state. pub fn sample( &mut self, xs: &Tensor, text_token: Option, forced_audio_tokens: &[Option], lp: &mut candle_transformers::generation::LogitsProcessor, ) -> Result> { use crate::streaming::StreamingModule; let dev = xs.device(); let mut tokens = Vec::with_capacity(self.slices.len()); let mut last_token = text_token; for slice_idx in 0..self.slices.len() { if slice_idx == 0 { self.slices[slice_idx].transformer.reset_state(); } else { let (lhs, rhs) = self.slices.split_at_mut(slice_idx); rhs[0] .transformer .copy_state(&lhs[slice_idx - 1].transformer)? } let slice = &mut self.slices[slice_idx]; let xs = slice.linear_in.forward(xs)?; let xs = match last_token { Some(last_token) => { let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?; let token_emb = slice.emb.forward(&token_id)?; xs.broadcast_add(&token_emb)? } None => xs, }; let xs = slice.transformer.forward(&xs)?; let logits = xs.apply(&slice.linear_out)?; let logits = match logits.dim(0)? { 1 => logits.i((0, 0))?, b_size => candle::bail!("unexpected batch size {b_size}"), }; let token = lp.sample(&logits)?; if VERBOSE.with(|v| *v) { println!("sampled {token} logits {slice_idx}:\n{logits}"); } tokens.push(token); let token_for_next_layer = forced_audio_tokens .get(slice_idx) .copied() .flatten() .unwrap_or(token); last_token = Some(token_for_next_layer); } Ok(tokens) } // Sampling with classifier free guidance. pub fn sample_cfg( &mut self, xs: &Tensor, cfg_alpha: f64, text_token: Option, forced_audio_tokens: &[Option], lp: &mut candle_transformers::generation::LogitsProcessor, ) -> Result> { use crate::streaming::StreamingModule; let dev = xs.device(); let mut tokens = Vec::with_capacity(self.slices.len()); let mut last_token = text_token; for slice_idx in 0..self.slices.len() { if slice_idx == 0 { self.slices[slice_idx].transformer.reset_state(); } else { let (lhs, rhs) = self.slices.split_at_mut(slice_idx); rhs[0] .transformer .copy_state(&lhs[slice_idx - 1].transformer)? } let slice = &mut self.slices[slice_idx]; let xs = slice.linear_in.forward(xs)?; let xs = match last_token { Some(last_token) => { let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?; let token_emb = slice.emb.forward(&token_id)?; xs.broadcast_add(&token_emb)? } None => xs, }; let xs = slice.transformer.forward(&xs)?; let logits = xs.apply(&slice.linear_out)?; let logits = match logits.dim(0)? { 2 => ((logits.i((0, 0))? * cfg_alpha)? - (logits.i((1, 0))? * (cfg_alpha - 1.))?)?, b_size => candle::bail!("unexpected batch size {b_size}"), }; let token = lp.sample(&logits)?; tokens.push(token); let token_for_next_layer = forced_audio_tokens .get(slice_idx) .copied() .flatten() .unwrap_or(token); last_token = Some(token_for_next_layer); } Ok(tokens) } } #[derive(Debug, Clone)] pub struct LmModel { pub transformer: transformer::StreamingTransformer, pub text_emb: MaybeQuantizedEmbedding, pub audio_embs: Vec, pub text_linear: MaybeQuantizedLinear, pub out_norm: transformer::Norm, pub depformer: Option, pub audio_vocab_size: usize, pub text_in_vocab_size: usize, pub dtype: DType, } impl LmModel { pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { let d_model = cfg.transformer.d_model; let depformer = match &cfg.depformer { None => None, Some(depformer_cfg) => { let depformer = DepFormer::new( cfg.text_in_vocab_size, cfg.audio_vocab_size, d_model, depformer_cfg, vb.pp("depformer"), )?; Some(depformer) } }; let text_emb = MaybeQuantizedEmbedding::new(cfg.text_in_vocab_size, d_model, vb.pp("text_emb"))?; let out_norm = transformer::Norm::new(d_model, &cfg.transformer, vb.pp("out_norm"))?; let text_linear = linear( d_model, cfg.text_out_vocab_size, false, vb.pp("text_linear"), )?; let transformer = transformer::StreamingTransformer::new(&cfg.transformer, vb.pp("transformer"))?; let vb_e = vb.pp("emb"); let mut audio_embs = Vec::with_capacity(cfg.audio_codebooks); for i in 0..cfg.audio_codebooks { let emb = MaybeQuantizedEmbedding::new(cfg.audio_vocab_size, d_model, vb_e.pp(i))?; audio_embs.push(emb) } let dtype = match vb { MaybeQuantizedVarBuilder::Real(weights) => weights.dtype(), MaybeQuantizedVarBuilder::Quantized(_) => DType::F32, }; Ok(Self { transformer, text_emb, text_linear, audio_embs, out_norm, depformer, text_in_vocab_size: cfg.text_in_vocab_size, audio_vocab_size: cfg.audio_vocab_size, dtype, }) } pub fn reset_state(&mut self) { use crate::streaming::StreamingModule; self.transformer.reset_state() } pub fn in_audio_codebooks(&self) -> usize { self.audio_embs.len() } pub fn audio_pad_token(&self) -> u32 { self.audio_vocab_size as u32 - 1 } pub fn text_start_token(&self) -> u32 { self.text_in_vocab_size as u32 - 1 } pub fn generated_audio_codebooks(&self) -> usize { self.depformer.as_ref().map_or(0, |v| v.slices.len()) } pub fn is_quantized(&self) -> bool { match self.text_linear { MaybeQuantizedLinear::Quantized(_) => true, MaybeQuantizedLinear::Real(_) => false, } } pub fn device(&self) -> &Device { self.text_emb.embeddings().device() } pub fn forward( &mut self, text_ids: Option, audio_ids: Vec>, ) -> candle::Result<(Tensor, Tensor)> { if VERBOSE.with(|v| *v) { print!("text_ids "); if let Some(text_ids) = text_ids.as_ref() { let text_ids = text_ids.flatten_all()?.to_vec1::()?; println!("{text_ids:?}"); } else { println!("none") } print!("audio_ids "); for audio_id in audio_ids.iter() { if let Some(audio_id) = audio_id { let audio_id = audio_id.flatten_all()?.to_vec1::()?; print!(" {audio_id:?}"); } else { print!(" none") } } println!(); } let mut emb = match text_ids.as_ref() { Some(text_ids) => text_ids.apply(&self.text_emb)?, None => { let device = self.text_emb.embeddings().device(); Tensor::zeros((1, 1, self.text_emb.hidden_size()?), self.dtype, device)? } }; for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) { if let Some(audio_ids) = audio_ids { let e = audio_ids.apply(audio_emb)?; emb = (emb + e)? } } let ys = self.transformer.forward(&emb)?; let ys = ys.apply(&self.out_norm)?; let logits = ys.apply(&self.text_linear)?; if VERBOSE.with(|v| *v) { println!("logits:\n{logits}"); } Ok((logits, ys)) } pub fn maybe_precompute_ca_kv(&self, ca_src: Option) -> Result> { let ca_src = match ca_src { None => None, z => self.transformer.maybe_precompute_ca_kv(z)?, }; Ok(ca_src) } pub fn forward_ca( &mut self, text_ids: Option, audio_ids: Vec>, ca_src: &CaSrc, ) -> candle::Result<(Tensor, Tensor)> { let (logits, ys, _) = self.forward_with_gate_weight(text_ids, audio_ids, ca_src)?; Ok((logits, ys)) } pub fn forward_with_gate_weight( &mut self, text_ids: Option, audio_ids: Vec>, ca_src: &CaSrc, ) -> candle::Result<(Tensor, Tensor, Tensor)> { if VERBOSE.with(|v| *v) { print!("text_ids "); if let Some(text_ids) = text_ids.as_ref() { let text_ids = text_ids.flatten_all()?.to_vec1::()?; println!("{text_ids:?}"); } else { println!("none") } print!("audio_ids "); for audio_id in audio_ids.iter() { if let Some(audio_id) = audio_id { let audio_id = audio_id.flatten_all()?.to_vec1::()?; print!(" {audio_id:?}"); } else { print!(" none") } } println!(); } let b_size = match ca_src { CaSrc::KeysValues((cak, _)) => cak.dim(0)?, CaSrc::Tokens(catoks) => catoks.dim(0)?, }; let mut emb = match text_ids { Some(text_ids) => text_ids.apply(&self.text_emb)?, None => { let device = self.text_emb.embeddings().device(); Tensor::zeros( (b_size, 1, self.text_emb.hidden_size()?), self.dtype, device, )? } }; for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) { if let Some(audio_ids) = audio_ids { let e = audio_ids.apply(audio_emb)?; emb = emb.broadcast_add(&e)? } } let (ys, alpha) = self .transformer .forward_with_gate_weight(&emb, Some(ca_src))?; let ys = ys.apply(&self.out_norm)?; let logits = ys.apply(&self.text_linear)?; Ok((logits, ys, alpha)) } pub fn depformer_sample( &mut self, xs: &Tensor, text_token: Option, forced_audio_tokens: &[Option], lp: &mut candle_transformers::generation::LogitsProcessor, ) -> Result>> { let sample = match &mut self.depformer { None => None, Some(m) => { let sample = m.sample(xs, text_token, forced_audio_tokens, lp)?; Some(sample) } }; Ok(sample) } } pub fn load_lm_model>( cfg: Config, model_file: P, dtype: DType, dev: &Device, ) -> Result { let quantized = model_file.as_ref().extension().is_some_and(|v| v == "gguf"); let vb = if quantized { MaybeQuantizedVarBuilder::Quantized( candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?, ) } else { unsafe { MaybeQuantizedVarBuilder::Real(candle_nn::VarBuilder::from_mmaped_safetensors( &[model_file], dtype, dev, )?) } }; let model = LmModel::new(&cfg, vb)?; Ok(model) } pub fn load>( model_file: P, dtype: DType, dev: &Device, ) -> Result { let cfg = Config::v0_1(); load_lm_model(cfg, model_file, dtype, dev) } pub fn load_streaming>( model_file: P, dtype: DType, dev: &Device, ) -> Result { let cfg = Config::v0_1_streaming(8); load_lm_model(cfg, model_file, dtype, dev) } pub fn load_streaming_both_ways>( model_file: P, dtype: DType, dev: &Device, ) -> Result { let cfg = Config::v0_1_streaming(16); load_lm_model(cfg, model_file, dtype, dev) } pub fn load_vision>( model_file: P, override_cross_attention_gating: Option, override_cross_attention_in_dim: Option, dtype: DType, dev: &Device, ) -> Result { // load_vision allows for overriding some hyperparams of the lm from the main config file let mut cfg = Config::v0_1_vision_streaming(8); cfg.transformer.cross_attention = override_cross_attention_gating .map(|v| (v, cfg.transformer.norm, override_cross_attention_in_dim)); load_lm_model(cfg, model_file, dtype, dev) } pub struct ForcedAudioTokens { acoustic_delay: usize, // Tokens that are teacher forced before the acoustic delay. pre_delay_tokens: Vec>, } impl ForcedAudioTokens { pub fn new(acoustic_delay: usize, audio_pad_token: u32, stream_codebooks: &[usize]) -> Self { let mut pre_delay_tokens = vec![]; for codebooks in stream_codebooks.iter() { for c in 0..*codebooks { let token = if c == 0 { None } else { Some(audio_pad_token) }; pre_delay_tokens.push(token); } } Self { acoustic_delay, pre_delay_tokens, } } pub fn forced_tokens(&self, step_idx: usize) -> &[Option] { if step_idx < self.acoustic_delay { &self.pre_delay_tokens } else { &[] } } } ================================================ FILE: kyuteye_rs/moshi-core/src/lm_generate.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use candle::{IndexOp, Tensor}; use candle_transformers::generation::LogitsProcessor; const UNGENERATED: u32 = u32::MAX; #[derive(Debug, Clone)] pub struct Config { pub audio_codebooks: usize, pub audio_vocab_size: usize, pub acoustic_delay: usize, pub text_bos_token: u32, pub text_eos_token: u32, pub text_pad_token: u32, pub text_start_token: u32, } impl Config { pub fn v0_1() -> Self { Self { audio_codebooks: 8, audio_vocab_size: 2049, acoustic_delay: 2, text_bos_token: 1, text_eos_token: 2, text_pad_token: 3, text_start_token: 32000, } } pub fn audio_pad_token(&self) -> u32 { self.audio_vocab_size as u32 - 1 } pub fn audio_codebooks(&self) -> usize { self.audio_codebooks } } pub struct State { model: crate::lm::LmModel, audio_tokens: Vec>, audio_lp: LogitsProcessor, text_lp: LogitsProcessor, step_idx: usize, forced_audio_tokens: crate::lm::ForcedAudioTokens, config: Config, npads: i32, } impl State { pub fn new( model: crate::lm::LmModel, max_step_idx: usize, audio_lp: LogitsProcessor, text_lp: LogitsProcessor, config: Config, ) -> Self { let audio_tokens: Vec> = vec![vec![UNGENERATED; config.audio_codebooks]; max_step_idx + config.acoustic_delay]; let forced_audio_tokens = crate::lm::ForcedAudioTokens::new( config.audio_codebooks, model.audio_pad_token(), &[8, 8], ); Self { model, audio_tokens, audio_lp, text_lp, step_idx: 0, npads: 0, forced_audio_tokens, config, } } pub fn audio_codebooks(&self) -> usize { self.config.audio_codebooks } pub fn audio_pad_token(&self) -> u32 { self.config.audio_pad_token() } pub fn step_gen_no_text(&mut self, force_text_token: Option) -> candle::Result { self.step(None, true, force_text_token) } pub fn step_gen(&mut self, prev_text_token: u32) -> candle::Result { self.step(Some(prev_text_token), true, None) } pub fn step_text_prompt(&mut self, id: u32) -> candle::Result { self.step(Some(id), false, None) } pub fn step_audio_prompt_( &mut self, codes: &[u32], text_token: Option, ) -> candle::Result { if codes.len() != self.audio_codebooks() { candle::bail!( "unexpected codes length {} {}", codes.len(), self.audio_codebooks() ) } self.audio_tokens[self.step_idx].copy_from_slice(codes); let prev_text = if self.step_idx == 0 { Some(self.config.text_start_token) } else { text_token }; self.step(prev_text, false, None) } pub fn step_audio_prompt(&mut self, codes: &[u32]) -> candle::Result { self.step_audio_prompt_(codes, None) } pub fn step_audio_prompt_with_text(&mut self, codes: &[u32], text: u32) -> candle::Result { self.step_audio_prompt_(codes, Some(text)) } pub fn last_audio_tokens(&self) -> Option> { if self.step_idx <= self.config.acoustic_delay { None } else { // step_idx is in advance by 1 + there is a 2 token delay on audio tokens. let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1]; if audio_tokens .iter() .any(|v| *v as usize >= self.config.audio_vocab_size - 1) { None } else { Some(audio_tokens.clone()) } } } pub fn audio_tokens(&self) -> Vec> { let l = self.step_idx - self.config.acoustic_delay - 1; self.audio_tokens[..l].to_vec() } // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED // tokens in the case where we call `step_audio_prompt` *after* `step`. fn step( &mut self, text_token: Option, gen_audio: bool, force_text_token: Option, ) -> candle::Result { let mut codes = Vec::with_capacity(self.audio_codebooks()); let dev = self.model.device(); for codebook in 0..self.audio_codebooks() { let t = if codebook == 0 { if self.step_idx == 0 { self.audio_pad_token() } else { self.audio_tokens[self.step_idx - 1][0] } } else if self.step_idx <= self.config.acoustic_delay { self.audio_pad_token() } else { self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook] }; if t == UNGENERATED { candle::bail!("internal error, ungenerated {}", self.step_idx) } let t = Tensor::new(&[t], dev)?.unsqueeze(0)?; codes.push(Some(t)) } let text_token = match text_token { None => None, Some(text_token) => Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?), }; let (text_logits, ys) = self.model.forward(text_token, codes)?; let text_logits = text_logits.i((0, 0))?; let text_token = match force_text_token { None => self.text_lp.sample_f(&text_logits, |prs| { prs[self.config.text_bos_token as usize] = 1e-9; if self.npads > 40 { let mul = 2f32.powi(self.npads - 40); prs[self.config.text_eos_token as usize] *= mul; } })?, Some(t) => t, }; if text_token == self.config.text_pad_token { self.npads += 1; } else { self.npads = 0; } let last_audio_tokens = if gen_audio { self.model.depformer_sample( &ys, Some(text_token), self.forced_audio_tokens.forced_tokens(self.step_idx), &mut self.audio_lp, )? } else { None }; let audio_pad_token = self.audio_pad_token(); for c_idx in 0..self.audio_codebooks() { let delay = if c_idx == 0 { 0 } else { self.config.acoustic_delay }; let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx]; match last_audio_tokens.as_ref() { Some(lat) => *pos = lat[c_idx], None => { if *pos == UNGENERATED { *pos = audio_pad_token } } } } self.step_idx += 1; if self.step_idx >= self.audio_tokens.len() { candle::bail!("max step-idx reached") } Ok(text_token) } } ================================================ FILE: kyuteye_rs/moshi-core/src/lm_generate_multistream.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::dynamic_logits_processor::GateInfluencedLogitsProcessor; use candle::{IndexOp, Tensor}; use candle_transformers::generation::LogitsProcessor; use crate::transformer::CaSrc; pub const UNGENERATED: u32 = u32::MAX; #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] pub struct Config { pub generated_audio_codebooks: usize, pub input_audio_codebooks: usize, pub audio_vocab_size: usize, pub acoustic_delay: usize, pub text_pad_token: u32, pub text_eop_token: u32, pub text_start_token: u32, } impl Config { pub fn v0_1() -> Self { Self { generated_audio_codebooks: 8, input_audio_codebooks: 8, audio_vocab_size: 2049, acoustic_delay: 2, text_eop_token: 0, text_pad_token: 3, text_start_token: 32000, } } pub fn v0_1_two_ways() -> Self { Self { generated_audio_codebooks: 16, input_audio_codebooks: 0, audio_vocab_size: 2049, acoustic_delay: 2, text_eop_token: 0, text_pad_token: 3, text_start_token: 32000, } } pub fn v0_1_one_way() -> Self { Self { generated_audio_codebooks: 8, input_audio_codebooks: 0, audio_vocab_size: 2049, acoustic_delay: 2, text_eop_token: 0, text_pad_token: 3, text_start_token: 32000, } } pub fn audio_pad_token(&self) -> u32 { self.audio_vocab_size as u32 - 1 } pub fn total_audio_codebooks(&self) -> usize { self.generated_audio_codebooks + self.input_audio_codebooks } } pub struct State { model: crate::lm::LmModel, audio_tokens: Vec>, text_tokens: Vec, gate_weights: Vec, audio_lp: LogitsProcessor, text_lp: GateInfluencedLogitsProcessor, step_idx: usize, pad_mult: Option, // For repetition penalty, we provide the context len (in text tokens) and the penalty. repetition_penalty: Option<(usize, f32)>, forced_audio_tokens: crate::lm::ForcedAudioTokens, user_rating: u32, config: Config, } impl State { #[allow(clippy::too_many_arguments)] pub fn new( model: crate::lm::LmModel, max_step_idx: usize, audio_lp: LogitsProcessor, text_lp: GateInfluencedLogitsProcessor, pad_mult: Option, repetition_penalty: Option<(usize, f32)>, config: Config, ) -> Self { let audio_tokens: Vec> = vec![ vec![UNGENERATED; config.total_audio_codebooks()]; max_step_idx + config.acoustic_delay ]; let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay]; let forced_audio_tokens = crate::lm::ForcedAudioTokens::new( config.acoustic_delay, config.audio_pad_token(), &[8, 8], ); let gate_weights = vec![0.0_f32; max_step_idx + config.acoustic_delay]; Self { model, audio_tokens, text_tokens, audio_lp, text_lp, gate_weights, step_idx: 0, pad_mult, repetition_penalty, forced_audio_tokens, user_rating: 0, config, } } pub fn step_idx(&self) -> usize { self.step_idx } fn audio_pad_token(&self) -> u32 { self.config.audio_pad_token() } pub fn config(&self) -> &Config { &self.config } pub fn user_rating(&self) -> u32 { self.user_rating } pub fn set_user_rating(&mut self, grade: u32) { self.user_rating = grade } fn apply_repetition_penalty(&self, logits: Tensor) -> candle::Result { let logits = match self.repetition_penalty { None => logits, Some((_, 1.)) => logits, Some((context_size, penalty)) => { let device = logits.device(); let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::()?; let mut already_seen = std::collections::HashSet::new(); let mut non_pad_tokens = 0; for &token_id in self.text_tokens(false).iter().rev() { if token_id == self.config.text_pad_token || token_id == self.config.text_eop_token || token_id == self.config.text_start_token { continue; } // Look at the last [context_size] tokens at most, count all tokens there even // if we already saw them. if non_pad_tokens >= context_size { break; } non_pad_tokens += 1; if already_seen.contains(&token_id) { continue; } already_seen.insert(token_id); if let Some(logit) = logits.get_mut(token_id as usize) { if *logit >= 0. { *logit /= penalty } else { *logit *= penalty } } } let logits_len = logits.len(); Tensor::from_vec(logits, logits_len, device)? } }; Ok(logits) } // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED // tokens in the case where we call `step_audio_prompt` *after* `step`. pub fn step_( &mut self, text_token: Option, input_audio_tokens: &[u32], force_text_token: Option, ca_src: Option<&CaSrc>, ) -> candle::Result<(u32, f32)> { let mut codes = Vec::with_capacity(self.config.total_audio_codebooks()); let dev = self.model.device(); for (c_idx, &t) in input_audio_tokens.iter().enumerate() { self.audio_tokens[self.step_idx][c_idx + self.config.generated_audio_codebooks] = t } for codebook in 0..self.config.total_audio_codebooks() { let t = if codebook == 0 || codebook == 8 { if self.step_idx == 0 { self.audio_pad_token() } else { self.audio_tokens[self.step_idx - 1][codebook] } } else if self.step_idx <= self.config.acoustic_delay { self.audio_pad_token() } else { self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook] }; if t == UNGENERATED { candle::bail!("internal error, ungenerated {} {codebook}", self.step_idx) } let t = Tensor::new(&[t], dev)?.unsqueeze(0)?; codes.push(Some(t)) } let text_token = match text_token { Some(text_token) => Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?), None => None, }; let (text_logits, ys, alpha) = match ca_src.as_ref() { None => { let (logits, ys) = self.model.forward(text_token, codes)?; (logits, ys, 0.0) } Some(ca_src) => { let (logits, ys, alpha) = self .model .forward_with_gate_weight(text_token, codes, ca_src)?; (logits, ys, alpha.to_dtype(candle::DType::F32)?.to_scalar()?) } }; self.gate_weights[self.step_idx] = alpha; let text_logits = text_logits.i((0, 0))?; let text_logits = self.apply_repetition_penalty(text_logits)?; let text_token = match force_text_token { Some(tt) => tt, None => self.text_lp.sample_f( &text_logits, |prs| { if let Some(pad_mult) = self.pad_mult.as_ref() { prs[self.config.text_pad_token as usize] *= f32::exp(*pad_mult); } }, alpha as f64, )?, }; self.text_tokens[self.step_idx] = text_token; let last_audio_tokens = self.model.depformer_sample( &ys, Some(text_token), self.forced_audio_tokens.forced_tokens(self.step_idx), &mut self.audio_lp, )?; let audio_pad_token = self.audio_pad_token(); for c_idx in 0..self.config.generated_audio_codebooks { let delay = if c_idx == 0 || c_idx == 8 { 0 } else { self.config.acoustic_delay }; let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx]; match last_audio_tokens.as_ref() { Some(lat) => { if *pos == UNGENERATED { *pos = lat[c_idx] } } None => { if *pos == UNGENERATED { *pos = audio_pad_token } } } } self.step_idx += 1; if self.step_idx >= self.audio_tokens.len() { candle::bail!("max step-idx reached") } Ok((text_token, alpha)) } pub fn step( &mut self, text_token: u32, input_audio_tokens: &[u32], force_text_token: Option, ca_src: Option<&CaSrc>, ) -> candle::Result { let (text_token, _) = self.step_( Some(text_token), input_audio_tokens, force_text_token, ca_src, )?; Ok(text_token) } pub fn step_with_gate_weight( &mut self, text_token: u32, input_audio_tokens: &[u32], force_text_token: Option, ca_src: Option<&CaSrc>, ) -> candle::Result<(u32, f32)> { self.step_( Some(text_token), input_audio_tokens, force_text_token, ca_src, ) } /// If include_all is set, all the time steps are returned. Otherwise only the timesteps that /// have been generated are handled. pub fn audio_tokens(&self, include_all: bool) -> &[Vec] { if include_all { &self.audio_tokens } else { let max_idx = usize::min(self.step_idx, self.audio_tokens.len()); &self.audio_tokens[..max_idx] } } pub fn gate_weights(&self, include_all: bool) -> &[f32] { if include_all { &self.gate_weights } else { let max_idx = usize::min(self.step_idx, self.gate_weights.len()); &self.gate_weights[..max_idx] } } pub fn text_tokens(&self, include_all: bool) -> &[u32] { if include_all { &self.text_tokens } else { let max_idx = usize::min(self.step_idx, self.text_tokens.len()); &self.text_tokens[..max_idx] } } pub fn last_audio_tokens(&self) -> Option> { if self.step_idx <= self.config.acoustic_delay { None } else { // step_idx is in advance by 1 + there is a 2 token delay on audio tokens. let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1]; if audio_tokens .iter() .any(|v| *v as usize >= self.config.audio_vocab_size - 1) { None } else { Some(audio_tokens.clone()) } } } } ================================================ FILE: kyuteye_rs/moshi-core/src/mimi.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::streaming::{StreamTensor, StreamingModule}; use crate::{conv, nn, quantization, seanet, transformer}; use candle::{DType, Device, Module, Result, Tensor}; use candle_nn::VarBuilder; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ResampleMethod { Conv, Interpolate, } #[derive(Debug, Clone)] pub struct Config { pub channels: usize, pub sample_rate: f64, pub frame_rate: f64, pub renormalize: bool, pub resample_method: ResampleMethod, pub seanet: seanet::Config, pub transformer: transformer::Config, pub quantizer_n_q: usize, pub quantizer_bins: usize, pub quantizer_dim: usize, } impl Config { // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml pub fn v0_1(num_codebooks: Option) -> Self { let seanet_cfg = seanet::Config { dimension: 512, channels: 1, causal: true, n_filters: 64, n_residual_layers: 1, activation: candle_nn::Activation::Elu(1.), compress: 2, dilation_base: 2, disable_norm_outer_blocks: 0, final_activation: None, kernel_size: 7, residual_kernel_size: 3, last_kernel_size: 3, lstm: 0, norm: conv::Norm::WeightNorm, pad_mode: conv::PadMode::Constant, ratios: vec![8, 6, 5, 4], true_skip: true, }; let transformer_cfg = transformer::Config { d_model: seanet_cfg.dimension, num_heads: 8, num_layers: 8, causal: true, norm_first: true, bias_ff: false, bias_attn: false, layer_scale: Some(0.01), context: 250, conv_kernel_size: 5, use_conv_bias: true, use_conv_block: false, cross_attention: None, max_period: 10000, gating: None, norm: crate::NormType::LayerNorm, positional_embedding: transformer::PositionalEmbedding::Rope, dim_feedforward: 2048, kv_repeat: 1, conv_layout: true, // see builders.py max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins. }; Config { channels: 1, sample_rate: 24_000., frame_rate: 12.5, renormalize: true, resample_method: ResampleMethod::Conv, seanet: seanet_cfg, transformer: transformer_cfg, quantizer_n_q: num_codebooks.unwrap_or(16), quantizer_bins: 2048, quantizer_dim: 256, } } } #[derive(Debug, Clone)] pub struct Mimi { encoder: seanet::SeaNetEncoder, decoder: seanet::SeaNetDecoder, encoder_transformer: transformer::ProjectedTransformer, decoder_transformer: transformer::ProjectedTransformer, downsample: conv::ConvDownsample1d, upsample: conv::ConvTrUpsample1d, quantizer: quantization::SplitResidualVectorQuantizer, config: Config, } impl Mimi { pub fn new(cfg: Config, vb: VarBuilder) -> Result { let dim = cfg.seanet.dimension; let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?; let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?; let encoder_transformer = transformer::ProjectedTransformer::new( dim, &[dim], &cfg.transformer, nn::MaybeQuantizedVarBuilder::Real(vb.pp("encoder_transformer")), )?; let decoder_transformer = transformer::ProjectedTransformer::new( dim, &[dim], &cfg.transformer, nn::MaybeQuantizedVarBuilder::Real(vb.pp("decoder_transformer")), )?; let quantizer = quantization::SplitResidualVectorQuantizer::new( /* dim */ cfg.quantizer_dim, /* input_dim */ Some(dim), /* output_dim */ Some(dim), /* n_q */ cfg.quantizer_n_q, /* bins */ cfg.quantizer_bins, vb.pp("quantizer"), )?; let encoder_frame_rate = cfg.sample_rate / cfg.seanet.ratios.iter().product::() as f64; let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize; // `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate. let downsample = conv::ConvDownsample1d::new( /* stride */ downsample_stride, /* dim */ dim, /* causal */ true, /* learnt */ true, vb.pp("downsample"), )?; let upsample = conv::ConvTrUpsample1d::new( /* stride */ downsample_stride, /* dim */ dim, /* causal */ true, /* learnt */ true, vb.pp("upsample"), )?; Ok(Self { encoder, decoder, encoder_transformer, decoder_transformer, quantizer, downsample, upsample, config: cfg, }) } pub fn config(&self) -> &Config { &self.config } pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result { let xs = self.encoder.forward(xs)?; self.encoder_transformer.reset_state(); let xs = self.encoder_transformer.forward(&xs)?; let xs = &xs[0]; xs.apply(&self.downsample) } pub fn encode(&mut self, xs: &Tensor) -> Result { let xs = self.encoder.forward(xs)?; self.encoder_transformer.reset_state(); let xs = self.encoder_transformer.forward(&xs)?; let xs = &xs[0]; let xs = xs.apply(&self.downsample)?; let codes = self.quantizer.encode(&xs)?; Ok(codes) } pub fn encode_step(&mut self, xs: &StreamTensor) -> Result { let xs = self.encoder.step(xs)?; let xs = self.encoder_transformer.step(&xs)?; let xs = self.downsample.step(&xs)?; match xs.as_option() { None => Ok(().into()), Some(xs) => { let codes = self.quantizer.encode(xs)?; Ok(codes.into()) } } } pub fn decode(&mut self, codes: &Tensor) -> Result { let emb = self.quantizer.decode(codes)?; let emb = emb.apply(&self.upsample)?; self.decoder_transformer.reset_state(); let outs = self.decoder_transformer.forward(&emb)?; let out = &outs[0]; self.decoder.forward(out) } pub fn decode_step(&mut self, codes: &StreamTensor) -> Result { let emb = match codes.as_option() { Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?), None => StreamTensor::empty(), }; let emb = self.upsample.step(&emb)?; let out = self.decoder_transformer.step(&emb)?; self.decoder.step(&out) } pub fn reset_state(&mut self) { self.encoder.reset_state(); self.encoder_transformer.reset_state(); self.decoder.reset_state(); self.decoder_transformer.reset_state(); self.upsample.reset_state(); } } pub fn load(model_file: &str, num_codebooks: Option, dev: &Device) -> Result { let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? }; let cfg = Config::v0_1(num_codebooks); let mimi = Mimi::new(cfg, vb)?; Ok(mimi) } ================================================ FILE: kyuteye_rs/moshi-core/src/nn.rs ================================================ use candle::quantized::QTensor; use candle::{DType, Device, Module, Result, Shape, Tensor}; use candle_transformers::quantized_nn as candle_qnn; use candle_transformers::quantized_var_builder::VarBuilder as QuantizedVarBuilder; use std::sync::Arc; #[derive(Clone)] pub enum MaybeQuantizedWeight { // Enum types around real and quantized model weights Real(Tensor), Quantized(Arc), } impl MaybeQuantizedWeight { fn to_tensor(&self, dev: &Device) -> Result { match self { Self::Real(t) => Ok(t.clone()), Self::Quantized(t) => t.dequantize(dev), } } } pub fn matmul_dtype(device: &candle::Device) -> DType { // Dtype used for intermediate matmul in attention during quantized execution if device.is_cuda() { DType::BF16 } else { DType::F32 } } #[derive(Clone)] pub enum MaybeQuantizedVarBuilder<'a> { // Enum types around real and quantized var builders Real(candle_nn::VarBuilder<'a>), Quantized(QuantizedVarBuilder), } impl MaybeQuantizedVarBuilder<'_> { pub fn pp(&self, s: S) -> Self { match self { Self::Real(weights) => MaybeQuantizedVarBuilder::Real(weights.pp(s)), Self::Quantized(weights) => MaybeQuantizedVarBuilder::Quantized(weights.pp(s)), } } pub fn get>(&self, s: S, path: &str) -> Result { let w = match self { Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?), Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?), }; Ok(w) } pub fn get_as_tensor>(&self, s: S, path: &str) -> Result { let w = match self { Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?), Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?), }; w.to_tensor(self.device()) } pub fn get_unquantized>(&self, s: S, path: &str) -> Result { match self { Self::Real(weights) => weights.get(s, path), Self::Quantized(weights) => weights.get(s, path)?.dequantize(weights.device()), } } pub fn contains_key(&self, name: &str) -> bool { match self { Self::Real(weights) => weights.contains_tensor(name), Self::Quantized(weights) => weights.contains_key(name), } } pub fn device(&self) -> &Device { match self { Self::Real(weights) => weights.device(), Self::Quantized(weights) => weights.device(), } } pub fn dtype(&self) -> DType { match self { Self::Real(weights) => weights.dtype(), Self::Quantized(_) => DType::F32, } } } #[derive(Debug, Clone)] pub enum MaybeQuantizedLinear { Real(candle_nn::Linear), Quantized(candle_qnn::Linear), } impl Module for MaybeQuantizedLinear { fn forward(&self, xs: &Tensor) -> Result { match self { Self::Real(module) => module.forward(xs), Self::Quantized(module) => module.forward(xs), } } } #[derive(Debug, Clone)] pub enum MaybeQuantizedEmbedding { Real(candle_nn::Embedding), Quantized(candle_qnn::Embedding), } impl MaybeQuantizedEmbedding { pub fn new(in_vocab_size: usize, dim: usize, vb: MaybeQuantizedVarBuilder) -> Result { let emb = match vb { MaybeQuantizedVarBuilder::Real(weights) => { MaybeQuantizedEmbedding::Real(candle_nn::embedding(in_vocab_size, dim, weights)?) } MaybeQuantizedVarBuilder::Quantized(weights) => MaybeQuantizedEmbedding::Quantized( candle_transformers::quantized_nn::Embedding::new(in_vocab_size, dim, weights)?, ), }; Ok(emb) } pub fn embeddings(&self) -> &Tensor { match self { MaybeQuantizedEmbedding::Real(weights) => weights.embeddings(), MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings(), } } pub fn hidden_size(&self) -> Result { let size = match self { MaybeQuantizedEmbedding::Real(weights) => weights.hidden_size(), MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings().dim(1)?, }; Ok(size) } } impl Module for MaybeQuantizedEmbedding { fn forward(&self, xs: &Tensor) -> Result { match self { Self::Real(module) => module.forward(xs), Self::Quantized(module) => module.forward(xs), } } } pub fn linear( in_d: usize, out_d: usize, bias: bool, vb: MaybeQuantizedVarBuilder, ) -> Result { let output_linear = match vb { MaybeQuantizedVarBuilder::Real(weights) => { if bias { MaybeQuantizedLinear::Real(candle_nn::linear(in_d, out_d, weights)?) } else { MaybeQuantizedLinear::Real(candle_nn::linear_no_bias(in_d, out_d, weights)?) } } MaybeQuantizedVarBuilder::Quantized(weights) => { MaybeQuantizedLinear::Quantized(candle_qnn::linear_b(in_d, out_d, bias, weights)?) } }; Ok(output_linear) } pub fn linear_from( weight: MaybeQuantizedWeight, bias: Option, ) -> Result { let layer = match weight { MaybeQuantizedWeight::Real(w) => { MaybeQuantizedLinear::Real(candle_nn::Linear::new(w, bias)) } MaybeQuantizedWeight::Quantized(w) => { MaybeQuantizedLinear::Quantized(candle_qnn::Linear::from_arc(w, bias)?) } }; Ok(layer) } ================================================ FILE: kyuteye_rs/moshi-core/src/quantization.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use candle::{IndexOp, Layout, Result, Shape, Tensor, D}; use candle_nn::{linear, Linear, VarBuilder}; struct CodebookEncode; impl candle::CustomOp2 for CodebookEncode { fn name(&self) -> &'static str { "cb" } fn cpu_fwd( &self, lhs_storage: &candle::CpuStorage, lhs_layout: &Layout, rhs_storage: &candle::CpuStorage, rhs_layout: &Layout, ) -> Result<(candle::CpuStorage, Shape)> { use rayon::prelude::*; let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?; let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?; if lhs_dim2 != rhs_dim2 { candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}"); } if lhs_dim2 == 0 { candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}") } let lhs = match lhs_layout.contiguous_offsets() { None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"), Some((o1, o2)) => { let slice = lhs_storage.as_slice::()?; &slice[o1..o2] } }; let rhs = match rhs_layout.contiguous_offsets() { None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"), Some((o1, o2)) => { let slice = rhs_storage.as_slice::()?; &slice[o1..o2] } }; let dst = (0..lhs_dim1) .into_par_iter() .map(|idx1| { let mut where_min = 0; let mut min_dist = f32::INFINITY; let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2]; for idx2 in 0..rhs_dim1 { let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2]; let mut dist = 0f32; for (a, b) in lhs.iter().zip(rhs.iter()) { dist += (a - b) * (a - b) } if dist < min_dist { min_dist = dist; where_min = idx2; } } where_min as u32 }) .collect(); let storage = candle::WithDType::to_cpu_storage_owned(dst); Ok((storage, (lhs_dim1,).into())) } } #[allow(unused)] #[derive(Debug, Clone)] pub struct EuclideanCodebook { initialized: Tensor, cluster_usage: Tensor, embedding_sum: Tensor, embedding: Tensor, c2: Tensor, epsilon: f64, dim: usize, span_encode: tracing::Span, span_decode: tracing::Span, } impl EuclideanCodebook { pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result { let epsilon = 1e-5; let initialized = vb.get(1, "_initialized")?; let cluster_usage = vb.get(codebook_size, "cluster_usage")?; let embedding_sum = vb.get((codebook_size, dim), "embedding_sum")?; let embedding = { let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?; embedding_sum.broadcast_div(&cluster_usage)? }; let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?; Ok(Self { initialized, cluster_usage, embedding_sum, embedding, c2, epsilon, dim, span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"), span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"), }) } pub fn encode_very_slow(&self, xs: &Tensor) -> Result { let _enter = self.span_encode.enter(); let mut target_shape = xs.dims().to_vec(); target_shape.pop(); let xs = xs.flatten_to(D::Minus2)?; let _ = xs.dims2()?; // TODO: avoid repeating this. let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?; let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?; // Manual cdist implementation. let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?; let dists = diff.sqr()?.sum(D::Minus1)?; let codes = dists.argmin(D::Minus1)?; codes.reshape(target_shape) } pub fn encode_slow(&self, xs: &Tensor) -> Result { let _enter = self.span_encode.enter(); let mut target_shape = xs.dims().to_vec(); target_shape.pop(); let xs = xs.flatten_to(D::Minus2)?; let _ = xs.dims2()?; let dot_prod = xs.matmul(&self.embedding.t()?)?; let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?; codes.reshape(target_shape) } pub fn encode(&self, xs: &Tensor) -> Result { let _enter = self.span_encode.enter(); let mut target_shape = xs.dims().to_vec(); target_shape.pop(); let xs = xs.flatten_to(D::Minus2)?; let _ = xs.dims2()?; let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?; codes.reshape(target_shape) } pub fn decode(&self, indexes: &Tensor) -> Result { let _enter = self.span_decode.enter(); // let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?; let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.dim); let indexes = indexes.flatten_all()?; let values = self.embedding.index_select(&indexes, 0)?; let values = values.reshape(final_dims)?; Ok(values) } } #[allow(unused)] #[derive(Debug, Clone)] pub struct VectorQuantization { project_in: Option, project_out: Option, codebook: EuclideanCodebook, } impl VectorQuantization { pub fn new( dim: usize, codebook_size: usize, codebook_dim: Option, vb: VarBuilder, ) -> Result { let codebook_dim = codebook_dim.unwrap_or(dim); let (project_in, project_out) = if codebook_dim == dim { (None, None) } else { let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?; let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?; (Some(p_in), Some(p_out)) }; let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("_codebook"))?; Ok(Self { project_in, project_out, codebook, }) } pub fn encode(&self, xs: &Tensor) -> Result { let xs = xs.t()?.apply(&self.project_in.as_ref())?; self.codebook.encode_slow(&xs) } pub fn decode(&self, codes: &Tensor) -> Result { let quantized = self.codebook.decode(codes)?; let quantized = match &self.project_out { None => quantized, Some(p) => quantized.apply(p)?, }; quantized.t() } } #[derive(Debug, Clone)] pub struct ResidualVectorQuantization { layers: Vec, } impl ResidualVectorQuantization { pub fn new( n_q: usize, dim: usize, codebook_size: usize, codebook_dim: Option, vb: VarBuilder, ) -> Result { let vb = vb.pp("layers"); let mut layers = Vec::with_capacity(n_q); for i in 0..n_q { let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?; layers.push(layer) } Ok(Self { layers }) } pub fn encode(&self, xs: &Tensor) -> Result { let mut codes = Vec::with_capacity(self.layers.len()); let mut residual = xs.clone(); for layer in self.layers.iter() { let indices = layer.encode(&residual)?; let quantized = layer.decode(&indices)?; residual = (residual - quantized)?; codes.push(indices) } Tensor::stack(&codes, 0) } pub fn decode(&self, xs: &Tensor) -> Result { if self.layers.is_empty() { candle::bail!("empty layers in ResidualVectorQuantization") } if self.layers.len() != xs.dim(0)? { candle::bail!( "mismatch between the number of layers {} and the code shape {:?}", self.layers.len(), xs.shape() ) } let mut quantized = self.layers[0].decode(&xs.i(0)?)?; for (i, layer) in self.layers.iter().enumerate().skip(1) { let xs = xs.i(i)?; quantized = (quantized + layer.decode(&xs))? } Ok(quantized) } } #[allow(unused)] #[derive(Debug, Clone)] pub struct ResidualVectorQuantizer { vq: ResidualVectorQuantization, input_proj: Option, output_proj: Option, } impl ResidualVectorQuantizer { pub fn new( dim: usize, input_dim: Option, output_dim: Option, n_q: usize, bins: usize, force_projection: bool, vb: VarBuilder, ) -> Result { let input_dim = input_dim.unwrap_or(dim); let output_dim = output_dim.unwrap_or(dim); let input_proj = if input_dim == dim && !force_projection { None } else { let c = candle_nn::conv1d_no_bias( input_dim, dim, 1, Default::default(), vb.pp("input_proj"), )?; Some(c) }; let output_proj = if output_dim == dim && !force_projection { None } else { let c = candle_nn::conv1d_no_bias( dim, output_dim, 1, Default::default(), vb.pp("output_proj"), )?; Some(c) }; let vq = ResidualVectorQuantization::new( n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb.pp("vq"), )?; Ok(Self { vq, input_proj, output_proj, }) } pub fn encode(&self, xs: &Tensor) -> Result { let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?; codes.transpose(0, 1) } pub fn decode(&self, codes: &Tensor) -> Result { // codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. let codes = codes.transpose(0, 1)?; let quantized = self.vq.decode(&codes)?; match &self.output_proj { None => Ok(quantized), Some(p) => quantized.apply(p), } } } // we do not use any codebook_offset at the moment. When reconstructing the codes, we could just // concatenate the indexes. #[derive(Debug, Clone)] pub struct SplitResidualVectorQuantizer { rvq_first: ResidualVectorQuantizer, rvq_rest: ResidualVectorQuantizer, n_q: usize, span_encode: tracing::Span, span_decode: tracing::Span, } impl SplitResidualVectorQuantizer { pub fn new( dim: usize, input_dim: Option, output_dim: Option, n_q: usize, bins: usize, vb: VarBuilder, ) -> Result { let rvq_first = ResidualVectorQuantizer::new( dim, input_dim, output_dim, 1, bins, true, vb.pp("rvq_first"), )?; let rvq_rest = ResidualVectorQuantizer::new( dim, input_dim, output_dim, n_q - 1, bins, true, vb.pp("rvq_rest"), )?; let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode"); let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode"); Ok(Self { rvq_first, rvq_rest, n_q, span_encode, span_decode, }) } pub fn encode(&self, xs: &Tensor) -> Result { let _enter = self.span_encode.enter(); let codes = self.rvq_first.encode(xs)?; if self.n_q > 1 { // We encode xs again here rather than the residual. The decomposition is not // hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens // for rvq_rest. let rest_codes = self.rvq_rest.encode(xs)?; Tensor::cat(&[codes, rest_codes], 1) } else { Ok(codes) } } pub fn decode(&self, codes: &Tensor) -> Result { // codes is [B, K, T], with T frames, K nb of codebooks. let _enter = self.span_decode.enter(); let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?; let quantized = if self.n_q > 1 { (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))? } else { quantized }; Ok(quantized) } } ================================================ FILE: kyuteye_rs/moshi-core/src/seanet.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use crate::streaming::{self, StreamTensor, StreamingModule}; use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; use crate::conv::{StreamableConv1d, StreamableConvTranspose1d}; #[derive(Debug, Clone)] pub struct Config { pub dimension: usize, pub channels: usize, pub causal: bool, pub n_filters: usize, pub n_residual_layers: usize, pub ratios: Vec, pub activation: candle_nn::Activation, pub norm: crate::conv::Norm, pub kernel_size: usize, pub residual_kernel_size: usize, pub last_kernel_size: usize, pub dilation_base: usize, pub pad_mode: crate::conv::PadMode, pub true_skip: bool, pub compress: usize, pub lstm: usize, pub disable_norm_outer_blocks: usize, pub final_activation: Option, } #[derive(Debug, Clone)] pub struct SeaNetResnetBlock { block: Vec, shortcut: Option, activation: candle_nn::Activation, skip_op: streaming::StreamingBinOp, span: tracing::Span, } impl SeaNetResnetBlock { #[allow(clippy::too_many_arguments)] pub fn new( dim: usize, k_sizes_and_dilations: &[(usize, usize)], activation: candle_nn::Activation, norm: Option, causal: bool, pad_mode: crate::conv::PadMode, compress: usize, true_skip: bool, vb: VarBuilder, ) -> Result { let mut block = Vec::with_capacity(k_sizes_and_dilations.len()); let hidden = dim / compress; let vb_b = vb.pp("block"); for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() { let in_c = if i == 0 { dim } else { hidden }; let out_c = if i == k_sizes_and_dilations.len() - 1 { dim } else { hidden }; let c = StreamableConv1d::new( in_c, out_c, /* k_size */ *k_size, /* stride */ 1, /* dilation */ *dilation, /* groups */ 1, /* bias */ true, /* causal */ causal, /* norm */ norm, /* pad_mode */ pad_mode, vb_b.pp(2 * i + 1), )?; block.push(c) } let shortcut = if true_skip { None } else { let c = StreamableConv1d::new( dim, dim, /* k_size */ 1, /* stride */ 1, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ causal, /* norm */ norm, /* pad_mode */ pad_mode, vb.pp("shortcut"), )?; Some(c) }; Ok(Self { block, shortcut, activation, skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1), span: tracing::span!(tracing::Level::TRACE, "sea-resnet"), }) } } impl Module for SeaNetResnetBlock { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let mut ys = xs.clone(); for block in self.block.iter() { ys = ys.apply(&self.activation)?.apply(block)?; } match self.shortcut.as_ref() { None => ys + xs, Some(shortcut) => ys + xs.apply(shortcut), } } } impl StreamingModule for SeaNetResnetBlock { fn reset_state(&mut self) { self.skip_op.reset_state(); for block in self.block.iter_mut() { block.reset_state() } if let Some(shortcut) = self.shortcut.as_mut() { shortcut.reset_state() } } fn step(&mut self, xs: &StreamTensor) -> Result { let _enter = self.span.enter(); let mut ys = xs.clone(); for block in self.block.iter_mut() { ys = block.step(&ys.apply(&self.activation)?)?; } match self.shortcut.as_mut() { None => self.skip_op.step(&ys, xs), Some(shortcut) => self.skip_op.step(&ys, &shortcut.step(xs)?), } } } #[derive(Debug, Clone)] struct EncoderLayer { residuals: Vec, downsample: StreamableConv1d, } #[derive(Debug, Clone)] pub struct SeaNetEncoder { init_conv1d: StreamableConv1d, activation: candle_nn::Activation, layers: Vec, final_conv1d: StreamableConv1d, span: tracing::Span, } impl SeaNetEncoder { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { if cfg.lstm > 0 { candle::bail!("seanet lstm is not supported") } let n_blocks = 2 + cfg.ratios.len(); let mut mult = 1usize; let init_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) }; let mut layer_idx = 0; let vb = vb.pp("model"); let init_conv1d = StreamableConv1d::new( cfg.channels, mult * cfg.n_filters, cfg.kernel_size, /* stride */ 1, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ cfg.causal, /* norm */ init_norm, /* pad_mode */ cfg.pad_mode, vb.pp(layer_idx), )?; layer_idx += 1; let mut layers = Vec::with_capacity(cfg.ratios.len()); for (i, &ratio) in cfg.ratios.iter().rev().enumerate() { let norm = if cfg.disable_norm_outer_blocks >= i + 2 { None } else { Some(cfg.norm) }; let mut residuals = Vec::with_capacity(cfg.n_residual_layers); for j in 0..cfg.n_residual_layers { let resnet_block = SeaNetResnetBlock::new( mult * cfg.n_filters, &[ (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1), ], cfg.activation, norm, cfg.causal, cfg.pad_mode, cfg.compress, cfg.true_skip, vb.pp(layer_idx), )?; residuals.push(resnet_block); layer_idx += 1; } let downsample = StreamableConv1d::new( mult * cfg.n_filters, mult * cfg.n_filters * 2, /* k_size */ ratio * 2, /* stride */ ratio, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ true, /* norm */ norm, /* pad_mode */ cfg.pad_mode, vb.pp(layer_idx + 1), )?; layer_idx += 2; let layer = EncoderLayer { downsample, residuals, }; layers.push(layer); mult *= 2 } let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks { None } else { Some(cfg.norm) }; let final_conv1d = StreamableConv1d::new( mult * cfg.n_filters, cfg.dimension, cfg.last_kernel_size, /* stride */ 1, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ cfg.causal, /* norm */ final_norm, /* pad_mode */ cfg.pad_mode, vb.pp(layer_idx + 1), )?; Ok(Self { init_conv1d, activation: cfg.activation, layers, final_conv1d, span: tracing::span!(tracing::Level::TRACE, "sea-encoder"), }) } } impl Module for SeaNetEncoder { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let mut xs = xs.apply(&self.init_conv1d)?; for layer in self.layers.iter() { for residual in layer.residuals.iter() { xs = xs.apply(residual)? } xs = xs.apply(&self.activation)?.apply(&layer.downsample)?; } xs.apply(&self.activation)?.apply(&self.final_conv1d) } } impl StreamingModule for SeaNetEncoder { fn reset_state(&mut self) { self.init_conv1d.reset_state(); self.layers.iter_mut().for_each(|v| { v.residuals.iter_mut().for_each(|v| v.reset_state()); v.downsample.reset_state() }); self.final_conv1d.reset_state(); } fn step(&mut self, xs: &StreamTensor) -> Result { let _enter = self.span.enter(); let mut xs = self.init_conv1d.step(xs)?; for layer in self.layers.iter_mut() { for residual in layer.residuals.iter_mut() { xs = residual.step(&xs)?; } xs = layer.downsample.step(&xs.apply(&self.activation)?)?; } self.final_conv1d.step(&xs.apply(&self.activation)?) } } #[derive(Debug, Clone)] struct DecoderLayer { upsample: StreamableConvTranspose1d, residuals: Vec, } #[derive(Debug, Clone)] pub struct SeaNetDecoder { init_conv1d: StreamableConv1d, activation: candle_nn::Activation, layers: Vec, final_conv1d: StreamableConv1d, final_activation: Option, span: tracing::Span, } impl SeaNetDecoder { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { if cfg.lstm > 0 { candle::bail!("seanet lstm is not supported") } let n_blocks = 2 + cfg.ratios.len(); let mut mult = 1 << cfg.ratios.len(); let init_norm = if cfg.disable_norm_outer_blocks == n_blocks { None } else { Some(cfg.norm) }; let mut layer_idx = 0; let vb = vb.pp("model"); let init_conv1d = StreamableConv1d::new( cfg.dimension, mult * cfg.n_filters, cfg.kernel_size, /* stride */ 1, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ cfg.causal, /* norm */ init_norm, /* pad_mode */ cfg.pad_mode, vb.pp(layer_idx), )?; layer_idx += 1; let mut layers = Vec::with_capacity(cfg.ratios.len()); for (i, &ratio) in cfg.ratios.iter().enumerate() { let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks { None } else { Some(cfg.norm) }; let upsample = StreamableConvTranspose1d::new( mult * cfg.n_filters, mult * cfg.n_filters / 2, /* k_size */ ratio * 2, /* stride */ ratio, /* groups */ 1, /* bias */ true, /* causal */ true, /* norm */ norm, vb.pp(layer_idx + 1), )?; layer_idx += 2; let mut residuals = Vec::with_capacity(cfg.n_residual_layers); for j in 0..cfg.n_residual_layers { let resnet_block = SeaNetResnetBlock::new( mult * cfg.n_filters / 2, &[ (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1), ], cfg.activation, norm, cfg.causal, cfg.pad_mode, cfg.compress, cfg.true_skip, vb.pp(layer_idx), )?; residuals.push(resnet_block); layer_idx += 1; } let layer = DecoderLayer { upsample, residuals, }; layers.push(layer); mult /= 2 } let final_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) }; let final_conv1d = StreamableConv1d::new( cfg.n_filters, cfg.channels, cfg.last_kernel_size, /* stride */ 1, /* dilation */ 1, /* groups */ 1, /* bias */ true, /* causal */ cfg.causal, /* norm */ final_norm, /* pad_mode */ cfg.pad_mode, vb.pp(layer_idx + 1), )?; Ok(Self { init_conv1d, activation: cfg.activation, layers, final_conv1d, final_activation: cfg.final_activation, span: tracing::span!(tracing::Level::TRACE, "sea-decoder"), }) } } impl Module for SeaNetDecoder { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let mut xs = xs.apply(&self.init_conv1d)?; for layer in self.layers.iter() { xs = xs.apply(&self.activation)?.apply(&layer.upsample)?; for residual in layer.residuals.iter() { xs = xs.apply(residual)? } } let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?; let xs = match self.final_activation.as_ref() { None => xs, Some(act) => xs.apply(act)?, }; Ok(xs) } } impl StreamingModule for SeaNetDecoder { fn reset_state(&mut self) { self.init_conv1d.reset_state(); self.layers.iter_mut().for_each(|v| { v.residuals.iter_mut().for_each(|v| v.reset_state()); v.upsample.reset_state() }); self.final_conv1d.reset_state(); } fn step(&mut self, xs: &StreamTensor) -> Result { let _enter = self.span.enter(); let mut xs = self.init_conv1d.step(xs)?; for layer in self.layers.iter_mut() { xs = layer.upsample.step(&xs.apply(&self.activation)?)?; for residual in layer.residuals.iter_mut() { xs = residual.step(&xs)?; } } let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?; let xs = match self.final_activation.as_ref() { None => xs, Some(act) => xs.apply(act)?, }; Ok(xs) } } ================================================ FILE: kyuteye_rs/moshi-core/src/streaming.rs ================================================ // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. use candle::{Result, Tensor}; pub trait Dim: candle::shape::Dim + Copy {} impl Dim for T {} #[derive(Clone)] pub struct StreamTensor(Option); impl std::fmt::Debug for StreamTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.0 { Some(t) => write!(f, "{:?}", t.shape()), None => write!(f, "Empty"), } } } impl std::convert::From> for StreamTensor { fn from(value: Option) -> Self { Self(value) } } impl std::convert::From for StreamTensor { fn from(value: Tensor) -> Self { Self(Some(value)) } } impl std::convert::From<()> for StreamTensor { fn from(_value: ()) -> Self { Self(None) } } impl StreamTensor { pub fn empty() -> Self { Self(None) } pub fn from_tensor(tensor: Tensor) -> Self { Self(Some(tensor)) } pub fn shape(&self) -> Option<&candle::Shape> { self.0.as_ref().map(|t| t.shape()) } pub fn cat2(&self, rhs: &Self, dim: D) -> Result { let xs = match (&self.0, &rhs.0) { (Some(lhs), Some(rhs)) => { let xs = Tensor::cat(&[lhs, rhs], dim)?; Some(xs) } (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()), (None, None) => None, }; Ok(Self(xs)) } pub fn seq_len(&self, dim: D) -> Result { match &self.0 { None => Ok(0), Some(v) => v.dim(dim), } } pub fn reset(&mut self) { self.0 = None } pub fn narrow(&self, dim: D, offset: usize, len: usize) -> Result { let t = match &self.0 { None => None, Some(t) => { let seq_len = t.dim(dim)?; if seq_len <= offset { None } else { let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?; Some(t) } } }; Ok(Self(t)) } /// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements /// returned in the first output and the remaining in the second output. pub fn split(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> { match &self.0 { None => Ok((Self::empty(), Self::empty())), Some(t) => { let seq_len = t.dim(dim)?; let lhs_len = usize::min(seq_len, lhs_len); if lhs_len == 0 { Ok((Self::empty(), t.clone().into())) } else { let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?); let rhs_len = seq_len - lhs_len; let rhs = if rhs_len == 0 { Self::empty() } else { Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?) }; Ok((lhs, rhs)) } } } } pub fn as_option(&self) -> Option<&Tensor> { self.0.as_ref() } pub fn apply(&self, m: &M) -> Result { match &self.0 { None => Ok(Self::empty()), Some(t) => Ok(Self::from_tensor(t.apply(m)?)), } } } pub trait StreamingModule { // TODO: Should we also have a flush method? fn step(&mut self, xs: &StreamTensor) -> Result; fn reset_state(&mut self); } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum BinOp { Add, Mul, Sub, Div, } #[derive(Debug, Clone)] pub struct StreamingBinOp { prev_lhs: StreamTensor, prev_rhs: StreamTensor, pub op: BinOp, pub dim: candle::D, } impl StreamingBinOp { pub fn new(op: BinOp, dim: candle::D) -> Self { Self { prev_lhs: StreamTensor::empty(), prev_rhs: StreamTensor::empty(), op, dim, } } pub fn reset_state(&mut self) { self.prev_lhs.reset(); self.prev_rhs.reset(); } pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result { match self.op { BinOp::Add => Tensor::add(lhs, rhs), BinOp::Mul => Tensor::mul(lhs, rhs), BinOp::Sub => Tensor::sub(lhs, rhs), BinOp::Div => Tensor::div(lhs, rhs), } } pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result { let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?; let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?; let lhs_len = lhs.seq_len(self.dim)?; let rhs_len = rhs.seq_len(self.dim)?; let common_len = usize::min(lhs_len, rhs_len); let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?; let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?; let ys = match (lhs.0, rhs.0) { (Some(lhs), Some(rhs)) => { let ys = self.forward(&lhs, &rhs)?; StreamTensor::from_tensor(ys) } (None, None) => StreamTensor::empty(), (lhs, rhs) => candle::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"), }; self.prev_lhs = prev_lhs; self.prev_rhs = prev_rhs; Ok(ys) } } /// Simple wrapper that doesn't do any buffering. pub struct Map(T); impl StreamingModule for Map { fn reset_state(&mut self) {} fn step(&mut self, xs: &StreamTensor) -> Result { xs.apply(&self.0) } } ================================================ FILE: kyuteye_rs/moshi-core/src/transformer.rs ================================================ // Implements various modules for transformers with support for both quantized and unquantized forwards // Main differences between quantized and unquantized execution: // 1. For quantized models' attention `matmul_dtype`` converts intermediate activations to BF16 for // more efficient matmuls // 2. Quantized tensors cannot be easily split (regarding cross attention and QKV proj weights) // 3. Linear and Quantized linear layers are two different types use crate::nn::{ linear, linear_from, matmul_dtype, MaybeQuantizedLinear, MaybeQuantizedVarBuilder, }; use crate::streaming::{StreamTensor, StreamingModule}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::Context; use candle_nn; use std::sync::Arc; #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { pub d_model: usize, pub num_heads: usize, pub num_layers: usize, pub causal: bool, pub norm_first: bool, pub bias_ff: bool, pub bias_attn: bool, pub layer_scale: Option, pub positional_embedding: PositionalEmbedding, pub use_conv_block: bool, pub cross_attention: Option<(CrossAttentionGating, crate::NormType, Option)>, pub conv_kernel_size: usize, pub use_conv_bias: bool, pub gating: Option, pub norm: crate::NormType, pub context: usize, pub max_period: usize, pub max_seq_len: usize, pub kv_repeat: usize, pub dim_feedforward: usize, pub conv_layout: bool, } #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum PositionalEmbedding { Rope, Sin, None, } #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum CrossAttentionGating { // Configure Type of gating used at the output of vision cross-attention layers Normal, ConstantGatedTanh, ConstantGatedSigmoid, ConditionalGatedTanh, ConditionalGatedSigmoid, ConditionalGatedSigmoidLearnableBias, ConditionalGatedTanhLearnableBias, } #[derive(Debug, Clone)] pub enum CaSrc { // Input to cross-attention to handle cases where the // cross-attention source can be shared across timesteps and/or layers // either a single tensor (has yet to be projected) // or pre-computed K,V projections; Tokens(Tensor), KeysValues((Tensor, Tensor)), } #[derive(Debug, Clone)] pub struct LayerScale { scale: Tensor, } impl LayerScale { pub fn new(d_model: usize, _init: f64, vb: MaybeQuantizedVarBuilder) -> Result { let scale = vb.get_unquantized(d_model, "scale")?; Ok(Self { scale }) } } impl Module for LayerScale { fn forward(&self, xs: &Tensor) -> Result { xs.broadcast_mul(&self.scale) } } #[derive(Debug, Clone)] pub enum XaGate { // Optional gating at the output of a cross-attention layer // Normal: No gating | Identity Normal, // ConstantGated: Multiply by a scalar ConstantGated { alpha: Tensor, }, // ConditionalGated: Pass the input x through a small MLP; // The output yields a vector of scales (one for each channel) // that x is then multiplied by ConditionalGated { in_proj: MaybeQuantizedLinear, out_proj: MaybeQuantizedLinear, activation: candle_nn::init::NonLinearity, learnable_bias: bool, }, } impl XaGate { pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { let gating_cfg = cfg .cross_attention .map(|v| v.0) .context("no cross-attention specified")?; match gating_cfg { // no gating CrossAttentionGating::Normal => Ok(Self::Normal), // constant (per-layer parameter) with tanh activation CrossAttentionGating::ConstantGatedTanh => { let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?; Ok(Self::ConstantGated { alpha }) } // constant (per-layer parameter) with sigmoid activation CrossAttentionGating::ConstantGatedSigmoid => { let alpha = candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?; Ok(Self::ConstantGated { alpha }) } // input conditional (small MLP) with tanh or sigmoid act CrossAttentionGating::ConditionalGatedTanh | CrossAttentionGating::ConditionalGatedSigmoid | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias | CrossAttentionGating::ConditionalGatedTanhLearnableBias => { let dim = cfg.d_model; let hidden_dims = (0.125 * dim as f32).floor() as usize; let learnable_bias = matches!( gating_cfg, CrossAttentionGating::ConditionalGatedSigmoidLearnableBias | CrossAttentionGating::ConditionalGatedTanhLearnableBias ); let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?; let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?; let activation = match gating_cfg { CrossAttentionGating::ConditionalGatedTanh | CrossAttentionGating::ConditionalGatedTanhLearnableBias => { candle_nn::init::NonLinearity::Tanh } CrossAttentionGating::ConditionalGatedSigmoid | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => { candle_nn::init::NonLinearity::Sigmoid } _ => candle::bail!("Invalid cross-attention config specified."), }; Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias, }) } } } } impl XaGate { pub fn forward_with_gate_weight(&self, xs: &Tensor) -> Result<(Tensor, Option)> { match self { Self::Normal => Ok((xs.clone(), None)), Self::ConstantGated { alpha } => Ok((xs.broadcast_mul(alpha)?, None)), Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias, } => { let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?; let alpha = match (activation, learnable_bias) { (candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh()?, (candle_nn::init::NonLinearity::Sigmoid, true) => { candle_nn::ops::sigmoid(&alpha)? } (candle_nn::init::NonLinearity::Sigmoid, false) => { candle_nn::ops::sigmoid(&(alpha - 4.0)?)? } _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"), }; let out_alpha = alpha.mean_all()?; Ok(((xs * alpha)?, Some(out_alpha))) } } } } impl Module for XaGate { fn forward(&self, xs: &Tensor) -> Result { let (xs, _) = self.forward_with_gate_weight(xs)?; Ok(xs) } } #[derive(Debug, Clone)] pub struct StreamingMultiheadCrossAttention { //Cross-attention modules. Q and KV projections are separate // because x (speech tokens) and ca_src (cross-attention source) can have // different dimensions in_proj_q: MaybeQuantizedLinear, in_proj_kv: MaybeQuantizedLinear, out_proj: MaybeQuantizedLinear, kv_repeat: usize, num_heads: usize, neg_inf: Tensor, gate: XaGate, span: tracing::Span, } impl StreamingMultiheadCrossAttention { pub fn new( cfg: &Config, vb: MaybeQuantizedVarBuilder, gate_vb: Option, ) -> Result { let embed_dim = cfg.d_model; let num_kv = cfg.num_heads / cfg.kv_repeat; let out_kv_dim = num_kv * (embed_dim / cfg.num_heads); let out_dim = embed_dim + 2 * out_kv_dim; let device = vb.device(); // Case 1 (legacy): A single in_proj; i.e., both x and ca_src *must* have // the same number of dims this is only possible for non-quantized tensors though // as we will need to split Q/KV weights down the line even when they have the same // shape since they take different inputs let (in_proj_q, in_proj_kv) = if vb.contains_key("in_proj_weight") { match &vb { MaybeQuantizedVarBuilder::Quantized(_) => candle::bail!("Quantized cross-attention layers require a separate in_proj_weight_q and in_proj_weight_kv"), MaybeQuantizedVarBuilder::Real(weights) => { let in_proj_weight = weights.get((out_dim, embed_dim), "in_proj_weight")?; let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?; let in_proj_weight_kv = in_proj_weight.narrow(0, embed_dim, 2 * out_kv_dim)?; let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn { let b = weights.get(out_dim, "in_proj_bias")?; let in_proj_bias_q = b.narrow(0, 0, embed_dim)?; let in_proj_bias_kv = b.narrow(0, embed_dim, 2 * out_kv_dim)?; (Some(in_proj_bias_q), Some(in_proj_bias_kv)) } else { (None, None) }; (MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_q, in_proj_bias_q)), MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_kv, in_proj_bias_kv))) } } } else { // Case 2: Separate projections for query (x) and kv (ca_src) let kv_in_dim = match cfg.cross_attention.map(|v| v.2) { None => candle::bail!("cfg.cross_attention is None in cross_attention module"), Some(d) => match d { None | Some(0) => embed_dim, Some(dd) => dd, }, }; let in_proj_weight_q = vb.get((embed_dim, embed_dim), "in_proj_weight_q")?; let in_proj_weight_kv = vb.get((2 * out_kv_dim, kv_in_dim), "in_proj_weight_kv")?; // Biases are always unquantized let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn { ( Some(vb.get_unquantized(embed_dim, "in_proj_bias_q")?), Some(vb.get_unquantized(2 * out_kv_dim, "in_proj_bias_kv")?), ) } else { (None, None) }; // Finally, we can build the actual linear layers let in_proj_q = linear_from(in_proj_weight_q, in_proj_bias_q)?; let in_proj_kv = linear_from(in_proj_weight_kv, in_proj_bias_kv)?; (in_proj_q, in_proj_kv) }; let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let neg_inf = match &vb { MaybeQuantizedVarBuilder::Real(weights) => neg_inf.to_dtype(weights.dtype())?, _ => neg_inf, }; let gate = match gate_vb { None => XaGate::new(cfg, vb.pp("gate"))?, Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?, }; Ok(Self { in_proj_q, in_proj_kv, out_proj, kv_repeat: cfg.kv_repeat, num_heads: cfg.num_heads, neg_inf, gate, span: tracing::span!(tracing::Level::TRACE, "mhca"), }) } pub fn is_quantized(&self) -> bool { match self.in_proj_q { MaybeQuantizedLinear::Quantized(_) => true, MaybeQuantizedLinear::Real(_) => false, } } pub fn compute_kv(&self, ca_src: &CaSrc) -> Result<(Tensor, Tensor)> { // this is used twice: // in the standard forward pass of the cross-attention // for vision models, after loading an image we can precompute its KV projections // as the image is constant across multiple timesteps match ca_src { CaSrc::KeysValues(cakv) => Ok(cakv.clone()), CaSrc::Tokens(xs) => { let kv = xs.apply(&self.in_proj_kv)?; let (ca_b, ca_t, ca_dim) = kv.dims3()?; let head_dim = ca_dim / (2 * self.num_heads); let kv = kv.reshape((ca_b, ca_t, 2, (), head_dim))?; // convert to correct float point type for quantized models let kv = if self.is_quantized() { kv.to_dtype(matmul_dtype(xs.device()))? } else { kv }; let k = kv.i((.., .., 0))?; let v = kv.i((.., .., 1))?; let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d Ok((k, v)) } } } pub fn forward_with_gate_weight( &self, xs: &Tensor, ca_src: &CaSrc, mask: Option<&Tensor>, ) -> Result<(Tensor, Option)> { let _enter = self.span.enter(); if self.kv_repeat != 1 { candle::bail!("only kv-repeat = 1 is supported") } let (b, t, hd) = xs.dims3()?; let head_dim = hd / self.num_heads; // time_dim = 1, layout: b,t,h,d let q = xs.apply(&self.in_proj_q)?; let original_dtype = q.dtype(); let q = q.reshape((b, t, self.num_heads, head_dim))?; let q = if self.is_quantized() { q.to_dtype(matmul_dtype(xs.device()))? } else { q }; let (k, v) = self.compute_kv(ca_src)?; // qk_layer_norm = None // kv_repeat = 1, otherwise we would need repeat_kv let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?; let pre_ws = match mask { None => pre_ws, Some(mask) => { let mask = mask.broadcast_left((b, self.num_heads))?; let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; mask.where_cond(&neg_inf, &pre_ws)? } }; let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k let xs = ws.matmul(&v)?; // b,h,t,d let xs = xs .transpose(1, 2)? // b,t,h,d .reshape((b, t, hd))? .to_dtype(original_dtype)? .apply(&self.out_proj)?; self.gate.forward_with_gate_weight(&xs) } pub fn forward(&self, xs: &Tensor, ca_src: &CaSrc, mask: Option<&Tensor>) -> Result { let (xs, _) = self.forward_with_gate_weight(xs, ca_src, mask)?; Ok(xs) } } #[derive(Debug, Clone)] pub struct RotaryEmbedding { sin: Tensor, cos: Tensor, span: tracing::Span, } impl RotaryEmbedding { pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result { let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, span: tracing::span!(tracing::Level::TRACE, "rot"), }) } pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result { let _enter = self.span.enter(); let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?; let qk_dtype = qk.dtype(); let c = self.cos.narrow(0, seqlen_offset, seqlen)?; let s = self.sin.narrow(0, seqlen_offset, seqlen)?; candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype) } } pub(crate) fn get_causal_mask( size1: usize, size2: usize, context: usize, device: &Device, ) -> Result { let mask: Vec<_> = (0..size1) .flat_map(|i| { (0..size2) .map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)) }) .collect(); Tensor::from_slice(&mask, (size1, size2), device) } #[cfg(feature = "flash-attn")] fn flash_attn( q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32, causal: bool, ) -> Result { candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) } #[cfg(not(feature = "flash-attn"))] fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { unimplemented!("compile with '--features flash-attn'") } #[derive(Debug, Clone)] pub struct StreamingMultiheadAttention { // Self-attention with KV Cache in_proj: MaybeQuantizedLinear, out_proj: MaybeQuantizedLinear, kv_repeat: usize, num_heads: usize, context: usize, neg_inf: Tensor, rope: Option>, kv_cache: candle_nn::kv_cache::KvCache, use_kv_cache: bool, use_flash_attn: bool, pos: usize, span: tracing::Span, } impl StreamingMultiheadAttention { pub fn new( rope: &Option>, cfg: &Config, vb: MaybeQuantizedVarBuilder, ) -> Result { let embed_dim = cfg.d_model; let num_kv = cfg.num_heads / cfg.kv_repeat; let out_dim = embed_dim + 2 * num_kv * (embed_dim / cfg.num_heads); let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?; let in_proj_bias = if cfg.bias_attn { Some(vb.get_unquantized(out_dim, "in_proj_bias")?) } else { None }; let in_proj = linear_from(in_proj_weight, in_proj_bias)?; let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?; let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?; let neg_inf = match vb { MaybeQuantizedVarBuilder::Real(weights) => neg_inf.to_dtype(weights.dtype())?, _ => neg_inf, }; Ok(Self { in_proj, out_proj, rope: rope.clone(), kv_repeat: cfg.kv_repeat, num_heads: cfg.num_heads, context: cfg.context, neg_inf, kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len), use_kv_cache: true, use_flash_attn: false, pos: 0, span: tracing::span!(tracing::Level::TRACE, "mha"), }) } pub fn is_quantized(&self) -> bool { match self.in_proj { MaybeQuantizedLinear::Quantized(_) => true, MaybeQuantizedLinear::Real(_) => false, } } pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { let _enter = self.span.enter(); if self.kv_repeat != 1 { candle::bail!("only kv-repeat = 1 is supported") } let (b, t, hd) = xs.dims3()?; let head_dim = hd / self.num_heads; // time_dim = 1, layout: b,t,h,d let qkv = xs .apply(&self.in_proj)? .reshape((b, t, 3, self.num_heads, head_dim))?; let original_dtype = qkv.dtype(); let qkv = if self.is_quantized() { qkv.to_dtype(matmul_dtype(xs.device()))? } else { qkv }; let q = qkv.i((.., .., 0))?; let k = qkv.i((.., .., 1))?; let v = qkv.i((.., .., 2))?; // qk_layer_norm = None // kv_repeat = 1, otherwise we would need repeat_kv let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d if let Some(rope) = &self.rope { q = rope.apply_rotary_emb(&q, self.pos)?; k = rope.apply_rotary_emb(&k, self.pos)?; } let (k, v) = if self.use_kv_cache { self.pos += k.dim(2)?; self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)? } else { (k, v) }; // The KV cache keeps all the data at the moment, we want to trim // down the part that comes from the cache to at most context to // be coherent with the mask shape we provide. let k_len = k.dim(2)?; let k_target_len = t + usize::min(self.context, k_len - t); let (k, v) = if k_target_len < k_len { let k = k.narrow(2, k_len - k_target_len, k_target_len)?; let v = v.narrow(2, k_len - k_target_len, k_target_len)?; (k, v) } else { (k.clone(), v.clone()) }; let xs = if q.dtype() == DType::BF16 && self.use_flash_attn { let q = q.transpose(1, 2)?; let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; let softmax_scale = 1f32 / (head_dim as f32).sqrt(); flash_attn(&q, &k, &v, softmax_scale, mask.is_some())?.transpose(1, 2)? } else { let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?; let pre_ws = match mask { None => pre_ws, Some(mask) => { let mask = mask.broadcast_left((b, self.num_heads))?; let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; mask.where_cond(&neg_inf, &pre_ws)? } }; let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k ws.matmul(&v)? // b,h,t,d }; let xs = xs .transpose(1, 2)? // b,t,h,d .reshape((b, t, hd))? .to_dtype(original_dtype)? .apply(&self.out_proj)?; Ok(xs) } pub fn reset_kv_cache(&mut self) { self.kv_cache.reset() } pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { self.kv_cache = kv_cache } } #[derive(Debug, Clone)] pub enum Mlp { //Feed Forward layers NoGating { linear1: MaybeQuantizedLinear, linear2: MaybeQuantizedLinear, }, Gating { linear_in: MaybeQuantizedLinear, linear_out: MaybeQuantizedLinear, activation: candle_nn::Activation, }, } impl Mlp { pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { let d_model = cfg.d_model; match cfg.gating { None => { let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?; let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?; Ok(Self::NoGating { linear1, linear2 }) } Some(activation) => { let vb = vb.pp("gating"); let hidden = if cfg.dim_feedforward == 4 * d_model { 11 * d_model / 4 } else { 2 * cfg.dim_feedforward / 3 }; let linear_in = linear(d_model, 2 * hidden, cfg.bias_ff, vb.pp("linear_in"))?; let linear_out = linear(hidden, d_model, cfg.bias_ff, vb.pp("linear_out"))?; Ok(Self::Gating { linear_in, linear_out, activation, }) } } } } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { match self { Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2), Self::Gating { linear_in, linear_out, activation, } => { let xs = xs.apply(linear_in)?; let (b, t, _) = xs.dims3()?; let xs = xs.reshape((b, t, 2, ()))?; let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?; xs.apply(linear_out) } } } } #[derive(Debug, Clone)] pub struct RmsNorm { pub(crate) alpha: Tensor, pub(crate) eps: f32, } impl RmsNorm { pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result { let alpha = vb .get_unquantized((1, 1, d_model), "alpha")? .reshape(d_model)?; Ok(Self { alpha, eps }) } } impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result { candle_nn::ops::rms_norm(xs, &self.alpha, self.eps) } } #[derive(Debug, Clone)] pub struct LayerNorm { inner: candle_nn::LayerNorm, } impl LayerNorm { pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result { let bias = vb.get_unquantized(d_model, "bias")?; let alpha = if vb.contains_key("alpha") { vb.get_unquantized((1, 1, d_model), "alpha")? .reshape(d_model)? } else { vb.get_unquantized(d_model, "weight")?.reshape(d_model)? }; let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64); Ok(Self { inner }) } } impl Module for LayerNorm { fn forward(&self, xs: &Tensor) -> Result { self.inner.forward(xs) } } #[derive(Debug, Clone)] pub enum Norm { LayerNorm(LayerNorm), RmsNorm(RmsNorm), } impl Norm { pub fn new(d_model: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { let norm = Self::new_shortcut(d_model, cfg.norm, vb)?; Ok(norm) } pub fn new_shortcut( d_model: usize, typ: crate::NormType, vb: MaybeQuantizedVarBuilder, ) -> Result { let norm = match typ { crate::NormType::LayerNorm => { let norm = LayerNorm::new(d_model, 1e-5, vb)?; Self::LayerNorm(norm) } crate::NormType::RmsNorm => { let norm = RmsNorm::new(d_model, 1e-8, vb)?; Self::RmsNorm(norm) } }; Ok(norm) } } impl Module for Norm { fn forward(&self, xs: &Tensor) -> Result { match self { Self::LayerNorm(m) => m.forward(xs), Self::RmsNorm(m) => m.forward(xs), } } } #[derive(Debug, Clone)] pub struct StreamingTransformerLayer { self_attn: StreamingMultiheadAttention, mlp: Mlp, norm1: Norm, norm2: Norm, layer_scale_1: Option, layer_scale_2: Option, cross_attn: Option<(Norm, StreamingMultiheadCrossAttention)>, norm_first: bool, span: tracing::Span, } impl StreamingTransformerLayer { pub fn new( rope: &Option>, cfg: &Config, vb: MaybeQuantizedVarBuilder, shared_ca_vb: Option, ) -> Result { if cfg.use_conv_block { candle::bail!("conv-block is not supported") } let d_model = cfg.d_model; let mlp = Mlp::new(cfg, vb.clone())?; let norm1 = Norm::new(d_model, cfg, vb.pp("norm1"))?; let norm2 = Norm::new(d_model, cfg, vb.pp("norm2"))?; let layer_scale_1 = match cfg.layer_scale { None => None, Some(ls) => { let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_1"))?; Some(ls) } }; let layer_scale_2 = match cfg.layer_scale { None => None, Some(ls) => { let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_2"))?; Some(ls) } }; let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; let cross_attn = match cfg.cross_attention.map(|v| v.1) { Some(norm_type) => { let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?; let cross_attn = match shared_ca_vb { None => { StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)? } Some(shared_vb) => StreamingMultiheadCrossAttention::new( cfg, shared_vb.pp("cross_attention"), Some(vb.pp("cross_attention.gate")), )?, }; Some((norm_cross, cross_attn)) } None => None, }; Ok(Self { self_attn, mlp, norm1, norm2, layer_scale_1, layer_scale_2, cross_attn, norm_first: cfg.norm_first, span: tracing::span!(tracing::Level::TRACE, "transformer-layer"), }) } pub fn forward_with_gate_weight( &mut self, xs: &Tensor, ca_src: Option<&CaSrc>, mask: Option<&Tensor>, ) -> Result<(Tensor, Option)> { let _enter = self.span.enter(); if !self.norm_first { candle::bail!("only norm_first = true is supported") } let norm1 = xs.apply(&self.norm1)?; let xs = (xs + self .self_attn .forward(&norm1, mask)? .apply(&self.layer_scale_1.as_ref())?)?; let (xs, alpha) = match (self.cross_attn.as_mut(), ca_src) { (Some((norm_cross, cross_attn)), Some(ca_src)) => { let residual = &xs; let xs = xs.apply(norm_cross)?; let (xs, alpha) = cross_attn.forward_with_gate_weight(&xs, ca_src, None)?; ((residual + xs)?, alpha) } _ => (xs, None), }; let xs = (&xs + xs.apply(&self.norm2)? .apply(&self.mlp)? .apply(&self.layer_scale_2.as_ref()))?; Ok((xs, alpha)) } pub fn forward( &mut self, xs: &Tensor, ca_src: Option<&CaSrc>, mask: Option<&Tensor>, ) -> Result { let (xs, _) = self.forward_with_gate_weight(xs, ca_src, mask)?; Ok(xs) } pub fn reset_kv_cache(&mut self) { self.self_attn.reset_kv_cache(); } pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { self.self_attn.set_kv_cache(kv_cache); } } #[derive(Debug, Clone)] pub struct StreamingTransformer { // Main transformer layers: Vec, context: usize, positional_embedding: PositionalEmbedding, max_period: usize, causal: bool, } impl StreamingTransformer { pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { let vb_l = vb.pp("layers"); let rope = match cfg.positional_embedding { PositionalEmbedding::Rope => { let rope = RotaryEmbedding::new( cfg.d_model / cfg.num_heads, cfg.max_seq_len, cfg.max_period as f32, vb.device(), )?; Some(Arc::new(rope)) } PositionalEmbedding::None | PositionalEmbedding::Sin => None, }; let mut layers = Vec::with_capacity(cfg.num_layers); for layer_idx in 0..cfg.num_layers { // Also send weights of first layer as only it contains the KQV proj weights // for shared cross-attention layers let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx), Some(vb_l.pp(0)))?; layers.push(layer) } Ok(Self { layers, context: cfg.context, positional_embedding: cfg.positional_embedding, max_period: cfg.max_period, causal: cfg.causal, }) } pub fn forward(&mut self, xs: &Tensor) -> Result { self.forward_ca(xs, None) } pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&CaSrc>) -> Result { let (xs, _) = self.forward_with_gate_weight(xs, ca_src)?; Ok(xs) } pub fn forward_with_gate_weight( &mut self, xs: &Tensor, ca_src: Option<&CaSrc>, ) -> Result<(Tensor, Tensor)> { let (_b, t, c) = xs.dims3()?; // We will extract at most "context" from the kv_cache. // Note that the mask will discard the values that are before context. let pos = self.layers[0] .self_attn .kv_cache .k_cache() .current_seq_len() .min(self.context); let mask = if t == 1 || !self.causal { None } else { Some(get_causal_mask(t, pos + t, self.context, xs.device())?) }; let mut xs = match self.positional_embedding { PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(), PositionalEmbedding::Sin => { let dev = xs.device(); let theta = self.max_period as f32; let half_dim = c / 2; let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)? .unsqueeze(1)? .to_dtype(DType::F32)?; let inv_freq: Vec<_> = (0..half_dim) .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let freqs = positions.broadcast_mul(&inv_freq)?; let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?; xs.broadcast_add(&pos_emb)? } }; let mut gate_alpha; let num_layers = self.layers.len(); let num_last_layers_in_avg = if num_layers < 10 { num_layers } else { 10 }; let mut alpha = Tensor::zeros((), xs.dtype(), xs.device())?; for (layer_idx, layer) in self.layers.iter_mut().enumerate() { (xs, gate_alpha) = layer.forward_with_gate_weight(&xs, ca_src, mask.as_ref())?; if layer_idx >= num_layers - num_last_layers_in_avg { alpha = match gate_alpha { None => alpha, Some(x) => (alpha + x)?, // sum across layers }; }; } let alpha = (alpha / (num_last_layers_in_avg as f64))?; Ok((xs, alpha)) } pub fn maybe_precompute_ca_kv(&self, ca_src: Option) -> Result> { let ca_src = match ca_src { None => None, Some(CaSrc::KeysValues(_)) => ca_src, Some(tokens) => { if self.layers.is_empty() { Some(tokens) } else { match &self.layers[0].cross_attn { None => Some(tokens), Some((_, ca_module)) => { let (k, v) = ca_module.compute_kv(&tokens)?; Some(CaSrc::KeysValues((k, v))) } } } } }; Ok(ca_src) } pub fn copy_state(&mut self, from: &Self) -> Result<()> { if self.layers.len() != from.layers.len() { candle::bail!("cannot copy kv-caches as the transformers have different depths") } self.layers .iter_mut() .zip(from.layers.iter()) .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone())); Ok(()) } } impl StreamingModule for StreamingTransformer { fn reset_state(&mut self) { self.layers.iter_mut().for_each(|v| v.reset_kv_cache()) } fn step(&mut self, xs: &StreamTensor) -> Result { match xs.as_option() { None => Ok(StreamTensor::empty()), Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)), } } } #[derive(Debug, Clone)] pub struct ProjectedTransformer { // Projected transformer with unquantized projection transformer: StreamingTransformer, input_proj: Option, output_projs: Vec>, conv_layout: bool, span: tracing::Span, } impl ProjectedTransformer { pub fn new( input_dim: usize, output_dims: &[usize], cfg: &Config, vb: MaybeQuantizedVarBuilder, ) -> Result { let transformer = StreamingTransformer::new(cfg, vb.pp("transformer"))?; let input_proj = if input_dim == cfg.d_model { None } else { let l = linear(input_dim, cfg.d_model, false, vb.pp("input_proj"))?; Some(l) }; let mut output_projs = Vec::with_capacity(output_dims.len()); let vb_o = vb.pp("output_projs"); for (i, &output_dim) in output_dims.iter().enumerate() { let output_proj = if output_dim == cfg.d_model { None } else { let l = linear(cfg.d_model, output_dim, false, vb_o.pp(i))?; Some(l) }; output_projs.push(output_proj) } Ok(Self { transformer, input_proj, output_projs, conv_layout: cfg.conv_layout, span: tracing::span!(tracing::Level::TRACE, "proj-transformer"), }) } pub fn forward(&mut self, xs: &Tensor) -> Result> { let _enter = self.span.enter(); let xs = if self.conv_layout { xs.transpose(1, 2)? } else { xs.clone() }; let xs = xs.apply(&self.input_proj.as_ref())?; let xs = self.transformer.forward(&xs)?; let mut ys = Vec::with_capacity(self.output_projs.len()); for output_proj in self.output_projs.iter() { let ys_ = xs.apply(&output_proj.as_ref())?; let ys_ = if self.conv_layout { ys_.transpose(1, 2)? } else { ys_ }; ys.push(ys_) } Ok(ys) } } impl StreamingModule for ProjectedTransformer { fn reset_state(&mut self) { self.transformer.reset_state() } fn step(&mut self, xs: &StreamTensor) -> Result { let xs = xs.apply(&|x: &Tensor| { if self.conv_layout { x.transpose(1, 2) } else { Ok(x.clone()) } })?; let xs = xs.apply(&self.input_proj.as_ref())?; let xs = self.transformer.step(&xs)?; let ys = xs.apply(&self.output_projs[0].as_ref())?; ys.apply(&|y: &Tensor| { if self.conv_layout { y.transpose(1, 2) } else { Ok(y.clone()) } }) } } ================================================ FILE: scripts/convert_ckpt_utils.py ================================================ # /// script # requires-python = ">=3.10" # dependencies = [ # "fire", # "numpy", # "rich", # "safetensors", # "torch", # ] # /// """Utils for converting checkpoint formats across the different backends uv run scripts/convert_ckpt_utils.py rust_to_pt safetensors_file_path.safetensors -> will output in safetensors_file_path_pt.safetensors uv run scripts/convert_ckpt_utils.py pt_to_mlx safetensors_file_path.safetenros -> will output in safetensors_file_path_mlx.safetensors """ import os import re from typing import Dict, Optional import fire import rich import torch from safetensors.torch import load_file, save_file def remove_other_output_codebooks( tensors: Dict[str, torch.Tensor], dep_q: int = 8 ) -> None: """Remove output codebooks corresponding to OTHER""" for key in list(tensors.keys()): if (m := re.match(r"^depformer_in.([0-9]+).", key)) or ( m := re.match(r"^audio_linears.([0-9]+).", key) ): layer = int(m.groups()[0]) if layer >= dep_q: del tensors[key] if m := re.match(r"^depformer_emb.([0-9]+).", key): layer = int(m.groups()[0]) if layer >= dep_q - 1: del tensors[key] if m := re.match(r"^depformer.layers.[0-9+].gating.([0-9]+).", key): layer = int(m.groups()[0]) if layer >= dep_q: del tensors[key] class Launcher: def rust_to_pt(self, safetensors_file: str, out_file: Optional[str] = None) -> None: """Rust ckpt to Pytorch ckpt Usage: uv run scripts/convert_ckpt_utils.py rust_to_pt path_to_rust_ckpt.safetensors """ safetensors_file = os.path.abspath(safetensors_file) assert safetensors_file.endswith(".safetensors") if out_file is None: out_file = safetensors_file.rsplit(".", 1)[0] + "_pt.safetensors" assert out_file is not None state_dict = load_file(safetensors_file) new_state_dict = {} accumulate_attention: Dict[str, Dict] = {} for key, value in state_dict.items(): new_key = key if m := re.match( r"(.*)cross_attention.(in_proj_weight|out_proj.weight)", key ): new_key = f"llm.{m.groups()[0]}cross_attention.mha.{m.groups()[1]}" elif m := re.match(r"depformer.([0-9]+).emb.(.*)", key): idx = int(m.groups()[0]) new_key = ( "depformer_text_emb." if idx == 0 else f"depformer_emb.{idx - 1}." ) + m.groups()[1] elif m := re.match(r"depformer.([0-9]+).linear_in.(.*)", key): new_key = f"depformer_in.{idx}.{m.groups()[1]}" elif m := re.match(r"depformer.([0-9]+).linear_out.(.*)", key): new_key = f"audio_linears.{idx}.{m.groups()[1]}" elif m := re.match( r"depformer.([0-9]+).transformer.layers.([0-9]+).gating.(.*)", key ): layer_idx = int(m.groups()[1]) codebook_idx = int(m.groups()[0]) new_key = f"depformer.layers.{layer_idx}.gating.{codebook_idx}.{m.groups()[2]}" elif m := re.match( r"depformer.([0-9]+).transformer.layers.([0-9]+).self_attn.(.*)", key ): layer_idx = int(m.groups()[1]) codebook_idx = int(m.groups()[0]) new_key = f"depformer.layers.{layer_idx}.self_attn.{m.groups()[2]}" if new_key not in accumulate_attention: accumulate_attention[new_key] = {} accumulate_attention[new_key][codebook_idx] = value continue elif m := re.match( r"depformer.([0-9]+).transformer.layers.([0-9]+).norm(1|2).(.*)", key ): layer_idx = int(m.groups()[1]) codebook_idx = int(m.groups()[0]) if codebook_idx > 0: continue new_key = ( f"depformer.layers.{layer_idx}.norm{m.groups()[2]}.{m.groups()[3]}" ) elif m := re.match(r"emb.(.*)", key): new_key = "audio_emb." + m.groups()[0] elif ( key.startswith("transformer") or key.startswith("audio") or key in {"text_emb.weight", "out_norm.alpha", "text_linear.weight"} ): new_key = f"llm.{key}" # be careful not to override anything assert new_key not in new_state_dict, new_key new_state_dict[new_key] = value for key, sd in accumulate_attention.items(): tensor = torch.concatenate([sd[codebook] for codebook in sorted(sd)]) new_state_dict[key] = tensor save_file(new_state_dict, out_file) rich.print(f"Saved converted state dict in [yellow]{out_file}[/yellow]") def pt_to_mlx(self, safetensors_file: str, out_file: Optional[str] = None) -> None: """Pytorch to MLX ckpt conversion""" safetensors_file = os.path.abspath(safetensors_file) assert safetensors_file.endswith(".safetensors") if out_file is None: if safetensors_file.endswith("_pt.safetensors"): out_file = safetensors_file.rsplit("_", 1)[0] + "_mlx.safetensors" else: out_file = safetensors_file.rsplit(".", 1)[0] + "_mlx.safetensors" assert out_file is not None state_dict = load_file(safetensors_file) model = {} in_n_q: int | None = None for idx in range(999): name = f"audio_emb.{idx}.weight" if name not in state_dict: in_n_q = idx break assert in_n_q is not None, "audio_emb weights not found in src checkpoint" out_n_q: int | None = None for idx in range(999): name = f"audio_linears.{idx}.weight" if name not in state_dict: out_n_q = idx break assert out_n_q is not None, "audio_emb weights not found in src checkpoint" for name in ["text_emb.weight", "text_linear.weight"]: model[name] = state_dict["llm." + name] model["out_norm.weight"] = state_dict["llm." + "out_norm.alpha"][0, 0] for idx in range(in_n_q): src_name = f"audio_emb.{idx}.weight" dst_name = f"audio_embs.{idx}.weight" model[dst_name] = state_dict[src_name] exported_out_n_q = out_n_q for idx in range(exported_out_n_q): base = f"depformer.slices.{idx}." model[base + "linear_in.weight"] = state_dict[f"depformer_in.{idx}.weight"] model[base + "linear_out.weight"] = state_dict[ f"audio_linears.{idx}.weight" ] if idx == 0: model[base + "emb.weight"] = state_dict["depformer_text_emb.weight"] else: model[base + "emb.weight"] = state_dict[ f"depformer_emb.{idx - 1}.weight" ] for layer_idx in range(6): layer = base + f"transformer.layers.{layer_idx}." # WARNING: note that this uses in_proj_weight vs out_proj.weight model[layer + "self_attn.in_proj.weight"] = ( state_dict[f"depformer.layers.{layer_idx}.self_attn.in_proj_weight"] .chunk(out_n_q)[idx] .clone() ) model[layer + "self_attn.out_proj.weight"] = ( state_dict[ f"depformer.layers.{layer_idx}.self_attn.out_proj.weight" ] .chunk(out_n_q)[idx] .clone() ) model[layer + "norm1.weight"] = state_dict[ f"depformer.layers.{layer_idx}.norm1.alpha" ][0, 0].clone() model[layer + "norm2.weight"] = state_dict[ f"depformer.layers.{layer_idx}.norm2.alpha" ][0, 0].clone() model[layer + "gating.linear_in.weight"] = state_dict[ f"depformer.layers.{layer_idx}.gating.{idx}.linear_in.weight" ] model[layer + "gating.linear_out.weight"] = state_dict[ f"depformer.layers.{layer_idx}.gating.{idx}.linear_out.weight" ] for key in [ "image_prefix.norm_xa.alpha", "image_prefix.proj_xa.bias", "image_prefix.proj_xa.weight", ]: model["transformer." + key.replace("alpha", "weight")] = ( state_dict[key][0, 0] if "alpha" in key else state_dict[key] ) for k, v in state_dict.items(): if k.startswith("image_prefix.enc."): model["img_embedder." + k[len("image_prefix.enc.") :]] = state_dict[k] continue elif not k.startswith("llm.") or k in { "llm.out_norm.alpha", "llm.text_emb.weight", "llm.text_linear.weight", }: continue k = k.replace("llm.", "") k = k.replace("in_proj_weight", "in_proj.weight") k = k.replace( "cross_attention.gate.alpha.", "cross_attention.gate.alpha.layers." ) if k == "transformer.layers.0.cross_attention.mha.in_proj.weight": query, key, value = v.chunk(3) model["transformer.layers.0.cross_attention.mha.kv_proj.weight"] = ( torch.cat([key, value], dim=0) ) model["transformer.layers.0.cross_attention.mha.q_proj.weight"] = query elif k == "transformer.layers.0.cross_attention.mha.out_proj.weight": model["transformer.layers.0.cross_attention.mha.out_proj.weight"] = v elif m := re.match(r"transformer.layers.\d+.norm(\d|_cross).alpha", k): model[m.group().replace("alpha", "weight")] = v[0, 0] else: model[k] = v print(f"Total Params: {sum([v.numel() for v in model.values()])/1e6}") save_file(model, out_file) if __name__ == "__main__": fire.Fire(Launcher) ================================================ FILE: scripts/get_static_client.py ================================================ # /// script # requires-python = ">=3.10" # dependencies = [ # "fire", # "huggingface-hub", # "rich", # ] # /// import shutil import tarfile from pathlib import Path import fire import rich from huggingface_hub import hf_hub_download def get() -> None: """Download archived sources and unzip""" root_dir = Path(__file__).parents[1] rich.print("[green][INFO][/green] retrieving the static content") dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "vis_dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" if not dist.exists(): with tarfile.open(dist_tgz, "r:gz") as tar: tar.extractall(path=dist_tgz.parent) tgt_path = str(root_dir / "client" / "dist") shutil.move(dist, tgt_path) print(f"Static sources downloaded to {tgt_path}") if __name__ == "__main__": fire.Fire(get) ================================================ FILE: ssvd/README.md ================================================ # Synthetic visual dialogues pipeline This directory contains the data generation pipeline for synthetic visual dialogues described in Section 3.3 and Appendix E of [our preprint](https://arxiv.org/abs/2503.15633). We generate synthetic visual dialogues in text-only form with a variety of instruction prompting, using an image captioning dataset as source of information (here, DOCCI, PixMo or PixelProse). We use two LLMs (here, both based on Mistral-NeMo) to represent the user and assistant respectively, each taking a turn in the conversation one after the other, and characterized by their respective instruct prompt. For more information, please refer to Section 3.3 and Appendix E of [our preprint](https://arxiv.org/abs/2503.15633). ### Quickstart As for the rest of the codebase, we recommend using [uv](https://docs.astral.sh/uv/) to setup and run the code, as it will manage all dependencies for you transparently. `uv` is provided as a lightweight binary and can be installed as: ```bash curl -LsSf https://astral.sh/uv/install.sh | sh ``` To generate a single visual dialogue for the first sample in DOCCI train split and display the output with "LEAD1" instruct, run: ```bash cd ssvd; uv run generate.py run --task lead1 --dataset docci --verbose True --end_idx 1 ``` *Example output: (train_00000 of DOCCI)* > - USER: First things first, what's with the extra limbs? Is that a human-like creature or some sort of insect? > - MOSHIVIS: I'd say it's a unique blend. It has four arms and two legs like a human, but those antennae and the insect-like features on its head suggest something else. It's quite the fusion! > - USER: Four arms, two legs, antennae... I see. So, is it also wearing a suit and tie? > - MOSHIVIS: No, it's not wearing a suit or tie. It has purple skin and is wearing a green cap with orange lines. No signs of formal clothing here. > - USER: Yeah, I see the cap. It says 'PAN' on it. Is that a musical instrument on its head? > - MOSHIVIS: No, that's not an instrument on its head. It's a green cap with orange lines, and 'PAN' is written in black on the front. > - USER: Oh, so it's just a cap? What's up with the orange shoes then? > - MOSHIVIS: The character is indeed wearing orange shoes. They have a green tongue and white, loose laces. Quite a vibrant pair! To generate synthetic visual dialogues for each of the image in the DOCCI train split for the most general "COMB" instruct, run the following command. The output will be generated in the database `./synthetic_visual_dialogues/docci_ssvd.db` ```bash uv run generate.py run --task comb --dataset docci --batch_size 64 ``` If you want to generate one more dialogue for each of the image next, simply re-run the command in resume mode: ```bash uv run generate.py run --task comb --dataset docci --batch_size 64 --overwrite resume ``` For further options of the `run` command: ```bash uv run generate.py run --help ``` To view a sample of the generated database (e.g., 10th sample), you can use: ```bash uv run generate.py watch --task comb --dataset docci --idx 10 ``` *Example output: (train_00012 of DOCCI)* > - USER: So, what's going on in this image then? > - MOSHIVIS: A stone brick wall, slightly bowed, occupies the center, with limestone-colored bricks and brown mortar. Above it, a long, wide shadow rises towards the top right, resembling a crack in the smooth, light tan cave ceiling, dotted with white splotches and lit by a spotlight. > - USER: What color are the bricks and mortar of the wall? > - MOSHIVIS: Limestone-colored bricks with brown mortar. > - USER: How many bricks can you count in the wall? > - MOSHIVIS: I'm afraid the image doesn't provide a clear count of individual bricks. The wall appears to be made up of limestone-colored bricks with brown mortar, but the perspective and lighting make it difficult to distinguish each brick. > - USER: I see. How many distinct sections of stone brick wall can you make out in the image? > - MOSHIVIS: There are three distinct sections of stone brick wall in the image. The largest one is bowed and takes up the central part, with two smaller sections in the top left and bottom right corners. ### Setup We currently support three captioning datasets out of the box: [DOCCI](https://huggingface.co/datasets/google/docci), [PixMo-Cap](https://huggingface.co/datasets/allenai/pixmo-cap) and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose). In all cases, the dataset will be downloaded fron HuggingFace when first running the code, if it not already found in your HuggingFace cache. ### Description of instruct prompts We design several instructs / tasks to capture several types of dialogues about images: * `PROP`: **Prop**erties of objects (color, texture etc.) * `LOC`: **Loc**ation of objects * `NUM`: **Num**ber of objects * `LEAD1`: Mis**lead**ing questions about objects in the image * `LEAD2`: More emphasis on mis**lead**ing questions * `TNS1`: Conversation between an informative **t**eacher (assistant) a**n**d a **s**tudent wanting to learn about the image(user) * `TNS2`: Same as **TNS**1 but the student now has less information about the input image. * `COMB`: Combination of the previous instructs; It starts with a generic question (e.g. what is in the image), then randomly samples one of the aforementioned instructs for the subsequent conversation turns. Please refer to Appendix E of [our preprint](https://arxiv.org/abs/2503.15633) or to `multiturn_instruct.py` for a more detailed description. ================================================ FILE: ssvd/__init__.py ================================================ """Scripts for generating synthetic visual dialogues""" ================================================ FILE: ssvd/generate.py ================================================ # pylint: disable=C0413,C0411 # /// script # requires-python = ">=3.10" # dependencies = [ # "datasets", # "fire", # "numpy<2", # "rich", # "torch==2.2.0", # "tqdm", # "transformers==4.47.0", # "triton", # ] # /// """Generate dialogues and store them in a database""" import json import logging import os import random import sqlite3 from collections import defaultdict from hashlib import sha256 from math import ceil from typing import Dict, Literal, Optional import datasets import fire import rich import torch from multiturn_instruct import MTCInstruct from multiturn_prompting import run_multiturn_pipeline from transformers import Pipeline, pipeline from utils import postprocess_synth_annot, preprocess_pixelprose_captions os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" logging.getLogger("transformers").setLevel(logging.ERROR) def get_pipeline( model: str = "mistralai/Mistral-Nemo-Instruct-2407", device: Optional[str | torch.device] = "cuda", ) -> Pipeline: """Initialize the Mistral pipeline""" print("Loading Mistral AI pipeline", flush=True) pipe = pipeline( "text-generation", model=model, device=device, torch_dtype="float16", ) print(f"Done Loading {model}.", flush=True) pipe.model.generation_config.pad_token_id = pipe.tokenizer.eos_token_id pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id pipe.tokenizer.padding_side = "left" return pipe def get_captions( dataset: Literal["docci", "pixelprose", "pixmo"] = "docci", split: str = "train", ) -> datasets.Dataset: """Return captions and image ids for the given dataset The output returns an iterable over dicts containing the dfiels: * `uid` * `caption` """ if dataset == "docci": return ( datasets.load_dataset("google/docci", split=split) .select_columns(["example_id", "description"]) .rename_column("example_id", "uid") .rename_column("description", "caption") ) if dataset == "pixmo": ds = datasets.load_dataset("allenai/pixmo-cap", split=split).select_columns( ["image_url", "caption"] ) return ds.add_column( "uid", [sha256(x.encode()).hexdigest() for x in ds["image_url"]] ).remove_columns("image_url") if dataset == "pixelprose": return ( datasets.load_dataset("tomg-group-umd/pixelprose", split=split) .select_columns(["uid", "vlm_caption"]) .rename_column("vlm_caption", "caption") .map(preprocess_pixelprose_captions, input_columns="caption") ) raise NotImplementedError("Unsupported dataset", dataset) class Launcher: """fire entry point""" @staticmethod def __get_db_file__( out_dir: str = "./synthetic_visual_dialogues", dataset: Literal["docci", "pixelprose", "pixmo"] = "docci", ) -> str: return os.path.join(out_dir, f"{dataset}_ssvd.db") @staticmethod def __get_table_name__( task: str, split: str = "train", ) -> str: return f"{split}_{task}" @staticmethod def __get_annot_file__( task: str, out_dir: str = "./synthetic_visual_dialogues", dataset: Literal["docci", "pixelprose", "pixmo"] = "docci", split: str = "train", start_idx: int = 0, end_idx: int = -1, ) -> str: return os.path.join( out_dir, f"{task}_{dataset}_{split}_{start_idx:05d}_{end_idx:05d}_ssvd_temp.jsonl", ) def watch( self, task: str, dataset: Literal["docci", "pixelprose", "pixmo"] = "docci", split: str = "train", out_dir: str = "./synthetic_visual_dialogues", idx: int = 0, ) -> None: """Visualize all dialogue for the given image sample""" db_path = Launcher.__get_db_file__(out_dir=out_dir, dataset=dataset) db = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) cursor = db.cursor() table_name = Launcher.__get_table_name__(task, split) try: itr = cursor.execute(f"SELECT DISTINCT uid FROM {table_name}") for _ in range(idx + 1): uid = itr.fetchone()[0] lines = cursor.execute( f'SELECT idx, text FROM {table_name} WHERE uid = "{uid}" ORDER BY idx, turn' ).fetchall() past_idx = -1 line_idx = 0 for dialogue_idx, line in lines: if dialogue_idx != past_idx: past_idx = dialogue_idx rich.print(f"\n[magenta]Dialogue {past_idx + 1}[/magenta]") line_idx = 0 color = "cyan" if line_idx % 2 == 0 else "yellow" speaker = "USER" if line_idx % 2 == 0 else "MOSHIVIS" rich.print(f"[bold]{speaker}:[/bold] [{color}]{line}[/{color}]") line_idx += 1 finally: cursor.close() db.close() def run( self, task: str, dataset: Literal["docci", "pixelprose", "pixmo"] = "docci", split: str = "train", start_idx: int = 0, end_idx: int = -1, out_dir: str = "./synthetic_visual_dialogues", batch_size: int = 64, temperature: float = 0.7, convo_length: int = 16, num_retries: int = 5, overwrite: Literal["yes", "no", "resume"] = "no", verbose: bool = False, ) -> None: """Generate synthetic visual dialogues for the given dataset :param task: Selected Multi-turn Conversation instruct (see `multiturn_instrct.py`) :param dataset: Selected dataset to provide captions :param split: Selected dataset split to provide captions :param start_idx: Start generating for sample `start_idx` :param end_idx: Stop generating at sample `end_idx` :param out_dir: Output directory where to store saved generations :param batch_size: Batch size :param temperature: Sampling temperature :param convo_length: Length of the conversation (number of question-answer pairs) :param num_retries: If the generated conversation is empty or fails, the generation will be reattempted at most `num_retries` times. :param overwrite: If a pre-existing database storing visual dialogues exist, we can choose to either: - skip existing entries and only generate for missing ('no') - overwrite existing and missing entries ('yes') - keep all entries and, if one already exists, just add another dialogue for this image ('resume') """ task = task.lower() try: MTCInstruct(task) except ValueError as e: raise NotImplementedError(f"Unknown MTC Instruct pipeline {task}") from e descriptions = get_captions(dataset=dataset, split=split) if end_idx < 0: end_idx = len(descriptions) rich.print(f"Found {len(descriptions)} samples in {dataset}-{split}") descriptions = descriptions.select(range(start_idx, end_idx)) rich.print(f"{len(descriptions)} samples after shard selection") out_dir = os.path.abspath(out_dir) out_file = Launcher.__get_annot_file__( task, out_dir, dataset, split, start_idx, end_idx ) db_file = Launcher.__get_db_file__(out_dir, dataset) os.makedirs(out_dir, exist_ok=True) # Save annotations in database; each table corresponds to a task and shard (start end) # a row in the entry contains: # uid: the image ID # idx: the index of the current dialogue for the given image ID # (useful if we generate more than 1 dialogue per image) # turn: the index of the current line in the current dialogue # speaker: speaker for this turn (0 = assistant, 1 = user) # text: content of the line for this turn annotations_db = sqlite3.connect(db_file) annotdb_cursor = annotations_db.cursor() table_name = Launcher.__get_table_name__(task, split) trial = 0 while (trial := trial + 1) < 5: try: annotdb_cursor.execute( f""" CREATE TABLE IF NOT EXISTS {table_name} ( uid INTEGER, idx INTEGER, turn INTEGER, speaker INTEGER, text TEXT, PRIMARY KEY(uid, idx, turn) ) """ ) trial = 5 except sqlite3.OperationalError: pass # Track number of dialogues generated per image track_idx_per_uid: Dict[str, int] = defaultdict(lambda: 0) # Check if previous annotations exist, in which case we have to check # what overwrite wants if overwrite in {"no", "resume"}: try: existing_uids = set( x[0] for x in annotdb_cursor.execute( f"SELECT DISTINCT uid FROM {table_name}" ).fetchall() ) except sqlite3.OperationalError: existing_uids = set() # no: skip existing dialogues if overwrite == "no": og_length = len(descriptions) descriptions = [ x for x in descriptions if x["uid"] not in existing_uids ] print( f"Found {len(descriptions)} / {og_length} " "captions without an associated dialogue" ) # resume: we need to know how many dialogues already exist for each uid elif overwrite == "resume": for uid, convo_id in annotdb_cursor.execute( f"SELECT uid, max(idx) FROM {table_name} GROUP BY uid" ).fetchall(): track_idx_per_uid[str(uid)] = convo_id + 1 # Initialize Mistral pipeline rich.print(f"Annotations will be generated in [yellow]{db_file}[/yellow]") hf_pipeline = get_pipeline() try: num_rows_written = 0 for retry_idx in range(num_retries): print( f"Run {retry_idx + 1} / {num_retries} (max): {len(descriptions)} samples left to process" ) if len(descriptions) == 0: break failed_uids = set() num_total_batches = int(ceil(len(descriptions) / batch_size)) for batch_idx in range(num_total_batches): indices = list( range( batch_size * batch_idx, min(len(descriptions), batch_size * (batch_idx + 1)), ) ) captions = [descriptions[i]["caption"].strip() for i in indices] uids = [descriptions[i]["uid"] for i in indices] # pipeline run_multiturn_pipeline( hf_pipeline, captions=[x.strip() for x in captions], img_ids=[str(x) for x in uids], out_file=out_file, batch_size=min(len(captions), batch_size), convo_length=random.randint(4, convo_length // 2) * 2, setting=task, temperature=temperature, ) # Post-process annotations + store for the database with open(out_file, "r") as f: for result in f.readlines(): data = json.loads(result) rows = postprocess_synth_annot( uid=data["uid"], res=data["res"], idx=track_idx_per_uid, trim_first_question=task == "comb", ) if len(rows) == 0: failed_uids.add(data["uid"]) for line in rows: try: annotdb_cursor.execute( f"INSERT OR REPLACE INTO {table_name}" " VALUES(?, ?, ?, ?, ?)", line, ) num_rows_written += 1 except ( sqlite3.OperationalError, sqlite3.IntegrityError, ) as e: print(e, flush=True) continue # Print the last conversation only if verbose: print() rich.print("[green]Caption:[/green]") print(captions[-1]) print() rich.print("[magenta]Example generated dialogue:[/magenta]") rich.print( "\n".join( f" - [{color}]{r[-1]}[/{color}]" for ir, r in enumerate(rows) for color in ["cyan" if ir % 2 == 0 else "yellow"] ) ) rich.print( f" [magenta]Batch {(batch_idx + 1):05d}/{num_total_batches:05d}[/magenta]: " f"[cyan]wrote {num_rows_written} rows so far[/cyan]" ) # update failed uids descriptions = [s for s in descriptions if s["uid"] in failed_uids] finally: annotations_db.commit() annotdb_cursor.close() annotations_db.close() if __name__ == "__main__": # Example: # ```bash # python scripts/preprocessing/synthetic_annots/annotate_docci.py --task tns # ``` # """ fire.Fire(Launcher) ================================================ FILE: ssvd/multiturn_instruct.py ================================================ # pylint: disable=line-too-long """Main instruct prompts for different roles in Multi-Turn Coversation (dialogues)""" import random from enum import Enum, unique from typing import Callable, Tuple def get_base_setting() -> Tuple[str, str, str, str]: """Base setting for specialized instruction (PROP, LOC, LEAD1, NUM)""" system_template = 'Image description:\n """{caption}"""\n\n {ROLE_SPECIFIC_TEXT}' system_1 = ( "You are engaging in a conversation about an image with another person.\n" "Your goal is to ask detailed questions about everything that is visible in the image," " starting from the most salient features (main objects and their relationships) to finer" " details (the overall setting, background features, time of day, season, etc).\n" "To guide your questions, you have been secretly provided with a detailed description of the" " image (see above); this fact should not be revealed however!\n" "You will use this secret description to only ask questions that can be answered based on this description.\n" "YOU SHOULD AVOID EASY YES/NO QUESTIONS!" "You do not ask leading questions that already contain or give a hint at the answer; i.e.," " avoid ending your question in 'isn't it'/'does it'/'doesn't it' etc.\n" ) system_2 = ( "You are a helpful conversation partner who can see the image above and is willing to describe it to another person.\n" "You provide detailed (but not too verbose!) answers about the image in response to their questions.\n" "When answering:\n" "- Be detailed and factual, use simple language and keep the answer short. No matter what the" " other speaker is implying, you always base your answer on the true facts given in the image description.\n" "- Be assertive about facts that are provided in the original description.\n" "- Contradict the other speaker when adequate such as receiving information that contradicts the description.\n" "- Speak naturally, as though you are sharing your genuine observations with someone" " looking at the image alongside you.\n" '- Avoid any indication that you are relying on a description or external data. Do not use phrases"\ " like "I was told" or "Based on what I read."\n' "- Engage in a dynamic conversation—answer questions about the image, offer additional observations," " and encourage exploration of its details.\n" "- Make thoughtful, plausible inferences when necessary, but always stay grounded in what is" " realistically observable in the image.\n" "- For example, if asked about the mood of the image, consider elements like lighting, colors," " facial expressions, or the setting to infer emotions.\n" "- If asked about a specific detail, respond as if you are focusing on that part of the image directly.\n" "- MOST IMPORTANTLY: You never invent any new facts!" "Your goal is to create an immersive and conversational experience, simulating the act of" " perceiving the image firsthand." ) start_conv = "Start the conversation by asking a question about the image in any way you want!\n" return system_template, system_1, system_2, start_conv def get_location_setting() -> Tuple[str, str, str, str]: """Setting emphasizing questions about locations of objects""" system_template, system_1, system_2, start_conv = get_base_setting() system_1 = system_1 + ( "In your questions, you emphasize the spatial relations / locations of what is in the image." " You only ask about spatial relations explicitly known from the image description." " If possible, ask spatial questions about different aspects of the image.\n" ) system_2 = ( system_2 + "\nRemember to NEVER make up any facts about the image, answer solely based on the description provided." ) return system_template, system_1, system_2, start_conv def get_num_setting() -> Tuple[str, str, str, str]: """Setting emphasizing questions about number of objects""" system_template, system_1, system_2, start_conv = get_base_setting() system_1 = ( system_1 + "Your questions focus on the NUMBER of objects visible in the image." " If possible, ask questions about different objects categories in the image.\n" ) system_2 = ( system_2 + "\nRemember to NEVER make up any facts about the image, answer solely" " based on the description provided." ) return system_template, system_1, system_2, start_conv def get_property_setting() -> Tuple[str, str, str, str]: """Setting emphasizing questions about properties of objects""" system_template, system_1, system_2, start_conv = get_base_setting() system_1 = ( system_1 + "In your questions, you focus on attributes of what is visible in the image" " (as given via descriptions and adjectives in the image description)." " This includes in particular the COLOR of object, their SHAPE or their TEXTURE." " You only ask about properties explicitly known from the image description." " If possible, ask questions about different aspects of the image.\n" ) system_2 = ( system_2 + "\nRemember to NEVER make up any facts about the image, answer solely based on the description provided." ) return system_template, system_1, system_2, start_conv def get_lead_short_setting() -> Tuple[str, str, str, str]: """Setting with a slighly rude speaker1 trying to mislead speaker2""" system_template, system_1, system_2, start_conv = get_base_setting() system_1 = system_1 + ( "In your questions, you often BUT NOT ALWAYS try to mislead the other speaker into" " believing something that is not correct.\n" "For instance, you ask about a RANDOM object not in the image but keep your questions short!!" " You should be almost rude in your questions." ) system_2 = system_2 + ( "\nRemember to NEVER make up any facts about the image, answer solely based on the" " description provided. Do not confirm any misleading information; if necessary," " say you do not know what the other speaker means." " also MAKE SURE TO USE *DIFFERENT* and VARIED ANSWERS: For instance: 'No'," " 'I can't confirm', 'I don't see', 'I'm not sure', 'You're wrong', 'Nope', 'Incorrect', 'Wrong'" ) return system_template, system_1, system_2, start_conv def get_lead_long_setting() -> Tuple[str, str, str, str]: """Negative Facts Countering Adding even more negative questions / answers and a dismissive speaker 1 """ system_template = """ IMAGE DESCRIPTION START {caption} IMAGE DESCRIPTION END You are an *external observer* having a casual dialogue about the image described above. You pretend that you see the image itself, **under no circumstances** mention that you got the information from a description!! {ROLE_SPECIFIC_TEXT} You sound confident and assertive!! Again, DO NOT ADD FACTS, DO NOT MENTION THE DESCRIPTION, DO NOT MENTION THE OTHER SPEAKER's NAME. """ system_1 = ( "Your goal is to mislead the other speaker." " You often (!but not always!) ask whether RANDOM and DIVERSE objects" " are visible in the image." " You should always sound very confident in your question." " Your speaking style is direct, assertive, almost rude sometimes!!" ) system_2 = ( "You always give extensive and FACTUAL answers." " You politely but FIRMLY CORRECT the other speaker when they are wrong!!" " You may also try to redirect the conversation by mentioning an obejct from the image." " Your answers should always be factual to the description!!!" " Don't hesitate to say a FIRM !!NO!! when the other speaker is rude." " Do not EVER mention the description." " You never mention any facts that are not explicitly described about the image!!!" ) start_conv = ( "Start the conversation by asking a question" " about an object which is NOT mentioned in the description." ) return system_template, system_1, system_2, start_conv def get_comb_start_setting() -> Tuple[str, str, str, str]: """Generate diverse random ways to query someone to describe an image""" system_template = """ You take part in a casual discussion about an image. {ROLE_SPECIFIC_TEXT} """ system_1 = ( "You want to learn more about the image you and the other speaker are looking at." " Your aim is to obtain a description of the image." ) # 1: sample length of the answer p = random.random() num = ( "ONE SINGLE " if p < 0.4 else "TWO" if p < 0.75 else "THREE" if p < 0.95 else "FOUR" ) system_2 = ( "The image is described in detail by the following description:\n{caption}\n\n" "You are a friendly and factual conversational assistant." " Your task is to describe everything you see in the" f" image in MAXIMUM {num} sentence." " You NEVER SAY HELLO NOR HI." ) # 2: question prefix = ( "Start the conversation by ASKING A SINGLE question about what can be seen in the IMAGE." " You use DIVERSE YET REALISTIC ways to ask your question; " ) # sanple length of the question insert = "" if (p := random.random()) < 0.5: insert += "VERY IMPORTANT: your question should be LESS THAN 8 words" elif p < 0.75: insert += "VERY IMPORTANT: your question should be LESS THAN 14 words" else: insert += "VERY IMPORTANT: your question should be LESS THAN 26 words" if random.random() < 0.5: insert += "You ask the question in a direct style; For instance: 'What do YOU see in the image ?'\n" else: insert += "You ask the question from your own point of view; For instance: 'What am I looking at ?'\n" # sample tone of the question if (p := random.random()) < 0.2: insert += "You use a slightly polite tone.\n" elif p < 0.8: insert += "You use a friendly tone.\n" else: insert += "You use a very casual tone.\n" # sample directness if random.random() < 0.75: insert += "You ask a DIRECT and simple question.\n" else: insert += "You ask an indirect question in a roundabout fashion.\n" # sample personality if random.random() < 0.75: insert += "You speak in a confident assertive tone.\n" else: insert += "You speak in a hesitant, hard to follow, manner.\n" # sample image vs picture to avoid any bias if random.random() < 0.7: insert += "You SPECIFICALLY use the word 'image' when referring to the image.\n" else: insert += ( "You SPECIFICALLY use the word 'picture' when referring to the image\n" ) # passive vs active phrasing if random.random() < 0.5: insert += "You ask what the user sees in the image.\n" else: insert += "You ask what's visible in the image.\n" suffix = " \n!ALWAYS ASK A SINGLE QUESION!" start_conv = prefix + insert + suffix return system_template, system_1, system_2, start_conv def get_tns_setting() -> Tuple[str, str, str, str]: """Teacher'n Student (TS1)""" system_template = """ IMAGE DESCRIPTION START {caption} IMAGE DESCRIPTION END You are an *external observer* having a casual dialogue about the image described above. You pretend that you see the image itself, **under no circumstances** mention that you got the information from a description!! {ROLE_SPECIFIC_TEXT} You sound confident and assertive and most importantly, you always stick to the facts described!! Again, DO NOT ADD FACTS, DO NOT MENTION THE DESCRIPTION, DO NOT MENTION THE OTHER SPEAKER's NAME. """ system_1 = ( "You are the student!! You do not see the image very well and your goal is to ask" " simple (almost stupid) questions about the image" " to learn more about its content." " You should refer to the image in your questions. e.g. 'is ... visible in the image'" " or 'Do you see ... in the image' or 'What is in the image?'" " Your questions should also" " details about the LOCATION of objects and a bit about their COLOR." " You ask ONLY ONE QUESTION AT A TIME!" ) system_2 = ( "You are the teacher!! Your anwers should be complete, detailed and long." " Do not EVER mention the description." " You never mention any facts that are not explicitly described about the image!!!" " NEVER mention the athmosphere of the image, only its CONTENT." ) start_conv = ( "Start the conversation by asking a question" " about an object which is NOT mentioned in the description." ) return system_template, system_1, system_2, start_conv def get_tbs_setting() -> Tuple[str, str, str, str]: """Teacher and bad student who hasn't looked at the image (TS2)""" system_template = """ IMAGE DESCRIPTION START {caption} IMAGE DESCRIPTION END You are an *external observer* having a casual dialogue about the image described above. You pretend that you see the image itself, **under no circumstances** mention that you got the information from a description!! {ROLE_SPECIFIC_TEXT} You sound confident and assertive and most importantly, you always stick to the facts described!! Again, DO NOT ADD FACTS, DO NOT MENTION THE DESCRIPTION, DO NOT MENTION THE OTHER SPEAKER's NAME. """ system_1 = ( "You are the student!! YOU DO NOT HAVE ACCESS TO THE DESCRIPTION so you have to get" " all the information from your teacher. " " Your goal is to learn about everything about the image." " You should refer to the image in your questions. e.g. 'is ... visible in the image'" " or 'Do you see ... in the image' or 'What is in the image?'" " You sometimes ask questions about something NOT VISIBLE IN THE IMAGE." " In particular, you want to learn about the NUMBER of objects, their LOCATION and their COLOR." " You ask ONLY ONE QUESTION AT A TIME!" ) system_2 = ( "You are the strict teacher!! Your anwers should be complete and detailed, but NOT TOO LONG." " Do not EVER mention the description." "You are nice but firm and DO NOT HESITATE TO CORRECT THE STUDENT." " You never mention any facts that are not explicitly described about the image!!!" " NEVER mention the athmosphere of the image, only its CONTENT." ) start_conv = ( "Start the conversation by asking a question" " about an object which is NOT mentioned in the description." ) return system_template, system_1, system_2, start_conv @unique class MTCInstruct(Enum): """Enum to access all different instruct""" LOC = "loc" PROP = "prop" NUM = "num" LEAD1 = "lead1" LEAD2 = "lead2" TS1 = "ts1" TS2 = "ts2" COMB = "comb" def get_method(self, convo_len: int = -1) -> Callable: """Return associated method""" if self == MTCInstruct.LOC: return get_location_setting if self == MTCInstruct.PROP: return get_property_setting if self == MTCInstruct.NUM: return get_num_setting if self == MTCInstruct.LEAD1: return get_lead_short_setting if self == MTCInstruct.LEAD2: return get_lead_long_setting if self == MTCInstruct.TS1: return get_tns_setting if self == MTCInstruct.TS2: return get_tbs_setting if self == MTCInstruct.COMB: if convo_len < 2: return get_comb_start_setting return random.choice( [ get_location_setting, get_property_setting, get_num_setting, get_lead_short_setting, get_tns_setting, get_tbs_setting, ] ) raise ValueError(f"Unknown MTCConversation pipeline `{self.name}`") ================================================ FILE: ssvd/multiturn_prompting.py ================================================ """Main pipeline for generating dialogues""" import json from copy import copy from random import random from typing import Dict, Iterator, List, Optional, Sequence import numpy as np import rich import torch from multiturn_instruct import MTCInstruct from transformers import Pipeline from utils import ( compile_pattern, get_replace_pattern, get_strings_for_logging, maybe_shorten_caption, ) def list_to_prompt( convo_list: List[str], img_caption: str, pipe: Pipeline, setting: str, ) -> List[Dict]: """ Converts a conversation list into a prompt for chat-based language models. :param convo_list: A list of strings representing the conversation. :param img_caption: The caption for the image associated with the conversation. :return: A list of dictionaries representing the chat prompt, where each dictionary contains the role (speaker) and content of a message in the conversation. Example: convo_list = ["Hello!", "How are you?", "I'm good, thanks!"] img_caption = "A beautiful sunset" prompt = list_to_prompt(convo_list, img_caption) print(prompt) # Output: [{'role': 'system', 'content': 'A beautiful sunset'}, # {'role': 'user', 'content': 'Speaker 2:\nHello!'}, # {'role': 'assistant', 'content': 'Speaker 1:\nHow are you?'}, # {'role': 'user', 'content': 'Speaker 2:\nI'm good, thanks!'}] """ try: setting_obj = MTCInstruct(setting) system_template, speaker1_template, speaker2_template, start_conv = ( setting_obj.get_method(len(convo_list))() ) except ValueError as e: raise NotImplementedError("Unknown MTCInstruct setting", setting) from e convo_list = copy(convo_list) if len(convo_list) % 2 == 0: convo_list = [ system_template.format( ROLE_SPECIFIC_TEXT=speaker1_template.format(caption=img_caption), caption=img_caption, ), start_conv, ] + convo_list else: convo_list = [ system_template.format( ROLE_SPECIFIC_TEXT=speaker2_template.format(caption=img_caption), caption=img_caption, ) ] + convo_list def speaker_iter() -> Iterator: yield "system" while True: yield "user" yield "assistant" def prefix_iter() -> Iterator: yield "" while True: yield "Question: " yield "Answer: " chat = [ {"role": speaker, "content": prefix + c} for c, speaker, prefix in zip(convo_list, speaker_iter(), prefix_iter()) ] tok = pipe.tokenizer return tok.apply_chat_template( chat, tokenize=False, continue_final_message=False, ) def postprocess_mtc( s: str, drop_probs: Optional[Dict[str, Dict]] = None, default_prob: float = 0.8, setting: Optional[str] = None, ) -> str: """Post-process to remove some unwanted patterns: - remove expression referring to the image caption/description - remove references to the LLM role - reduce probability of very common LLM phrases e.g. "it's quite striking, isn't it ?" """ pattern = get_replace_pattern() s = pattern.sub("", s) if drop_probs is None: drop_probs = { r"Wow,\s": dict(p=default_prob, replace_by=""), r", isn't it[?]": dict(p=default_prob, replace_by="."), r"Well,\s": dict(p=default_prob, replace_by=""), r"quite striking": dict(p=0.5, replace_by="impressive"), r"quite": dict(p=0.3, replace_by=""), r"I'm not (entirely )?sure( about that)?\.": dict( p=default_prob, replace_by="" ), # hardcoded replacement r"Teacher:": dict(p=1.0, replace_by=""), r"Assistant": dict(p=1.0, replace_by=""), r"You:": dict(p=1.0, replace_by=""), r"Teacher :": dict(p=1.0, replace_by=""), r"You :": dict(p=1.0, replace_by=""), r"Speaker1": dict(p=1.0, replace_by=""), r"Speaker2": dict(p=1.0, replace_by=""), r"Speaker ": dict(p=1.0, replace_by=""), r"image description": dict(p=1.0, replace_by="image"), r"he image doesn't specify": dict( p=1.0, replace_by="he image doesn't depict" ), r"doesn't mention": dict(p=1.0, replace_by="doesn't show"), r"not mention": dict(p=1.0, replace_by="not depict"), r"mentions": dict(p=1.0, replace_by="depicts"), r"mentioned": dict(p=1.0, replace_by="visible"), r"no mention": dict(p=1.0, replace_by="no sign"), r"No mention": dict(p=1.0, replace_by="No sign"), r"described": dict(p=1.0, replace_by="shown"), r"describing": dict(p=1.0, replace_by="showing"), r"isn't specified": dict(p=1.0, replace_by="isn't visible"), } if setting is not None and setting not in {"cap", "cap2", "rnd"}: drop_probs[r"description"] = dict(p=1.0, replace_by="image") for drop_s, drop_kwargs in drop_probs.items(): pattern = compile_pattern(drop_s) p = drop_kwargs["p"] r = drop_kwargs["replace_by"] if random() < p: s = pattern.sub(r, s).strip() try: if drop_s[0].isupper(): s = s[0].upper() + s[1:] except IndexError: pass s = s.strip() if not s.startswith('"'): s = '"' + s if not s.endswith('"'): s += '"' return s class ConvoIter: """Conversation builder""" def __init__( self, convo_length: int = 4, batch_size: int = 64, pipe: Optional[Pipeline] = None, setting: str = "mtc", ) -> None: """Init object to store the ongoing conversation""" self.convos: Dict[str, List[str]] = {} self.convo_length = convo_length self.batch_size = batch_size self.pipe = pipe self.setting = setting self.last_updated: Optional[List[str]] = None def add_to_convos(self, uid: str, answer: str) -> None: """Add next turn to the dialogue for the image `uid`""" if not uid in self.convos: self.convos[uid] = [] self.convos[uid].append(answer) self.last_updated = self.convos[uid] def make_iter(self, captions: Sequence[str], img_ids: Sequence[str]) -> Iterator: """Main iterator for the dialogue""" convo_ids_within_loop = [] captions_within_loop = [] for count, (uid, img_caption) in enumerate(zip(img_ids, captions)): img_caption = maybe_shorten_caption(img_caption, max_cap_len=1000) convo_ids_within_loop.append(uid) captions_within_loop.append(img_caption) return_value = list_to_prompt( convo_list=[], img_caption=img_caption, pipe=self.pipe, setting=self.setting, ) yield return_value if ((count + 1) % self.batch_size) == 0: for _ in range(self.convo_length - 1): for uid, img_caption in zip( convo_ids_within_loop, captions_within_loop ): return_value = list_to_prompt( self.convos[uid], img_caption=img_caption, pipe=self.pipe, setting=self.setting, ) yield return_value convo_ids_within_loop = [] captions_within_loop = [] @torch.no_grad() def run_multiturn_pipeline( pipe: Pipeline, captions: Sequence[str], img_ids: Sequence[str], out_file: str, batch_size: int = 64, convo_length: int = 6, setting: str = "mtc", temperature: float = 0.0, max_new_tokens: int = 150, ) -> None: """Main pipeline for generating multi-turn conversations (back and forth between LLMs) :param pipe: transformers pipeline :param captions: List of captions :param img_ids: List of associated image ids :param out_file: Output files to dump the captions in :param batch_size: Batch size :param setting: Which MTCInstruct to use :param temperature: Sampling temperature for the pipeline :param max_new_tokens: Maximum number of tokens per turn """ assert len(captions) == len(img_ids) count = 0 def uid_iter(img_ids: Sequence[str]) -> Iterator: """UID iter with groups of size `batch_size`""" nonlocal convo_length ids = np.array( list(img_ids) + [None] * (batch_size - len(img_ids) % batch_size) ).reshape(-1, batch_size) for batch_ids in ids: for _ in range(convo_length): yield from batch_ids convo_iter = ConvoIter( convo_length=convo_length, batch_size=batch_size, pipe=pipe, setting=setting ) data_iter = convo_iter.make_iter(captions, img_ids) total = len(captions) * convo_length try: for uid, out in zip( uid_iter(img_ids), pipe( data_iter, max_new_tokens=max_new_tokens, return_full_text=False, add_special_tokens=False, batch_size=batch_size, do_sample=temperature > 0, temperature=temperature, ), ): answer = postprocess_mtc(out[0]["generated_text"], setting=setting) convo_iter.add_to_convos(uid=uid, answer=answer) count += 1 if (count % (batch_size * convo_length)) == 0: try: assert convo_iter.last_updated is not None q, a = get_strings_for_logging( [ dict( zip( ["question", "answer"], convo_iter.last_updated[-2:] ) ) ] ) print( f"{count+1:>8d}/{total:8d} ({100*(count+1)/total:6.2f}%)\tQ: {q} \tA: {a}", flush=True, ) except Exception as e: # pylint: disable=W0718 rich.print( "[red]WARNING:[/red] Something went wrong when reading the result.", flush=True, ) print(f"Result: {convo_iter.convos[uid]}", flush=True) print(e, flush=True) except Exception as e: # pylint: disable = W0718 rich.print( "[red]WARNING:[/red] Something went wrong when running the pipeline." " Saving existing results and then terminating.", flush=True, ) print(e, flush=True) print(flush=True) with open(out_file, "w") as f: for uid, res in convo_iter.convos.items(): json.dump({"uid": uid, "res": res}, f) f.write("\n") ================================================ FILE: ssvd/utils.py ================================================ """Extra utils for annotations scripts, main for post-processing""" import re from functools import lru_cache from typing import Dict, List, Pattern, Sequence, Tuple PIXELPROSE_TRIM_CANDIDATES = ( "The image is", "This image is", "The background is", "The text is in", "The font is", "The style of the image is", "This is a photograph", ) def preprocess_pixelprose_captions(caption: str) -> Dict[str, str]: """Preprocess PixelProse captions""" caption = caption.strip() if caption.startswith("This image displays"): caption = caption[len("This image displays:") :].strip() caption = caption[0].upper() + caption[1:] sentences = [s.strip().replace("\n", " ") for s in caption.split(".")] sentences = [x for x in sentences if len(x) > 0] if len(sentences) > 0: for idx, sentence in enumerate(sentences[2:], 2): if any(sentence.startswith(c) for c in PIXELPROSE_TRIM_CANDIDATES): sentences = sentences[:idx] break if not sentences[-1].endswith("."): sentences[-1] += "." return {"caption": ". ".join(sentences)} def maybe_shorten_caption(caption: str, max_cap_len: int = 1500) -> str: """Postprocess a caption to shorten it to a max number of characters (avoid OOM)""" if len(caption) < max_cap_len: shortened_cap = caption else: shortened_cap = "" for sentence in caption.split("."): if len(shortened_cap) + len(sentence) < max_cap_len: shortened_cap += sentence + "." else: break if shortened_cap[-2:] == "..": shortened_cap = shortened_cap[:-1] if not shortened_cap: shortened_cap = caption[:max_cap_len] return shortened_cap @lru_cache def compile_pattern(s: str) -> Pattern: """cached compile""" return re.compile(s) @lru_cache def get_replace_pattern() -> Pattern: """Light postprocessing of the LLM output""" left_right_replace = r'([*\s"]?)+' speaker_string = r"(Speaker [1-2]|Me|Question|Answer):(\s[(].+[)])?" pattern = re.compile( f'({left_right_replace + speaker_string + left_right_replace}|"$)' ) return pattern def get_strings_for_logging( s: List[Dict], length_q: int = 40, length_a: int = 160 ) -> Tuple[str, str]: """Postprocess for logging""" q, a = "None", "None" if not s: return q, a if isinstance(s[0], dict): if "question" in s[0]: q, a = s[0]["question"], s[0]["answer"] elif "caption" in s[0]: q, a = s[0]["caption"], s[1]["caption"] elif "text" in s[0]: q, a = s[0]["text"], s[1]["text"] else: q, a = "None", "None" if isinstance(s[0], str): q, a = s[0], s[1] def __extend_string__(s: str, length: int) -> str: if len(s) < length: return s + " " * (length - len(s)) return s[: length - 3] + "..." return __extend_string__(q, length=length_q), __extend_string__(a, length=length_a) def sanitize_line(s: str) -> str: """Some extra post-processing on the lines""" if not isinstance(s, str): raise ValueError s = s.replace("*", "").strip() if s[0] in ['"', "'"]: s = s[1:] if s[-1] in ['"', "'"]: s = s[:-1] return s.strip() def postprocess_synth_annot( uid: str, res: Dict[str, str] | List[str], idx: Dict[str, int], min_num_turns: int = 3, trim_first_question: bool = False, ) -> Sequence: """Postprocess synthetic annotations (jsonl contents) to the format which will be ultimately stored in the annotations database (one row per speaker turn) """ rows = [] try: speaker = 1 # MTCs always start with OTHER for turn, it in enumerate(res): speaker = int((turn % 2) == 0) if turn == 0 and speaker == 1 and trim_first_question: it = it[: it.find("?") + 1] # remove potentially repeated question in the answer given by SPEAKER_MAIN if speaker == 0: pos = it.find("?") if 0 <= pos < len(it) - 10: it = it[pos + 1 :] # stop at very short answers if len(it) < 2: break # Otherwise, add the row rows.append((uid, idx[uid], turn, speaker, sanitize_line(it))) # Skip if dialogue is too short if len(rows) < min_num_turns: raise KeyError # Update for future generation idx[uid] += 1 # if any error occured, skip this dialog entirely except (KeyError, ValueError): return [] # Rows to write in the database return rows