Repository: kyutai-labs/moshivis
Branch: main
Commit: 8624f4e01d4b
Files: 170
Total size: 751.6 KB
Directory structure:
gitextract_d_6tdl44/
├── .dockerignore
├── .gitattributes
├── .github/
│ ├── actions/
│ │ └── rust_build/
│ │ └── action.yml
│ ├── requirements_github_actions.txt
│ └── workflows/
│ ├── checks.yml
│ └── rust-ci.yml
├── .gitignore
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE/
│ ├── bug.yml
│ └── question.yml
├── LICENSE-APACHE
├── LICENSE-MIT
├── LICENSE.md
├── PULL_REQUEST_TEMPLATE.md
├── README.md
├── client/
│ ├── .eslinrc.json
│ ├── .nvmrc
│ ├── .prettierignore
│ ├── .prettierrc.json
│ ├── Dockerfile
│ ├── LICENSE
│ ├── README.md
│ ├── index.html
│ ├── package.json
│ ├── postcss.config.js
│ ├── public/
│ │ └── assets/
│ │ ├── decoderWorker.min.wasm
│ │ └── images/
│ │ └── demo/
│ │ └── attribution.txt
│ ├── src/
│ │ ├── app.tsx
│ │ ├── audio-processor.ts
│ │ ├── components/
│ │ │ ├── Button/
│ │ │ │ └── Button.tsx
│ │ │ ├── ImageGallery/
│ │ │ │ └── ImageGallery.tsx
│ │ │ └── Input/
│ │ │ └── Input.tsx
│ │ ├── decoder/
│ │ │ └── decoderWorker.ts
│ │ ├── env.ts
│ │ ├── index.css
│ │ ├── modules.d.ts
│ │ ├── pages/
│ │ │ ├── Conversation/
│ │ │ │ ├── Conversation.tsx
│ │ │ │ ├── MediaContext.ts
│ │ │ │ ├── SocketContext.ts
│ │ │ │ ├── components/
│ │ │ │ │ ├── AudioVisualizer/
│ │ │ │ │ │ ├── AudioVisualizer.tsx
│ │ │ │ │ │ ├── ClientVisualizer.tsx
│ │ │ │ │ │ └── ServerVisualizer.tsx
│ │ │ │ │ ├── Controls/
│ │ │ │ │ │ └── Controls.tsx
│ │ │ │ │ ├── ModelParams/
│ │ │ │ │ │ └── ModelParams.tsx
│ │ │ │ │ ├── ServerAudio/
│ │ │ │ │ │ ├── ServerAudio.tsx
│ │ │ │ │ │ └── ServerAudioStats.tsx
│ │ │ │ │ ├── ServerInfo/
│ │ │ │ │ │ └── ServerInfo.tsx
│ │ │ │ │ ├── TextDisplay/
│ │ │ │ │ │ ├── TextDisplay.tsx
│ │ │ │ │ │ └── TextDisplayStats.tsx
│ │ │ │ │ └── UserAudio/
│ │ │ │ │ ├── UserAudio.tsx
│ │ │ │ │ └── UserAudioStats.tsx
│ │ │ │ ├── getMimeType.ts
│ │ │ │ └── hooks/
│ │ │ │ ├── audioUtils.ts
│ │ │ │ ├── useModelParams.ts
│ │ │ │ ├── useServerAudio.ts
│ │ │ │ ├── useServerInfo.ts
│ │ │ │ ├── useServerText.ts
│ │ │ │ ├── useSocket.ts
│ │ │ │ └── useUserAudio.ts
│ │ │ └── Queue/
│ │ │ ├── Queue.tsx
│ │ │ ├── api/
│ │ │ │ ├── client.ts
│ │ │ │ ├── errors/
│ │ │ │ │ ├── api_error.ts
│ │ │ │ │ └── response_error.ts
│ │ │ │ └── validators.ts
│ │ │ └── hooks/
│ │ │ └── useUserEmail.ts
│ │ └── protocol/
│ │ ├── encoder.ts
│ │ ├── testMessages.ts
│ │ └── types.ts
│ ├── tailwind.config.js
│ ├── tsconfig.json
│ └── vite.config.ts
├── docker-bake.hcl
├── kyuteye_mlx/
│ ├── .pylintrc
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── README.md
│ ├── kyuteye_mlx/
│ │ ├── __init__.py
│ │ ├── benchmark.py
│ │ ├── local_web.py
│ │ ├── mlx_vlm/
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ └── models/
│ │ │ ├── __init__.py
│ │ │ ├── pixtral/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vision.py
│ │ │ └── siglip/
│ │ │ └── vision.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── generate.py
│ │ │ ├── lm.py
│ │ │ ├── pixtral.py
│ │ │ └── siglip.py
│ │ ├── modules/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── cross_attention.py
│ │ │ ├── kv_cache.py
│ │ │ └── transformer.py
│ │ ├── py.typed
│ │ ├── quantize.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── loading.py
│ │ ├── profiling.py
│ │ └── sampling.py
│ ├── pixtral-12b-8bit.config
│ ├── pyproject.toml
│ ├── siglip448.config
│ └── tests/
│ └── test_siglip.py
├── kyuteye_pt/
│ ├── .pylintrc
│ ├── LICENSE.md
│ ├── README.md
│ ├── configs/
│ │ └── moshika-vis.yaml
│ ├── kyuteye/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── enums.py
│ │ │ ├── kyuteye_config.py
│ │ │ └── subconfigs.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── docker-bake.hcl
│ │ │ ├── helium.py
│ │ │ ├── hf_model_configs.py
│ │ │ ├── image_projection.py
│ │ │ ├── loaders.py
│ │ │ └── moshivis.py
│ │ ├── modules/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── cross_attention.py
│ │ │ ├── image_encoder.py
│ │ │ ├── image_transforms.py
│ │ │ ├── streaming_utils.py
│ │ │ ├── transformer.py
│ │ │ └── utils.py
│ │ ├── server.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── dist_utils.py
│ │ ├── logging_utils.py
│ │ └── struct_utils.py
│ ├── pyproject.toml
│ └── tests/
│ └── hello.py
├── kyuteye_rs/
│ ├── Cargo.toml
│ ├── configs/
│ │ ├── config-moshika-vis-q8.json
│ │ └── config-moshika-vis.json
│ ├── moshi-backend/
│ │ ├── Cargo.toml
│ │ ├── build.rs
│ │ └── src/
│ │ ├── audio.rs
│ │ ├── build.rs
│ │ ├── image_embedder.rs
│ │ ├── main.rs
│ │ ├── metrics.rs
│ │ ├── standalone.rs
│ │ ├── stream_both.rs
│ │ └── utils.rs
│ └── moshi-core/
│ ├── Cargo.toml
│ └── src/
│ ├── conv.rs
│ ├── dynamic_logits_processor.rs
│ ├── lib.rs
│ ├── lm.rs
│ ├── lm_generate.rs
│ ├── lm_generate_multistream.rs
│ ├── mimi.rs
│ ├── nn.rs
│ ├── quantization.rs
│ ├── seanet.rs
│ ├── streaming.rs
│ └── transformer.rs
├── scripts/
│ ├── convert_ckpt_utils.py
│ └── get_static_client.py
└── ssvd/
├── README.md
├── __init__.py
├── generate.py
├── multiturn_instruct.py
├── multiturn_prompting.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
**/target/
**/node_modules/
**/dist
ssvd/synthetic_visual_dialogues/
================================================
FILE: .gitattributes
================================================
*.wav filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/actions/rust_build/action.yml
================================================
name: rust_build
description: 'Setup rust env'
inputs:
os:
default: ubuntu-latest
toolchain:
default: stable
target:
default: check
runs:
using: "composite"
steps:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ inputs.toolchain }}
override: true
- name: cargo cache
uses: actions/cache@v3
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
kyuteye_rs/target/
key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ inputs.os }}-cargo-
- name: install deps
shell: bash
run: |
sudo apt-get update
sudo apt-get install libasound2-dev
================================================
FILE: .github/requirements_github_actions.txt
================================================
# Main setup
# old version: transformers 4.43.3 and accelerate 0.33.0
# new version (for pixtrla): transformers 4.46.0 and accelerate 1.0.1
accelerate==1.0.1
anls
anls-star
av<12
auditok<0.3.0
cython
datasets
deepspeed
demucs
einops
encodec
fasttext
flashy>=0.0.1
gradio
huggingface_hub
hydra_colorlog
hydra-core>=1.1
ipywidgets
jiwer
julius
jupyterlab
librosa
maturin
num2words
numpy
onnxruntime
opencv-python
protobuf
pyannote.audio
pyannote.metrics
pycocoevalcap
pycocotools
sentencepiece
spacy==3.5.2
tensorboard
timm
torch==2.2.0
torchaudio==2.2.0
torchmetrics
torchtyping
torchvision==0.17.0
tqdm
transformers==4.47.0 # need Encodec there.
webdataset==0.2.100 # for sanity
evaluate
rouge-score
xformers==0.0.24
# specific clip commit
clip @ https://github.com/openai/CLIP/archive/master.zip#sha256=11c3593912e6e6446fb0bde144c5ea374f7e19eeab9072c3eb00b59dd8afb706
# launcheon + code prettifying stuff
fire
rich
pyyaml
black
mypy==1.11.2
pylint
matplotlib
seaborn
================================================
FILE: .github/workflows/checks.yml
================================================
name: Checks
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
workflow_dispatch:
jobs:
pylint_pytorch:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- name: Static lint analysis with pylint
run: |
cd kyuteye_pt && uv run --locked pylint --rcfile=.pylintrc --fail-under=8.5 ./kyuteye
ruff_mlx:
runs-on: macos-14
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- name: Static lint analysis with pylint
run: |
cd kyuteye_mlx && uv run ruff format --diff && uv run ruff check --select I
sanity_check_pytorch:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- name: Sanity check
run: |
cd kyuteye_pt && uv run --locked sanity-check
sanity_check_mlx:
runs-on: macos-14
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- name: Sanity check
run: |
cd kyuteye_mlx && uv run --locked sanity-check
sanity_check_rust:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Cache Cargo
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
kyuteye_rs/target
key: cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: cargo-
- run: cd kyuteye_rs && cargo fmt --all -- --check
- name: Ubuntu dependencies
run: |
sudo apt-get update
sudo apt-get install -y -qq libasound2-dev libssl-dev libpulse-dev libdbus-1-dev portaudio19-dev protobuf-compiler
- name: Clippy
run: cd kyuteye_rs && cargo --locked clippy --workspace --tests --examples --locked -- -D warnings
build_client:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: docker buildx bake client
- run: tail client/dist/index.html
================================================
FILE: .github/workflows/rust-ci.yml
================================================
on:
push:
branches: [ main ]
pull_request:
branches: [ main, refacto ]
name: Rust CI
jobs:
check:
name: Check
defaults:
run:
working-directory: ./kyuteye_rs
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/rust_build
- name: check
shell: bash
run: |
cargo check
- name: clippy
shell: bash
run: |
cargo clippy -- -D warnings
- name: fmt
shell: bash
run: |
cargo fmt --all -- --check
test:
name: Test
defaults:
run:
working-directory: ./kyuteye_rs
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v5
with:
python-version: 3.11
- uses: ./.github/actions/rust_build
with:
target: test
- name: test
shell: bash
run: |
cargo test
================================================
FILE: .gitignore
================================================
~*
__pycache__
*.pt
*.pth
*.ipynb*
*.egg-info
*.jsonl
nohup.out
.idea/*
client/node_modules
client/dist
target/
*.safetensors
.DS_Store
*.lprof
*.prof
cert.pem
key.pem
.mypy_cache
Gemfile.lock
project_page/_site/*
kyuteye_mlx/static/*
ssvd/synthetic_visual_dialogues/
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to MoshiVis
## Pull Requests
MoshiVis is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
However, we certainly welcome them for bug fixes.
1. Fork the repo and create your branch from `main`.
2. If you have changed APIs, update the documentation accordingly.
3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
5. Accept the Contributor License Agreement (see after).
Note that in general, we will not accept refactoring of the code.
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a Contributor License Agreement.
If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
> I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
The full CLA is provided as follows:
> I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
> irrevocable license to use, modify, distribute, and sublicense my Contributions.
> I understand and accept that Contributions are limited to modifications, improvements, or changes
> to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
> review, accept, reject, or request changes to any Contributions I submit, and that submitting
> a pull request does not guarantee its inclusion in the project.
> By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
> reproduce, distribute, and create derivative works based on my Contributions.
> I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
> giving the Kyutai-labs full rights to file for and enforce patents.
> I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
> I confirm that my Contributions are original and that I have the legal right to grant this license.
> If my Contributions include third-party materials, I will ensure that I have the necessary permissions
> and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
> I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
> Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
> By submitting a pull request, I agree to be bound by these terms.
## Issues
Please submit issues on our Github repository.
## License
By contributing to MoshiVis, you agree that your contributions will be licensed
under the LICENSE-* files in the root directory of this source tree.
In particular, the rust code is licensed under APACHE, and the python code under MIT.
================================================
FILE: ISSUE_TEMPLATE/bug.yml
================================================
name: Bug Report
description: You found a bug.
labels: ["bug", "triage"]
body:
- type: dropdown
id: backend
attributes:
label: Backend impacted
description: Which backend is concerned with your bug report?
options:
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
description: What is your operating system?
options:
- Linux
- Mac OS X
- Windows (unsupported)
default: 0
validations:
required: true
- type: dropdown
id: hardware
attributes:
label: Hardware
description: What hardware are you using?
options:
- CPU
- GPU with CUDA
- Metal with MLX
default: 0
validations:
required: true
- type: textarea
id: description
attributes:
label: Description
description: Provide a detailed description of your bug.
placeholder:
value:
validations:
required: true
- type: textarea
id: more_info
attributes:
label: Extra information
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value:
validations:
required: true
- type: textarea
id: env
attributes:
label: Environment
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value: |
Fill in the following information on your system.
- Operating system version:
If the backend impacted is PyTorch:
- Python version:
- PyTorch version:
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
- GPU model and memory:
If the backend is MLX:
- Mac model:
validations:
required: true
================================================
FILE: ISSUE_TEMPLATE/question.yml
================================================
name: Question
description: You have a question about Moshi/Mimi, this codebase.
labels: ["question", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/moshi/blob/main/FAQ.md).
- type: checkboxes
id: terms
attributes:
label: Due diligence
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
options:
- label: I have done my due diligence in trying to find the answer myself.
required: true
- type: dropdown
id: backend
attributes:
label: Topic
description: What is your question about?
options:
- The paper
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: textarea
id: question
attributes:
label: Question
description: What is your question?
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
value:
validations:
required: true
================================================
FILE: LICENSE-APACHE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LICENSE-MIT
================================================
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
================================================
FILE: LICENSE.md
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: PULL_REQUEST_TEMPLATE.md
================================================
## Checklist
- [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
- [ ] Run pre-commit hook.
- [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
## PR Description
================================================
FILE: README.md
================================================
# M👁️shiVis: Teaching Speech Models to Converse about Images

[[Preprint]][moshi-vision-arxiv] [[Demo]][talk-to-moshivis] [[Models on Hugging Face]](https://huggingface.co/collections/kyutai/)
MoshiVis is a **Vision Speech Model** (VSM) directly building on the speech-text foundation model [Moshi][moshi-arxiv] and augmenting it with the ability to freely discuss about an image while maintaining its natural conversation style and low latency. In total, MoshiVis adds $\sim$ 206M adapters parameters on top of the 7B Moshi and a pretrained frozen 400M PaliGemma2 vision encoder.
This repository currently contains inference code to run your own MoshiVis server supporting three different backends via a webUI frontend. We are also planning to release training/finetuning code in the future.
For more information about our speech codec Mimi and speech model Moshi, please visit the original [Moshi repo][moshi-github].
For more technical details on MoshiVis, see our [blog post][blog] and [preprint][moshi-vision-arxiv].
[Talk to MoshiVis][talk-to-moshivis] now on our live demo !
To inject visual inputs in the stream of *speech tokens* from Moshi, we extend the core transformer with a **cross-attention mechanism** to infuse visual information into the speech tokens stream. To maintain Moshi's **low-latency** and reduce memory usage, the cross-attention projection weights are shared **across layers.**
Moreover, to ensure that Moshi’s original conversational abilities are not lost in the process, the cross-attention modules feature a gating mechanism that allows the model to modulate the visual input stream at will.
For more details on MoshiVis, including our training pipeline, synthetic data generation pipeline, and ablation experiments on the gating mechanism see our [preprint][moshi-vision-arxiv].
## Model Release
We release MoshikaVis, based on the original Moshika (*female voice*) checkpoints from Moshi's open-source release. For the image embedding part, we rely on publicly available off-the-shelf image-text encoders: The checkpoints we release use the frozen weights of a vision encoder from the [PaliGemma2](https://arxiv.org/abs/2412.03555) family, specifically on the weights provided at [huggingface](https://huggingface.co/google/paligemma2-3b-pt-448). Note that for convenience, each MoshiVis checkpoint contains the full model: i.e., the vision adaptation modules weights are bundled together with the weights of Mimi (speech codec), the Helium text tokenizer, image encoder, and base Moshi model.
For each model, we release several variants compatible with three different backends and quantization formats. Further instructions for each backend can be found below.
| Backend | Moshi**ka** |
| ------- | ----------- |
| [PyTorch](#pytorch-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-pytorch-bf16) |
| [Rust](#rust-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-candle-bf16) [Q8_0](https://huggingface.co/kyutai/moshika-vis-candle-q8) |
| [MLX](#mlx-backend) | [BF16](https://huggingface.co/kyutai/moshika-vis-mlx-bf16) |
All model weights (*excluding the bundled vision encoder*) are released under the CC-BY 4.0 license; The bundled vision encoder (*PaliGemma2's vision encoder*) is released under the [Gemma license](https://ai.google.dev/gemma/terms).
## Organisation of the Repository
For the **frontend**, we recommend using the provided web UI as it allows for additional echo cancellation that helps
the overall model quality. To obtain the client, you can either **(i)** build it yourself from the sources in [`client`](client/) as [described here](#building-the-frontend) or **(ii)** download the pre-built static
version we provide:
```bash
# Download prebuilt client sources
# option 1: using uv dependency manager
uv run scripts/get_static_client.py
# OR option 2: with pip
pip install fire rich huggingface_hub
python scripts/get_static_client.py
```
Most commands below will serve this UI by default using the `https` protocol (see more info [here](#http-vs-https)). To connect via `https`, you will need to generate SSL certificates first, as follows:
```bash
# Generate the SSL certificates in the root directory
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem
```
We provide three different **backends** for the MoshiVis inference stack in this repo. While we hope that the present codebase will work on Windows, we do not provide official support for it.
- A [PyTorch](#pytorch-backend) version in the [`kyuteye_pt`](kyuteye_pt) directory.
- A [Rust](#rust-backend) version (as used in the online demo) is in the [`kyuteye_rs`](kyuteye_rs/) directory.
- A [MLX](#mlx-backend) version (tested on a MacBook Pro M3) is in the [`kyuteye_mlx`](kyuteye_mlx/) directory
For the PyTorch and MLX backends, we recommend using [uv](https://docs.astral.sh/uv/) to setup and run the code,
as it will manage all dependencies for you transparently.
`uv` is provided as a lightweight binary and can be installed as:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
### PyTorch Backend
> Note: At the moment, we do not support quantization
> for the PyTorch version, so you will need a GPU with a significant amount of memory ($\sim$ 24GB).
You can start the MoshiVis PyTorch server with the following command and then access the web UI on [https://localhost:8008](https://localhost:8008)
```bash
cd kyuteye_pt
uv run server configs/moshika-vis.yaml --port 8088
```
Note that if your GPU is on a distant machine, you may need to forward the remote 8088 port to your localhost using ssh `-L` flag. Then connects to [https://localhost:8088](https://localhost:8088) as mentionned previously.
### Rust Backend
> For the Rust backend, you will need a recent version of the [Rust toolchain](https://rustup.rs/).
> To compile GPU support, you will need a valid [CUDA](https://developer.nvidia.com/cuda-toolkit) installation, in particular with `nvcc`.
In order to run the Rust inference server, use the following command:
```bash
cd kyuteye_rs
pip install pkg-config
cargo run --features cuda --bin moshi-backend -r -- --config configs/config-moshika-vis.json standalone --vis
```
When using macOS, you can replace `--features cuda` with `--features metal`.
Alternatively you can use `config-moshika-vis-q8.json` rather than `config-moshika-vis.json` to use the
quantized q8 model. You can also change some of the server options (e.g., starting port) in the json file directly.
Once the server has printed 'standalone worker listening', this means the model is ready.
By default the Rust server will be accessible at [https://localhost:8088](https://localhost:8088).
### MLX Backend
We provide a MLX model checkpoint in `bfloat16` as well as quantized checkpoints
using `q4` and `q8`.
To start the MoshiVis MLX backend you can then run the following commands:
```bash
cd kyuteye_mlx
# In bfloat16 - weights will be downloaded from HF
uv run server
# In q4
uv run server -q 4
# In q8
uv run server -q 8
```
You can then access the web UI at [http://localhost:8008](http://localhost:8008).
Note that unlike other backends, not all settings available in the web UI are propagated to the MLX backend. Instead, you can configure some options directly via the command line e.g. `--text-temperature`.
### Frontends
#### WebUI
We recommend using the WebUI frontend as explained [here](#organisation-of-the-repository). If you want to build the sources yourself, follow these steps (further installation and build instructions can be found in the `client` directory):
**via NPM.**
```bash
cd client
npm install
npm run build
```
**via Docker.** If you have `docker` installed, you can also build the client via
```bash
docker buildx bake client
```
After building the sources, the static dir for the web UI can then be found in the
`client/dist` directory, and will be used as default for the different backend.
#### Rust Command Line
Alternatively, we also provide a command line interface for the Rust backend:
```bash
cd kyuteye_rs;
cargo run --bin moshi-cli -r -- tui --host localhost
```
## Troubleshooting
### http vs https
By default, the web UI server starts with the `https` protocol rather than `http`: Accessing a server that is not localhost via `http` may cause issues with using the microphone in the web UI (in some browsers this is only allowed using https).
To use an `https` connection, you will first need to setup SSL certificates:
```bash
# Generate the SSL certificates in the root directory
# pip install openssl
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem
```
Note that if you want to use a `http` connection instead you can:
* For the PyTorch backend, add the flag `--ssl False`
* For the MLX backend, `http` is the default and `https` can be used with `--ssl certdir` where `certdir` is the directory that contains the certificates.
Note that when using `https` you may get warnings from the browser about the site being unsafe.
When using chrome for instance, you
can bypass these by selecting "Details" or "Advanced", then "Visit this unsafe
site" or "Proceed to localhost (unsafe)".
## License
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
The web client code is provided under the MIT license.
The model weights (*excluding the vision encoder*) for the models are released under the CC-BY 4.0 license; the vision encoder is licensed under Apache 2.0.
All images displayed in the web UI are obtained under the free Unsplash license. For the precise list of image urls and authors, please refer to [this file](client/public/assets/images/demo/attribution.txt).
## Datasets
We also release two data-related artifacts to accompany MoshiVis:
* In the `ssvd` directory, we include code and instructions to reproduce our synthetic visual dialogue datasets described in Section 3.3 and Appendix E of our preprint
* For evaluation purposes, we also release [`Babillage`](https://huggingface.co/datasets/kyutai/Babillage) on HuggingFace, which contains spoken versions of three common VLM benchmarks (COCO-Captions 2014, OCR-VQA and VQAv2) for prompting the model's visual understanding in audio form.
## Citation
If you use MoshiVis in your research, please cite our work:
```
@article{kyutai2025moshivis,
author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and
Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez},
year = {2025},
title = {Vision-Speech Models: Teaching Speech Models to Converse about Images},
journal = {ArXiv},
url = {https://arxiv.org/abs/2503.15633}
}
@techreport{kyutai2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and
Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
year={2024},
eprint={2410.00037},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2410.00037},
}
```
[blog]: https://kyutai.org/moshivis
[moshi-vision-arxiv]: https://arxiv.org/abs/2503.15633
[moshi-arxiv]: https://arxiv.org/abs/2410.00037
[moshi-github]: https://github.com/kyutai-labs/moshi/tree/main?tab=readme-ov-file#models
[talk-to-moshivis]: https://vis.moshi.chat
================================================
FILE: client/.eslinrc.json
================================================
{
"env": {
"browser": true,
"es2021": true
},
"extends": [
"plugin:react/recommended",
"standard-with-typescript",
"plugin:import/typescript",
"plugin:prettier/recommended"
],
"parser": "@typescript-eslint/parser",
"overrides": [],
"parserOptions": {
"ecmaVersion": "latest",
"sourceType": "module",
"project": "./tsconfig.json"
},
"plugins": ["react", "prettier"],
"rules": {
"@typescript-eslint/triple-slash-reference": "off"
}
}
================================================
FILE: client/.nvmrc
================================================
v20.12.2
================================================
FILE: client/.prettierignore
================================================
dist/*
================================================
FILE: client/.prettierrc.json
================================================
{
"arrowParens": "avoid",
"singleQuote": false,
"trailingComma": "all",
"tabWidth": 2,
"useTabs": false,
"semi": true,
"printWidth": 80,
"plugins": ["prettier-plugin-tailwindcss"]
}
================================================
FILE: client/Dockerfile
================================================
FROM node:20 AS builder
WORKDIR /app
COPY . /app
RUN npm install
RUN npx vite build
FROM scratch AS build_result
COPY --from=builder /app/dist /
================================================
FILE: client/LICENSE
================================================
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
================================================
FILE: client/README.md
================================================
# moshi-client
Frontend for the demo.
## Quickstart
To start developping, you will need a basic environment with NodeJS, for instance:
```bash
cd client
micromamba create -n node22 python=3.10
micromamba activate node22
micromamba install nodejs=22.11
# install
npm install
```
Alternatively, you can use [NVM](https://github.com/nvm-sh/nvm) to help you manage your node version and make sure you're on the recommended version for this project. If you do so run, `nvm use`.
To run the client in dev mode, use:
```bash
# typically will start on port 5173
npm run dev
```
When you're satisfied, build the client (in `dist` directory) that will be used as
static dir by the different backends:
```bash
npm run build
```
If Docker is available, you can skip all the previous steps and just run
```
docker buildx bake
```
from the root of this repository. It will output the static sources for the website in `client/dist`.
### License
The present code is provided under the MIT license.
================================================
FILE: client/index.html
================================================
moshi.chat
================================================
FILE: client/package.json
================================================
{
"name": "kyutai-client",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc && vite build",
"lint": "eslint",
"lint:fix": "eslint --fix",
"prettier": "prettier --write .",
"preview": "vite preview"
},
"devDependencies": {
"@eslint/js": "^9.3.0",
"@types/react": "^18.3.1",
"@types/react-dom": "^18.3.0",
"@types/ws": "^8.5.10",
"autoprefixer": "^10.4.19",
"daisyui": "^4.12.2",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-react": "^7.34.1",
"globals": "^15.2.0",
"postcss": "^8.4.38",
"prettier": "^3.2.5",
"prettier-eslint": "^16.3.0",
"prettier-plugin-tailwindcss": "^0.5.14",
"tailwindcss": "^3.4.3",
"typescript": "^5.2.2",
"typescript-eslint": "^7.9.0",
"vite": "^6.2.1",
"vite-plugin-top-level-await": "^1.4.1"
},
"dependencies": {
"eruda": "^3.0.1",
"opus-recorder": "^8.0.5",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-router-dom": "^6.23.1",
"webm-duration-fix": "^1.0.4",
"ws": "^8.16.0",
"zod": "^3.23.8"
}
}
================================================
FILE: client/postcss.config.js
================================================
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
};
================================================
FILE: client/public/assets/images/demo/attribution.txt
================================================
image1.jpg https://unsplash.com/photos/seven-brushes-and-water-color-palette-TTwwVG4Isjw Crystal de Passillé-Chabot
image2.jpg https://unsplash.com/photos/a-bunch-of-stuffed-animals-that-are-on-a-shelf-uTDvpjnF2nw Hoyoun Lee
image3.jpg https://unsplash.com/photos/an-orange-and-white-clownfish-in-an-aquarium-8iHpPG7Vk9Y James Lee
image4.jpg https://unsplash.com/photos/red-dragon-action-figure-on-table-X-A-LJVAhzk Clint Bustrillos
image5.jpg https://unsplash.com/photos/panda-bear-sitting-on-bamboo-sticks-surrounded-with-trees-NsNRu6dfRds Ying Wu
image6.jpg https://unsplash.com/photos/carousel-with-string-lights-OEBeLcrzlaw cmophoto.net
image7.jpg https://unsplash.com/photos/flight-of-pigeons-flying-above-grass-field-near-eiffel-tower-in-paris-m45uW4f9YQg Stijn te Strake
image8.jpg https://unsplash.com/photos/gray-typewriter-and-macbook-1F4MukO0UNg Glenn Carstens-Peters
image9.jpg https://unsplash.com/photos/a-couple-of-sandwiches-sitting-on-top-of-a-cutting-board-smv9xho-dnE Deepthi Clicks
image10.jpg https://unsplash.com/photos/man-in-white-hat-and-black-shirt-painting-nkVa4ylaWG0 Federico Scarionati
image11.jpg https://unsplash.com/photos/a-statue-of-a-gnome-next-to-a-light-pole-UNT9ExjTgZE Lionel Mermoz
image12.jpg https://unsplash.com/photos/brown-concrete-building-on-green-grass-field-during-daytime-mDIMJzdu5D0 Celine Chamiot-Poncet
image13.jpg https://unsplash.com/photos/astronaut-in-spacesuit-floating-in-space-Yj1M5riCKk4 NASA
image14.jpg https://unsplash.com/photos/marble-toy-lot-near-yellow-drawstring-pouch-1kZzV02D2hM Crissy Jarvis
image15.jpg https://unsplash.com/photos/person-holding-black-frying-pan-APDMfLHZiRA Kevin McCutcheon
image16.jpg https://unsplash.com/photos/selective-focus-photo-of-four-green-humming-birds-with-red-flowers-5TU1htuOUn4 James wainscoat
image17.jpg https://unsplash.com/photos/orange-and-white-tabby-cat-sitting-on-brown-wooden-table-in-kitchen-room-w2DsS-ZAP4U Paul Hanoka
image18.jpg https://unsplash.com/photos/lantern-on-the-street-at-nighttime--F3wMFrZ7z0 Denys Nevozhai
image19.jpg https://unsplash.com/photos/baked-breads-in-rack-ZnPNZpjzi0M Dan Gold
image20.jpg https://unsplash.com/photos/lemonades-on-tray-JB5YCqOXV1o Rod Long
================================================
FILE: client/src/app.tsx
================================================
import ReactDOM from "react-dom/client";
import {
createBrowserRouter,
RouterProvider,
} from "react-router-dom";
import "./index.css";
// @ts-expect-error - Worker is not recognized by the TS compiler
import { DecoderWorker } from "./decoder/decoderWorker";
import { Queue } from "./pages/Queue/Queue";
const router = createBrowserRouter([
{
path: "/",
element: ,
},
]);
ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render(
);
================================================
FILE: client/src/audio-processor.ts
================================================
// @ts-nocheck
function asMs(samples) {
return (samples * 1000 / sampleRate).toFixed(1);
}
function asSamples(mili) {
return Math.round(mili * sampleRate / 1000);
}
class MoshiProcessor extends AudioWorkletProcessor {
constructor() {
super();
console.log("Moshi processor lives", currentFrame, sampleRate);
console.log(currentTime);
// Buffer length definitions
let frameSize = asSamples(80);
// initialBufferSamples: we wait to have at least that many samples before starting to play
this.initialBufferSamples = 2 * frameSize;
// once we have enough samples, we further wait that long before starting to play.
// This allows to have buffer lengths that are not a multiple of frameSize.
this.partialBufferSamples = asSamples(80);
// If the buffer length goes over that many, we will drop the oldest packets until
// we reach back initialBufferSamples + partialBufferSamples.
this.maxBufferSamples = asSamples(80);
// increments
this.partialBufferIncrement = asSamples(40);
this.maxPartialWithIncrements = asSamples(240);
this.maxBufferSamplesIncrement = asSamples(40);
this.maxMaxBufferWithIncrements = asSamples(240);
// State and metrics
this.initState();
this.port.onmessage = (event) => {
if (event.data.type == "reset") {
console.log("Reset audio processor state.");
this.initState();
return;
}
let frame = event.data.frame;
this.frames.push(frame);
if (this.currentSamples() >= this.initialBufferSamples && !this.started) {
this.start();
}
if (this.pidx < 20) {
console.log(this.timestamp(), "Got packet", this.pidx++, asMs(this.currentSamples()), asMs(frame.length))
}
if (this.currentSamples() >= this.totalMaxBufferSamples()) {
console.log(this.timestamp(), "Dropping packets", asMs(this.currentSamples()), asMs(this.totalMaxBufferSamples()));
let target = this.initialBufferSamples + this.partialBufferSamples
while (this.currentSamples() > (this.initialBufferSamples + this.partialBufferSamples)) {
let first = this.frames[0];
let to_remove = this.currentSamples() - target;
to_remove = Math.min(first.length - this.offsetInFirstBuffer, to_remove);
this.offsetInFirstBuffer += to_remove;
this.timeInStream += to_remove / sampleRate;
if (this.offsetInFirstBuffer == first.length) {
this.frames.shift();
this.offsetInFirstBuffer = 0;
}
}
console.log(this.timestamp(), "Packet dropped", asMs(this.currentSamples()));
this.maxBufferSamples += this.maxBufferSamplesIncrement;
this.maxBufferSamples = Math.min(this.maxMaxBufferWithIncrements, this.maxBufferSamples);
console.log("Increased maxBuffer to", asMs(this.maxBufferSamples));
}
let delay = this.currentSamples() / sampleRate;
this.port.postMessage({
totalAudioPlayed: this.totalAudioPlayed,
actualAudioPlayed: this.actualAudioPlayed,
delay: event.data.micDuration - this.timeInStream,
minDelay: this.minDelay,
maxDelay: this.maxDelay,
});
};
}
initState() {
this.frames = new Array();
this.offsetInFirstBuffer = 0;
this.firstOut = false;
this.remainingPartialBufferSamples = 0;
this.timeInStream = 0.;
this.resetStart();
// Metrics
this.totalAudioPlayed = 0.;
this.actualAudioPlayed = 0.;
this.maxDelay = 0.;
this.minDelay = 2000.;
// Debug
this.pidx = 0;
// For now let's reset the buffer params.
this.partialBufferSamples = asSamples(80);
this.maxBufferSamples = asSamples(80);
}
totalMaxBufferSamples() {
return this.maxBufferSamples + this.partialBufferSamples + this.initialBufferSamples;
}
timestamp() {
return Date.now() % 1000;
}
currentSamples() {
let samples = 0;
for (let k = 0; k < this.frames.length; k++) {
samples += this.frames[k].length
}
samples -= this.offsetInFirstBuffer;
return samples;
}
resetStart() {
this.started = false;
}
start() {
this.started = true;
this.remainingPartialBufferSamples = this.partialBufferSamples;
this.firstOut = true;
}
canPlay() {
return this.started && this.frames.length > 0 && this.remainingPartialBufferSamples <= 0;
}
process(inputs, outputs, parameters) {
let delay = this.currentSamples() / sampleRate;
if (this.canPlay()) {
this.maxDelay = Math.max(this.maxDelay, delay);
this.minDelay = Math.min(this.minDelay, delay);
}
const output = outputs[0][0];
if (!this.canPlay()) {
if (this.actualAudioPlayed > 0) {
this.totalAudioPlayed += output.length / sampleRate;
}
this.remainingPartialBufferSamples -= output.length;
return true;
}
if (this.firstOut) {
console.log(this.timestamp(), "Audio resumed", asMs(this.currentSamples()), this.remainingPartialBufferSamples);
}
let first = this.frames[0];
let out_idx = 0;
while (out_idx < output.length && this.frames.length) {
let first = this.frames[0];
let to_copy = Math.min(first.length - this.offsetInFirstBuffer, output.length - out_idx);
output.set(first.subarray(this.offsetInFirstBuffer, this.offsetInFirstBuffer + to_copy), out_idx);
this.offsetInFirstBuffer += to_copy;
out_idx += to_copy;
if (this.offsetInFirstBuffer == first.length) {
this.offsetInFirstBuffer = 0;
this.frames.shift();
}
}
if (this.firstOut) {
this.firstOut = false;
for (let i = 0; i < out_idx; i++) {
output[i] *= i / out_idx;
}
}
if (out_idx < output.length) {
console.log(this.timestamp(), "Missed some audio", output.length - out_idx);
this.partialBufferSamples += this.partialBufferIncrement;
this.partialBufferSamples = Math.min(this.partialBufferSamples, this.maxPartialWithIncrements);
console.log("Increased partial buffer to", asMs(this.partialBufferSamples));
// We ran out of a buffer, let's revert to the started state to replenish it.
this.resetStart();
for (let i = 0; i < out_idx; i++) {
output[i] *= (out_idx - i) / out_idx;
}
}
this.totalAudioPlayed += output.length / sampleRate;
this.actualAudioPlayed += out_idx / sampleRate;
this.timeInStream += out_idx / sampleRate;
return true;
}
}
registerProcessor("moshi-processor", MoshiProcessor);
================================================
FILE: client/src/components/Button/Button.tsx
================================================
import { FC } from "react";
type ButtonProps = React.ButtonHTMLAttributes;
export const Button: FC = ({ children, className, ...props }) => {
return (
{children}
);
};
================================================
FILE: client/src/components/ImageGallery/ImageGallery.tsx
================================================
import { useState, ChangeEvent } from "react";
import { Button } from "../Button/Button";
// Natural images
import img1 from "/assets/images/demo/image1.jpg";
import img2 from "/assets/images/demo/image2.jpg";
import img3 from "/assets/images/demo/image3.jpg";
import img4 from "/assets/images/demo/image4.jpg";
import img5 from "/assets/images/demo/image5.jpg";
import img6 from "/assets/images/demo/image6.jpg";
import img7 from "/assets/images/demo/image7.jpg";
import img8 from "/assets/images/demo/image8.jpg";
import img9 from "/assets/images/demo/image9.jpg";
import img10 from "/assets/images/demo/image10.jpg";
import img11 from "/assets/images/demo/image11.jpg";
import img12 from "/assets/images/demo/image12.jpg";
import img13 from "/assets/images/demo/image13.jpg";
import img14 from "/assets/images/demo/image14.jpg";
import img15 from "/assets/images/demo/image15.jpg";
import img16 from "/assets/images/demo/image16.jpg";
import img17 from "/assets/images/demo/image17.jpg";
import img18 from "/assets/images/demo/image18.jpg";
import img19 from "/assets/images/demo/image19.jpg";
import img20 from "/assets/images/demo/image20.jpg";
const images = [
img1,
img2,
img3,
img4,
img5,
img6,
img7,
img8,
img9,
img10,
img11,
img12,
img13,
img14,
img15,
img16,
img17,
img18,
img19,
img20,
]
var images_order: number[] = [];
for (let i = 0; i < images.length; i++) {
images_order.push(i)
}
type ImageGalleryProps = React.InputHTMLAttributes & {
// Properties for the ImageGallery
paramsSetter: Function;
clickAction: Function;
size: number;
numImages: number;
}
type ImageItemProps = React.InputHTMLAttributes & {
// Properties for a single item in the ImageGallery
// Two actions:
// paramsSetter sets the chosen image url into the model params
// clickAction then starts the conversation
paramsSetter: Function;
clickAction: Function;
size: number;
imageUrl: string;
}
function ImageSelect(props: ImageItemProps) {
// Represents a single image in the gallery
const [isHover, setIsHover] = useState(false);
const handleMouseEnter = () => {
setIsHover(true);
};
const handleMouseLeave = () => {
setIsHover(false);
};
let bordercolor = isHover ? "#f7a319" : "black";
let bgalpha = isHover ? 0.05 : 0.6;
let textalpha = isHover ? 1.0 : 0.0
let label = isHover ? "Select" : "X";
let style = {
width: props.size,
height: props.size,
background: `url(${props.imageUrl})`,
backgroundSize: "100% 100%",
border: `3px solid ${bordercolor}`,
margin: "2px",
padding: "0px",
color: `rgba(255, 255, 255, ${textalpha})`,
boxShadow: `inset 0 0 0 1000px rgba(0,0,0,${bgalpha})`,
textShadow: `2px 2px 2px rgba(0, 0, 0, ${textalpha})`
};
return (
{ await props.paramsSetter(props.imageUrl); sessionStorage.removeItem("imageUrl"); props.clickAction() }
} > {label}
);
}
const shuffle = (array: number[]) => {
return array.sort(() => Math.random() - 0.5);
};
export const ImageGallery = (props: ImageGalleryProps) => {
const [ordering, SetOrdering] = useState(images_order);
const [preview, setPreview] = useState(sessionStorage.getItem("imageUrl"));
const handleFileChange = (e: ChangeEvent, isCapture: boolean) => {
if (e.target.files && e.target.files[0]) {
const file = e.target.files[0];
const url = URL.createObjectURL(file);
setPreview(url);
props.paramsSetter(url);
// only save the image URL when it's an uploaded file
// doesn't really seem to work with one-shot photo otherwise
if (!isCapture) {
sessionStorage.setItem("imageUrl", url);
}
}
};
const resetFile = () => {
setPreview(null);
props.paramsSetter(undefined);
sessionStorage.removeItem("imageUrl");
};
function handleShuffle() {
SetOrdering(shuffle([...ordering]));
}
// Image Gallery widget (random subset)
const steps = [];
for (let i = 0; i < props.numImages; i++) {
steps.push( );
}
return (
{preview &&
}
{preview && await props.clickAction()}>Connect }
{preview && X }
{!preview && }
{!preview && }
{!preview && }
{!preview &&
⟳
}
{steps}
)
;
};
================================================
FILE: client/src/components/Input/Input.tsx
================================================
type InputProps = React.InputHTMLAttributes & {
error?: string;
}
export const Input = ({className, error, ...props}:InputProps) => {
return (
);
}
================================================
FILE: client/src/decoder/decoderWorker.ts
================================================
export const DecoderWorker = new Worker(
new URL("/assets/decoderWorker.min.js", import.meta.url),
);
================================================
FILE: client/src/env.ts
================================================
type ENV = {
VITE_QUEUE_API_PATH: string;
VITE_ENV: 'development' | 'production';
};
const parseEnv = (): ENV => {
const VITE_QUEUE_API_PATH = import.meta.env.VITE_QUEUE_API_PATH;
if (!VITE_QUEUE_API_PATH) {
throw new Error("VITE_QUEUE_API_PATH is not defined");
}
return {
VITE_QUEUE_API_PATH,
VITE_ENV: import.meta.env.DEV ? 'development' : 'production',
};
};
export const env = parseEnv();
================================================
FILE: client/src/index.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer utilities {
/* Hide scrollbar for Chrome, Safari and Opera */
.no-scrollbar::-webkit-scrollbar {
display: none;
}
/* Hide scrollbar for IE, Edge and Firefox */
.no-scrollbar {
-ms-overflow-style: none;
/* IE and Edge */
scrollbar-width: none;
/* Firefox */
}
.scrollbar::-webkit-scrollbar {
width: 10px;
}
.scrollbar::-webkit-scrollbar-track {
background: transparent;
}
.scrollbar::-webkit-scrollbar-thumb {
background: white;
border: 3px solid #f6f7ed;
}
}
.settingsbutton#changed:before {
content: "C";
width: 13px;
height: 13px;
line-height: 18px;
text-align: center;
display: block;
border-radius: 50%;
background: #54e8b3;
border: 1px solid #FFF;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4);
color: #FFF;
position: absolute;
top: -7px;
right: -7px;
}
.main-grid {
display: grid;
grid-template-columns: 1fr;
grid-template-rows: min-content 1fr 1fr;
gap: 30px;
grid-auto-flow: column;
grid-template-areas:
"controls"
"player"
"player-text";
@media screen and (min-width: 768px) {
grid-template-columns: 2fr 2.5fr;
grid-template-rows: min-content min-content min-content 1fr;
gap: 30px 30px;
grid-auto-flow: column;
align-items: center;
justify-items: center;
grid-template-areas:
"controls controls"
"player player-stats"
"player player-text"
"player player-text";
}
}
.presentation {
max-width: 450px;
}
.presentation>p {
padding-top: 10px;
}
.gallery {
max-width: 450px;
}
.cute-words {
color: #54e8b3;
}
.vis-words {
color: #f7a319;
}
.explain-links {
color: #BCFCE5;
}
.controls {
grid-area: controls;
}
.player {
grid-area: player;
grid-template-areas:
"server-audio"
"user-audio"
"user-image"
"download-links";
display: grid;
grid-template-columns: 1fr 1fr;
grid-template-rows: 3fr;
justify-items: stretch;
row-gap: 30px;
/* margin:auto; */
}
.user-image {
grid-area: user-image;
grid-column: 1 / -1;
grid-row: 1;
height: 200px
}
.user-image img {
height: 100%;
width: auto;
margin: auto
}
.server-audio {
grid-area: server-audio;
grid-column: 1;
grid-row: 2;
}
.user-audio {
grid-area: user-audio;
grid-column: 2;
grid-row: 2;
}
.download-links {
grid-area: download-links;
grid-column: 1/-1;
grid-row: 3;
color: #54e8b3;
height: 10%;
}
.player-stats {
grid-area: player-stats;
width: 100%;
height: 100%;
}
.commands {
grid-area: commands;
width: 100%;
height: 100%;
}
.player-text {
grid-area: player-text;
width: 100%;
height: 100%;
overflow: scroll;
}
================================================
FILE: client/src/modules.d.ts
================================================
declare module "opus-recorder";
================================================
FILE: client/src/pages/Conversation/Conversation.tsx
================================================
import { FC, MutableRefObject, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useSocket } from "./hooks/useSocket";
import { SocketContext } from "./SocketContext";
import { ServerAudio } from "./components/ServerAudio/ServerAudio";
import { UserAudio } from "./components/UserAudio/UserAudio";
import { Button } from "../../components/Button/Button";
import { ServerAudioStats } from "./components/ServerAudio/ServerAudioStats";
import { AudioStats } from "./hooks/useServerAudio";
import { TextDisplay } from "./components/TextDisplay/TextDisplay";
import { MediaContext } from "./MediaContext";
import { ServerInfo } from "./components/ServerInfo/ServerInfo";
import { ModelParamsValues, useModelParams } from "./hooks/useModelParams";
import { ModelParams } from "./components/ModelParams/ModelParams";
import fixWebmDuration from "webm-duration-fix";
import canvasLogo from "./canvas-logo.png";
import { getMimeType, getExtension } from "./getMimeType";
type ConversationProps = {
workerAddr: string;
workerAuthId?: string;
sessionAuthId?: string;
sessionId?: number;
email?: string;
audioContext: MutableRefObject;
worklet: MutableRefObject;
onConversationEnd?: () => void;
isBypass?: boolean;
} & Partial;
const buildURL = ({
workerAddr,
params,
workerAuthId,
email,
textSeed,
audioSeed,
}: {
workerAddr: string;
params: ModelParamsValues;
workerAuthId?: string;
email?: string;
textSeed: number;
audioSeed: number;
}) => {
if (workerAddr == "same" || workerAddr == "") {
workerAddr = window.location.hostname + ":" + window.location.port;
console.log("Overriding workerAddr to", workerAddr);
}
const wsProtocol = (window.location.protocol === 'https:') ? 'wss' : 'ws';
const url = new URL(`${wsProtocol}://${workerAddr}/api/chat`);
if (workerAuthId) {
url.searchParams.append("worker_auth_id", workerAuthId);
}
if (email) {
url.searchParams.append("email", email);
}
url.searchParams.append("text_temperature", params.textTemperature.toString());
url.searchParams.append("text_topk", params.textTopk.toString());
url.searchParams.append("audio_temperature", params.audioTemperature.toString());
url.searchParams.append("audio_topk", params.audioTopk.toString());
url.searchParams.append("pad_mult", params.padMult.toString());
url.searchParams.append("text_seed", textSeed.toString());
url.searchParams.append("audio_seed", audioSeed.toString());
url.searchParams.append("repetition_penalty_context", params.repetitionPenaltyContext.toString());
url.searchParams.append("repetition_penalty", params.repetitionPenalty.toString());
// Add image params if given
if (params.imageUrl != undefined) {
url.searchParams.append("image_resolution", params.imageResolution.toString());
url.searchParams.append("center_crop", params.centerCrop.toString());
url.searchParams.append("xa_start", params.gateDelay.toString());
url.searchParams.append("text_temperature_gating_influence", params.gateInfluence.toString());
}
return url.toString();
};
export const Conversation: FC = ({
workerAddr,
workerAuthId,
audioContext,
worklet,
sessionAuthId,
sessionId,
onConversationEnd,
isBypass = false,
email,
...params
}) => {
const getAudioStats = useRef<() => AudioStats>(() => ({
playedAudioDuration: 0,
missedAudioDuration: 0,
totalAudioMessages: 0,
delay: 0,
minPlaybackDelay: 0,
maxPlaybackDelay: 0,
}));
const isRecording = useRef(false);
const videoChunks = useRef([]);
const audioChunks = useRef([]);
const audioStreamDestination = useRef(audioContext.current.createMediaStreamDestination());
const mediaRecorder = useRef(null);
const audioRecorder = useRef(new MediaRecorder(audioStreamDestination.current.stream, { mimeType: getMimeType("audio"), audioBitsPerSecond: 128000 }));
const [videoURL, setVideoURL] = useState("");
const [audioURL, setAudioURL] = useState("");
const [userRating, setUserRating] = useState(0);
const [userRatingHovered, setUserRatingHovered] = useState(0);
const [baseBlobName, setBaseBlobName] = useState("moshi");
const [isOver, setIsOver] = useState(false);
const modelParams = useModelParams(params);
const micDuration = useRef(0);
const actualAudioPlayed = useRef(0);
const textContainerRef = useRef(null);
const textSeed = useMemo(() => Math.round(1000000 * Math.random()), []);
const audioSeed = useMemo(() => Math.round(1000000 * Math.random()), []);
const canvasRef = useRef(null);
const logoRef = useRef(null);
const [isLogoLoaded, setIsLogoLoaded] = useState(false);
const WSURL = buildURL({
workerAddr,
params: modelParams,
workerAuthId,
email: email,
textSeed: textSeed,
audioSeed: audioSeed,
});
const onDisconnect = useCallback(() => {
setIsOver(true);
console.log("on disconnect!");
stopRecording();
}, [setIsOver]);
const { isConnected, sendMessage, socket, start, stop } = useSocket({
// onMessage,
uri: WSURL,
onDisconnect,
imageUrl: params.imageUrl,
});
useEffect(() => {
audioRecorder.current.ondataavailable = (e) => {
audioChunks.current.push(e.data);
};
audioRecorder.current.onstop = async () => {
let blob: Blob;
const mimeType = getMimeType("audio");
if (mimeType.includes("webm")) {
blob = await fixWebmDuration(new Blob(audioChunks.current, { type: mimeType }));
} else {
blob = new Blob(audioChunks.current, { type: mimeType });
}
setAudioURL(URL.createObjectURL(blob));
audioChunks.current = [];
console.log("Audio Recording and encoding finished");
};
}, [mediaRecorder, audioRecorder, setVideoURL, setAudioURL, videoChunks, audioChunks]);
const RatingButton = (props: { rating: number }) => {
const [isHover, setIsHover] = useState(false);
const handleMouseEnter = () => {
setUserRatingHovered(props.rating);
setIsHover(true);
};
const handleMouseLeave = () => {
setUserRatingHovered(0);
setIsHover(false);
};
let style = {
color: (isHover || userRating >= props.rating || userRatingHovered >= props.rating) ? `#f7a319` : '#333333',
};
return (
{
setUserRating(props.rating); sendMessage({
type: "user_rating",
data: props.rating,
})
}}
>
★
);
};
useEffect(() => {
start();
return () => {
stop();
};
}, [start, workerAuthId]);
useEffect(() => {
if (!canvasRef) {
console.log("No canvas ref");
return;
}
if (!logoRef) {
console.log("No logo ref");
return;
}
if (!isLogoLoaded) {
console.log("Logo not loaded");
return;
}
if (!canvasRef.current) {
console.log("No canvas");
return;
}
if (!logoRef.current) {
console.log("No logo");
return;
}
const ctx = canvasRef.current.getContext("2d");
if (ctx) {
ctx.drawImage(logoRef.current, 20, 250, 320, 98);
ctx.lineWidth = 1;
ctx.strokeStyle = "white";
ctx.strokeRect(5, 5, 370, 370);
}
}, [canvasRef, logoRef, isLogoLoaded]);
const startRecording = useCallback(() => {
if (isRecording.current) {
return;
}
console.log(Date.now() % 1000, "Starting recording");
console.log("Starting recording");
if (canvasRef.current) {
// Note: Attaching a track from this stream to the existing MediaRecorder
// rather than creating a new MediaRecorder for the canvas stream
// doesn't work on Safari as it just ends the recording immediately.
// It works on Chrome though and is much cleaner.
console.log("Adding canvas to stream");
const captureStream = canvasRef.current.captureStream(30);
captureStream.addTrack(audioStreamDestination.current.stream.getAudioTracks()[0]);
mediaRecorder.current = new MediaRecorder(captureStream, { mimeType: getMimeType("video"), videoBitsPerSecond: 1000000 });
mediaRecorder.current.ondataavailable = (e) => {
console.log("Video data available");
videoChunks.current.push(e.data);
};
mediaRecorder.current.onstop = async () => {
let blob: Blob;
const mimeType = getMimeType("video");
if (mimeType.includes("webm")) {
blob = await fixWebmDuration(new Blob(videoChunks.current, { type: mimeType }));
} else {
blob = new Blob(videoChunks.current, { type: mimeType });
}
setVideoURL(URL.createObjectURL(blob));
videoChunks.current = [];
console.log("Video Recording and encoding finished");
};
}
worklet.current?.connect(audioStreamDestination.current);
// videoStream.current.addTrack(audioStreamDestination.current.stream.getAudioTracks()[0]);
setVideoURL("");
setAudioURL("");
mediaRecorder.current?.start();
audioRecorder.current.start();
isRecording.current = true;
}, [isRecording, setVideoURL, setVideoURL, worklet, audioStreamDestination, mediaRecorder, audioRecorder, canvasRef]);
const stopRecording = useCallback(() => {
console.log("Stopping recording");
console.log("isRecording", isRecording)
if (!isRecording.current) {
return;
}
worklet.current?.disconnect(audioStreamDestination.current);
audioRecorder.current.stop();
mediaRecorder.current?.stop();
isRecording.current = false;
}, [isRecording, worklet, audioStreamDestination, mediaRecorder, audioRecorder]);
return (
{isOver && !isBypass && (
{
// Reload the page to reset the conversation on iOS
const isIOS = /iPad|iPhone|iPod/.test(navigator.userAgent)
if (onConversationEnd && !isIOS) {
onConversationEnd();
return;
}
sessionStorage.setItem("textTemperature", modelParams.textTemperature.toString());
sessionStorage.setItem("textTopk", modelParams.textTopk.toString());
sessionStorage.setItem("audioTemperature", modelParams.audioTemperature.toString());
sessionStorage.setItem("audioTopk", modelParams.audioTopk.toString());
sessionStorage.setItem("padMult", modelParams.padMult.toString());
sessionStorage.setItem("repetitionPenalty", modelParams.repetitionPenalty.toString());
sessionStorage.setItem("repetitionPenaltyContext", modelParams.repetitionPenaltyContext.toString());
sessionStorage.setItem("imageResolution", modelParams.imageResolution.toString());
sessionStorage.setItem("gateDelay", modelParams.gateDelay.toString());
sessionStorage.setItem("gateInfluence", modelParams.gateInfluence.toString());
sessionStorage.setItem("displayColor", modelParams.displayColor.toString());
sessionStorage.setItem("centerCrop", modelParams.centerCrop.toString());
document.location.reload();
}}
>
Start Over
)
}
{
(!isOver || isBypass) && (
{
audioContext.current.resume();
isConnected ? stop() : start();
}}
>
{!isConnected ? "Connect" : "Disconnect"}
)
}
AudioStats) =>
(getAudioStats.current = callback)
}
/>
Feel free to rate the interaction before ending the session:
{audioURL &&
}
{videoURL &&
}
{videoURL && getExtension("video") === "webm" &&
}
setBaseBlobName(x)} />
{!workerAuthId && }
{
console.log("Logo loaded");
setIsLogoLoaded(true);
}} />
);
};
================================================
FILE: client/src/pages/Conversation/MediaContext.ts
================================================
import { MutableRefObject, createContext, useContext } from "react";
type MediaContextType = {
startRecording: () => void;
stopRecording: () => void;
audioContext: MutableRefObject;
audioStreamDestination: MutableRefObject;
worklet: MutableRefObject;
micDuration: MutableRefObject;
actualAudioPlayed: MutableRefObject;
};
export const MediaContext = createContext(null);
export const useMediaContext = () => {
const context = useContext(MediaContext);
if (!context) {
throw new Error(
"useMediaContext must be used within a MediaContextProvider",
);
}
return context;
};
================================================
FILE: client/src/pages/Conversation/SocketContext.ts
================================================
import { createContext, useContext } from "react";
import { WSMessage } from "../../protocol/types";
type SocketContextType = {
isConnected: boolean;
socket: WebSocket | null;
sendMessage: (message: WSMessage) => void;
};
export const SocketContext = createContext({
isConnected: false,
socket: null,
sendMessage: () => {},
});
export const useSocketContext = () => {
return useContext(SocketContext);
};
================================================
FILE: client/src/pages/Conversation/components/AudioVisualizer/AudioVisualizer.tsx
================================================
import { FC, useCallback, useEffect, useRef } from "react";
type AudioVisualizerProps = {
analyser: AnalyserNode | null;
};
export const AudioVisualizer: FC = ({ analyser }) => {
const requestRef = useRef(null);
const canvasRef = useRef(null);
const visualizeData = useCallback(() => {
requestRef.current = window.requestAnimationFrame(() => visualizeData());
if (!canvasRef.current) {
console.log("Canvas not found");
return;
}
const audioData = new Uint8Array(140);
analyser?.getByteFrequencyData(audioData);
const bar_width = 3;
let start = 0;
const ctx = canvasRef.current.getContext("2d");
if (!ctx) {
console.log("Canvas context not found");
return;
}
ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
for (let i = 0; i < audioData.length; i++) {
start = i * 4;
let gradient = ctx.createLinearGradient(
0,
0,
canvasRef.current.width,
canvasRef.current.height,
);
gradient.addColorStop(0.2, "#2392f5");
gradient.addColorStop(0.5, "#fe0095");
gradient.addColorStop(1.0, "purple");
ctx.fillStyle = gradient;
ctx.fillRect(
start,
canvasRef.current.height,
bar_width,
(-audioData[i] * 100) / 255,
);
}
}, [analyser]);
const resetCanvas = useCallback(() => {
if (!canvasRef.current) {
return;
}
const ctx = canvasRef.current.getContext("2d");
if (!ctx) {
return;
}
ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
}, []);
useEffect(() => {
if (!analyser) {
return;
}
visualizeData();
return () => {
if (requestRef.current) {
console.log("Canceling animation frame");
cancelAnimationFrame(requestRef.current);
}
};
}, [visualizeData, analyser, resetCanvas]);
return ;
};
================================================
FILE: client/src/pages/Conversation/components/AudioVisualizer/ClientVisualizer.tsx
================================================
import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react";
import { clamp } from "../../hooks/audioUtils";
type AudioVisualizerProps = {
analyser: AnalyserNode | null;
parent: RefObject;
copyCanvasRef: RefObject;
};
const MAX_INTENSITY = 255;
const COLORS = [
"#197556",
"#299e77",
"#32b89b",
"#31d4b8",
"#14d9d5",
"#41eff2",
"#7ff3f5",
"#789bf5",
"#eb94eb",
"#e63280",
"#c41862",
];
export const ClientVisualizer: FC = ({ analyser, parent, copyCanvasRef }) => {
const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0);
const requestRef = useRef(null);
const canvasRef = useRef(null);
const drawBars = useCallback(
(
ctx: CanvasRenderingContext2D,
x: number,
y: number,
volume: number,
height: number,
width: number,
gap: number,
) => {
const barHeight = height / 10 - gap;
for (let i = 1; i <= 10; i++) {
const barY = y + height + gap + Math.min(1, width / 30) - (i * barHeight + i * gap);
ctx.fillStyle = COLORS[i - 1];
ctx.strokeStyle = "white";
ctx.lineWidth = Math.min(1, height / 100);
if (i <= volume) {
ctx.fillRect(x, barY, width, barHeight);
}
ctx.strokeRect(x, barY, width, barHeight);
}
},
[],
);
const draw = useCallback((ctx: CanvasRenderingContext2D, audioData: Uint8Array, x: number, y: number, width: number, height: number) => {
const stereoGap = Math.floor(width / 30);
const barGap = Math.floor(height / 30);
const padding = Math.floor(width / 30);
const maxBarHeight = Math.floor(height - padding * 2);
const maxBarWidth = Math.floor(
width / 2.5 - stereoGap - padding * 2,
);
const centerX = x + width / 2;
const averageIntensity = Math.sqrt(
audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length,
);
const intensity = clamp(
averageIntensity * 1.4,
averageIntensity,
MAX_INTENSITY,
);
const volume = Math.floor((intensity * 10) / MAX_INTENSITY);
ctx.fillStyle = "rgba(0, 0, 0, 0)";
ctx.fillRect(x, y, width, height);
drawBars(
ctx,
centerX - maxBarWidth - stereoGap / 2,
y,
volume,
maxBarHeight,
maxBarWidth,
barGap,
);
drawBars(
ctx,
centerX + stereoGap / 2,
y,
volume,
maxBarHeight,
maxBarWidth,
barGap,
);
}, [analyser, drawBars]);
const visualizeData = useCallback(() => {
const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0
if (width !== canvasWidth) {
console.log("Setting canvas width");
setCanvasWidth(width);
}
requestRef.current = window.requestAnimationFrame(() => visualizeData());
if (!canvasRef.current) {
console.log("Canvas not found");
return;
}
const audioData = new Uint8Array(140);
analyser?.getByteFrequencyData(audioData);
const ctx = canvasRef.current.getContext("2d");
if (!ctx) {
console.log("Canvas context not found");
return;
}
ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
draw(ctx, audioData, 0, 0, width, width);
if (copyCanvasRef?.current) {
const copyCtx = copyCanvasRef.current.getContext("2d");
if (copyCtx) {
const x = 240;
const y = 140;
const width = 140 / 1.25; // slightly scaled down version
const height = 180 / 1.25; // slightly scaled down version
copyCtx.clearRect(x, y, width, height);
draw(copyCtx, audioData, x, y, width, height);
}
}
}, [analyser, canvasWidth, drawBars, parent, copyCanvasRef, draw]);
useEffect(() => {
visualizeData();
return () => {
if (requestRef.current) {
console.log("Canceling animation frame");
cancelAnimationFrame(requestRef.current);
}
};
}, [visualizeData, analyser]);
return (
);
};
================================================
FILE: client/src/pages/Conversation/components/AudioVisualizer/ServerVisualizer.tsx
================================================
import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react";
import { clamp } from "../../hooks/audioUtils";
import { useSocketContext } from "../../SocketContext";
type AudioVisualizerProps = {
analyser: AnalyserNode | null;
parent: RefObject;
imageUrl: string | undefined;
copyCanvasRef?: RefObject;
};
const MAX_INTENSITY = 255;
export const ServerVisualizer: FC = ({ analyser, parent, imageUrl, copyCanvasRef }) => {
const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0);
const requestRef = useRef(null);
const canvasRef = useRef(null);
const { isConnected } = useSocketContext();
const draw = useCallback((width: number, centerX: number, centerY: number, audioData: Uint8Array, ctx: CanvasRenderingContext2D) => {
const maxCircleWidth = Math.floor(width * 0.95);
const averageIntensity = Math.sqrt(
audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length,
);
const intensity = clamp(
averageIntensity * 1.4,
averageIntensity,
MAX_INTENSITY,
);
const relIntensity = intensity / MAX_INTENSITY;
const radius = ((isConnected ? 0.3 + 0.7 * relIntensity : relIntensity) * maxCircleWidth) / 2;
// Draw a circle with radius based on intensity
ctx.clearRect(centerX - width / 2, centerY - width / 2, width, width);
ctx.fillStyle = 'rgba(0, 0, 0, 0)';
ctx.fillRect(centerX - width / 2, centerY - width / 2, width, width);
ctx.beginPath();
//ctx.fillStyle = "#39e3a7";
ctx.fillStyle = 'rgba(57, 227, 167, 0.5)';
ctx.arc(centerX, centerY, radius, 0, 2 * Math.PI);
ctx.fill();
ctx.closePath();
// Draw an inner circle if we are connected.
if (isConnected) {
ctx.beginPath();
ctx.arc(centerX, centerY, maxCircleWidth / 6, 0, 2 * Math.PI);
// ctx.fillStyle = "#BCFCE5";
ctx.fillStyle = 'rgba(188, 252, 229, 0.5)';
ctx.fill();
ctx.closePath();
}
//Draw a circle with max radius
ctx.beginPath();
ctx.arc(centerX, centerY, maxCircleWidth / 2, 0, 2 * Math.PI);
ctx.strokeStyle = "white";
ctx.lineWidth = (width / 50 < 3) ? 3 : width / 50;
ctx.stroke();
ctx.fillStyle = 'rgba(0, 0, 0, 0.6)';
ctx.fill()
ctx.closePath();
}, [isConnected]);
const visualizeData = useCallback(() => {
const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0;
if (width !== canvasWidth) {
console.log("Setting canvas width");
setCanvasWidth(width);
}
requestRef.current = window.requestAnimationFrame(() => visualizeData());
if (!canvasRef.current) {
console.log("Canvas not found");
return;
}
const ctx = canvasRef.current.getContext("2d");
const audioData = new Uint8Array(140);
analyser?.getByteFrequencyData(audioData);
if (!ctx) {
console.log("Canvas context not found");
return;
}
const centerX = width / 2;
const centerY = width / 2;
draw(width, centerX, centerY, audioData, ctx);
if (copyCanvasRef?.current) {
const copyCtx = copyCanvasRef.current.getContext("2d");
if (copyCtx) {
draw(100, 295, 70, audioData, copyCtx);
if (imageUrl) {
const img = new Image()
img.src = imageUrl;
img.onload = function () {
copyCtx.drawImage(img, 25, 25, 200, 200);
copyCtx.strokeStyle = 'white';
copyCtx.rect(25, 25, 200, 200);
copyCtx.stroke();
};
}
}
}
}, [analyser, isConnected, canvasWidth, parent, copyCanvasRef]);
useEffect(() => {
if (!analyser) {
return;
}
analyser.smoothingTimeConstant = 0.95;
visualizeData();
return () => {
if (requestRef.current) {
console.log("Canceling animation frame");
cancelAnimationFrame(requestRef.current);
}
};
}, [visualizeData, analyser]);
return (
);
};
================================================
FILE: client/src/pages/Conversation/components/Controls/Controls.tsx
================================================
import {
controlBOSMessage,
controlEOSMessage,
} from "../../../../protocol/testMessages";
import { useSocketContext } from "../../SocketContext";
import { Button } from "../../../../components/Button/Button";
export const Controls = () => {
const { sendMessage } = useSocketContext();
const sendControlBOS = () => {
sendMessage(controlBOSMessage);
};
const sendControlEOS = () => {
sendMessage(controlEOSMessage);
};
return (
eos
bos
);
};
================================================
FILE: client/src/pages/Conversation/components/ModelParams/ModelParams.tsx
================================================
import { FC, RefObject } from "react";
import { useModelParams } from "../../hooks/useModelParams";
import { Button } from "../../../../components/Button/Button";
type ModelParamsProps = {
isConnected: boolean;
modal?: RefObject,
} & ReturnType;
export const ModelParams: FC = ({
textTemperature,
textTopk,
audioTemperature,
audioTopk,
padMult,
repetitionPenalty,
repetitionPenaltyContext,
imageResolution,
gateDelay,
gateInfluence,
displayColor,
centerCrop,
setTextTemperature,
setTextTopk,
setAudioTemperature,
setAudioTopk,
setPadMult,
setRepetitionPenalty,
setRepetitionPenaltyContext,
setImageResolution,
setGateDelay,
setGateInfluence,
setDisplayColor,
setCenterCrop,
resetParams,
isConnected,
modal,
}) => {
return (
{!isConnected &&
Hover on each element to display a helpful tooltip }
{!isConnected && Reset }
{!isConnected && modal?.current?.close()} className="mt-6 ml-4">Ok }
)
};
================================================
FILE: client/src/pages/Conversation/components/ServerAudio/ServerAudio.tsx
================================================
import { FC, useRef } from "react";
import { AudioStats, useServerAudio } from "../../hooks/useServerAudio";
import { ServerVisualizer } from "../AudioVisualizer/ServerVisualizer";
type ServerAudioProps = {
setGetAudioStats: (getAudioStats: () => AudioStats) => void;
imageUrl: string | undefined;
copyCanvasRef?: React.RefObject;
};
export const ServerAudio: FC = ({ setGetAudioStats, imageUrl, copyCanvasRef }) => {
const { analyser, hasCriticalDelay, setHasCriticalDelay } = useServerAudio({
setGetAudioStats,
});
const containerRef = useRef(null);
return (
<>
{hasCriticalDelay && (
A connection issue has been detected, you've been reconnected
{
setHasCriticalDelay(false);
}}
className="bg-white p-1 text-black"
>
Dismiss
)}
>
);
};
================================================
FILE: client/src/pages/Conversation/components/ServerAudio/ServerAudioStats.tsx
================================================
import { useState, useEffect, useRef } from "react";
type ServerAudioStatsProps = {
getAudioStats: React.MutableRefObject<
() => {
playedAudioDuration: number;
missedAudioDuration: number;
totalAudioMessages: number;
delay: number;
minPlaybackDelay: number;
maxPlaybackDelay: number;
}
>;
};
export const ServerAudioStats = ({ getAudioStats }: ServerAudioStatsProps) => {
const [audioStats, setAudioStats] = useState(getAudioStats.current());
const movingAverageSum = useRef(0.);
const movingAverageCount = useRef(0.);
const movingBeta = 0.85;
let convertMinSecs = (total_secs: number) => {
// convert secs to the format mm:ss.cc
let mins = (Math.floor(total_secs / 60)).toString();
let secs = (Math.floor(total_secs) % 60).toString();
let cents = (Math.floor(100 * (total_secs - Math.floor(total_secs)))).toString();
if (secs.length < 2) {
secs = "0" + secs;
}
if (cents.length < 2) {
cents = "0" + cents;
}
return mins + ":" + secs + "." + cents;
};
useEffect(() => {
const interval = setInterval(() => {
const newAudioStats = getAudioStats.current();
setAudioStats(newAudioStats);
movingAverageCount.current *= movingBeta;
movingAverageCount.current += (1 - movingBeta) * 1;
movingAverageSum.current *= movingBeta;
movingAverageSum.current += (1 - movingBeta) * newAudioStats.delay;
}, 141);
return () => {
clearInterval(interval);
};
}, []);
return (
Server Audio Stats
Audio played:
{convertMinSecs(audioStats.playedAudioDuration)}
Missed audio:
{convertMinSecs(audioStats.missedAudioDuration)}
Latency:
{(movingAverageSum.current / movingAverageCount.current).toFixed(3)}
Min/Max buffer:
{audioStats.minPlaybackDelay.toFixed(3)} / {audioStats.maxPlaybackDelay.toFixed(3)}
);
};
================================================
FILE: client/src/pages/Conversation/components/ServerInfo/ServerInfo.tsx
================================================
import { useServerInfo } from "../../hooks/useServerInfo";
function pretty_format(num: number): number {
return Math.round((num + Number.EPSILON) * 100) / 100
}
export const ServerInfo = (props: { setFileName: Function }) => {
const { serverInfo } = useServerInfo();
if (!serverInfo) {
return null;
}
props.setFileName(serverInfo.base_filename);
return (
Our server is running on the following configuration:
Image resolution: {serverInfo.image_resolution} px
Text temperature: {pretty_format(serverInfo.text_temperature)}
Text topk: {serverInfo.text_topk}
Temperature gating: {pretty_format(serverInfo.text_temperature_gating_influence)}
Audio temperature: {pretty_format(serverInfo.audio_temperature)}
Audio topk: {serverInfo.audio_topk}
Pad mult: {serverInfo.pad_mult}
Repeat penalty last N: {serverInfo.repetition_penalty_context}
Repeat penalty: {serverInfo.repetition_penalty}
LM model file: {serverInfo.lm_model_file}
Instance name: {serverInfo.instance_name}
);
};
================================================
FILE: client/src/pages/Conversation/components/TextDisplay/TextDisplay.tsx
================================================
import { FC, useEffect, useRef } from "react";
import { useServerText } from "../../hooks/useServerText";
type TextDisplayProps = {
containerRef: React.RefObject;
displayColor: boolean | undefined;
};
// Palette 2: Purple to Green Moshi
// sns.diverging_palette(288, 145, s=90, l=72, n=11).as_hex()
// Palette 2: Green to orange Moshi
// sns.diverging_palette(145, 40, s=90, l=72, n=11).as_hex()
const textDisplayColors = [
'#38c886', '#5bd09a', '#80d9af',
'#a4e2c4', '#c8ead9', '#f2f1f1',
'#f4e0cb', '#f5cea6', '#f5bd81',
'#f6ac5b', '#f79b37']
function clamp_color(v: number) {
return v <= 0
? 0
: v >= textDisplayColors.length
? textDisplayColors.length
: v
}
export const TextDisplay: FC = ({
containerRef, displayColor
}) => {
const { text, textColor } = useServerText();
const currentIndex = text.length - 1;
const prevScrollTop = useRef(0);
useEffect(() => {
if (containerRef.current) {
prevScrollTop.current = containerRef.current.scrollTop;
containerRef.current.scroll({
top: containerRef.current.scrollHeight,
behavior: "smooth",
});
}
}, [text]);
if (displayColor && (textColor.length == text.length)) {
return (
{text.map((t, i) => (
{t}
))
}
);
}
else {
return (
{text.map((t, i) => (
{t}
))}
);
};
};
================================================
FILE: client/src/pages/Conversation/components/TextDisplay/TextDisplayStats.tsx
================================================
import { FC } from "react";
type TextDisplayStatsProps = {
totalTextMessages: number;
};
export const TextDisplayStats: FC = ({
totalTextMessages,
}) => {
return (
Text Display Stats
Total messages:
{totalTextMessages}
);
};
================================================
FILE: client/src/pages/Conversation/components/UserAudio/UserAudio.tsx
================================================
import { FC, useCallback, useEffect, useRef, useState } from "react";
import { useSocketContext } from "../../SocketContext";
import { useUserAudio } from "../../hooks/useUserAudio";
import { ClientVisualizer } from "../AudioVisualizer/ClientVisualizer";
type UserAudioProps = {
copyCanvasRef: React.RefObject;
};
export const UserAudio: FC = ({ copyCanvasRef }) => {
const [analyser, setAnalyser] = useState(null);
const { sendMessage, isConnected } = useSocketContext();
const containerRef = useRef(null);
const onRecordingStart = useCallback(() => {
console.log("Recording started");
}, []);
const onRecordingStop = useCallback(() => {
console.log("Recording stopped");
}, []);
const onRecordingChunk = useCallback(
(chunk: Uint8Array) => {
if (!isConnected) {
return;
}
sendMessage({
type: "audio",
data: chunk,
});
},
[sendMessage, isConnected],
);
const { startRecordingUser, stopRecording } = useUserAudio({
constraints: {
audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true,
channelCount: 1,
},
video: false,
},
onDataChunk: onRecordingChunk,
onRecordingStart,
onRecordingStop,
});
useEffect(() => {
let res: Awaited>;
if (isConnected) {
startRecordingUser().then(result => {
if (result) {
res = result;
setAnalyser(result.analyser);
}
});
}
return () => {
console.log("Stop recording called from somewhere else.");
stopRecording();
res?.source?.disconnect();
};
}, [startRecordingUser, stopRecording, isConnected]);
return (
);
};
================================================
FILE: client/src/pages/Conversation/components/UserAudio/UserAudioStats.tsx
================================================
import { FC } from "react";
type UserAudioStatsProps = {
sentMessagesCount: number;
};
export const UserAudioStats: FC = ({
sentMessagesCount,
}) => {
return (
User Audio Stats
Total messages:
{sentMessagesCount}
);
};
================================================
FILE: client/src/pages/Conversation/getMimeType.ts
================================================
export const mimeTypeCheck = () => {
const types = [
"audio/ogg",
"audio/wav",
"audio/webm;codecs=opus",
"audio/webm;codecs=pcm",
"audio/webm;codecs=pcm_s16le",
"audio/webm;codecs=pcm_f32le",
"audio/mp3",
"audio/aac",
"audio/mp4",
"audio/webm",
"audio/mpeg",
"video/mp4",
"video/webm;codecs=vp9",
"video/webm;codecs=vp8",
"video/webm",
];
for (const mime of types) {
console.log(mime, MediaRecorder.isTypeSupported(mime));
}
}
const getVideoMimeType = () => {
if (!MediaRecorder.isTypeSupported){
return "video/mp4";
}
if (MediaRecorder.isTypeSupported("video/webm")) {
return "video/webm";
}
if (MediaRecorder.isTypeSupported("video/mp4")) {
return "video/mp4";
}
console.log("No supported video mime type found")
return "";
};
const getAudioMimeType = () => {
if (!MediaRecorder.isTypeSupported){
return "audio/mp4";
}
if (MediaRecorder.isTypeSupported("audio/webm")) {
return "audio/webm";
}
if (MediaRecorder.isTypeSupported("audio/mpeg")) {
return "audio/mpeg";
}``
if (MediaRecorder.isTypeSupported("audio/mp4")) {
return "audio/mp4";
}
console.log("No supported audio mime type found")
return "";
}
export const getMimeType = (type: "audio" | "video") => {
if(type === "audio") {
return getAudioMimeType();
}
return getVideoMimeType();
}
export const getExtension = (type: "audio" | "video") => {
if(getMimeType(type).includes("mp4")) {
return "mp4";
}
if(getMimeType(type).includes("mpeg")) {
return "mp3";
}
return "webm";
}
================================================
FILE: client/src/pages/Conversation/hooks/audioUtils.ts
================================================
export const clamp = (value: number, min: number, max: number) => {
return Math.min(Math.max(value, min), max);
};
================================================
FILE: client/src/pages/Conversation/hooks/useModelParams.ts
================================================
import { useCallback, useState } from "react";
export const DEFAULT_TEXT_TEMPERATURE = 0.45;
export const DEFAULT_TEXT_TOPK = 25;
export const DEFAULT_AUDIO_TEMPERATURE = 0.7;
export const DEFAULT_AUDIO_TOPK = 250;
export const DEFAULT_PAD_MULT = 0;
export const DEFAULT_REPETITION_PENALTY_CONTEXT = 64;
export const DEFAULT_REPETITION_PENALTY = 1.15;
export const DEFAULT_IMAGE_RESOLUTION = 448;
export const DEFAULT_IMAGE_URL = undefined;
export const DEFAULT_GATE_DELAY = 16;
export const DEFAULT_GATE_INFLUENCE = 0.0;
export const DEFAULT_DISPLAY_COLOR = true;
export const DEFAULT_CENTER_CROP = false;
export type ModelParamsValues = {
textTemperature: number;
textTopk: number;
audioTemperature: number;
audioTopk: number;
padMult: number;
repetitionPenaltyContext: number,
repetitionPenalty: number,
imageResolution: number,
imageUrl: string | undefined,
gateDelay: number,
gateInfluence: number,
displayColor: boolean,
centerCrop: boolean,
};
export function importantSettingsHaveChanged(params: ModelParamsValues): boolean {
return (params.textTemperature != DEFAULT_TEXT_TEMPERATURE) ||
(params.textTopk != DEFAULT_TEXT_TOPK) ||
(params.audioTemperature != DEFAULT_AUDIO_TEMPERATURE) ||
(params.audioTopk != DEFAULT_AUDIO_TOPK) ||
(params.padMult != DEFAULT_PAD_MULT) ||
(params.repetitionPenalty != DEFAULT_REPETITION_PENALTY) ||
(params.repetitionPenaltyContext != DEFAULT_REPETITION_PENALTY_CONTEXT) ||
(params.imageResolution != DEFAULT_IMAGE_RESOLUTION) ||
(params.gateDelay != DEFAULT_GATE_DELAY) ||
(params.gateInfluence != DEFAULT_GATE_INFLUENCE) ||
(params.centerCrop != DEFAULT_CENTER_CROP)
}
type useModelParamsArgs = Partial;
export const useModelParams = (params?: useModelParamsArgs) => {
const [textTemperature, setTextTemperatureBase] = useState(params?.textTemperature || DEFAULT_TEXT_TEMPERATURE);
const [textTopk, setTextTopkBase] = useState(params?.textTopk || DEFAULT_TEXT_TOPK);
const [audioTemperature, setAudioTemperatureBase] = useState(params?.audioTemperature || DEFAULT_AUDIO_TEMPERATURE);
const [audioTopk, setAudioTopkBase] = useState(params?.audioTopk || DEFAULT_AUDIO_TOPK);
const [padMult, setPadMultBase] = useState(params?.padMult || DEFAULT_PAD_MULT);
const [repetitionPenalty, setRepetitionPenaltyBase] = useState(params?.repetitionPenalty || DEFAULT_REPETITION_PENALTY);
const [repetitionPenaltyContext, setRepetitionPenaltyContextBase] = useState(params?.repetitionPenaltyContext || DEFAULT_REPETITION_PENALTY_CONTEXT);
const [imageResolution, setImageResolutionBase] = useState(params?.imageResolution || DEFAULT_IMAGE_RESOLUTION);
const [imageUrl, setImageUrlBase] = useState(params?.imageUrl || DEFAULT_IMAGE_URL);
const [gateDelay, setGateDelayBase] = useState(params?.gateDelay || DEFAULT_GATE_DELAY);
const [gateInfluence, setGateInfluenceBase] = useState(params?.gateInfluence || DEFAULT_GATE_INFLUENCE);
const [displayColor, setDisplayColorBase] = useState(params?.displayColor == undefined ? DEFAULT_DISPLAY_COLOR : params?.displayColor);
const [centerCrop, setCenterCropBase] = useState(params?.centerCrop == undefined ? DEFAULT_CENTER_CROP : params?.centerCrop);
const resetParams = useCallback(() => {
setTextTemperatureBase(DEFAULT_TEXT_TEMPERATURE);
setTextTopkBase(DEFAULT_TEXT_TOPK);
setAudioTemperatureBase(DEFAULT_AUDIO_TEMPERATURE);
setAudioTopkBase(DEFAULT_AUDIO_TOPK);
setPadMultBase(DEFAULT_PAD_MULT);
setRepetitionPenaltyBase(DEFAULT_REPETITION_PENALTY);
setRepetitionPenaltyContextBase(DEFAULT_REPETITION_PENALTY_CONTEXT);
setImageResolutionBase(DEFAULT_IMAGE_RESOLUTION);
setImageUrlBase(DEFAULT_IMAGE_URL);
setGateDelayBase(DEFAULT_GATE_DELAY);
setGateInfluenceBase(DEFAULT_GATE_INFLUENCE);
setDisplayColorBase(DEFAULT_DISPLAY_COLOR);
setCenterCropBase(DEFAULT_CENTER_CROP);
}, [
setTextTemperatureBase,
setTextTopkBase,
setAudioTemperatureBase,
setAudioTopkBase,
setPadMultBase,
setRepetitionPenaltyBase,
setRepetitionPenaltyContextBase,
setImageResolutionBase,
setImageUrlBase,
setDisplayColorBase,
setCenterCropBase,
]);
const setTextTemperature = useCallback((value: number) => {
if (value <= 1.2 && value >= 0.2) {
setTextTemperatureBase(value);
}
}, []);
const setTextTopk = useCallback((value: number) => {
if (value <= 500 && value >= 10) {
setTextTopkBase(value);
}
}, []);
const setAudioTemperature = useCallback((value: number) => {
if (value <= 1.2 && value >= 0.2) {
setAudioTemperatureBase(value);
}
}, []);
const setAudioTopk = useCallback((value: number) => {
if (value <= 500 && value >= 10) {
setAudioTopkBase(value);
}
}, []);
const setPadMult = useCallback((value: number) => {
if (value <= 4 && value >= -4) {
setPadMultBase(value);
}
}, []);
const setRepetitionPenalty = useCallback((value: number) => {
if (value <= 2.0 && value >= 1.0) {
setRepetitionPenaltyBase(value);
}
}, []);
const setRepetitionPenaltyContext = useCallback((value: number) => {
if (value <= 200 && value >= 0) {
setRepetitionPenaltyContextBase(value);
}
}, []);
const setImageResolution = useCallback((value: number) => {
if (value <= 512 && value >= 160) {
setImageResolutionBase(value);
}
}, []);
const setImageUrl = useCallback((value: string | undefined) => {
setImageUrlBase(value);
}, []);
const setGateDelay = useCallback((value: number) => {
if (value <= 32 && value >= 0) {
setGateDelayBase(value);
}
}, []);
const setGateInfluence = useCallback((value: number) => {
if (value <= 1.0 && value >= 0.0) {
setGateInfluenceBase(value);
}
}, []);
const setDisplayColor = useCallback((value: boolean) => {
setDisplayColorBase(value);
}, []);
const setCenterCrop = useCallback((value: boolean) => {
setCenterCropBase(value);
}, []);
return {
textTemperature,
textTopk,
audioTemperature,
audioTopk,
padMult,
repetitionPenalty,
repetitionPenaltyContext,
imageResolution,
imageUrl,
gateDelay,
gateInfluence,
displayColor,
centerCrop,
setTextTemperature,
setTextTopk,
setAudioTemperature,
setAudioTopk,
setPadMult,
setRepetitionPenalty,
setRepetitionPenaltyContext,
setImageUrl,
setImageResolution,
setGateDelay,
setGateInfluence,
setDisplayColor,
setCenterCrop,
resetParams,
}
}
================================================
FILE: client/src/pages/Conversation/hooks/useServerAudio.ts
================================================
import { useCallback, useEffect, useRef, useState } from "react";
import { useSocketContext } from "../SocketContext";
import { decodeMessage } from "../../../protocol/encoder";
import { useMediaContext } from "../MediaContext";
import { DecoderWorker } from "../../../decoder/decoderWorker";
export type AudioStats = {
playedAudioDuration: number;
missedAudioDuration: number;
totalAudioMessages: number;
delay: number;
minPlaybackDelay: number;
maxPlaybackDelay: number;
};
type useServerAudioArgs = {
setGetAudioStats?: (getAudioStats: () => AudioStats) => void;
};
type WorkletStats = {
totalAudioPlayed: number;
actualAudioPlayed: number;
delay: number;
minDelay: number;
maxDelay: number;
};
export const useServerAudio = ({setGetAudioStats}: useServerAudioArgs) => {
const { socket } = useSocketContext();
const {startRecording, stopRecording, audioContext, worklet, micDuration, actualAudioPlayed } =
useMediaContext();
const analyser = useRef(audioContext.current.createAnalyser());
worklet.current.connect(analyser.current);
const startTime = useRef(null);
const decoderWorker = useRef(DecoderWorker);
const [hasCriticalDelay, setHasCriticalDelay] = useState(false);
const totalAudioMessages = useRef(0);
const receivedDuration = useRef(0);
const workletStats = useRef({
totalAudioPlayed: 0,
actualAudioPlayed: 0,
delay: 0,
minDelay: 0,
maxDelay: 0,});
const onDecode = useCallback(
async (data: Float32Array) => {
receivedDuration.current += data.length / audioContext.current.sampleRate;
worklet.current.port.postMessage({frame: data, type: "audio", micDuration: micDuration.current});
},
[],
);
const onWorkletMessage = useCallback(
(event: MessageEvent) => {
workletStats.current = event.data;
actualAudioPlayed.current = workletStats.current.actualAudioPlayed;
},
[],
);
worklet.current.port.onmessage = onWorkletMessage;
const getAudioStats = useCallback(() => {
return {
playedAudioDuration: workletStats.current.actualAudioPlayed,
delay: workletStats.current.delay,
minPlaybackDelay: workletStats.current.minDelay,
maxPlaybackDelay: workletStats.current.maxDelay,
missedAudioDuration: workletStats.current.totalAudioPlayed - workletStats.current.actualAudioPlayed,
totalAudioMessages: totalAudioMessages.current,
};
}, []);
const onWorkerMessage = useCallback(
(e: MessageEvent) => {
if (!e.data) {
return;
}
onDecode(e.data[0]);
},
[onDecode],
);
let midx = 0;
const decodeAudio = useCallback((data: Uint8Array) => {
if (midx < 5) {
console.log(Date.now() % 1000, "Got NETWORK message", micDuration.current - workletStats.current.actualAudioPlayed, midx++);
}
decoderWorker.current.postMessage(
{
command: "decode",
pages: data,
},
[data.buffer],
);
}, []);
const onSocketMessage = useCallback(
(e: MessageEvent) => {
const dataArray = new Uint8Array(e.data);
const message = decodeMessage(dataArray);
if (message.type === "audio") {
decodeAudio(message.data);
//For stats purposes for now
totalAudioMessages.current++;
}
},
[decodeAudio],
);
useEffect(() => {
const currentSocket = socket;
if (!currentSocket) {
return;
}
worklet.current.port.postMessage({type: "reset"});
console.log(Date.now() % 1000, "Should start in a bit");
startRecording();
currentSocket.addEventListener("message", onSocketMessage);
totalAudioMessages.current = 0;
return () => {
console.log("Stop recording called in unknown function.")
stopRecording();
startTime.current = null;
currentSocket.removeEventListener("message", onSocketMessage);
};
}, [socket]);
useEffect(() => {
if (setGetAudioStats) {
console.log("Setting getAudioStats");
setGetAudioStats(getAudioStats);
}
}, [setGetAudioStats, getAudioStats]);
useEffect(() => {
decoderWorker.current.onmessage = onWorkerMessage;
// 960 = 24000 / 12.5 / 2
// The /2 is a bit optional, but won't hurt for recording the mic, and for the
// the decoding it might help getting some decoded audio out asap.
decoderWorker.current.postMessage({
command: "init",
bufferLength: 960 * audioContext.current.sampleRate / 24000,
decoderSampleRate: 24000,
outputBufferSampleRate: audioContext.current.sampleRate,
resampleQuality: 0,
});
return () => {
console.log("Terminating worker");
};
}, [onWorkerMessage]);
return {
decodeAudio,
analyser,
getAudioStats,
hasCriticalDelay,
setHasCriticalDelay,
};
};
================================================
FILE: client/src/pages/Conversation/hooks/useServerInfo.ts
================================================
import { useCallback, useEffect, useState } from "react";
import { useSocketContext } from "../SocketContext";
import { decodeMessage } from "../../../protocol/encoder";
import { z } from "zod";
const ServersInfoSchema = z.object({
text_temperature: z.number(),
text_topk: z.number(),
text_temperature_gating_influence: z.number(),
audio_temperature: z.number(),
audio_topk: z.number(),
pad_mult: z.number(),
repetition_penalty_context: z.number(),
repetition_penalty: z.number(),
image_resolution: z.number(),
lm_model_file: z.string(),
instance_name: z.string(),
base_filename: z.string(),
build_info: z.object({
build_timestamp: z.string(),
build_date: z.string(),
git_branch: z.string(),
git_timestamp: z.string(),
git_date: z.string(),
git_hash: z.string(),
git_describe: z.string(),
rustc_host_triple: z.string(),
rustc_version: z.string(),
cargo_target_triple: z.string(),
}),
});
const parseInfo = (infos: any) => {
const serverInfo = ServersInfoSchema.safeParse(infos);
if (!serverInfo.success) {
console.error(serverInfo.error);
return null;
}
return serverInfo.data;
};
type ServerInfo = {
text_temperature: number;
text_topk: number;
text_temperature_gating_influence: number;
audio_temperature: number;
audio_topk: number;
pad_mult: number;
repetition_penalty_context: number;
repetition_penalty: number;
image_resolution: number;
lm_model_file: string;
instance_name: string;
base_filename: string;
build_info: {
build_timestamp: string;
build_date: string;
git_branch: string;
git_timestamp: string;
git_date: string;
git_hash: string;
git_describe: string;
rustc_host_triple: string;
rustc_version: string;
cargo_target_triple: string;
};
}
export const useServerInfo = () => {
const [serverInfo, setServerInfo] = useState(null);
const { socket } = useSocketContext();
const onSocketMessage = useCallback((e: MessageEvent) => {
const dataArray = new Uint8Array(e.data);
const message = decodeMessage(dataArray);
if (message.type === "metadata") {
const infos = parseInfo(message.data);
if (infos) {
setServerInfo(infos);
console.log("received metadata", infos);
}
}
}, [setServerInfo]);
useEffect(() => {
const currentSocket = socket;
if (!currentSocket) {
return;
}
setServerInfo(null);
currentSocket.addEventListener("message", onSocketMessage);
return () => {
currentSocket.removeEventListener("message", onSocketMessage);
};
}, [socket]);
return { serverInfo };
};
================================================
FILE: client/src/pages/Conversation/hooks/useServerText.ts
================================================
import { useCallback, useEffect, useState } from "react";
import { useSocketContext } from "../SocketContext";
import { decodeMessage } from "../../../protocol/encoder";
export const useServerText = () => {
const [text, setText] = useState([]);
const [textColor, setTextColor] = useState([]);
const [totalTextMessages, setTotalTextMessages] = useState(0);
const { socket } = useSocketContext();
const onSocketMessage = useCallback((e: MessageEvent) => {
const dataArray = new Uint8Array(e.data);
const message = decodeMessage(dataArray);
if (message.type === "text") {
setText(text => [...text, message.data]);
setTotalTextMessages(count => count + 1);
} else if (message.type === "coloredtext") {
setText(text => [...text, message.data]);
setTextColor(textColor => [...textColor, message.color]);
setTotalTextMessages(count => count + 1);
}
}, []);
useEffect(() => {
const currentSocket = socket;
if (!currentSocket) {
return;
}
setText([]);
currentSocket.addEventListener("message", onSocketMessage);
return () => {
currentSocket.removeEventListener("message", onSocketMessage);
};
}, [socket]);
return { text, textColor, totalTextMessages };
};
================================================
FILE: client/src/pages/Conversation/hooks/useSocket.ts
================================================
import { useState, useEffect, useCallback, useRef } from "react";
import { WSMessage } from "../../../protocol/types";
import { decodeMessage, encodeMessage } from "../../../protocol/encoder";
export const useSocket = ({
onMessage,
uri,
onDisconnect: onDisconnectProp,
imageUrl,
}: {
onMessage?: (message: WSMessage) => void;
uri: string;
onDisconnect?: () => void;
imageUrl?: string;
}) => {
const lastMessageTime = useRef(null);
const [isConnected, setIsConnected] = useState(false);
const [imageSent, setImageSent] = useState(false);
const [onConnectDone, setOnConnectDone] = useState(false);
const [socket, setSocket] = useState(null);
const sendMessage = useCallback(
(message: WSMessage) => {
if (!socket) {
console.log("socket not present");
return false;
}
// audio message with no connection
if (message.type == "audio" && !isConnected) {
console.log("isConnected false on audio message, please wait for handshake");
return false;
}
// otherwise send message
socket.send(encodeMessage(message));
return true;
},
[isConnected, socket],
);
useEffect(() => {
async function sendImage() {
console.log("image send", imageSent);
console.log("image url", imageUrl);
if (imageUrl && !imageSent) {
const imageBytes = await fetchImageBytes(imageUrl);
const sent = sendMessage({
type: "image",
data: imageBytes,
});
if (sent) {
console.log("Image sent");
setImageSent(true);
}
}
}
sendImage();
}, [socket, onConnectDone, imageUrl, imageSent]);
const onConnect = useCallback(() => {
console.log("connected, now waiting for handshake.");
setOnConnectDone(true);
}, [setIsConnected, socket]);
const onDisconnect = useCallback(() => {
console.log("disconnected");
if (onDisconnectProp) {
onDisconnectProp();
}
setIsConnected(false);
}, [onDisconnectProp]);
const onMessageEvent = useCallback(
(eventData: MessageEvent) => {
lastMessageTime.current = Date.now();
const dataArray = new Uint8Array(eventData.data);
const message = decodeMessage(dataArray);
if (message.type == "handshake") {
console.log("Handshake received, let's rocknroll.");
setIsConnected(true);
}
if (!onMessage) {
return;
}
onMessage(message);
},
[onMessage, setIsConnected],
);
const start = useCallback(() => {
const ws = new WebSocket(uri);
ws.binaryType = "arraybuffer";
ws.addEventListener("open", onConnect);
ws.addEventListener("close", onDisconnect);
ws.addEventListener("message", onMessageEvent);
setSocket(ws);
console.log("Socket created", ws);
lastMessageTime.current = Date.now();
}, [uri, onMessage, onDisconnectProp]);
const stop = useCallback(() => {
setIsConnected(false);
if (onDisconnectProp) {
onDisconnectProp();
}
socket?.close();
setSocket(null);
}, [socket]);
useEffect(() => {
if (!isConnected) {
return;
}
let intervalId = setInterval(() => {
if (lastMessageTime.current && Date.now() - lastMessageTime.current > 10000) {
console.log("closing socket due to inactivity", socket);
socket?.close();
onDisconnect();
clearInterval(intervalId);
}
}, 500);
return () => {
lastMessageTime.current = null;
clearInterval(intervalId);
};
}, [isConnected, socket]);
return {
isConnected,
socket,
sendMessage,
start,
stop,
};
};
async function fetchImageBytes(imageUrl: string) {
const response = await fetch(imageUrl);
if (!response.ok) {
throw new Error(`Failed to fetch image: ${response.statusText}`);
}
const arrayBuffer = await response.arrayBuffer();
return new Uint8Array(arrayBuffer);
}
================================================
FILE: client/src/pages/Conversation/hooks/useUserAudio.ts
================================================
import { useCallback, useRef, useState } from "react";
import Recorder from "opus-recorder";
import encoderPath from "opus-recorder/dist/encoderWorker.min.js?url";
import { useMediaContext } from "../MediaContext";
export enum UserMediaStatuses {
IDLE = "IDLE",
READY = "READY",
WAITING_FOR_PERMISSION = "WAITING_FOR_PERMISSION",
ERROR = "ERROR",
RECORDING = "RECORDING",
STOPPED = "STOPPED",
STOPPING = "STOPPING",
}
type useUserAudioArgs = {
constraints: MediaStreamConstraints;
onDataChunk?: (chunk: Uint8Array) => void;
onRecordingStart?: () => void;
onRecordingStop?: () => void;
};
export const useUserAudio = ({
constraints,
onDataChunk,
onRecordingStart = () => {},
onRecordingStop = () => {},
}: useUserAudioArgs) => {
const { audioStreamDestination, audioContext, micDuration } = useMediaContext();
const [error, setError] = useState(null);
const [status, setStatus] = useState(
UserMediaStatuses.IDLE,
);
//TODO: Fix any type for recorder
const recorder = useRef(null);
const getMediaStream = useCallback(async () => {
setStatus(UserMediaStatuses.WAITING_FOR_PERMISSION);
try {
const stream =
await window.navigator.mediaDevices.getUserMedia(constraints);
setStatus(UserMediaStatuses.IDLE);
return stream;
} catch (error: any) {
console.error(error);
setError(error.name);
setStatus(UserMediaStatuses.ERROR);
return null;
}
}, [constraints, setStatus]);
const startRecordingUser = useCallback(async () => {
console.log(Date.now() % 1000, "Starting recording in user audio");
const mediaStream = await getMediaStream();
if (mediaStream) {
const analyser = audioContext.current.createAnalyser();
const source = audioContext.current.createMediaStreamSource(mediaStream);
source.connect(analyser);
source.connect(audioStreamDestination.current);
// For buffer length: 960 = 24000 / 12.5 / 2
// The /2 is a bit optional, but won't hurt for recording the mic.
// Note that bufferLength actually has 0 impact for mono audio, only
// the frameSize and maxFramesPerPage seems to have any.
const recorderOptions = {
mediaTrackConstraints: constraints,
encoderPath,
bufferLength: Math.round(960 * audioContext.current.sampleRate / 24000),
encoderFrameSize: 20,
encoderSampleRate: 24000,
maxFramesPerPage: 2,
numberOfChannels: 1,
recordingGain: 1,
resampleQuality: 3,
encoderComplexity: 0,
encoderApplication: 2049,
streamPages: true,
};
let chunk_idx = 0;
let lastpos = 0;
recorder.current = new Recorder(recorderOptions);
recorder.current.ondataavailable = (data: Uint8Array) => {
// opus actually always works at 48khz, so it seems this is the proper value to use here.
micDuration.current = recorder.current.encodedSamplePosition / 48000;
if (chunk_idx < 5) {
console.log(Date.now() % 1000, "Mic Data chunk", chunk_idx++, (recorder.current.encodedSamplePosition - lastpos) / 48000, micDuration.current);
lastpos = recorder.current.encodedSamplePosition;
}
if (onDataChunk) {
onDataChunk(data);
}
};
recorder.current.onstart = () => {
setStatus(UserMediaStatuses.RECORDING);
onRecordingStart();
};
recorder.current.onstop = () => {
setStatus(UserMediaStatuses.STOPPED);
source.disconnect();
onRecordingStop();
recorder.current = null;
};
if (recorder.current) {
// setTimeout(() => {recorder.current.start(); setStatus(UserMediaStatuses.RECORDING);}, 1500);
recorder.current.start();
}
return {
analyser,
mediaStream,
source,
};
}
return {
analyser: null,
mediaStream: null,
source: null,
};
}, [setStatus, onDataChunk, onRecordingStart, onRecordingStop]);
const stopRecording = useCallback(() => {
setStatus(UserMediaStatuses.STOPPING);
if (recorder.current) {
recorder.current.stop();
}
}, [setStatus]);
return {
status,
error,
startRecordingUser,
stopRecording,
};
};
================================================
FILE: client/src/pages/Queue/Queue.tsx
================================================
import moshiProcessorUrl from "../../audio-processor.ts?worker&url";
import { FC, useEffect, useMemo, useState, useCallback, useRef, MutableRefObject } from "react";
import eruda from "eruda";
import { useSearchParams } from "react-router-dom";
import { Conversation } from "../Conversation/Conversation";
import { Button } from "../../components/Button/Button";
import { ImageGallery } from "../../components/ImageGallery/ImageGallery";
import { useModelParams, importantSettingsHaveChanged } from "../Conversation/hooks/useModelParams";
import { ModelParams } from "../Conversation/components/ModelParams/ModelParams";
import { env } from "../../env";
import { useUserEmail } from "./hooks/useUserEmail";
import { Input } from "../../components/Input/Input";
import { getAPIClient } from "./api/client";
type Status = "connecting" | "in_queue" | "has_credentials" | "error" | "no_queue" | "idle" | "bypass";
function getFloatFromStorage(val: string | null) {
return (val == null) ? undefined : parseFloat(val)
}
function getIntFromStorage(val: string | null) {
return (val == null) ? undefined : parseInt(val)
}
function getBoolFromStage(val: string | null) {
return (val == 'true') ? true : ((val == 'false') ? false : undefined)
}
export const Queue: FC = () => {
const [searchParams] = useSearchParams();
let queueId = searchParams.get("queue_id");
if (!queueId) {
queueId = 'talktomoshi';
}
const [sessionId, setSessionId] = useState(null);
const [sessionAuthId, setSessionAuthId] = useState(null);
const [workerAddr, setWorkerAddr] = useState(null);
const [workerAuthId, setWorkerAuthId] = useState(null);
const [currentPosition, setCurrentPosition] = useState(null);
const [error, setError] = useState(null);
const overrideWorkerAddr = searchParams.get("worker_addr");
const [hasMicrophoneAccess, setHasMicrophoneAccess] = useState(false);
const [showMicrophoneAccessMessage, setShowMicrophoneAccessMessage] = useState(false);
const [shouldConnect, setShouldConnect] = useState(false);
let default_image_url = sessionStorage.getItem("imageUrl");
const modelParams = useModelParams({
textTemperature: getFloatFromStorage(sessionStorage.getItem("textTemperature")),
textTopk: getIntFromStorage(sessionStorage.getItem("textTopk")),
audioTemperature: getFloatFromStorage(sessionStorage.getItem("audioTemperature")),
audioTopk: getIntFromStorage(sessionStorage.getItem("audioTopk")),
padMult: getFloatFromStorage(sessionStorage.getItem("padMult")),
repetitionPenalty: getFloatFromStorage(sessionStorage.getItem("repetitionPenalty")),
repetitionPenaltyContext: getIntFromStorage(sessionStorage.getItem("repetitionPenaltyContext")),
imageResolution: getIntFromStorage(sessionStorage.getItem("imageResolution")),
gateDelay: getIntFromStorage(sessionStorage.getItem("gateDelay")),
gateInfluence: getFloatFromStorage(sessionStorage.getItem("gateInfluence")),
displayColor: getBoolFromStage(sessionStorage.getItem("displayColor")),
centerCrop: getBoolFromStage(sessionStorage.getItem("centerCrop")),
imageUrl: (default_image_url == null) ? undefined : default_image_url
});
const modalRef = useRef(null);
let def_user_email = sessionStorage.getItem("userEmail");
const { userEmail, setUserEmail, error: emailError, validate } = useUserEmail(!!overrideWorkerAddr, (def_user_email == null) ? '' : def_user_email);
const audioContext = useRef(null);
const worklet = useRef(null);
// enable eruda in development
useEffect(() => {
if (env.VITE_ENV === "development") {
eruda.init();
}
() => {
if (env.VITE_ENV === "development") {
eruda.destroy();
}
};
}, []);
const getMicrophoneAccess = useCallback(async () => {
try {
await window.navigator.mediaDevices.getUserMedia({ audio: true });
setHasMicrophoneAccess(true);
return true;
} catch (e) {
console.error(e);
setShowMicrophoneAccessMessage(true);
setHasMicrophoneAccess(false);
}
return false;
}, [setHasMicrophoneAccess, setShowMicrophoneAccessMessage, setShouldConnect]);
const startProcessor = useCallback(async () => {
if (!audioContext.current) {
audioContext.current = new AudioContext();
}
if (worklet.current) {
return;
}
let ctx = audioContext.current;
ctx.resume();
try {
worklet.current = new AudioWorkletNode(ctx, 'moshi-processor');
} catch (err) {
await ctx.audioWorklet.addModule(moshiProcessorUrl);
worklet.current = new AudioWorkletNode(ctx, 'moshi-processor');
}
worklet.current.connect(ctx.destination);
}, [audioContext, worklet]);
const onConnect = useCallback(async () => {
if (!validate(userEmail)) {
return;
}
await startProcessor();
const hasAccess = await getMicrophoneAccess();
if (hasAccess) {
setShouldConnect(true);
}
}, [setShouldConnect, startProcessor, userEmail, getMicrophoneAccess, validate]);
const status: Status = useMemo(() => {
if (overrideWorkerAddr) {
return "bypass";
}
if (!queueId) {
return "no_queue";
}
if (error) {
return "error";
}
if (!shouldConnect) {
return "idle";
}
if (workerAddr && workerAuthId) {
return "has_credentials";
}
if (!sessionId || !sessionAuthId) {
return "connecting";
}
return "in_queue";
}, [queueId, sessionId, sessionAuthId, workerAddr, workerAuthId, currentPosition, hasMicrophoneAccess, error, shouldConnect]);
const client = useMemo(() => {
return getAPIClient(env.VITE_QUEUE_API_PATH)
}, [env.VITE_QUEUE_API_PATH]);
useEffect(() => {
if (!shouldConnect) {
return;
}
if (status !== "connecting" || !queueId) {
return;
}
client.addUser(queueId)
.then(({ session_id, session_auth_id }) => {
setSessionId(session_id);
setSessionAuthId(session_auth_id);
console.log("Added user to queue", session_id, session_auth_id);
})
.catch((e) => {
setError(e.message);
console.error(e);
});
}, [queueId, client, status, shouldConnect]);
useEffect(() => {
if (!sessionId || !sessionAuthId) {
return;
}
if (status === "has_credentials") {
return;
}
let isQuerying = false;
let intervalId: number | null = null;
const checkUser = () => {
if (isQuerying) {
return;
}
isQuerying = true;
client.checkUser(sessionId, sessionAuthId)
.then(({ worker_addr, worker_auth_id, current_position }) => {
setCurrentPosition(current_position);
if (worker_addr && worker_auth_id) {
setWorkerAddr(worker_addr);
setWorkerAuthId(worker_auth_id);
if (intervalId !== null) {
clearInterval(intervalId);
}
}
})
.catch((e) => {
if (intervalId !== null) {
clearInterval(intervalId);
}
setError(e.message);
console.error(e);
}).finally(() => {
isQuerying = false;
});
}
intervalId = setInterval(checkUser, 400);
return () => {
if (intervalId !== null) {
clearInterval(intervalId);
}
};
}, [sessionId, sessionAuthId, client, setCurrentPosition, setWorkerAddr, setWorkerAuthId, status, setError]);
if (status === "bypass" && hasMicrophoneAccess && audioContext.current && worklet.current) {
return (
}
worklet={worklet as MutableRefObject}
{...modelParams}
/>
);
}
if (status === "has_credentials" && workerAddr && audioContext.current && workerAuthId && sessionId && sessionAuthId && worklet?.current) {
return (
}
worklet={worklet as MutableRefObject}
sessionId={sessionId}
sessionAuthId={sessionAuthId}
onConversationEnd={() => {
setWorkerAddr(null);
setWorkerAuthId(null);
setSessionId(null);
setSessionAuthId(null);
setShouldConnect(false);
}}
{...modelParams}
/>
);
}
return (
M👁️shiVis
{/*
To add more space to the top add padding to the top of the following div
by changing the pt-4 class to pt-8 or pt-12. (see: https://tailwindcss.com/docs/padding)
👁️ If you'd like to move this part to the bottom of the screen, change the class to pb-4 or pb-8 and move the following so it is contained by the last one in the page.
Font size can be changed by changing the text-sm class to text-lg or text-xl. (see : https://tailwindcss.com/docs/font-size)
As for the links you can use the one below as an example and add more by copying it and changing the href and text.
*/}
MoshiVis is an experimental multimodal conversational AI.
Like Moshi , MoshiVis can listen to you and
talk at all time for maximum conversational flow. Now augmented with visual inputs.
For instance, you can now ask Moshi to describe your favorite movie poster ,
grill it on details about the plot , then go back for more
details about the image ask it to do some Pirate role play.
We strive to support all browsers but Chrome works best. Conversations are limited to 5 min .
Head to the Settings to configure the image size and other parameters.
For more information about this project, check out the MoshiVis project page !
Baked with <3 @Kyutai .
Add your email address first, then feel free
to upload your own image or select one below.
Uploaded images should be smaller than 15 MB .
{status == 'error' &&
{error}
}
{status == 'no_queue' &&
No queue id provided
}
{(status === 'idle' || status === 'bypass') && (
<>
{showMicrophoneAccessMessage &&
Please enable your microphone before proceeding
}
setUserEmail(e.target.value)}
error={emailError ?? ""}
onKeyDown={(e) => {
if (e.key === "Enter") {
if (modelParams.imageUrl == undefined) {
modelParams.setImageUrl("/assets/images/demo/image" + Math.floor(1 + Math.random() * 19) + ".jpg")
}
onConnect();
}
}}
/>
modalRef.current?.showModal()}>Settings
>
)}
{status === "connecting" &&
Connecting to queue...
}
{status === "in_queue" && (
You're in the queue !
{currentPosition && Current position: {currentPosition} }
)
}
)
};
================================================
FILE: client/src/pages/Queue/api/client.ts
================================================
import { APIError } from "./errors/api_error";
import { ResponseError } from "./errors/response_error";
import { validateAddUser, validateCheckUser } from "./validators";
export const getAPIClient = (url:string) => ({
addUser: async (queueId:string) => {
const encodedQueueId = encodeURIComponent(queueId);
const response = await fetch(`${url}/add_user?queue_id=${encodedQueueId}`);
if (!response.ok) {
const errorText = await response.text();
throw new APIError(errorText , response.status);
}
const json = await response.json();
const result = validateAddUser(json);
if(result.success) {
return result.data;
}
console.error(result.error.message);
throw new ResponseError("Failed to validate response");
},
checkUser: async (sessionId:number, sessionAuthId:string) => {
const encodedSessionAuthId = encodeURIComponent(sessionAuthId);
const encodedSessionId = encodeURIComponent(sessionId);
const response = await fetch(`${url}/check_user?session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}`);
if (!response.ok) {
const errorText = await response.text();
throw new APIError(errorText , response.status);
}
const json = await response.json();
const result = validateCheckUser(json);
if(result.success) {
return result.data;
}
console.error(result.error.message);
throw new ResponseError("Failed to validate response");
},
addFeedback: async ({
workerAuthId,
sessionId,
sessionAuthId,
feedback,
timestamp,
email
}:{
workerAuthId:string;
sessionId:number;
sessionAuthId:string;
feedback:0|1;
timestamp:number;
email:string;
} ) => {
const encodedWorkerAuthId = encodeURIComponent(workerAuthId);
const encodedSessionAuthId = encodeURIComponent(sessionAuthId);
const encodedSessionId = encodeURIComponent(sessionId);
const encodedFeedback = encodeURIComponent(feedback);
const encodedTimestamp = encodeURIComponent(timestamp);
const encodedEmail = encodeURIComponent(email);
const response = await fetch(`${url}/user_feedback?worker_auth_id=${encodedWorkerAuthId}&session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}&feedback=${encodedFeedback}×tamp=${encodedTimestamp}&email=${encodedEmail}`);
if (!response.ok) {
const errorText = await response.text();
throw new APIError(errorText , response.status);
}
return response.json();
}
});
================================================
FILE: client/src/pages/Queue/api/errors/api_error.ts
================================================
export class APIError extends Error {
status:number;
constructor(message:string, status:number) {
super(message);
this.status = status;
this.name = "APIError";
}
}
================================================
FILE: client/src/pages/Queue/api/errors/response_error.ts
================================================
export class ResponseError extends Error {
constructor(message:string) {
super(message);
this.name = "ResponseError";
}
}
================================================
FILE: client/src/pages/Queue/api/validators.ts
================================================
import { z } from "zod"
export const validateAddUser = (response: unknown) => {
const AddUser = z.object({
session_id: z.number(),
session_auth_id: z.string(),
});
return AddUser.safeParse(response);
};
export const validateCheckUser = (response: unknown) => {
const CheckUser = z.object({
session_id: z.number(),
// TODO: add more statuses
status: z.enum(['wait', 'ready']),
worker_auth_id: z.string().nullable(),
worker_addr: z.string().nullable(),
current_position: z.string(),
});
return CheckUser.safeParse(response);
}
================================================
FILE: client/src/pages/Queue/hooks/useUserEmail.ts
================================================
import { useCallback, useState } from "react";
import { z } from "zod";
const validateEmail = z.string().email();
export const useUserEmail = (isBypass: boolean, init_value: string) => {
const [userEmail, setUserEmail] = useState(init_value);
const [error, setError] = useState(null);
const validate = useCallback((email: string) => {
if (isBypass) {
setError(null);
return true;
}
const result = validateEmail.safeParse(email);
if (result.success) {
setError(null);
sessionStorage.setItem("userEmail", email.toString());
return true;
}
setError('Invalid email address');
return false;
}, [setError]);
return { userEmail, setUserEmail, error, validate };
}
================================================
FILE: client/src/protocol/encoder.ts
================================================
import {
CONTROL_MESSAGE,
CONTROL_MESSAGES_MAP,
MODELS_MAP,
WSMessage,
VERSIONS_MAP,
} from "./types";
export const encodeMessage = (message: WSMessage): Uint8Array => {
switch (message.type) {
case "handshake":
return new Uint8Array([
0x00,
VERSIONS_MAP[message.version],
MODELS_MAP[message.model],
]);
case "audio":
return new Uint8Array([0x01, ...message.data]);
case "text":
// Not used in practice
return new Uint8Array([0x02, ...new TextEncoder().encode(message.data)]);
case "control":
// Not used in practice
return new Uint8Array([0x03, CONTROL_MESSAGES_MAP[message.action]]);
case "metadata":
// Not used in practice
return new Uint8Array([
0x04,
...new TextEncoder().encode(JSON.stringify(message.data)),
]);
case "error":
// Not used in practice
return new Uint8Array([0x05, ...new TextEncoder().encode(message.data)]);
case "ping":
// Not used in practice
return new Uint8Array([0x06]);
case "coloredtext":
// Not used in practice
return new Uint8Array([0x07, 0x05, ...new TextEncoder().encode(message.data)]);
case "image":
return new Uint8Array([0x08, ...message.data]);
case "user_rating":
return new Uint8Array([0x0A, message.data]);
}
};
export const decodeMessage = (data: Uint8Array): WSMessage => {
const type = data[0];
const payload = data.slice(1);
switch (type) {
case 0x00: {
return {
type: "handshake",
version: 0,
model: 0,
};
}
case 0x01:
return {
type: "audio",
data: payload,
};
case 0x02:
return {
type: "text",
data: new TextDecoder().decode(payload),
};
case 0x03: {
const action = Object.keys(CONTROL_MESSAGES_MAP).find(
key => CONTROL_MESSAGES_MAP[key as CONTROL_MESSAGE] === payload[0],
) as CONTROL_MESSAGE | undefined;
//TODO: log this and don't throw
if (!action) {
throw new Error("Unknown control message");
}
return {
type: "control",
action,
};
}
case 0x04:
return {
type: "metadata",
data: JSON.parse(new TextDecoder().decode(payload)),
}
case 0x05:
return {
type: "error",
data: new TextDecoder().decode(payload),
}
case 0x06:
return {
type: "ping",
}
case 0x07:
return {
type: "coloredtext",
color: payload[0],
data: new TextDecoder().decode(payload.slice(1)),
};
case 0x08:
return {
type: "image",
data: payload,
};
// never used in practice
case 0x0A:
return {
type: "user_rating",
data: payload[0],
};
default: {
console.log(type);
throw new Error("Unknown message type");
}
}
};
================================================
FILE: client/src/protocol/testMessages.ts
================================================
import { WSMessage } from "./types";
export const handshakeMessage: WSMessage = {
type: "handshake",
version: 0,
model: 0,
};
export const audioMessage: WSMessage = {
type: "audio",
data: new Uint8Array(10),
};
export const textMessage: WSMessage = {
type: "text",
data: "Hello",
};
export const controlBOSMessage: WSMessage = {
type: "control",
action: "start",
};
export const controlEOSMessage: WSMessage = {
type: "control",
action: "endTurn",
};
export const metadataMessage: WSMessage = {
type: "metadata",
data: { key: "value" },
};
================================================
FILE: client/src/protocol/types.ts
================================================
export type MessageType =
| "handshake"
| "audio"
| "text"
| "coloredtext"
| "control"
| "metadata";
export const VERSIONS_MAP = {
0: 0b00000000,
} as const;
export const MODELS_MAP = {
0: 0b00000000,
} as const;
export type VERSION = keyof typeof VERSIONS_MAP;
export type MODEL = keyof typeof MODELS_MAP;
export type WSMessage =
| {
type: "handshake";
version: VERSION;
model: MODEL;
}
| {
type: "user_rating";
data: number;
}
| {
type: "audio";
data: Uint8Array;
}
| {
type: "text";
data: string;
}
| {
type: "coloredtext";
color: number;
data: string;
}
| {
type: "control";
action: CONTROL_MESSAGE;
}
| {
type: "metadata";
data: unknown;
}
| {
type: "error";
data: string;
}
| {
type: "ping";
}
| {
type: "image";
data: Uint8Array;
}
export const CONTROL_MESSAGES_MAP = {
start: 0b00000000,
endTurn: 0b00000001,
pause: 0b00000010,
restart: 0b00000011,
} as const;
export type CONTROL_MESSAGE = keyof typeof CONTROL_MESSAGES_MAP;
================================================
FILE: client/tailwind.config.js
================================================
/** @type {import('tailwindcss').Config} */
export default {
content: ["./src/**/*.{js,jsx,ts,tsx}", "./index.html"],
theme: {
extend: {},
},
plugins: [require('daisyui')],
};
================================================
FILE: client/tsconfig.json
================================================
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"module": "ESNext",
"lib": [
"ES2020",
"DOM",
"DOM.Iterable"
],
"skipLibCheck": true,
"outDir": "dist",
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
"types": [
"vite/client"
]
},
"include": [
"src"
]
}
================================================
FILE: client/vite.config.ts
================================================
import { ProxyOptions, defineConfig, loadEnv } from "vite";
import topLevelAwait from "vite-plugin-top-level-await";
export default defineConfig(({ mode }) => {
const env = loadEnv(mode, process.cwd());
const proxyConf: Record = env.VITE_QUEUE_API_URL ? {
"/api": {
target: env.VITE_QUEUE_API_URL,
changeOrigin: true,
},
} : {};
return {
server: {
host: "0.0.0.0",
https: {
cert: "./cert.pem",
key: "./key.pem",
},
proxy: {
...proxyConf,
}
},
plugins: [
topLevelAwait({
// The export name of top-level await promise for each chunk module
promiseExportName: "__tla",
// The function to generate import names of top-level await promise in each chunk module
promiseImportName: i => `__tla_${i}`,
}),
],
};
});
================================================
FILE: docker-bake.hcl
================================================
group "default" {
targets = ["client"]
}
target "client" {
context = "./client"
# Specify output type as a local directory
output = [
"type=local,dest=./client/dist"
]
}
================================================
FILE: kyuteye_mlx/.pylintrc
================================================
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked and
# will not be imported (useful for modules/projects where namespaces are
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.10
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
#verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1200
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
too-many-locals,
unspecified-encoding,
too-many-arguments,
too-many-instance-attributes,
too-many-branches,
too-many-statements,
too-many-return-statements,
too-many-public-methods,
too-few-public-methods,
use-dict-literal,
unnecessary-lambda-assignment,
too-many-function-args
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=12
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
================================================
FILE: kyuteye_mlx/LICENSE
================================================
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
================================================
FILE: kyuteye_mlx/MANIFEST.in
================================================
include LICENSE*
include *.md
include *.cfg
include requirements.txt
include moshi_mlx/py.typed
================================================
FILE: kyuteye_mlx/README.md
================================================
# MoshiVis - MLX
See the [top-level README.md][main_repo] for more information on MoshiVis.
This is the MLX implementation for MoshiVis.
## Usage
We have tested the MLX version with MacBook Air M3 (4-bit quantization) and a MacMini M4 Pro (both 4- and 8-bit quantization).
You can start the server with:
```bash
# In Bfloat16 - weights will be downloaded from HF
uv run server
# In Q4
uv run server -q 4
# In Q8
uv run server -q 8
```
This will start the web UI which you can connect to via http, at [localhost:8008](http://localhost:8008).
Note that unlike other backends, not all settings available in the web UI are propagated to the MLX backend. Instead, you can configure some options directly via the command line e.g. `--text-temperature`.
## License
The present code is provided under the MIT license.
Some of this code was taken from mlx-vlm v0.1.9, the code can be found here:
https://github.com/Blaizzy/mlx-vlm/tree/a11c034adf6ae4bca5a197990d1ecb77aba83c47
The license of mlx-vlm is MIT.
## Citation
If you use either Mimi or Moshi, please cite the following paper,
```
@article{kyutai2025moshivis,
author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and
Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez},
year = {2025},
title = {Vision-Speech Models: Teaching Speech Models to Converse about Images},
journal = {ArXiv},
url = {https://arxiv.org/abs/2503.15633}
}
@techreport{kyutai2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and
Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
year={2024},
eprint={2410.00037},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2410.00037},
}
```
[moshi]: https://kyutai.org/Moshi.pdf
[main_repo]: https://github.com/kyutai-labs/MoshiVis
================================================
FILE: kyuteye_mlx/kyuteye_mlx/__init__.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""
kyuteye_mlx is the MLX inference codebase for Kyutai audio+vision generation models.
"""
import os
import subprocess
# TODO: remove this when https://github.com/ml-explore/mlx/issues/1963 is fixed
if subprocess.check_output(["sysctl", "hw.model"]).decode().split(":")[1].strip() == "Mac15,12":
os.environ["MLX_MAX_OPS_PER_BUFFER"] = "8"
os.environ["MLX_MAX_MB_PER_BUFFER"] = "1000000"
from jaxtyping import install_import_hook
# Set to True to get runtime type-checking and other small checks
# but will slow down the steps by ~2%
DEBUG_MODE = bool(int(os.environ.get("MOSHI_MLX_DEBUG_MODE", "0")))
if DEBUG_MODE:
print("debug mode enabled.")
install_import_hook("kyuteye_mlx", "beartype.beartype")
else:
print("debug mode disabled.")
from . import models, modules, utils
__version__ = "0.1.0"
================================================
FILE: kyuteye_mlx/kyuteye_mlx/benchmark.py
================================================
import time
import mlx.core as mx
import numpy as np
from .local_web import get_args_for_main, get_model, predict_text_and_audio
def main():
args = get_args_for_main()
args.quantized = 4
gen = get_model(args, load_weights=False)
sum_times = 0
for i in range(100):
data = mx.arange(8, dtype=mx.uint32).reshape(8, 1)
uploaded_image_embeddings = mx.arange(1024 * 1152, dtype=mx.bfloat16).reshape(1, 1024, 1152)
mx.eval((data, uploaded_image_embeddings))
t1 = time.time()
predict_text_and_audio(gen, data, uploaded_image_embeddings)
t2 = time.time()
if i >= 5:
sum_times += t2 - t1
print(f"average time per step: {(sum_times / 95) * 1000:1f} ms")
================================================
FILE: kyuteye_mlx/kyuteye_mlx/local_web.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Entrypoint for the web server."""
import argparse
import asyncio
import io
import multiprocessing
import multiprocessing.queues
import os
import queue
import sys
import tarfile
import time
import webbrowser
from enum import Enum
from pathlib import Path
from typing import Any
import aiohttp
import huggingface_hub
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import rustymimi
import sentencepiece
import sphn
from aiohttp import web
from jaxtyping import UInt32
from PIL import Image
from kyuteye_mlx import models, utils
from kyuteye_mlx.models.pixtral import PixtralWrapper
from kyuteye_mlx.models.siglip import SiglipWrapper
from kyuteye_mlx.quantize import quantize
from kyuteye_mlx.utils.loading import repeat_shared_weights, split_embedder_weights
from kyuteye_mlx.utils.profiling import PROFILING_ENABLED, profile
SAMPLE_RATE = 24000
FRAME_SIZE = 1920
CHANNELS = 1
class ModelInput(Enum):
AUDIO = 0
IMAGE = 1
class ModelOutput(Enum):
AUDIO = 0
TEXT = 1
START = 2
IMAGE_PROCESSED = 3
class ServerMediaInput(Enum):
AUDIO = 1
IMAGE = 8
ClientServerQueue = multiprocessing.queues.Queue[tuple[ModelInput, Any]]
ServerClientQueue = multiprocessing.queues.Queue[tuple[ModelOutput, Any]]
def colorize(text: str, color: str) -> str:
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def log(level: str, msg: str) -> None:
if level == "warning":
prefix = colorize("[Warn]", "1;31")
elif level == "info":
prefix = colorize("[Info]", "1;34")
elif level == "error":
prefix = colorize("[Err ]", "1;31")
else:
raise ValueError(f"Unknown level {level}")
print(prefix + " " + msg)
def hf_hub_download(repo: str | None, path: str) -> str:
if repo is None or repo == "":
raise ValueError(f"the --hf-repo flag is required to retrieve {path}")
return huggingface_hub.hf_hub_download(repo, path)
def full_warmup(
audio_tokenizer: rustymimi.StreamTokenizer,
client_to_server: ClientServerQueue,
server_to_client: ServerClientQueue,
) -> None:
for i in range(4):
pcm_data = np.array([0.0] * 1920).astype(np.float32)
audio_tokenizer.encode(pcm_data)
while True:
time.sleep(0.01)
data = audio_tokenizer.get_encoded()
if data is not None:
break
client_to_server.put_nowait((ModelInput.AUDIO, data))
if i == 0:
continue
while True:
kind, data = server_to_client.get()
if kind == ModelOutput.AUDIO:
audio_tokenizer.decode(data)
break
while True:
time.sleep(0.01)
data = audio_tokenizer.get_decoded()
if data is not None:
break
def get_model_file(args) -> str:
model_file = args.moshi_weights
if model_file is None:
if args.quantized in (4, 8):
model_file = hf_hub_download(args.hf_repo, f"model.q{args.quantized}.safetensors")
elif args.quantized is not None:
raise ValueError(f"Invalid quantized value: {args.quantized}")
else:
model_file = hf_hub_download(args.hf_repo, "model.safetensors")
return model_file
def get_tokenizer(args) -> sentencepiece.SentencePieceProcessor:
tokenizer_file = args.tokenizer
if tokenizer_file is None:
tokenizer_file = hf_hub_download(args.hf_repo, "tokenizer_spm_32k_3.model")
log("info", f"[SERVER] loading text tokenizer {tokenizer_file}")
return sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore
def get_embedder(args) -> SiglipWrapper | PixtralWrapper:
model_file = get_model_file(args)
if args.encoder == "pixtral":
lm_config = models.config_pixtral()
elif args.encoder == "siglip":
lm_config = models.config_siglip()
else:
raise ValueError(f"Unknown encoder {args.encoder}")
weights = mx.load(model_file)
if lm_config.transformer.xa_shared:
# for shared cross-attention, we have the weights only once in ckpt
weights = repeat_shared_weights(weights, lm_config.transformer.num_layers)
_, embedder_weights = split_embedder_weights(weights)
if args.encoder == "siglip":
embedder = SiglipWrapper()
elif args.encoder == "pixtral":
embedder = PixtralWrapper()
else:
raise ValueError(f"Unknown encoder {args.encoder}")
if embedder_weights:
log("info", "[SERVER] loading embedder weights")
embedder_weights["model.embeddings.patch_embedding.weight"] = embedder_weights[
"model.embeddings.patch_embedding.weight"
].transpose(0, 2, 3, 1)
embedder.load_weights(list(embedder_weights.items()), strict=True)
log("info", "[SERVER] embedder weights loaded")
embedder.eval()
log("info", "[SERVER] Embedder warmed up")
return embedder
def get_model(args, load_weights: bool = True) -> models.LmGen:
mx.random.seed(299792458)
if args.encoder == "pixtral":
lm_config = models.config_pixtral()
elif args.encoder == "siglip":
lm_config = models.config_siglip()
else:
raise ValueError(f"Unknown encoder {args.encoder}")
lm_config.transformer.xa_start = args.xa_start
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
if args.quantized is not None:
group_size = 32 if args.quantized == 4 else 64
nn.quantize(model, bits=args.quantized, group_size=group_size)
if load_weights:
model_file = get_model_file(args)
log("info", f"[SERVER] loading weights {model_file}")
weights = mx.load(model_file)
if lm_config.transformer.xa_shared:
# for shared cross-attention, we have the weights only once in ckpt
weights = repeat_shared_weights(weights, lm_config.transformer.num_layers)
weights, _ = split_embedder_weights(weights)
model.load_weights(list(weights.items()), strict=True)
log("info", "[SERVER] weights loaded")
model.warmup()
log("info", "[SERVER] model warmed up")
gen = models.LmGen(
model=model,
max_steps=args.steps + 5,
text_sampler=utils.Sampler(temp=args.text_temperature, top_p=args.text_top_p),
audio_sampler=utils.Sampler(temp=args.audio_temperature, top_p=args.audio_top_p),
check=False,
)
return gen
def model_server(
client_to_server: ClientServerQueue,
server_to_client: ServerClientQueue,
args: argparse.Namespace,
):
gen = get_model(args)
embedder = get_embedder(args)
text_tokenizer = get_tokenizer(args)
server_to_client.put((ModelOutput.START, "start"))
log("info", "[SERVER] connected!")
try:
uploaded_image_embeddings = None
gen.reset()
for i in range(10000000000000):
if i == 150:
if PROFILING_ENABLED:
profile.print_stats()
data_type, data = client_to_server.get()
if data_type == ModelInput.AUDIO:
handle_audio(
data,
gen,
uploaded_image_embeddings,
server_to_client,
text_tokenizer,
)
elif data_type == ModelInput.IMAGE:
img = Image.open(io.BytesIO(data))
# compute longer image size
if args.encoder == "pixtral":
w, h = img.width, img.height
if w > h:
new_w = args.img_size
new_h = int(h * new_w / w)
else:
new_h = args.img_size
new_w = int(w * new_h / h)
img = img.resize((new_w, new_h), resample=3)
else:
img = img.resize((args.img_size, args.img_size), resample=3)
image = np.asarray(img)
uploaded_image_embeddings = embedder(mx.array(image)).astype(mx.bfloat16)
gen.reset()
log("info", f"received image embeddings: {image.shape}")
server_to_client.put_nowait((ModelOutput.IMAGE_PROCESSED, None))
except KeyboardInterrupt:
pass
def handle_audio(
data: UInt32[np.ndarray, "tokens 1"],
gen: models.LmGen,
uploaded_image_embeddings: mx.array | None,
server_to_client: ServerClientQueue,
text_tokenizer: sentencepiece.SentencePieceProcessor,
):
t_start = time.time()
text_token, audio_tokens = predict_text_and_audio(gen, mx.array(data), uploaded_image_embeddings)
elapsed_eval_seconds = time.time() - t_start
elapsed_eval_milliseconds = elapsed_eval_seconds * 1000
print(f"eval in {elapsed_eval_milliseconds} ms")
text_token = text_token.item()
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
server_to_client.put_nowait((ModelOutput.TEXT, _text))
if audio_tokens is not None:
audio_tokens = np.array(audio_tokens).astype(np.uint32)
server_to_client.put_nowait((ModelOutput.AUDIO, audio_tokens))
elapsed_seconds = time.time() - t_start
elapsed_milliseconds = elapsed_seconds * 1000
print(f"step in {elapsed_milliseconds} ms")
def predict_text_and_audio(
gen: models.LmGen, data: UInt32[mx.array, "tokens 1"], uploaded_image_embeddings
) -> tuple[mx.array, mx.array | None]:
data = data.transpose(1, 0)[:, :8]
if gen.model.xa_cache is not None and not gen.model.xa_cache.is_set:
text_token = gen.step(data, uploaded_image_embeddings)
else:
text_token = gen.step(data, None)
text_token = text_token[0]
audio_tokens = gen.last_audio_tokens()
mx.eval((text_token, audio_tokens))
return text_token, audio_tokens
def web_server(
client_to_server: ClientServerQueue,
server_to_client: ServerClientQueue,
args: argparse.Namespace,
):
mimi_file = args.mimi_weights
if mimi_file is None:
mimi_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors")
input_queue = queue.Queue()
output_queue = queue.Queue()
text_queue = queue.Queue()
audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore
kind, start_message = server_to_client.get()
if kind != ModelOutput.START:
log(
"error",
f"[CLIENT] recieve {(kind, start_message)} at startup, this is unexpected.",
)
log("info", f"[CLIENT] received '{start_message}' from server, starting...")
full_warmup(audio_tokenizer, client_to_server, server_to_client)
log("info", "warmup done")
async def send_loop() -> None:
while True:
await asyncio.sleep(0.001)
try:
pcm_data = input_queue.get(block=False)
audio_tokenizer.encode(pcm_data)
except queue.Empty:
continue
async def recv_loop() -> None:
while True:
data = audio_tokenizer.get_decoded()
if data is None:
await asyncio.sleep(0.001)
continue
output_queue.put_nowait(data)
async def send_loop2() -> None:
while True:
data = audio_tokenizer.get_encoded()
if data is None:
await asyncio.sleep(0.001)
continue
client_to_server.put_nowait((ModelInput.AUDIO, data))
async def recv_loop2() -> None:
while True:
try:
kind, data = server_to_client.get(block=False)
if kind == ModelOutput.AUDIO:
audio_tokenizer.decode(data)
elif kind == ModelOutput.TEXT:
text_queue.put_nowait(data)
except queue.Empty:
await asyncio.sleep(0.001)
continue
lock = asyncio.Lock()
async def handle_chat(request) -> None:
ws = web.WebSocketResponse()
await ws.prepare(request)
async def recv_loop() -> None:
nonlocal close
try:
async for message in ws:
if message.type == aiohttp.WSMsgType.ERROR:
log("error", f"{ws.exception()}")
break
elif message.type == aiohttp.WSMsgType.CLOSED:
break
elif message.type != aiohttp.WSMsgType.BINARY:
log("error", f"unexpected message type {message.type}")
continue
message = message.data
if not isinstance(message, bytes):
log("error", f"unsupported message type {type(message)}")
continue
if len(message) == 0:
log("warning", "empty message")
continue
kind = message[0]
if kind == ServerMediaInput.AUDIO.value:
payload = message[1:]
opus_reader.append_bytes(payload)
else:
log("warning", f"unknown message kind {kind}")
finally:
close = True
log("info", "connection closed")
async def opus_loop() -> None:
all_pcm_data = None
while True:
if close:
return
await asyncio.sleep(0.001)
pcm = opus_reader.read_pcm()
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= FRAME_SIZE:
chunk = all_pcm_data[:FRAME_SIZE]
all_pcm_data = all_pcm_data[FRAME_SIZE:]
input_queue.put_nowait(chunk)
async def send_loop() -> None:
while True:
if close:
return
await asyncio.sleep(0.001)
msg = opus_writer.read_bytes()
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)
try:
_text = text_queue.get(block=False)
await ws.send_bytes(b"\x02" + bytes(_text, encoding="utf8"))
except queue.Empty:
continue
async def another_loop() -> None:
while True:
if close:
return
await asyncio.sleep(0.001)
try:
pcm_data = output_queue.get(block=False)
assert pcm_data.shape == (1920,), pcm_data.shape
opus_writer.append_pcm(pcm_data)
except queue.Empty:
continue
log("info", "accepted connection")
close = False
async with lock:
opus_writer = sphn.OpusStreamWriter(SAMPLE_RATE)
opus_reader = sphn.OpusStreamReader(SAMPLE_RATE)
# Send the handshake.
encoded_image = await extract_image(ws)
client_to_server.put_nowait((ModelInput.IMAGE, encoded_image))
message_type, _ = server_to_client.get()
if message_type != ModelOutput.IMAGE_PROCESSED:
log("error", f"We recieved {message_type} instead of IMAGE_PROCESSED.")
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop(), another_loop())
log("info", "done with connection")
return ws
async def extract_image(ws: web.WebSocketResponse) -> bytes:
"""Get the image at the start of the stream.
The bytes returned are encoded (png, jpeg, etc...).
"""
first_message = await ws.receive()
first_message = first_message.data
kind = first_message[0]
if kind != ServerMediaInput.IMAGE.value:
log("error", f"First messsage should be an image, got {kind}")
return first_message[1:]
async def go() -> None:
app = web.Application()
app.router.add_get("/api/chat", handle_chat)
static_path: None | str = None
if args.static is None:
log("info", "retrieving the static content")
dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "vis_dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
if not dist.exists():
with tarfile.open(dist_tgz, "r:gz") as tar:
tar.extractall(path=dist_tgz.parent)
static_path = str(dist)
elif args.static != "none":
# When set to the "none" string, we don't serve any static content.
static_path = args.static
if static_path is not None:
async def handle_root(_) -> web.FileResponse:
return web.FileResponse(os.path.join(static_path, "index.html"))
log("info", f"serving static content from {static_path}")
app.router.add_get("/", handle_root)
app.router.add_static("/", path=static_path, name="static")
runner = web.AppRunner(app)
await runner.setup()
ssl_context = None
protocol = "http"
if args.ssl is not None:
import ssl
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
cert_file = os.path.join(args.ssl, "cert.pem")
key_file = os.path.join(args.ssl, "key.pem")
ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file)
protocol = "https"
site = web.TCPSite(runner, args.host, args.port, ssl_context=ssl_context)
log("info", f"listening to {protocol}://{args.host}:{args.port}")
if not args.no_browser:
log("info", f"opening browser at {protocol}://{args.host}:{args.port}")
webbrowser.open(f"{protocol}://{args.host}:{args.port}")
await asyncio.gather(recv_loop(), send_loop(), recv_loop2(), send_loop2(), site.start())
await runner.cleanup()
try:
asyncio.run(go())
except KeyboardInterrupt:
pass
def get_args_for_main() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--moshi-weights", type=str)
parser.add_argument("--mimi-weights", type=str)
parser.add_argument("-q", "--quantized", type=int, choices=[4, 8])
parser.add_argument("--steps", default=4000, type=int)
parser.add_argument("--hf-repo", type=str, default="kyutai/moshika-vis-mlx")
parser.add_argument("--static", type=str)
parser.add_argument("--img_size", type=int, default=448)
parser.add_argument("--encoder", type=str, default="siglip")
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8008, type=int)
parser.add_argument(
"--ssl",
type=str,
help=(
"use https instead of http, this flag should point to a directory "
"that contains valid key.pem and cert.pem files"
),
)
parser.add_argument("--no-browser", action="store_true")
parser.add_argument("--text-temperature", type=float, default=0.45)
parser.add_argument("--audio-temperature", type=float, default=0.7)
parser.add_argument("--text-top-p", type=float, default=0.95)
parser.add_argument("--audio-top-p", type=float, default=0.95)
parser.add_argument("--xa-start", type=int, default=16)
args = parser.parse_args()
if args.moshi_weights is not None:
args.quantized = (
args.quantized or 8 if "q8" in args.moshi_weights else 4 if "q4" in args.moshi_weights else None
)
return args
def main() -> None:
args = get_args_for_main()
client_to_server: ClientServerQueue = multiprocessing.Queue()
server_to_client: ServerClientQueue = multiprocessing.Queue()
# Create two processes
subprocess_args = client_to_server, server_to_client, args
p1 = multiprocessing.Process(target=web_server, args=subprocess_args)
p2 = multiprocessing.Process(target=model_server, args=subprocess_args)
# Start the processes
p1.start()
p2.start()
try:
while p1.is_alive() and p2.is_alive():
time.sleep(0.001)
except KeyboardInterrupt:
log("warning", "Interrupting, exiting connection.")
p1.terminate()
p2.terminate()
# Wait for both processes to finish
p1.join()
p2.join()
log("info", "All done!")
def sanity_check() -> None:
"""A small sanity check to make sure all packages can be imported."""
if __name__ == "__main__":
main()
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/LICENSE
================================================
MIT License
Copyright © 2023 Apple Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/__init__.py
================================================
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/__init__.py
================================================
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/pixtral/__init__.py
================================================
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/pixtral/vision.py
================================================
import inspect
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@dataclass
class VisionConfig:
model_type: str
num_hidden_layers: int = 24
hidden_size: int = 1024
head_dim: int = 64
intermediate_size: int = 4096
num_attention_heads: int = 16
image_size: int = 336
patch_size: int = 14
projection_dim: int = 768
vocab_size: int = 32000
num_channels: int = 3
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
def position_ids_in_meshgrid(patch_embeds_list: list[mx.array], max_width: int):
positions = []
for patch in patch_embeds_list:
height, width = patch.shape[1], patch.shape[2]
h_grid, v_grid = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij")
h_grid = h_grid.reshape(-1, 1)
v_grid = v_grid.reshape(-1, 1)
ids = h_grid * max_width + v_grid
positions.append(ids.flatten())
return mx.concatenate(positions)
def generate_block_attention_mask(patch_embeds_list: list[mx.array], tensor: mx.array):
seq_len = tensor.shape[1]
d_min = -1e9 # Using a large negative value as MLX doesn't have finfo
causal_mask = mx.full((seq_len, seq_len), vals=d_min)
block_end_idx = mx.cumsum(mx.array(patch_embeds_list))
block_start_idx = mx.concatenate([mx.array([0]), mx.array(patch_embeds_list[:-1])])
block_start_idx = mx.cumsum(block_start_idx)
for start, end in zip(block_start_idx, block_end_idx):
start, end = int(start), int(end) # Convert to integers for indexing
causal_mask[start:end, start:end] = 0
causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len))
return causal_mask
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = mx.expand_dims(cos, axis=unsqueeze_dim)
sin = mx.expand_dims(sin, axis=unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: int | None = None,
key_input_dims: int | None = None,
value_input_dims: int | None = None,
value_dims: int | None = None,
value_output_dims: int | None = None,
bias: bool = False,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.embed_dim = dims
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def __call__(self, queries, keys, values, position_embeddings, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
cos, sin = position_embeddings
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0)
attn_weights = mx.matmul(queries, keys.transpose(0, 1, 3, 2)) * self.scale
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = mx.softmax(attn_weights, axis=-1)
output = mx.matmul(attn_weights, values)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
dim = config.hidden_size
hidden_dim = config.intermediate_size
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class EncoderLayer(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.attention = Attention(config.hidden_size, config.num_attention_heads, bias=True)
self.attention_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
self.feed_forward = MLP(config)
self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
position_embeddings: mx.array,
mask: mx.array | None = None,
) -> mx.array:
y = self.attention_norm(x)
y = self.attention(y, y, y, position_embeddings, mask)
x = x + y
y = self.ffn_norm(x)
y = self.feed_forward(y)
return x + y
class Encoder(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
class PixtralRotaryEmbedding:
def __init__(self, config):
self.dim = config.head_dim
self.base = config.rope_theta
max_patches_per_side = config.image_size // config.patch_size
freqs = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim))
h = mx.arange(max_patches_per_side)
w = mx.arange(max_patches_per_side)
freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32)
freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32)
inv_freq = mx.concatenate(
[
mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)),
],
axis=-1,
).reshape(-1, self.dim // 2)
self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1)
def __call__(self, x, position_ids):
freqs = self.inv_freq[position_ids]
emb = freqs
cos = mx.cos(emb)
sin = mx.sin(emb)
return cos.astype(x.dtype), sin.astype(x.dtype)
class PixtralVisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = nn.RMSNorm(config.hidden_size)
self.transformer = Encoder(config)
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
def __call__(
self,
x: list[mx.array],
output_hidden_states: bool | None = None,
) -> mx.array:
B, H, W, C = x[0].shape
patch_embeds_list = [self.patch_conv(img) for img in x]
patch_embeds = mx.concatenate([p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1)
patch_embeds = self.ln_pre(patch_embeds)
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size,
)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
mask = generate_block_attention_mask(
[p.shape[2] * p.shape[1] for p in patch_embeds_list], patch_embeds
)
encoder_states = (patch_embeds,) if output_hidden_states else None
for layer in self.transformer.layers:
patch_embeds = layer(patch_embeds, mask=mask, position_embeddings=position_embedding)
if output_hidden_states:
encoder_states = encoder_states + (patch_embeds,)
return patch_embeds, encoder_states
================================================
FILE: kyuteye_mlx/kyuteye_mlx/mlx_vlm/models/siglip/vision.py
================================================
import inspect
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
import numpy as np
@dataclass
class VisionConfig:
model_type: str
num_hidden_layers: int
hidden_size: int
intermediate_size: int
num_attention_heads: int
patch_size: int
projection_dim: int
image_size: int = 224
num_channels: int = 3
layer_norm_eps: float = 1e-6
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
def check_array_shape(arr):
shape = arr.shape
# Check if the shape has 4 dimensions
if len(shape) != 4:
return False
out_channels, kH, KW, _ = shape
# Check if out_channels is the largest, and kH and KW are the same
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
return True
else:
return False
class Attention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: int | None = None,
key_input_dims: int | None = None,
value_input_dims: int | None = None,
value_dims: int | None = None,
value_output_dims: int | None = None,
bias: bool = True,
) -> None:
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
head_dim = dims // num_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, x, mask=None):
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, config: VisionConfig) -> None:
super().__init__()
self.activation_fn = nn.GELU(approx="precise")
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
def __call__(self, x: mx.array) -> mx.array:
x = self.fc1(x)
x = self.activation_fn(x)
x = self.fc2(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, config: VisionConfig) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Attention(config.hidden_size, config.num_attention_heads, bias=True)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
r = self.self_attn(self.layer_norm1(x), mask)
h = x + r
r = self.mlp(self.layer_norm2(h))
return h + r
class Encoder(nn.Module):
def __init__(self, config: VisionConfig) -> None:
super().__init__()
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
def __call__(
self,
x: mx.array,
output_hidden_states: bool | None = None,
mask: mx.array | None = None,
) -> tuple[mx.array, mx.array | None]:
encoder_states = (x,) if output_hidden_states else None
h = x
for layer in self.layers:
x = layer(x, mask=mask)
if output_hidden_states:
encoder_states = encoder_states + (x,)
h = x[0]
return (h, encoder_states)
class VisionEmbeddings(nn.Module):
def __init__(self, config: VisionConfig) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def __call__(self, x: mx.array) -> mx.array:
patch_embeddings = self.patch_embedding(x)
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
position_ids = mx.array(np.arange(self.num_positions)[None, :])
embeddings = patch_embeddings
embeddings += self.position_embedding(position_ids)
return embeddings
class SigLipVisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
self.encoder = Encoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def __call__(
self,
x: mx.array,
output_hidden_states: bool | None = None,
) -> tuple[mx.array, mx.array, mx.array | None]:
x = self.embeddings(x)
encoder_outputs = self.encoder(x=x, output_hidden_states=output_hidden_states, mask=None)
pooler_output = self.post_layernorm(encoder_outputs[0])
return pooler_output, x, encoder_outputs[-1]
class VisionModel(nn.Module):
def __init__(self, config: VisionConfig) -> None:
super().__init__()
self.model_type = config.model_type
if self.model_type != "siglip_vision_model":
raise ValueError(f"Unsupported model type: {self.model_type}")
self.vision_model = SigLipVisionModel(config)
def __call__(self, x: mx.array, output_hidden_states: bool | None = None) -> mx.array:
return self.vision_model(x, output_hidden_states)
def sanitize(self, weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
# PyTorch conv2d weight tensors have shape:
# [out_channels, in_channels, kH, KW]
# MLX conv2d expects the weight be of shape:
# [out_channels, kH, KW, in_channels]
if check_array_shape(v):
sanitized_weights[k] = v
else:
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights
================================================
FILE: kyuteye_mlx/kyuteye_mlx/models/__init__.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
"""
from .lm import (
Lm,
LmConfig,
config_v0_1,
config_siglip,
config_pixtral,
config1b_202412,
config1b_202412_16rvq,
config_helium_1_preview_2b,
)
from .generate import LmGen
================================================
FILE: kyuteye_mlx/kyuteye_mlx/models/generate.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import mlx.core as mx
from jaxtyping import BFloat16, Int32, UInt32
from kyuteye_mlx import DEBUG_MODE
from ..models import Lm
from ..utils import sampling
class LmGen:
def __init__(
self,
model: Lm,
max_steps: int,
text_sampler: sampling.Sampler,
audio_sampler: sampling.Sampler,
check: bool = False,
):
self.model: Lm = model
self.text_sampler = text_sampler
self.audio_sampler = audio_sampler
self.max_steps = max_steps
self.check = check
self.num_codebooks = 1 + model.cfg.audio_codebooks
self.gen_sequence = mx.full(
shape=(1, self.num_codebooks, max_steps),
vals=self.ungenerated_token,
dtype=mx.int32,
)
self.step_idx = 0
self.audio_padding_token = self.model.cfg.audio_padding_token
self.audio_delays = self.model.cfg.audio_delays
self.max_delay = max(self.audio_delays)
self.main_codebooks = self.model.cfg.depformer.num_slices
@property
def zero_token(self) -> int:
"""Special value in the input tokens, indicating that no sampling should
happen for that value, and no input should be given to the model."""
return -1
@property
def ungenerated_token(self) -> int:
"""Special value that can be provided in the prompt to indicate that this specific
value should be predicted and sampled. This allows for partial teacher forcing, by generating
one modality, with the other one fixed.
"""
return -2
@property
def nb_input_tokens(self) -> int:
return self.model.cfg.audio_codebooks - self.main_codebooks
# Runs one step of inference and return the generated text token.
def step(
self,
other_audio_tokens: UInt32[mx.array, "1 {self.nb_input_tokens}"],
image_embeddings: BFloat16[mx.array, "1 dim1 dim2"] | None,
) -> UInt32[mx.array, "1"]:
if self.step_idx >= self.max_steps:
raise ValueError(f"reached max-steps {self.max_steps}")
if self.step_idx == 0:
text_tokens = mx.array([[32000]])
else:
text_tokens = self.gen_sequence[:, 0, self.step_idx - 1][None]
self.gen_sequence[:, 1 + self.main_codebooks :, self.step_idx] = other_audio_tokens
audio_tokens = []
for cb_idx, delay in enumerate(self.audio_delays):
gen_idx = self.step_idx - 1 - delay
if gen_idx >= 0:
audio_token = self.gen_sequence[:, cb_idx + 1, gen_idx][None]
else:
audio_token = mx.array([[self.audio_padding_token]])
if DEBUG_MODE and (audio_token == self.ungenerated_token).any(): # type: ignore
raise ValueError(f"ungenerated value in audio tokens cb: {cb_idx} step: {self.step_idx}")
assert audio_token.shape == (1, 1), "invalid audio-tokens shape"
audio_tokens.append(audio_token)
if DEBUG_MODE and (text_tokens == self.ungenerated_token).any(): # type: ignore
raise ValueError(f"ungenerated value in text tokens {self.step_idx}")
assert text_tokens.shape == (1, 1), "invalid text-tokens shape"
text_tokens, audio_tokens = self.model.sample(
text_tokens,
audio_tokens,
self.step_idx,
self.text_sampler,
self.audio_sampler,
image_embeddings=image_embeddings,
)
assert audio_tokens.shape == (8,), "invalid output audio-token shape"
self.gen_sequence[:, 0, self.step_idx] = text_tokens
for cb_idx, delay in enumerate(self.audio_delays[: self.main_codebooks]):
gen_idx = self.step_idx - delay
if gen_idx >= 0:
self.gen_sequence[:, cb_idx + 1, gen_idx] = audio_tokens[cb_idx]
self.step_idx += 1
return text_tokens
def last_audio_tokens(self) -> Int32[mx.array, "1 {self.nb_input_tokens}"] | None:
gen_idx = self.step_idx - 1 - self.max_delay
if gen_idx < 0:
return None
tokens = self.gen_sequence[:, 1 : 1 + self.main_codebooks, gen_idx]
if DEBUG_MODE and (tokens == self.audio_padding_token).any(): # type: ignore
return None
if DEBUG_MODE and (tokens == self.ungenerated_token).any(): # type: ignore
raise ValueError(f"ungenerated value in last-audio tokens {self.step_idx}")
return tokens
def reset(self) -> None:
self.gen_sequence = mx.full(
shape=(1, self.num_codebooks, self.max_steps),
vals=self.ungenerated_token,
dtype=mx.int32,
)
self.step_idx = 0
self.model.reset_all_caches()
================================================
FILE: kyuteye_mlx/kyuteye_mlx/models/lm.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from importlib.metadata import version
import mlx.core as mx
import mlx.nn as nn
from packaging.version import parse as parse_version
from ..modules.kv_cache import KVCache, RotatingKVCache, XACache
from ..modules.transformer import Transformer, TransformerConfig
from ..utils import sampling
MLX_BELOW_0_22_0 = parse_version(version("mlx")) < parse_version("0.22.0")
@dataclass
class DepFormerConfig:
transformer: TransformerConfig
num_slices: int
@dataclass
class LmConfig:
transformer: TransformerConfig
depformer: DepFormerConfig
text_in_vocab_size: int
text_out_vocab_size: int
audio_vocab_size: int
audio_codebooks: int
audio_delays: list[int]
@property
def audio_eos_token(self) -> int:
return self.audio_vocab_size - 2
@property
def audio_padding_token(self) -> int:
return self.audio_vocab_size - 1
class DepFormerSlice(nn.Module):
def __init__(
self,
in_vocab_size: int,
out_vocab_size: int,
main_transformer_dim: int,
cfg: TransformerConfig,
):
super().__init__()
dim = cfg.d_model
self.emb = nn.Embedding(in_vocab_size, dim)
self.linear_in = nn.Linear(main_transformer_dim, dim, bias=False)
self.linear_out = nn.Linear(dim, out_vocab_size, bias=False)
self.transformer = Transformer(cfg)
def __call__(self, _: mx.array) -> mx.array:
raise ValueError("not implemented")
class DepFormer(nn.Module):
def __init__(self, cfg: LmConfig):
super().__init__()
self.slices: list[DepFormerSlice] = []
for slice_idx in range(cfg.depformer.num_slices):
in_vs = cfg.text_in_vocab_size if slice_idx == 0 else cfg.audio_vocab_size
slice = DepFormerSlice(
in_vs,
cfg.audio_vocab_size - 1,
main_transformer_dim=cfg.transformer.d_model,
cfg=cfg.depformer.transformer,
)
self.slices.append(slice)
def __call__(self, _: mx.array) -> mx.array:
raise ValueError("not implemented")
def sample(
self,
main_transformer_out: mx.array,
step_idx: int,
sampler: sampling.Sampler,
text_token: mx.array,
cache: list[KVCache] | list[RotatingKVCache],
) -> mx.array:
tokens = []
last_token = text_token
# The cache is shared between the depformer slices but not persisted between sample calls.
for c in cache:
c.reset()
for slice_idx, slice in enumerate(self.slices):
last_token = last_token if step_idx > 0 or slice_idx in (0, 1, 9) else mx.array(2048)
if MLX_BELOW_0_22_0:
embedding = slice.emb(last_token)
else:
last_token_reshaped = mx.expand_dims(last_token, axis=0)
embedding = slice.emb(last_token_reshaped).squeeze(0)
xs = slice.linear_in(main_transformer_out) + embedding
xs = slice.transformer(xs, cache=cache)
logits = slice.linear_out(xs)
last_token, _ = sampler(logits[0])
tokens.append(last_token)
tokens = mx.concatenate(tokens)
return tokens
class Lm(nn.Module):
def __init__(self, cfg: LmConfig):
super().__init__()
dim = cfg.transformer.d_model
self.transformer: Transformer = Transformer(cfg.transformer, with_img_prefix=True)
self.depformer: DepFormer = DepFormer(cfg)
self.text_emb = nn.Embedding(cfg.text_in_vocab_size, dim)
self.cfg: LmConfig = cfg
if cfg.transformer.norm == "layer_norm":
self.out_norm = nn.LayerNorm(dim, 1e-5)
elif cfg.transformer.norm == "rms_norm":
self.out_norm = nn.RMSNorm(dim, 1e-8)
else:
raise ValueError(f"unsupported norm type {cfg.transformer.norm}")
self.text_linear = nn.Linear(dim, cfg.text_out_vocab_size, bias=False)
self.audio_embs = [nn.Embedding(cfg.audio_vocab_size, dim) for _ in range(cfg.audio_codebooks)]
self.transformer_cache: list[RotatingKVCache] = self.transformer.make_rot_cache()
if cfg.transformer.cross_attention:
self.xa_cache = XACache()
if len(self.depformer.slices) > 0:
self.depformer_cache: list[KVCache] = self.depformer.slices[0].transformer.make_cache()
else:
self.depformer_cache = []
def __call__(
self,
token_ids: mx.array,
) -> mx.array:
# Note that this does not apply the depformer.
xs = self.text_emb(token_ids)
transformer_out = self.transformer(xs, cache=self.transformer_cache)
transformer_out = self.out_norm(transformer_out)
text_logits = self.text_linear(transformer_out)
return text_logits
def sample(
self,
text_token_ids: mx.array,
audio_token_ids: list[mx.array],
step_idx: int,
text_sampler: sampling.Sampler,
audio_sampler: sampling.Sampler,
image_embeddings: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
xs = self.text_emb(text_token_ids)
for token_ids, emb in zip(audio_token_ids, self.audio_embs):
xs = xs + emb(token_ids)
transformer_out = self.transformer(
xs,
cache=self.transformer_cache,
img_embeds=image_embeddings,
xa_cache=self.xa_cache,
)
transformer_out = self.out_norm(transformer_out)
text_logits = self.text_linear(transformer_out)
text_token, _ = text_sampler(text_logits[:, 0])
audio_tokens = self.depformer.sample(
transformer_out,
step_idx,
audio_sampler,
text_token,
self.depformer_cache,
)
return text_token, audio_tokens
def warmup(self) -> None:
text, audio = self.sample(
mx.array([[32000]]),
[mx.array([[0]])] * 8,
0,
text_sampler=sampling.Sampler(0.5, 0.5),
audio_sampler=sampling.Sampler(0.5, 0.5),
)
if text.sum().item() == 42:
raise ValueError(42)
if audio.sum().item() == 42:
raise ValueError(42)
for c in self.transformer_cache:
c.reset()
def reset_all_caches(self) -> None:
for c in self.transformer_cache:
c.reset()
if hasattr(self, "xa_cache"):
self.xa_cache.reset()
for c in self.depformer_cache:
c.reset()
def config1b_202412() -> LmConfig:
transformer = TransformerConfig(
d_model=2048,
num_heads=16,
num_layers=16,
dim_feedforward=2048 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=750,
max_period=100000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
)
depformer = DepFormerConfig(
transformer=TransformerConfig(
d_model=1024,
num_heads=16,
num_layers=6,
dim_feedforward=1024 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=8,
max_period=10000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="none",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
),
num_slices=8,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=48001,
text_out_vocab_size=48000,
audio_codebooks=16,
audio_delays=([0] + [1] * 7) * 2,
)
def config1b_202412_16rvq() -> LmConfig:
transformer = TransformerConfig(
d_model=2048,
num_heads=16,
num_layers=16,
dim_feedforward=2048 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=750,
max_period=100000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
)
depformer = DepFormerConfig(
transformer=TransformerConfig(
d_model=1024,
num_heads=16,
num_layers=6,
dim_feedforward=1024 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=16,
max_period=10000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="none",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
),
num_slices=16,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=48001,
text_out_vocab_size=48000,
audio_codebooks=32,
audio_delays=([0] + [1] * 15) * 2,
)
def config_v0_1() -> LmConfig:
transformer = TransformerConfig(
d_model=4096,
num_heads=32,
num_layers=32,
dim_feedforward=4096 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=3000,
max_period=10000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=True,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
xa_gating="sigmoid",
xa_shared=True,
img_emb_dim=1152,
)
depformer = DepFormerConfig(
transformer=TransformerConfig(
d_model=1024,
num_heads=16,
num_layers=6,
dim_feedforward=1024 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=8,
max_period=10000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="none",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
xa_gating="none",
xa_shared=True,
),
num_slices=8,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=32001,
text_out_vocab_size=32000,
audio_codebooks=16,
audio_delays=([0] + [1] * 7) * 2,
)
def config_siglip() -> LmConfig:
config = config_v0_1()
config.transformer.img_emb_dim = 1152
return config
def config_pixtral() -> LmConfig:
config = config_v0_1()
config.transformer.img_emb_dim = 1024
return config
def config_helium_1_preview_2b() -> LmConfig:
transformer = TransformerConfig(
d_model=2560,
num_heads=20,
num_layers=24,
dim_feedforward=2560 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=4096,
max_period=100000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
)
depformer = DepFormerConfig(
transformer=transformer,
num_slices=0,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=48000,
text_out_vocab_size=48000,
audio_codebooks=0,
audio_delays=[],
)
================================================
FILE: kyuteye_mlx/kyuteye_mlx/models/pixtral.py
================================================
import json
import mlx
import mlx.core as mx
import mlx.nn
from ..mlx_vlm.models.pixtral.vision import PixtralVisionModel, VisionConfig
class PixtralWrapper(mlx.nn.Module):
"""Pixtral encoder returning penultimate features"""
def __init__(self) -> None:
super().__init__()
# if not os.path.exists("pixtral-12b-8bit.safetensors"):
# self.load_pixtral_weights()
with open("pixtral-12b-8bit.config", "r") as f:
vision_config = VisionConfig(**json.load(f))
self.model = PixtralVisionModel(vision_config)
weights = mx.load("pixtral-12b-8bit.safetensors")
mlx.nn.quantize(self.model, bits=8, group_size=64)
self.model.load_weights(list(weights.items()))
# def load_pixtral_weights(self):
# from mlx_vlm.models.pixtral import Model
# from dataclasses import asdict
# pixtral = Model.from_pretrained("mlx-community/pixtral-12b-8bit")
# self.model = pixtral.vision_tower.vision_model
# self.model.save_weights("pixtral-12b-8bit.safetensors")
# with open("pixtral-12b-8bit.config", "w") as f:
# json.dump(asdict(pixtral.config.vision_config), f)
def __call__(self, x: mx.array) -> mx.array:
"""Forward to the last hidden states
We expect the input to be an RGB uint8 image (H, W, C)
"""
means = mx.array([0.48145466, 0.4578275, 0.40821073])
std = mx.array([0.26862954, 0.26130258, 0.27577711])
x = ((x / 255.0) - means[None, None, :]) / std[None, None, :]
x = [x[None, :, :, :].astype(mx.float32)]
assert isinstance(x, list), "Pixtral expects a list of tensors."
return self.model(x, output_hidden_states=False)[0]
def warmup(self) -> None:
eval(self(mx.zeros((224, 224, 3), dtype=mx.uint8)))
================================================
FILE: kyuteye_mlx/kyuteye_mlx/models/siglip.py
================================================
import json
import os
import mlx
import mlx.core as mx
import mlx.nn
from ..mlx_vlm.models.siglip.vision import VisionConfig, VisionModel
class SiglipWrapper(mlx.nn.Module):
"""Siglip encoder returning penultimate features"""
def __init__(self) -> None:
super().__init__()
with open("siglip448.config", "r") as f:
vision_config = VisionConfig(**json.load(f))
print("loaded!")
self.model = VisionModel(vision_config).vision_model
def __call__(self, x: mx.array) -> mx.array:
"""Forward to the last hidden states
We expect the input to be an RGB uint8 image (H, W, C)
"""
means = mx.array([0.5] * 3)
std = mx.array([0.5] * 3)
x = ((x / 255.0) - means[None, None, :]) / std[None, None, :]
x = x[None, :, :, :].astype(mx.float32)
out = self.model(x, output_hidden_states=False)[0]
return out[None]
def warmup(self) -> None:
eval(self(mx.zeros((224, 224, 3), dtype=mx.uint8)))
================================================
FILE: kyuteye_mlx/kyuteye_mlx/modules/__init__.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""Modules used for building the models."""
from .kv_cache import KVCache, RotatingKVCache
from .transformer import Transformer, TransformerConfig
================================================
FILE: kyuteye_mlx/kyuteye_mlx/modules/config.py
================================================
from dataclasses import dataclass
from typing import Literal
@dataclass
class TransformerConfig:
d_model: int
num_heads: int
num_layers: int
causal: bool
norm_first: bool
bias_ff: bool
bias_attn: bool
layer_scale: float | None
positional_embedding: str
use_conv_block: bool
cross_attention: bool
xa_shared: bool
xa_gating: Literal["none", "sigmoid", "tanh"]
conv_kernel_size: int
use_conv_bias: bool
gating: bool
norm: str
context: int
max_period: int
max_seq_len: int
kv_repeat: int
dim_feedforward: int
conv_layout: bool
img_emb_dim: int | None = None
xa_start: int = 0
@property
def head_dim(self) -> int:
return self.d_model // self.num_heads
================================================
FILE: kyuteye_mlx/kyuteye_mlx/modules/cross_attention.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Gated cross-attention module"""
from typing import Any, Literal
import mlx.core as mx
import mlx.nn as nn
from jaxtyping import BFloat16
from .config import TransformerConfig
from .kv_cache import XACache
class SharedModuleType(type):
"""Wrapper to build shared Pytorch modules"""
_instances = {} # type: ignore
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
cls._instances[cls] = super(SharedModuleType, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class CrossAttention(nn.Module):
def __init__(self, cfg: TransformerConfig):
super().__init__()
self.cfg = cfg
self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.kv_proj = nn.Linear(cfg.d_model, 2 * cfg.d_model, bias=False)
self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias_attn)
self.scale = cfg.head_dim ** (-0.5)
def __call__(
self,
xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"],
xa_cache: XACache,
kv: BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"] | None = None,
) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]:
k, v = xa_cache.state if xa_cache is not None else (None, None)
assert kv is not None or (k is not None and v is not None), (
"Need to provide embeddings or pre-computed keys and values"
)
b, t, hd = xs.shape
q = self.q_proj(xs).reshape(b, self.cfg.num_heads, t, self.cfg.head_dim)
if k is None and v is None:
assert kv is not None, "No image embeds given but also no cache found"
kv = self.kv_proj(kv).reshape(b, -1, 2, self.cfg.num_heads, self.cfg.head_dim)
k = kv[:, :, 0].transpose(0, 2, 1, 3)
v = kv[:, :, 1].transpose(0, 2, 1, 3)
xa_cache.set(k, v)
xs = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
xs = xs.reshape(b, t, hd)
xs = self.out_proj(xs)
return xs
class SharedCrossAttention(CrossAttention, metaclass=SharedModuleType):
"""Shared Cross Attention projection across all layers"""
pass # pylint: disable=unnecessary-pass
class XAGate(nn.Module):
def __init__(
self,
cfg: TransformerConfig,
hidden_dims_factor: float = 0.125,
activation: Literal["tanh", "sigmoid"] = "sigmoid",
conditional_gating: bool = True,
):
super().__init__()
assert conditional_gating
self.dims = cfg.d_model
hidden_dims = int(hidden_dims_factor * self.dims)
self.alpha = nn.Sequential(
nn.Linear(self.dims, hidden_dims, bias=False),
nn.ReLU(),
nn.Linear(hidden_dims, self.dims, bias=False),
)
if activation == "tanh":
self.act = nn.Tanh()
elif activation == "sigmoid":
# shift left to mimic initialization ~ close to 0
self.act = lambda x: mx.sigmoid(x - 4)
else:
raise NotImplementedError("Unknown activation function", activation)
def __call__(
self, xs: BFloat16[mx.array, "batch time {self.dims}"]
) -> BFloat16[mx.array, "batch time {self.dims}"]:
return xs * self.act(self.alpha(xs))
class GatedCrossAttention(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
self.mha = (SharedCrossAttention if cfg.xa_shared else CrossAttention)(cfg)
# Output Gating
self.gate: nn.Module | None = None
if cfg.xa_gating != "none":
self.gate = XAGate(cfg)
def __call__(
self,
xs: BFloat16[mx.array, "batch time features"],
xa_cache: XACache,
kv: BFloat16[mx.array, "batch kv_size features"] | None = None,
) -> BFloat16[mx.array, "batch time features"]:
if kv is None and not xa_cache.is_set:
return xs
xs = self.mha(xs=xs, xa_cache=xa_cache, kv=kv)
if self.gate is not None:
xs = self.gate(xs)
return xs
================================================
FILE: kyuteye_mlx/kyuteye_mlx/modules/kv_cache.py
================================================
# Most of the code below comes from:
# https://github.com/ml-explore/mlx-examples/blob/6c2369e4b97f49fb5906ec46033497b39931b25d/llms/mlx_lm/models/base.py#L1
# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Any, Self
import mlx.core as mx
import numpy as np
class XACache:
def __init__(self) -> None:
self.keys: mx.array | None = None
self.values: mx.array | None = None
self.is_set: bool = False
def set(self, k: mx.array, v: mx.array) -> None:
# See https://github.com/ml-explore/mlx/issues/1918 for an
# explanation of this hack.
self.keys = mx.array(np.array(k.astype(mx.float32))).astype(mx.bfloat16)
self.values = mx.array(np.array(v.astype(mx.float32))).astype(mx.bfloat16)
self.is_set = True
def reset(self) -> None:
self.keys = None
self.values = None
self.is_set = False
@property
def state(self) -> tuple[mx.array | None, mx.array | None]:
return self.keys, self.values
class KVCache:
def __init__(self, head_dim: int | tuple[int, int], n_kv_heads: int) -> None:
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys: mx.array | None = None
self.values: mx.array | None = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys: mx.array, values: mx.array) -> tuple[mx.array, mx.array]:
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
assert self.values is not None
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
assert self.values is not None
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
def reset(self) -> None:
self.offset = 0
@property
def state(self) -> tuple[mx.array | None, mx.array | None]:
return self.keys, self.values
class RotatingKVCache:
def __init__(
self,
head_dim: int | tuple[int, int],
n_kv_heads: int,
max_size: int,
keep: int = 0,
step: int = 256,
) -> None:
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys: mx.array | None = None
self.values: mx.array | None = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size: int, v: mx.array, append: mx.array | None = None) -> mx.array:
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys: mx.array, values: mx.array) -> tuple[mx.array, mx.array]:
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
assert self.values is not None
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
assert self.values is not None
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def reset(self) -> None:
self.offset = 0
self._idx = 0
@property
def state(self) -> tuple[mx.array | None, mx.array | None]:
return self.keys, self.values
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params: dict[str, Any]):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
================================================
FILE: kyuteye_mlx/kyuteye_mlx/modules/transformer.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import mlx.core as mx
import mlx.nn as nn
from jaxtyping import BFloat16
from .config import TransformerConfig
from .cross_attention import GatedCrossAttention
from .kv_cache import KVCache, RotatingKVCache, XACache
class Attention(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
num_kv = cfg.num_heads // cfg.kv_repeat
out_dim = cfg.d_model + 2 * num_kv * cfg.d_model // cfg.num_heads
self.cfg = cfg
self.in_proj = nn.Linear(cfg.d_model, out_dim, bias=cfg.bias_attn)
self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias_attn)
self.scale = cfg.head_dim ** (-0.5)
self.rope = None
if cfg.positional_embedding == "rope":
self.rope = nn.RoPE(cfg.head_dim, traditional=True, base=cfg.max_period)
def __call__(
self,
xs: BFloat16[mx.array, "batch time features"],
cache: KVCache | RotatingKVCache,
mask: BFloat16[mx.array, "batch kv_size features"] | None = None,
) -> BFloat16[mx.array, "batch time features"]:
assert self.cfg.kv_repeat == 1, "only kv_repeat==1 is supported"
b, t, hd = xs.shape
qkv = self.in_proj(xs).reshape(b, t, 3, self.cfg.num_heads, self.cfg.head_dim)
q = qkv[:, :, 0].transpose(0, 2, 1, 3)
k = qkv[:, :, 1].transpose(0, 2, 1, 3)
v = qkv[:, :, 2].transpose(0, 2, 1, 3)
if self.rope is not None:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
k_len = k.shape[2]
k_target_len = t + min(self.cfg.context, k_len - t)
if k_target_len < k_len:
k = k[:, :, k_len - k_target_len :]
v = v[:, :, k_len - k_target_len :]
xs = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
xs = xs.transpose(0, 2, 1, 3).reshape(b, t, hd)
xs = self.out_proj(xs)
return xs
class MlpGating(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
hidden = 2 * cfg.dim_feedforward // 3
if cfg.dim_feedforward == 4 * cfg.d_model:
hidden = 11 * cfg.d_model // 4
self.linear_in = nn.Linear(cfg.d_model, 2 * hidden, bias=cfg.bias_ff)
self.linear_out = nn.Linear(hidden, cfg.d_model, bias=cfg.bias_ff)
def __call__(
self, xs: BFloat16[mx.array, "batch time features"]
) -> BFloat16[mx.array, "batch time features"]:
xs = self.linear_in(xs)
b, t, _ = xs.shape
xs = xs.reshape(b, t, 2, -1)
return self.linear_out(nn.silu(xs[:, :, 0]) * xs[:, :, 1])
class MlpNoGating(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward, bias=cfg.bias_ff)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model, bias=cfg.bias_ff)
def __call__(self, xs: mx.array) -> mx.array:
return self.linear2(nn.gelu_approx(self.linear1(xs)))
class TransformerLayer(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
self.cfg = cfg
assert not cfg.use_conv_block, "conv-block is not supported"
if cfg.gating:
self.gating = MlpGating(cfg)
else:
self.gating = MlpNoGating(cfg)
if cfg.norm == "layer_norm":
self.norm1 = nn.LayerNorm(cfg.d_model, 1e-5)
self.norm2 = nn.LayerNorm(cfg.d_model, 1e-5)
elif cfg.norm == "rms_norm":
self.norm1 = nn.RMSNorm(cfg.d_model, 1e-8)
self.norm2 = nn.RMSNorm(cfg.d_model, 1e-8)
else:
raise ValueError(f"unsupported norm type {cfg.norm}")
self.self_attn = Attention(cfg)
if cfg.cross_attention:
self.cross_attention = GatedCrossAttention(cfg)
if cfg.norm == "rms_norm":
self.norm_cross = nn.RMSNorm(cfg.d_model, 1e-8)
elif cfg.norm == "layer_norm":
self.norm_cross = nn.LayerNorm(cfg.d_model, 1e-5)
else:
raise ValueError(f"unsupported norm type {cfg.norm}")
else:
self.cross_attention = None
self.xa_start = cfg.xa_start
def __call__(
self,
xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"],
cache: KVCache | RotatingKVCache,
img_embeds: (BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"] | None) = None,
xa_cache: XACache | None = None,
) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]:
xs = xs + self.self_attn(self.norm1(xs), cache=cache)
if self.cross_attention is not None:
if xa_cache is None:
raise ValueError("xa_cache should never be None when using cross attention.")
if cache.offset >= self.xa_start:
xs = xs + self.cross_attention(self.norm_cross(xs), xa_cache=xa_cache, kv=img_embeds)
xs = xs + self.gating(self.norm2(xs))
return xs
class ImagePrefix(nn.Module):
def __init__(self, cfg: TransformerConfig) -> None:
super().__init__()
self.cfg = cfg
self.norm_xa = nn.RMSNorm(cfg.d_model, 1e-8)
self.proj_xa = nn.Linear(cfg.img_emb_dim, cfg.d_model, bias=True)
def __call__(
self, xa: BFloat16[mx.array, "batch kv_size {self.cfg.img_emb_dim}"]
) -> BFloat16[mx.array, "batch kv_size {self.cfg.d_model}"]:
xa = self.proj_xa(xa)
xa = self.norm_xa(xa)
return xa
class Transformer(nn.Module):
def __init__(self, cfg: TransformerConfig, with_img_prefix: bool = False) -> None:
super().__init__()
self.cfg = cfg
self.layers = [TransformerLayer(cfg=cfg) for _ in range(cfg.num_layers)]
if with_img_prefix:
self.image_prefix = ImagePrefix(cfg)
def __call__(
self,
xs: BFloat16[mx.array, "batch time {self.cfg.d_model}"],
cache: list[KVCache] | list[RotatingKVCache],
img_embeds: (BFloat16[mx.array, "batch kv_size {self.cfg.img_emb_dim}"] | None) = None,
xa_cache: XACache | None = None,
) -> BFloat16[mx.array, "batch time {self.cfg.d_model}"]:
if img_embeds is not None and xa_cache is not None and not xa_cache.is_set:
img_embeds = self.image_prefix(img_embeds)
else:
img_embeds = None
for layer, c in zip(self.layers, cache):
xs = layer(xs, cache=c, xa_cache=xa_cache, img_embeds=img_embeds)
return xs
def make_cache(self) -> list[KVCache]:
num_kv_heads = self.cfg.num_heads // self.cfg.kv_repeat
return [KVCache(head_dim=self.cfg.head_dim, n_kv_heads=num_kv_heads) for _ in self.layers]
def make_rot_cache(self) -> list[RotatingKVCache]:
num_kv_heads = self.cfg.num_heads // self.cfg.kv_repeat
return [
RotatingKVCache(
head_dim=self.cfg.head_dim,
n_kv_heads=num_kv_heads,
max_size=self.cfg.max_seq_len,
)
for _ in self.layers
]
================================================
FILE: kyuteye_mlx/kyuteye_mlx/py.typed
================================================
================================================
FILE: kyuteye_mlx/kyuteye_mlx/quantize.py
================================================
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "fire",
# "mlx==0.18.1",
# "safetensors >= 0.4.0, < 0.5",
# ]
# ///
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Literal, Optional
import fire
import mlx.core as mx
from mlx import nn
from mlx.utils import tree_flatten
from kyuteye_mlx import models
from kyuteye_mlx.utils.loading import (
remove_shared_weights,
repeat_shared_weights,
split_embedder_weights,
)
def quantize(
model_file: str,
out_file: Optional[str] = None,
img_embed: Literal["siglip", "pixtral"] = "siglip",
bits: int = 8,
group_size: int = 64,
quantize_embedder: bool = False,
) -> None:
if out_file is None:
out_file = model_file.replace(".safetensors", f".q{bits}.safetensors")
weights = mx.load(model_file)
if img_embed == "siglip":
lm_config = models.config_siglip()
if quantize_embedder:
from kyuteye_mlx.kyuteye_mlx.models.siglip import SiglipWrapper
embedder = SiglipWrapper()
else:
embedder = None
else:
lm_config = models.config_pixtral()
if quantize_embedder:
from kyuteye_mlx.kyuteye_mlx.models.pixtral import PixtralWrapper
embedder = PixtralWrapper()
else:
embedder = None
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
weights = repeat_shared_weights(weights, lm_config.transformer.num_layers)
weights, embed_weights = split_embedder_weights(weights)
model.load_weights(list(weights.items()), strict=True)
print("weights loaded")
nn.quantize(model, bits=bits, group_size=group_size)
if quantize_embedder:
assert embedder is not None
nn.quantize(embedder, bits=bits, group_size=group_size)
embed_weights = dict(tree_flatten(embedder.parameters()))
print(f"saving the quantized q{bits} weights in {out_file}")
new_weights = dict(tree_flatten(model.parameters()))
new_weights = remove_shared_weights(new_weights, lm_config.transformer.num_layers)
# Re-adding prefix for consistency with other scripts.
new_weights.update({"img_embedder." + k: v for k, v in embed_weights.items()})
mx.save_safetensors(out_file, new_weights)
def main():
fire.Fire(quantize)
if __name__ == "__name__":
main()
================================================
FILE: kyuteye_mlx/kyuteye_mlx/utils/__init__.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""Utilities."""
from .sampling import Sampler
================================================
FILE: kyuteye_mlx/kyuteye_mlx/utils/loading.py
================================================
import mlx.core as mx
def repeat_shared_weights(weights: dict[str, mx.array], num_layers: int) -> dict[str, mx.array]:
for layer_idx in range(1, num_layers):
for srckey in [
"transformer.layers.0.cross_attention.mha.kv_proj",
"transformer.layers.0.cross_attention.mha.q_proj",
"transformer.layers.0.cross_attention.mha.out_proj",
]:
for subkey in ["weight", "scales", "biases"]:
k = srckey + "." + subkey
if k in weights:
weights[k.replace(".0.", f".{layer_idx}.")] = weights[k]
return weights
def remove_shared_weights(weights: dict[str, mx.array], num_layers: int) -> dict[str, mx.array]:
for layer_idx in range(1, num_layers):
k1 = "transformer.layers.0.cross_attention.mha.kv_proj.weight"
weights.pop(k1.replace(".0.", f".{layer_idx}."))
return weights
def split_embedder_weights(
weights: dict[str, mx.array],
) -> tuple[dict[str, mx.array], dict[str, mx.array]]:
embedder_weights = {}
model_weights = {}
for k, v in weights.items():
if k.startswith("img_embedder."):
embedder_weights[k[len("img_embedder.") :]] = v
else:
model_weights[k] = v
return model_weights, embedder_weights
================================================
FILE: kyuteye_mlx/kyuteye_mlx/utils/profiling.py
================================================
from typing import Callable
import line_profiler
PROFILING_ENABLED = False
profile: line_profiler.LineProfiler | Callable
if PROFILING_ENABLED:
profile = line_profiler.LineProfiler()
else:
def profile(x: Callable) -> Callable:
return x
================================================
FILE: kyuteye_mlx/kyuteye_mlx/utils/sampling.py
================================================
# Taken from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from functools import partial
import mlx.core as mx
from jaxtyping import BFloat16, UInt32
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(
logits: BFloat16[mx.array, "batch vocab"], top_p: float, temperature: float
) -> UInt32[mx.array, "batch"]:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 # noqa
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits: BFloat16[mx.array, "batch vocab"], temp: float) -> UInt32[mx.array, "batch"]:
return mx.random.categorical(logits * (1 / temp))
@dataclass
class Sampler:
temp: float
top_p: float
def __call__(
self, logits: BFloat16[mx.array, "batch vocab"]
) -> tuple[UInt32[mx.array, "batch"], BFloat16[mx.array, "batch vocab"]]:
logit_bias: dict[int, float] | None = None
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
if self.temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if self.top_p > 0 and self.top_p < 1.0:
token = top_p_sampling(logits, self.top_p, self.temp)
else:
token = categorical_sampling(logits, self.temp)
logprobs = logits - mx.logsumexp(logits)
return token, logprobs
================================================
FILE: kyuteye_mlx/pixtral-12b-8bit.config
================================================
{"model_type": "pixtral", "num_hidden_layers": 24, "hidden_size": 1024, "head_dim": 64, "intermediate_size": 4096, "num_attention_heads": 16, "image_size": 1024, "patch_size": 16, "projection_dim": 768, "vocab_size": 32000, "num_channels": 3, "rms_norm_eps": 1e-05, "rope_theta": 10000.0}
================================================
FILE: kyuteye_mlx/pyproject.toml
================================================
[project]
name = "kyuteye_mlx"
requires-python = ">= 3.10,<3.13"
description = "Kyutai with an 'eye', but running on macOS"
dependencies = [
"numpy >= 2.1.0, < 2.2",
"safetensors >= 0.4.0, < 0.5",
"huggingface-hub >= 0.24, < 0.25",
"rustymimi == 0.2.2",
"sentencepiece == 0.2",
"sounddevice == 0.5",
"sphn >= 0.1.4",
# Do not change this version of mlx. All the others up
# to 0.22.1 are slower for this codebase.
"mlx==0.23.1",
"aiohttp>=3.10.5, <3.11",
"pillow",
"line-profiler>=4.2.0",
"rich>=13.9.4",
"packaging>=24.2",
"jaxtyping==0.3.0",
"beartype>=0.19.0",
"fire>=0.7.0",
]
authors = [
{ name = "Gabriel de Marmiesse", email = "gabriel@kyutai.org" },
{ name = "Moritz Boehle", email = "moritz@kyutai.org" }
]
license = {text = "MIT"}
dynamic = ["version"]
readme = "README.md"
[project.scripts]
server = "kyuteye_mlx.local_web:main"
sanity-check = "kyuteye_mlx.local_web:sanity_check"
quantize = "kyuteye_mlx.quantize:main"
benchmark = "kyuteye_mlx.benchmark:main"
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.dynamic]
version = {attr = "kyuteye_mlx.__version__"}
[tool.ruff]
line-length = 110
[tool.mypy]
ignore_missing_imports = true
disallow_untyped_defs = true
[[tool.mypy.overrides]]
module = "kyuteye_mlx.mlx_vlm.*"
disallow_untyped_defs = false
[dependency-groups]
dev = [
"mypy>=1.11.2",
"pylint>=3.3.4",
"pytest>=8.3.4",
"torch>=2.3.0",
"moshi==0.1.0",
"ruff>=0.9.7",
"monkeytype>=23.3.0",
]
================================================
FILE: kyuteye_mlx/siglip448.config
================================================
{"model_type": "siglip_vision_model", "num_hidden_layers": 27, "hidden_size": 1152, "intermediate_size": 4304, "num_attention_heads": 16, "patch_size": 14, "projection_dim": 2304, "image_size": 448, "num_channels": 3, "layer_norm_eps": 1e-06}
================================================
FILE: kyuteye_mlx/tests/test_siglip.py
================================================
import mlx.core as mx
import numpy as np
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from kyuteye_mlx.models.siglip import SiglipWrapper
def convert_weights_for_mlx(weights: dict[str, torch.Tensor]) -> dict[str, mx.array]:
new_weights = {}
for k, v in weights.items():
new_key = k.removeprefix("vision_model.")
new_key = "model." + new_key
new_weights[new_key] = mx.array(v)
new_weights["model.embeddings.patch_embedding.weight"] = new_weights[
"model.embeddings.patch_embedding.weight"
].transpose(0, 2, 3, 1)
return new_weights
@torch.no_grad()
def test_siglip_weights_conversion() -> None:
model_id = "google/paligemma2-3b-pt-448"
processor = AutoProcessor.from_pretrained(model_id)
image_processor = processor.image_processor
model = AutoModelForImageTextToText.from_pretrained(model_id)
np.random.seed(99)
img = np.random.randint(0, 255, size=(448, 448, 3), dtype=np.uint8)
inp = image_processor(images=img, return_tensors="pt")
out_pytorch = model.vision_tower(pixel_values=inp["pixel_values"])
out_pytorch_as_np = out_pytorch.last_hidden_state.detach().numpy()
as_mlx_weights = convert_weights_for_mlx(model.vision_tower.state_dict())
siglip_wrapper = SiglipWrapper()
siglip_wrapper.load_weights(list(as_mlx_weights.items()), strict=True)
output = siglip_wrapper(mx.array(img))
out_mlx_as_np = np.array(output, copy=False)
diff_average = np.mean(np.abs(out_pytorch_as_np - out_mlx_as_np))
assert diff_average < 2e-5, f"Average difference between the two embeddings is {diff_average}"
================================================
FILE: kyuteye_pt/.pylintrc
================================================
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked and
# will not be imported (useful for modules/projects where namespaces are
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.10
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
#verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1200
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
too-many-locals,
unspecified-encoding,
too-many-arguments,
too-many-instance-attributes,
too-many-branches,
too-many-statements,
too-many-return-statements,
too-many-public-methods,
too-few-public-methods,
use-dict-literal,
unnecessary-lambda-assignment,
too-many-function-args
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=12
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
================================================
FILE: kyuteye_pt/LICENSE.md
================================================
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
================================================
FILE: kyuteye_pt/README.md
================================================
# MoshiVis - PyTorch
See the [top-level README.md][main_repo] for more information on MoshiVis.
This is the PyTorch implementation for MoshiVis.
## License
The present code is provided under the MIT license.
## Citation
If you use MoshiVis, please cite this repository and the Moshi paper.
```
@article{kyutai2025moshivis,
author = {Amélie Royer and Moritz Böhle and Gabriel de Marmiesse and
Laurent Mazaré and Alexandre Défossez and Neil Zeghidour and Patrick Pérez},
year = {2025},
title = {Vision-Speech Models: Teaching Speech Models to Converse about Images},
journal = {ArXiv},
url = {https://arxiv.org/abs/2503.15633}
}
@techreport{kyutai2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and
Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
year={2024},
eprint={2410.00037},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2410.00037},
}
```
[main_repo]: https://github.com/kyutai-labs/moshivis
================================================
FILE: kyuteye_pt/configs/moshika-vis.yaml
================================================
add_boi_eoi: false
align_img_and_speech_tokens_dim: true
encoder_name: siglip_gemma2_448
hf_repo: 'kyutai/moshika-vis-pytorch-bf16'
image_size: 512
interpolation: bicubic
mimi_codec: tokenizer-e351c8d8-checkpoint125.safetensors
model: model.safetensors
norm_extra: null
norm_xa: rms_norm
num_crossattended_tokens: -1
num_extra_tokens: 0
text_tokenizer: tokenizer_spm_32k_3.model
xa_conditional_gating: true
xa_delay: 0
xa_dim: null
xa_start: start
xa_end: end
xa_gate_shared: false
xa_gating: sigmoid
xa_layers: []
xa_shared: true
xa_step: 1
================================================
FILE: kyuteye_pt/kyuteye/__init__.py
================================================
================================================
FILE: kyuteye_pt/kyuteye/config/__init__.py
================================================
================================================
FILE: kyuteye_pt/kyuteye/config/enums.py
================================================
"""Knowledge base of useful fixed values used across the codebase"""
from enum import Enum, unique
from typing import List, Optional, Tuple
@unique
class ImageEncoder(Enum):
"""Encapsulate every image encoder"""
CLIP_VIT = "clip_vit"
CLIP_VIT_LARGE = "clip_vit_large"
MOBILECLIP_S1 = "mobileclip_s1"
MOBILECLIP_S2 = "mobileclip_s2"
# original siglip
SIGLIP = "siglip"
# Siglip Pretrained: Phase 2 of PaliGemma 1 training
SIGLIP_GEMMA1_224 = "siglip_gemma1_224"
# Same but for PaliGemma 2
SIGLIP_GEMMA2_224 = "siglip_gemma2_224"
SIGLIP_GEMMA2_448 = "siglip_gemma2_448"
SIGLIP_GEMMA2_896 = "siglip_gemma2_896"
PIXTRAL = "pixtral"
@property
def out_dims(self) -> int:
"""Return the number of dimensions output by the given image encoder"""
if self in {
ImageEncoder.PIXTRAL,
ImageEncoder.CLIP_VIT_LARGE,
ImageEncoder.MOBILECLIP_S1,
}:
return 1024
if self == ImageEncoder.CLIP_VIT:
return 512
if self == ImageEncoder.MOBILECLIP_S2:
return 1280
if self in {
ImageEncoder.SIGLIP,
ImageEncoder.SIGLIP_GEMMA1_224,
ImageEncoder.SIGLIP_GEMMA2_224,
ImageEncoder.SIGLIP_GEMMA2_448,
ImageEncoder.SIGLIP_GEMMA2_896,
}:
return 1152
raise NotImplementedError("Unknown image encoder", self.name)
def to_rust(self) -> str:
"""Return the corresponding `ImageEncoder` enum name in the rust codebase"""
if self == ImageEncoder.PIXTRAL:
return "Pixtral"
if self in {
ImageEncoder.SIGLIP,
ImageEncoder.SIGLIP_GEMMA1_224,
ImageEncoder.SIGLIP_GEMMA2_224,
}:
return "Siglip224"
if self == ImageEncoder.SIGLIP_GEMMA2_448:
return "Siglip448"
if self == ImageEncoder.SIGLIP_GEMMA2_896:
return "Siglip896"
if self == ImageEncoder.MOBILECLIP_S1:
return "MobileclipS1"
if self == ImageEncoder.MOBILECLIP_S2:
return "MobileclipS2"
raise ValueError(
f"Image encoder {self.name} is not implemented in the Rust codebase"
)
================================================
FILE: kyuteye_pt/kyuteye/config/kyuteye_config.py
================================================
"""Main configuration object used to configure the model and training pipeline"""
import os
from copy import deepcopy
from dataclasses import asdict, fields
from pathlib import Path
from typing import Any, Dict, Optional, Sequence
import torch
import yaml
from kyuteye.config.enums import ImageEncoder
from kyuteye.config.subconfigs import (
FusionConfig,
ImageEncoderConfig,
LMConfig,
MoshiConfig,
)
from kyuteye.utils.dist_utils import print_main
from kyuteye.utils.logging_utils import flatten_nested_dict
class KyuteyeConfig:
"""Base class encapsulating all options for training and evaluating a multimodal model"""
fuse: FusionConfig
image: ImageEncoderConfig
lm: LMConfig
def __init__(self, **kwargs: Any):
self._fields_to_sub: Dict[str, str] = {}
# Define all modular subconfigs defined in subconfigs.py
backup_kwargs = deepcopy(kwargs)
self._subnames = []
for name, constructor in [
("fuse", FusionConfig),
("image", ImageEncoderConfig),
("lm", LMConfig),
("moshi", MoshiConfig),
]:
keys = {f.name for f in fields(constructor)}
setattr(
self,
name,
constructor(
**{k: backup_kwargs.pop(k) for k in keys if k in backup_kwargs}
),
)
self._subnames.append(name)
# Extra arguments comming from the CLI argparser will be passed but not used
# we still mark them to remember them
if len(backup_kwargs):
print_main(
"\n[bold yellow]WARN:[/bold yellow] Found superfluous arguments "
"passed to [cyan]KyuteyeConfig[/cyan]:\n"
+ "\n".join(f" - {k} = {v}" for k, v in backup_kwargs.items())
+ "\n",
rich=True,
flush=True,
)
# Map field names to the correct subconfig it belongs to
for name in self._subnames:
for f in fields(getattr(self, name)):
if f.name in self._fields_to_sub:
raise AssertionError(
f"Subconfig {name} uses field {f.name} which is already in use"
)
if f.name in self._subnames:
raise AssertionError(
f"Subconfig {name} uses field {f.name} which"
" is already defined in KyuteyeConfig"
)
self._fields_to_sub[f.name] = name
# Postinit
if self.fuse.num_crossattended_tokens == 0:
self.image.norm_xa = None
if self.fuse.num_extra_tokens == 0:
self.image.norm_extra = None
if self.fuse.xa_dim is None:
if not self.fuse.align_img_and_speech_tokens_dim:
self.fuse.xa_dim = ImageEncoder(self.image.encoder_name).out_dims
def __getattribute__(self, name: str) -> Any:
"""Getattr with direct shortcut to all subconfigs' subfields"""
if name not in ["_fields_to_sub", "_subnames"] and name in self._fields_to_sub:
return getattr(getattr(self, self._fields_to_sub[name]), name)
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
"""Setattr with direct shortcut to all subconfigs' fields"""
if hasattr(self, "_fields_to_sub") and name in self._fields_to_sub:
setattr(getattr(self, self._fields_to_sub[name]), name, value)
else:
super().__setattr__(name, value)
@property
def moshi_constructor_kwargs(self) -> Dict[str, Any]:
"""Return constructor for constructing a MoshiVis model"""
return dict(
**asdict(self.moshi),
xa_dim=self.fuse.xa_dim,
**self.fuse.crossattention_kwargs,
)
@classmethod
def from_yml(cls, path: Path | str) -> "KyuteyeConfig":
"""Initialize current config from a yaml file"""
return cls(**__load_yaml__(str(path)))
def to_yml(self, path: Optional[Path | str] = None) -> None:
"""Save current config to a yaml file"""
if path is None:
path = self.output_dir
path = str(path)
__save_yaml__(
{
k: (
v.name
if hasattr(v, "name")
else (
str(v).replace("torch.", "")
if isinstance(v, torch.dtype)
else (
tuple(x.name for x in v)
if k in {"train_dataset", "eval_dataset", "blind_dataset"}
else v
)
)
)
for k, v in self.to_dict().items()
},
path,
)
def print(self, flat: bool = False, only: Optional[Sequence[str]] = None) -> None:
"""Pretty print current config"""
print_main("-" * 100, "[bold green]Config[/bold green]", rich=True)
if flat:
print_main(
"\n".join(
f"\t{k} = {v}"
for k, v in self.to_dict(flat=True).items()
if (only is None or k in only)
),
rich=True,
)
else:
for name, subd in self.to_dict(flat=False).items():
if only is not None and name not in only:
continue
print_main("\n\t", "-" * 50, f"[cyan]{name}[/cyan]", rich=True)
print_main(
"\n".join(f"\t\t{k} = {v}" for k, v in subd.items()),
rich=True,
)
print_main("-" * 100)
def to_dict(self, flat: bool = True) -> Dict[str, Any]:
"""Returns current config as a (flat) dictionary"""
d = {
subconfig: asdict(getattr(self, subconfig)) for subconfig in self._subnames
}
if flat:
return flatten_nested_dict(d)
return d
def __load_yaml__(path: Path | str) -> Dict:
"""Load a dictionary from a YAML file
:param path: Path to load the config from
:return: the dictionary of kwargs that will be fed to `KyuteyeConfig`
"""
path = str(path)
assert os.path.exists(path), f"Could not load config {path}: File does not exist"
with open(path, "r") as stream:
config = yaml.safe_load(stream)
# KyuteyeConfig works with flattened dict as inputs
config = flatten_nested_dict(config)
# Yaml parse sequences sa list -> tuples
config = {k: tuple(v) if isinstance(v, list) else v for k, v in config.items()}
return config
def __save_yaml__(config: Dict, path: Path | str) -> None:
"""Save a config dictionary to a YAML
:param config: Kyuteye config converted to a dict
:param path: Path to save the config to
"""
path = str(path)
path = path + (".yml" if not path.endswith(".yml") else "")
base_dir = os.path.abspath(os.path.dirname(path))
os.makedirs(base_dir, exist_ok=True)
with open(path, "w") as stream:
config = yaml.safe_dump(config, stream)
================================================
FILE: kyuteye_pt/kyuteye/config/subconfigs.py
================================================
"""Modular configs for configuring a Kyuteye model training run"""
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Tuple
from kyuteye.config.enums import ImageEncoder
from kyuteye.utils.dist_utils import print_main
def __is_nonstring_iterable__(arg: Any) -> bool:
return isinstance(arg, Iterable) and not isinstance(arg, str)
@dataclass(frozen=True)
class LMConfig:
"""Configure the model weights"""
# if repo is None, then model, mimi_codec and tokenizer are
# expected to point to local files
hf_repo: Optional[str] = "kyutai/moshika-vis-pytorch-bf16"
# MoshiVis model; Note that if the model doesn't contain weights
# for the frozen image encoder, those will be loaded from audiocraft
# directly
model: str = "model.safetensors"
# Mimi codec
mimi_codec: str = "tokenizer-e351c8d8-checkpoint125.safetensors"
# Tokenizer for Helium
text_tokenizer: str = "tokenizer_spm_32k_3.model"
@staticmethod
def help(field_name: str) -> str:
"""Optional; returns the argparse's help message for each field name
in this subconfig"""
if field_name == "model_path":
return (
"Path to .safetensors containing model weights (vision encoder + LLM)"
)
if field_name == "mimi_codec":
return "Path to .safetensors containing Mimi weights"
if field_name == "text_tokenizer":
return "Path to .safetensors containing text tokenizer"
return ""
@dataclass(frozen=False)
class ImageEncoderConfig:
"""Configure the image encoder for MoshiVis"""
# Main image backbone to load
encoder_name: str = "pixtral"
interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = (
"bicubic"
)
# Whether to add dropout after the learned image linear projection
image_size: int = 256
# normalization used after projecting the extra tokens
norm_extra: Optional[Literal["layer_norm", "rms_norm"]] = "rms_norm"
# normalization used after projecting the xa tokens
norm_xa: Optional[Literal["layer_norm", "rms_norm"]] = "rms_norm"
def __post_init__(self) -> None:
assert self.image_size > 0
self.encoder_name = self.encoder_name.lower()
try:
ImageEncoder(self.encoder_name)
except ValueError as e:
raise ValueError(f"Unknown image encoder {self.encoder_name}") from e
self.aug_strategy = "Pixtral" if self.encoder_name == "pixtral" else "None"
@staticmethod
def help(field_name: str) -> str:
"""Optional; returns the argparse's help message for each field name
in this subconfig"""
if field_name == "encoder_name":
return "Name of the pretrained image encoder to use"
if field_name == "interpolation":
return "Interpolation algorithm for image resizing"
if field_name == "image_size":
return "Input image size to the encoder"
return ""
@dataclass(frozen=True)
class MoshiConfig:
"""Configure the backbone Moshi"""
dim: int = 4096
text_card: int = 32000
padding_token_id: int = 3
n_q: int = 16
dep_q: int = 8
audio_card: int = 2048
num_heads: int = 32
num_layers: int = 32
hidden_scale: float = 4.125
causal: bool = True
context: int = 3000
max_period: int = 10000
gating: bool = True
activation: str = "silu"
norm: str = "rms_norm_f32"
positional_embedding: str = "rope"
depformer: bool = True
depformer_dim: int = 1024
depformer_dim_feedforward: int = int(4.125 * 1024)
depformer_num_heads: int = 16
depformer_num_layers: int = 6
depformer_multi_linear: bool = True
depformer_context: int = 8
depformer_gating: bool = True
depformer_activation: str = "silu"
depformer_pos_emb: str = "none"
depformer_weights_per_step: bool = True
delays: Tuple[int, ...] = (0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1)
@staticmethod
def help(field_name: str) -> str:
"""Optional; returns the argparse's help message for each field name
in this subconfig"""
if field_name == "model_path":
return (
"Path to .safetensors containing model weights (vision encoder + LLM)"
)
if field_name == "mimi_codec":
return "Path to .safetensors containing Mimi weights"
if field_name == "text_tokenizer":
return "Path to .safetensors containing text tokenizer"
return ""
@dataclass(frozen=False)
class FusionConfig:
"""Configures how we integrate the image information in MoshiVis"""
num_extra_tokens: int = -1
add_boi_eoi: bool = False
token_insertion: Optional[Literal["prefix"]] = "prefix"
num_crossattended_tokens: int = 0
align_img_and_speech_tokens_dim: bool = True
xa_dim: Optional[int] = None
xa_start: Optional[Literal["start", "boi", "eoi"]] = None
xa_end: Optional[Literal["end", "eoi"] | int] = None
xa_step: int = 1
xa_delay: int = 0
xa_shared: bool = True
xa_gate_shared: bool = False
xa_gating: Literal["tanh", "sigmoid", "none"] = "tanh"
xa_conditional_gating: bool = False
xa_layers: Tuple[int, ...] = ()
xa_extended_layer_dims: Optional[int] = None
xa_extended_layer_embed_dims: int = 1024
xa_extended_layer_p_norm: int = 2
@staticmethod
def help(field_name: str) -> str:
"""Optional; returns the argparse's help message for each field name
in this subconfig"""
if field_name == "num_extra_tokens":
return (
"Number of extra tokens (containing image information) to insert into"
"the text+audio tokens. If -1, uses all the image tokens."
)
if field_name == "token_insertion":
return (
"An option used later in `fuse_utils` to determine where"
" to insert the extra tokens"
)
if field_name == "num_crossattended_tokens":
return (
"Number of extra tokens (containing image information) to use as keys/values"
" source in the cross-attention mechanism. If -1, uses all the image tokens"
)
if field_name == "xa_shared":
return (
"If True, the linear layers of the cross-attention mechanism"
"are shared across layers"
)
if field_name == "xa_gate_shared":
return (
"If True, the gating modules of the cross-attention mechanism"
"are shared across layers"
)
if field_name == "xa_gating":
return "Type of multiplicative gating at the output of each cross-attention"
if field_name == "xa_conditional_gating":
return "Whether to make the gating input dependent"
if field_name == "xa_layers":
return (
"If given, specifies the subset of layers to apply cross-attention to"
)
if field_name == "xa_start":
return "At which token to start applying the cross-attention"
if field_name == "xa_end":
return "Until which token to apply the cross-attention"
if field_name == "xa_step":
return "Applies cross-attention every `xa_step` token between xa_start and xa_end"
if field_name == "xa_delay":
return "Applies cross-attention to token t with the embedding from token t - xa_delay"
if field_name == "xa_extended_layer_dims":
return "By how many dimensions to extend the layers in xa_layers."
return ""
def __post_init__(self) -> None:
"""Check that the subconfig is valid"""
if self.num_extra_tokens == 0 and self.num_crossattended_tokens == 0:
print_main("[bold yellow]WARN:[/bold yellow]: No fusion mechanisms active")
if not self.align_img_and_speech_tokens_dim and self.xa_dim is not None:
print_main(
"[bold yellow]WARN:[/bold yellow]: xa_dim is set hence it "
"takes precedence over the value fed to align_img_and_speech_tokens"
)
# Convert
if isinstance(self.xa_layers, list):
self.xa_layers = tuple(self.xa_layers)
if self.xa_end is not None and self.xa_end not in {"end", "eoi"}:
self.xa_end = int(self.xa_end)
if self.num_extra_tokens == 0:
self.add_boi_eoi = False
if self.num_crossattended_tokens == 0:
self.xa_start = None
self.xa_end = None
else:
assert (
self.xa_start is not None
), "xa_start should be given to use cross-attention"
assert (
self.xa_end is not None
), "xa_end should be given to use cross-attention"
if self.num_extra_tokens == 0:
assert (
self.xa_start
not in {
"boi",
"eoi",
}
and self.xa_end != "eoi"
), "BoI and EoI tokens are not inserted if `num_extra_tokens` is 0"
elif self.xa_start != "start":
print_main(
"[yellow]WARN:[/yellow]Making cross_attention start before BoI is weird",
rich=True,
)
if self.xa_step > 1:
raise NotImplementedError("xa_step > 1 is not implemented yet")
if self.xa_extended_layer_dims is not None:
assert (
self.xa_extended_layer_dims > 0
), "xa_extended_layer_dims should be positive"
return
@property
def crossattention_kwargs(self) -> Dict[str, Any]:
"""Return crossattention related kwargs used for model initialization.
These are passed to modules/cross_attention.py:GatedCrossAttention
down the line"""
return {
# passed to Moshi/Helium constructor
"cross_attention": self.num_crossattended_tokens != 0,
"xa_layers": self.xa_layers,
# Futher passed to GatedCrossAttention construction
"xa_gating": self.xa_gating,
"xa_conditional_gating": self.xa_conditional_gating,
"xa_shared": self.xa_shared,
"xa_gate_shared": self.xa_gate_shared,
"xa_start": self.xa_start,
"xa_end": self.xa_end,
"xa_step": self.xa_step,
"xa_delay": self.xa_delay,
}
================================================
FILE: kyuteye_pt/kyuteye/models/__init__.py
================================================
================================================
FILE: kyuteye_pt/kyuteye/models/docker-bake.hcl
================================================
================================================
FILE: kyuteye_pt/kyuteye/models/helium.py
================================================
# pylint: disable=redefined-outer-name, pointless-string-statement
"""Port of Helium from Jax to Pytorch and then HF.
The architecture is also made to be easily converted to the mimi/audiocraft codebase"""
from typing import Any, Literal, Optional, Tuple
import torch
from kyuteye.modules.transformer import Transformer
from kyuteye.modules.utils import ClampedEmbedding, NormalizationLayer
class Helium(torch.nn.Module):
"""Jax -> Pytorch port of Helium (text LLM)
:param dim: Inner dimension of the tokens
:param num_head: Number of heads
:param num_layers: Number of layers
:param hidden_scale: Scale for the inner dimension of the FFN layes (wrt. `dim`)
:param context: Context size for the model. This is only used to determine the automatic
causal mask and also set the KV cache accordingly at inference
:param cross_attention: Whether to add cross-attention layers
:param card: Cardinality of the vocabulary
:param output_card: (Optional) can be used to specify a different number of output tokens
than card (which is used for the embedding layers). This is useful for audio
models that add an extra *initial token* which is never predicted
:param norm: Type of normalization layers to use
:param positional_embedding: Type of positional embedding to use
:param max_period: Maximum period / theta for RoPE embeddings
:param causal: Whether to automatically force a causal mask in MHSA
:param gating: If True, will use gated FFN (Swi-GLU like)
:param activation: Activation to use for the FFN gating, if `gating` is True
:param padding_token_id: Padding token ID of the tokenizer
:param freeze_padding_embedding: If True, the embedding of the padding token ID
will not receive gradients/updates during training
:param device: Device to load the model on
:param dtype: Dtype to define the model parameters as
"""
def __init__(
self,
# Architecture
dim: int,
num_heads: int,
num_layers: int,
hidden_scale: float = 4.125,
context: int = 2048,
cross_attention: bool = False,
card: int = 32000,
output_card: Optional[int] = None,
norm: Literal[
"rms_norm",
"rms_norm_f32",
"real_rms_norm",
"real_rms_norm_f32",
] = "real_rms_norm_f32",
positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "rope",
max_period: float = 10000,
causal: bool = True,
gating: bool = True,
activation: str = "silu",
# Padding token
padding_token_id: int = 3,
freeze_padding_embedding: bool = False,
zero_token_id: int = -1,
# Other kwargs
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs: Any,
):
super().__init__()
# Initial tokens projection
self.dim = dim
self.text_emb = ClampedEmbedding(
card,
dim,
padding_idx=padding_token_id if freeze_padding_embedding else None,
zero_idx=zero_token_id,
dtype=dtype,
device=device,
)
# Output RMS Norm and linear layer
self.out_norm = getattr(NormalizationLayer, norm.upper()).create_norm_fn(
dim, device=device, dtype=dtype
)
self.text_linear = torch.nn.Linear(
dim, card if output_card is None else output_card, bias=False
)
self.cross_attention = cross_attention
# Main transformer
self.transformer = Transformer(
# Architecture
d_model=dim,
num_heads=num_heads,
num_layers=num_layers,
dim_feedforward=int(hidden_scale * dim),
causal=causal,
context=context,
cross_attention=cross_attention,
positional_embedding=positional_embedding,
max_period=max_period,
# Transformer Layer kwargs
gating=gating,
activation=activation,
norm=norm,
# Others
device=device,
dtype=dtype,
**kwargs,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_src: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
return_features: bool = False,
) -> Tuple[torch.Tensor, float]:
"""Forward function
:param input_ids: Input tokens IDs with size `(batch_size, seq_length)`
:param inputs_embeds: If given, this will skip input_ids and the initial embedding phase.
Instead, `inputs_embeds` are directly fed to the transformer
:param attention_mask: 0/1 Attention mask with shape `(batch_size, seq_length)`.
Indicates which tokens to mask in the attention, e.g. padding tokens
:param cross_attention_src: Cross-attention source with shape
`(batch_size, seq_length, dim)`
:param cross_attention_mask: Cross-attention mask with shape `(batch_size, seq_length)`
:param return_features: If True, will skip the last classification layer and only
output features with dimensions `(batch_size, seq_length, dim)`
"""
if inputs_embeds is None:
assert input_ids is not None
inputs_embeds = self.text_emb(input_ids)
x, gate_weight = self.transformer(
inputs_embeds,
attention_mask=attention_mask,
cross_attention_src=cross_attention_src,
cross_attention_mask=cross_attention_mask,
)
x = self.out_norm(x)
if return_features:
return x, gate_weight
return self.text_linear(x), gate_weight
================================================
FILE: kyuteye_pt/kyuteye/models/hf_model_configs.py
================================================
# pylint: disable=protected-access
"""Configuration for HF-compliant models"""
from typing import Any, Literal, Optional, Sequence
from transformers import PretrainedConfig
class HeliumConfig(PretrainedConfig):
"""Config class for Helium language models (LLM part)"""
model_type = "Helium_v2"
def __init__(
self,
dim: int = 1024,
num_heads: int = 12,
num_layers: int = 24,
hidden_scale: int = 4,
context: int = 2048,
cross_attention: bool = False,
card: int = 32000,
output_card: Optional[int] = None,
norm: Literal[
"rms_norm",
"rms_norm_f32",
"real_rms_norm",
"real_rms_norm_f32",
] = "real_rms_norm_f32",
positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "rope",
freeze_padding_embedding: bool = False,
max_period: float = 10000,
causal: bool = True,
gating: bool = True,
activation: Literal[
"none",
"identity",
"sigmoid",
"tanh",
"relu",
"leaky_relu",
"elu",
"gelu",
"silu",
"mish",
"softsign",
] = "silu",
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: int = 3,
**kwargs: Any,
):
super().__init__(
vocab_size=32000,
hidden_size=dim,
num_attention_heads=num_heads,
num_hidden_layers=num_layers,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.hidden_scale = hidden_scale
self.context = context
self.cross_attention = cross_attention
self.positional_embedding = positional_embedding
self.card = card
self.output_card = output_card
self.norm = norm
self.max_period = max_period
self.causal = causal
self.gating = gating
self.activation = activation
self.freeze_padding_embedding = freeze_padding_embedding
class MoshiVisConfig(HeliumConfig):
"""Config for Moshi-Vis"""
model_type = "Moshi_v1"
def __init__(
self,
text_card: int = 32000,
text_context: int = 3000,
n_q: int = 8,
n_q_per_source: Optional[int] = None,
audio_card: int = 1024,
depformer: bool = False,
depformer_multi_linear: bool = False,
depformer_pos_emb: Optional[Literal["none", "sin", "rope", "sin_rope"]] = None,
depformer_dim: Optional[int] = None,
depformer_dim_feedforward: Optional[int] = None,
depformer_num_layers: Optional[int] = None,
depformer_num_heads: Optional[int] = None,
depformer_weights_per_step: bool = False,
depformer_input_cumsum: bool = False,
depformer_skip_self_attn: bool = False,
delays: Optional[Sequence[int]] = None,
same_initial: bool = False,
text_loss_weight: float = 1.0,
audio_loss_weight: float = 1.0,
audio_other_channel_loss_weight: float = 1.0,
audio_semantic_loss_weight: float = 100.0,
audio_acoustic_loss_weight: float = 1.0,
padding_loss_weight: float = 1.0,
textonly_padding_loss_weight: float = 1.0,
audio_padding_loss_weight: float = 1.0,
sparsity_loss_weight: float = 0,
mask_audio_codebooks_from: int = -1,
add_pad_embed_text_only_other_tokens: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
**kwargs,
)
# rename some args from HeliumConfig to explicitly
# distinguish audio from text
del self.card
self.text_card = text_card
del self.context
self.text_context = text_context
# add Audio config
self.n_q = n_q
self.n_q_per_source = n_q_per_source or n_q
self.audio_card = audio_card
self.depformer = depformer
self.depformer_multi_linear = depformer_multi_linear
self.depformer_pos_emb = depformer_pos_emb
self.depformer_dim = depformer_dim
self.depformer_dim_feedforward = depformer_dim_feedforward
self.depformer_num_layers = depformer_num_layers
self.depformer_num_heads = depformer_num_heads
self.depformer_weights_per_step = depformer_weights_per_step
self.depformer_input_cumsum = depformer_input_cumsum
self.depformer_skip_self_attn = depformer_skip_self_attn
self.delays = list(delays) if delays is not None else None
self.same_initial = same_initial
# loss weights for the different codebbooks
self.text_loss_weight = text_loss_weight
self.audio_loss_weight = audio_loss_weight
self._audio_other_channel_loss_weight = audio_other_channel_loss_weight
self._audio_semantic_loss_weight = audio_semantic_loss_weight
self._audio_acoustic_loss_weight = audio_acoustic_loss_weight
self.padding_loss_weight = padding_loss_weight
self.textonly_padding_loss_weight = textonly_padding_loss_weight
self.audio_padding_loss_weight = audio_padding_loss_weight
self._sparsity_loss_weight = sparsity_loss_weight
self.mask_audio_codebooks_from = mask_audio_codebooks_from
self.add_pad_embed_text_only_other_tokens = add_pad_embed_text_only_other_tokens
if self.mask_audio_codebooks_from >= 0:
self.num_audio_tokens_in_loss_main = self.mask_audio_codebooks_from
self.num_audio_tokens_in_loss_other = self.mask_audio_codebooks_from
else:
self.num_audio_tokens_in_loss_main = self.n_q_per_source - 1
self.num_audio_tokens_in_loss_other = (self.n_q - self.n_q_per_source) - 1
@property
def total_audio_loss_weight(self) -> float:
"""Total weight used to normalize the losses on audio tokens"""
return (
# weight for Moshi
self._audio_semantic_loss_weight
+ self._audio_acoustic_loss_weight * self.num_audio_tokens_in_loss_main
# weight for Other
+ self._audio_other_channel_loss_weight
* (
self._audio_semantic_loss_weight
+ self._audio_acoustic_loss_weight * self.num_audio_tokens_in_loss_other
)
)
@property
def audio_semantic_loss_weight(self) -> float:
"""Loss weight set on the semantic token"""
return (
self.audio_loss_weight
* self._audio_semantic_loss_weight
/ (self.total_audio_loss_weight + 1e-6)
)
@property
def audio_acoustic_loss_weight(self) -> float:
"""Loss weight set on the semantic token"""
return (
self.audio_loss_weight
* self._audio_acoustic_loss_weight
/ (self.total_audio_loss_weight + 1e-6)
)
@property
def audio_other_semantic_loss_weight(self) -> float:
"""Loss weight set on audio codebooks for OTHER channel"""
return self._audio_other_channel_loss_weight * self.audio_semantic_loss_weight
@property
def audio_other_acoustic_loss_weight(self) -> float:
"""Loss weight set on audio codebooks for OTHER channel"""
return self._audio_other_channel_loss_weight * self.audio_acoustic_loss_weight
@property
def sparsity_loss_weight(self) -> float:
"""Loss weight set on the sparsity loss for the extended transformer."""
return self._sparsity_loss_weight
================================================
FILE: kyuteye_pt/kyuteye/models/image_projection.py
================================================
"""Image encoders (CLIP, SigLIP)"""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
from einops import rearrange
from kyuteye.config.enums import ImageEncoder
from kyuteye.modules.image_encoder import (
PixtralOutput,
get_img_normalize,
load_image_encoder,
)
from kyuteye.modules.utils import NormalizationLayer
if TYPE_CHECKING:
from kyuteye.config.kyuteye_config import KyuteyeConfig
class ImageProjection(torch.nn.Module):
"""
Takes in a batch of images and returns a batch of embeddings of the
same dimensions, which are then fed to the LM, either by inserting
the tokens in the stream, or by using them as source for cross-attention
moduels (or both).
:param config: KyuteyeConfig object
:param lm_model_dim: Output dimension (number of channels) for this module
"""
def __init__(
self,
kyuteye_config: "KyuteyeConfig",
lm_model_dim: Optional[int],
load_pretrained_encoder: bool = True,
):
super().__init__()
self.kyuteye_config = kyuteye_config
try:
self.encoder_type = getattr(
ImageEncoder, self.kyuteye_config.image.encoder_name.upper()
)
except AttributeError as e:
raise NotImplementedError(
f"Unknown image encoder {self.encoder_type}"
) from e
# Number of output dimensions of the entire module (i.e. including
# potential projection after the encoder)
if self.kyuteye_config.xa_dim is not None:
self.out_dim = self.kyuteye_config.xa_dim
else:
assert lm_model_dim is not None
self.out_dim = lm_model_dim
# Load the image encoder
self.enc = load_image_encoder(
self.encoder_type, pretrained=load_pretrained_encoder
)
# Projection layer; there are two possible projection targets
# A. for the extra tokens
self.proj_extra = self.init_proj_module(
self.kyuteye_config.fuse.num_extra_tokens
)
# B. for the cross attention
self.proj_xa = self.init_proj_module(
self.kyuteye_config.fuse.num_crossattended_tokens
)
# Output normalizations
self.norm_extra = self.init_norm_module(self.kyuteye_config.image.norm_extra)
self.norm_xa = self.init_norm_module(self.kyuteye_config.image.norm_xa)
@classmethod
def from_config(
cls,
kyuteye_config: "KyuteyeConfig",
lm_model_dim: Optional[int] = None,
moshi_weight: Optional[Dict[str, Any]] = None,
device: str | torch.device = "cpu",
) -> "ImageProjection":
"""Init image projection from config"""
load_pretrained_encoder = moshi_weight is None or not any(
x.startswith("enc.") for x in moshi_weight
)
image_projection = cls(
kyuteye_config,
lm_model_dim,
load_pretrained_encoder=load_pretrained_encoder,
)
if moshi_weight is not None:
missing_keys, _ = image_projection.load_state_dict(
moshi_weight, strict=False
)
encoder_keys: List[str] = []
proj_keys: List[str] = []
for key in missing_keys:
(encoder_keys if key.startswith("enc.") else proj_keys).append(key)
print(proj_keys)
assert len(proj_keys) == 0, "Failed to load image to speech projections"
print(encoder_keys)
assert len(encoder_keys) == 0, "Failed to load frozen image encoder"
return image_projection.to(device)
def init_proj_module(self, num_tokens: int) -> Optional[torch.nn.Module]:
"""Init the project module for the inserted and/or cross-attended iamge tokens"""
if num_tokens == 0:
return None
if num_tokens == -1:
if self.encoder_out_dim != self.out_dim:
return torch.nn.Linear(self.encoder_out_dim, self.out_dim)
return torch.nn.Identity()
raise ValueError(f"Found negative number of tokens for projection {num_tokens}")
@property
def encoder_out_dim(self) -> int:
"""Number of dimension output by the encoder"""
return self.encoder_type.out_dims
@property
def to_tensor_and_normalize(self) -> Callable:
"""Image normalization function"""
return get_img_normalize(self.encoder_type)()
def init_norm_module(self, norm_type: Optional[str]) -> Optional[torch.nn.Module]:
"""Init normalization module"""
if norm_type is None:
return None
return getattr(NormalizationLayer, norm_type.upper()).create_norm_fn(
self.out_dim
)
def forward(self, x: torch.Tensor | List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Image embedding mapping"""
# Apply image encoder
encoded, mask = self.encode(x)
# Apply different projection for extra vs cross attended tokens
out = {}
# The mask will be handled by the QP mapper and this module will output the same
# number of tokens for every sample in the batch, i.e., no padding is needed anymore.
# => We will not forward the mask in this case.
if mask is not None:
out["cross_attention_mask"] = mask
if self.proj_extra is not None:
assert mask is None, "proj_extra is not implemented yet for pixtral."
out["image_embeds"] = self.project_extra(encoded)
if self.proj_xa is not None:
out["cross_attention_src"] = self.project_xa(encoded)
return out
def encode(
self, x: torch.Tensor | List[torch.Tensor]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Pass through the image encoder backbone and reshape to (batch, seq, dim) Tensors"""
logits = self.enc(x)
if self.encoder_type == ImageEncoder.PIXTRAL:
assert isinstance(logits, PixtralOutput)
return logits.out, logits.mask
if logits.ndim == 4:
logits = rearrange(logits, "b d h w -> b (h w) d")
if logits.ndim != 3:
raise ValueError(
"The image encoder should output a tensor of shape"
" (B, Seq, D) (ViT) or (B, D, H, W) (CNN)"
)
return logits, None
def project_extra(self, logits: torch.Tensor) -> torch.Tensor:
"""Projection 1: Used for inserted extra tokens"""
assert self.proj_extra is not None
logits = self.proj_extra(logits)
if self.norm_extra is not None:
return self.norm_extra(logits)
return logits
def project_xa(self, logits: torch.Tensor) -> torch.Tensor:
"""Projection 2: Used for cross-attended tokens"""
assert self.proj_xa is not None
logits = self.proj_xa(logits)
if self.norm_xa is not None:
return self.norm_xa(logits)
return logits
================================================
FILE: kyuteye_pt/kyuteye/models/loaders.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Load moshi-vis neccessary components."""
from typing import Any, Dict, Optional, Tuple
import torch
from kyuteye.config.kyuteye_config import KyuteyeConfig
from kyuteye.models.image_projection import ImageProjection
from kyuteye.models.moshivis import MoshiVisGen
def get_moshi_vis(
kyuteye_config: KyuteyeConfig,
moshi_weight: Optional[str] = None,
device: str | torch.device = "cpu",
dtype: torch.dtype = torch.bfloat16,
gen_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[MoshiVisGen, ImageProjection]:
"""Return main Moshi model"""
image_proj_state: Dict[str, torch.Tensor] = {}
model_state: Dict[str, torch.Tensor] = {}
if moshi_weight is not None:
from safetensors.torch import load_file
for key, v in load_file(moshi_weight, device=device).items(): # type: ignore
if key.startswith("image_prefix."):
image_proj_state[key[13:]] = v
else:
model_state[key] = v
moshi_vis = MoshiVisGen.from_config(
kyuteye_config, model_state, device, dtype, **(gen_kwargs or {})
)
image_embedder = ImageProjection.from_config(
kyuteye_config, moshi_vis.model_dim, image_proj_state, device
)
return moshi_vis.to(dtype), image_embedder.to(dtype)
================================================
FILE: kyuteye_pt/kyuteye/models/moshivis.py
================================================
"""Moshi the little AI"""
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Tuple
import torch
from kyuteye.config.kyuteye_config import KyuteyeConfig
from kyuteye.models.helium import Helium
from kyuteye.modules.streaming_utils import StreamingModule
from kyuteye.modules.transformer import Transformer
from kyuteye.modules.utils import ClampedEmbedding
from moshi.utils.sampling import sample_token
class MoshiVis(StreamingModule):
"""Moshi model derived from Audiocraft with extra stuff for vision conditioninign"""
# Class attributes; extra special tokens
end_of_text_padding_id = 0
zero_token_id = -1
ungenerated_token_id = -2
def __init__(
self,
hidden_scale: float = 4.125,
norm: str = "real_rms_norm_f32",
gating: bool = True,
activation: str = "silu",
n_q: int = 8,
dep_q: Optional[int] = None,
audio_card: int = 1024,
audio_context: Optional[int] = None,
depformer: bool = False,
depformer_multi_linear: bool = False,
depformer_pos_emb: Optional[Literal["none", "sin", "rope", "sin_rope"]] = None,
depformer_dim: Optional[int] = None,
depformer_dim_feedforward: Optional[int] = None,
depformer_num_layers: Optional[int] = None,
depformer_num_heads: Optional[int] = None,
depformer_weights_per_step: bool = False,
depformer_context: Optional[int] = 8,
depformer_gating: Optional[bool] = None,
depformer_activation: Optional[str] = None,
delays: Optional[List[int]] = None,
text_card: int = 32000,
text_context: Optional[int] = None,
padding_token_id: int = 3,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs: Any,
) -> None:
"""Initialize a MoshiVis model"""
super().__init__()
# Set parameter for generation/preprocessing
self.text_card = text_card
self.audio_card = audio_card
self.text_context = text_context
self.text_padding_token_id = padding_token_id
self.audio_context = audio_context
self.n_q = n_q
self.dep_q = dep_q or self.n_q
assert delays is not None and len(delays) > 0, "Delays must be non empty"
assert len(delays) <= self.num_codebooks, "Too many delays"
if len(delays) < self.num_codebooks:
delays = delays + [delays[-1]] * (self.num_codebooks - len(delays))
self.delays = delays
embeddings_factory = partial(
ClampedEmbedding, device=device, dtype=dtype, zero_idx=self.zero_token_id
)
# LLM backbone (includes text embedding + text linear projection)
self.llm = Helium(
hidden_scale=hidden_scale,
card=text_card
+ 1, # Add an initial token in the embedding but not in the text linear
output_card=text_card + int(padding_token_id is None),
padding_token_id=padding_token_id,
device=device,
dtype=dtype,
zero_token_id=self.zero_token_id,
**kwargs,
)
# Audio input embeddings
self.audio_emb = torch.nn.ModuleList(
[
embeddings_factory(audio_card + 1, self.llm.dim)
for _ in range(self.num_audio_codebooks_in)
]
)
# Depformer
self.depformer: Optional[torch.nn.Module] = None
self.depformer_multi_linear = depformer_multi_linear
if depformer:
assert depformer_dim is not None
assert depformer_num_heads is not None
assert depformer_num_layers is not None
assert depformer_pos_emb is not None
if depformer_dim_feedforward is None:
depformer_dim_feedforward = int(hidden_scale * depformer_dim)
assert depformer_dim_feedforward is not None
self.depformer_in = torch.nn.ModuleList(
[
torch.nn.Linear(self.llm.dim, depformer_dim, bias=False)
for _ in range(
self.num_audio_codebooks_out if depformer_multi_linear else 1
)
]
)
# Text and audio input embeddings for the depformer
self.depformer_emb = torch.nn.ModuleList(
[
embeddings_factory(audio_card + 1, depformer_dim)
for _ in range(self.num_audio_codebooks_out - 1)
]
)
self.depformer_text_emb = embeddings_factory(text_card + 1, depformer_dim)
self.depformer = Transformer(
d_model=depformer_dim,
dim_feedforward=depformer_dim_feedforward,
positional_embedding=depformer_pos_emb,
num_heads=depformer_num_heads,
num_layers=depformer_num_layers,
norm=norm,
device=device,
dtype=dtype,
causal=True,
cross_attention=False,
context=depformer_context,
gating=depformer_gating or gating,
activation=depformer_activation or activation,
weights_per_step=dep_q if depformer_weights_per_step else None,
)
# Output projection
self.audio_linears = torch.nn.ModuleList(
[
torch.nn.Linear(depformer_dim, audio_card, bias=False)
for _ in range(self.num_audio_codebooks_out)
]
)
@property
def cross_attention(self) -> bool:
"""Shortcut for checking whether cross_attention i sused"""
return self.llm.cross_attention
@property
def num_audio_codebooks_in(self) -> int:
"""Number of audio codebooks to model as input"""
return self.n_q
@property
def num_audio_codebooks_out(self) -> int:
"""Number of audio codebooks to model in the depformer"""
return self.dep_q
@property
def num_codebooks(self) -> int:
"""Number codebooks including text"""
return self.num_audio_codebooks_in + 1
@property
def initial_audio_token_id(self) -> int:
"""Initial token for the audio codebooks"""
return self.audio_card
@property
def initial_text_token_id(self) -> int:
"""Initial token for the text; takes into account the "fake/proxy"
tokens for beginning and end of image if they have been set"""
return self.text_card
@property
def audio_offset(self) -> int:
"""Offset in the audio codebook. Returns 1 because we always generate with text"""
return 1
def forward_text(
self,
input_ids: torch.Tensor,
cross_attention_src: Optional[
Tuple[torch.Tensor, torch.Tensor] | torch.Tensor
] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
"""Forward pass for Moshi
:param input_ids: Text + audio tokens of shape (batch, codebooks, seq length)
:param cross_attention_src: Conditioning (image) tokens that can be
cross-attended to through the cross attention module
:param cross_attention_mask: Additional mask for the cross_attention_src.
This is necessary mainly for Pixtral models, as the cross-attended images
might be of different sizes and therefore padded.
:param attention_mask: Optional attention mask on input_ids (e.g. used at
generation for batched inference with left padding)
:return: A tuple containing the
* text logits (None if `text_or_audio` is audio)
* audio logits (None if `text_or_audio` is text)
"""
# Embed tokens
inputs_embeds = torch.zeros((), device=input_ids.device)
if self.audio_offset > 0:
inputs_embeds = self.llm.text_emb(input_ids[:, 0, :])
for cb_index in range(self.num_audio_codebooks_in):
update = self.audio_emb[cb_index](
input_ids[:, cb_index + self.audio_offset, :]
)
inputs_embeds += update
# Pass through Helium
transformer_out, gate_weight = self.llm(
inputs_embeds=inputs_embeds,
cross_attention_src=cross_attention_src,
cross_attention_mask=cross_attention_mask,
attention_mask=attention_mask,
return_features=True,
)
# Output proj
text_logits = self.llm.text_linear(transformer_out)[:, None]
return transformer_out, text_logits, gate_weight
def forward_depformer(
self,
depformer_cb_index: int,
input_ids: torch.Tensor,
depformer_input: torch.Tensor,
) -> torch.Tensor:
"""Forward one depformer step"""
_, num_codes, seq_len = input_ids.shape
assert self.depformer is not None
assert (
num_codes == 1
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {num_codes}."
assert (
seq_len == 1
), f"Steps for Depformer streaming should be passed 1 by 1, got {seq_len}."
assert (
depformer_input.shape[1] == 1
), "Transformer output should be a for a single step."
# project transformer out
depformer_input = self.depformer_in[
depformer_cb_index if self.depformer_multi_linear else 0
](depformer_input)
# project input ids
if depformer_cb_index == 0:
depformer_input += self.depformer_text_emb(input_ids[:, 0])
else:
depformer_input += self.depformer_emb[depformer_cb_index - 1](
input_ids[:, 0]
)
# depformer_input is [B, 1, depformer_dim].
# The streaming state of the depformer ensures that the proper layer is run.
dep_output, _ = self.depformer(depformer_input)
logits = self.audio_linears[depformer_cb_index](dep_output)
logits = logits[:, None]
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
return logits
@property
def device(self) -> torch.device:
"""Torch device"""
return next(iter(self.parameters())).device
def get_initial_token(self) -> torch.Tensor:
"""Returns the initial token that will be fed to the model to predict the
very first timestep. This is akin to a beginning of sentence tokens but
to handle potentially delayed codebooks
:param text_or_audio: Whether we are predicting for text, audio, or both
:return: A Tensor fo shape (B, K, 1)
"""
zero = torch.full(
[1, 1, 1], MoshiVis.zero_token_id, device=self.device, dtype=torch.long
)
audio_token = torch.full_like(
zero, self.initial_audio_token_id or MoshiVis.zero_token_id
)
text_token = torch.full_like(
zero, self.initial_text_token_id or MoshiVis.zero_token_id
)
return torch.cat(
[text_token, audio_token.expand(-1, self.num_audio_codebooks_in, -1)], dim=1
)
class MoshiVisGen(StreamingModule):
"""MoshiVis for autoregressive generation at inference"""
def __init__(
self,
moshi_vis: MoshiVis,
use_sampling: bool = True,
temp: float = 0.8,
temp_text: float = 0.7,
top_k: int = 250,
top_k_text: int = 25,
check: bool = False,
):
assert not moshi_vis.training, "generation shouldn't be used in training mode."
super().__init__()
self.lm_model = moshi_vis
self.use_sampling = use_sampling
self.temp = temp
self.temp_text = temp_text
self.top_k = top_k
self.top_k_text = top_k_text
self.check = check
self.max_delay = max(
moshi_vis.delays
) # with delays, we need to generate a few more time steps.
self.delays_cuda = torch.tensor(
moshi_vis.delays, device=self.lm_model.device, dtype=torch.long
)
self.initial_token = self.lm_model.get_initial_token()
def update_gen_kwargs(
self,
temp: Optional[float] = None,
temp_text: Optional[float] = None,
top_k: Optional[int] = None,
top_k_text: Optional[int] = None,
) -> None:
"""update params for sampling during generation"""
self.temp = temp or self.temp
self.temp_text = temp_text or self.temp_text
self.top_k = top_k or self.top_k
self.top_k_text = top_k_text or self.top_k_text
@property
def model_dim(self) -> int:
"""Return dimension of the tokens in the model"""
return self.lm_model.llm.dim
@property
def num_audio_codebooks_out(self) -> int:
"""Number of audio codebooks generated by the model"""
return self.lm_model.num_audio_codebooks_out
@classmethod
def from_config(
cls,
kyuteye_config: KyuteyeConfig,
moshi_weight: Optional[Dict[str, Any]] = None,
device: str | torch.device = "cpu",
dtype: torch.dtype = torch.bfloat16,
**gen_kwargs: Any,
) -> "MoshiVisGen":
"""Instantiate model from a config
:param base config:
:param moshi_weight
"""
moshivis = MoshiVis(**kyuteye_config.moshi_constructor_kwargs, dtype=dtype)
if moshi_weight is not None:
missing_keys, _ = moshivis.load_state_dict(moshi_weight, strict=False)
# cross-attention MHSA is shared across layers
missing_keys = [
k
for k in missing_keys
if ("cross_attention.mha" not in k or "layers.0" in k)
]
assert len(missing_keys) == 0
return MoshiVisGen(moshi_vis=moshivis.eval().to(device), **gen_kwargs)
@torch.no_grad()
def precompte_ca_kv(
self, embeddings: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precompte kv proj for cross-attention"""
ca_layer = self.lm_model.llm.transformer.layers[0].cross_attention.mha
if hasattr(ca_layer, "in_proj_weight_kv"):
splits = torch.chunk(ca_layer.in_proj_weight_kv, 2)
else:
splits = torch.chunk(ca_layer.in_proj_weight, 3)
k = torch.nn.functional.linear( # pylint: disable=not-callable
embeddings, splits[-2]
)
v = torch.nn.functional.linear( # pylint: disable=not-callable
embeddings, splits[-1] # type: ignore
)
return k, v
@torch.no_grad()
def step(
self,
input_tokens: torch.Tensor,
ca_src: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None,
) -> Tuple[torch.Tensor | None, float]:
"""One step of generation"""
state = self._streaming_state
if state is None:
raise RuntimeError(
"You should wrap those calls with a `with lm_gen.streaming(): ...`."
)
lm_model = self.lm_model
assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
batch_size, num_codes, seq_len = input_tokens.shape
assert seq_len == 1, "Only support being given steps one by one."
needed_tokens = lm_model.num_codebooks - lm_model.num_audio_codebooks_out - 1
assert (
num_codes == needed_tokens
), f"We expect {needed_tokens} tokens from the user stream, got {num_codes}."
current_input_cache = self.get_streaming_attribute(
"cache",
torch.full(
(batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
self.lm_model.ungenerated_token_id,
device=self.lm_model.device,
dtype=torch.long,
),
)
current_offset = self.get_streaming_attribute("offset", 0)
dcache_len = current_input_cache.shape[2]
# write input_tokens (sent from Mimi) in OTHER codebooks
for q_other in range(input_tokens.shape[1]):
k = lm_model.num_audio_codebooks_out + lm_model.audio_offset + q_other
write_position = (current_offset + lm_model.delays[k]) % dcache_len
current_input_cache[:, k, write_position : write_position + 1] = (
input_tokens[:, q_other]
)
# Only for the very beginning, we extend the initial token for the acoustic
# token that are delayed, and thus have no good value to take.
position = current_offset % dcache_len
for k, delay in enumerate(lm_model.delays):
if current_offset <= delay:
current_input_cache[:, k, position] = self.initial_token[:, k, 0]
# Transformer forward
input_ = current_input_cache[:, :, position : position + 1]
if self.check:
# Check that we are not feeding in any value that is not generated yet.
assert not (input_ == lm_model.ungenerated_token_id).any(), (
current_offset,
input_,
)
assert (
input_[:, lm_model.audio_offset :] <= lm_model.audio_card
).all(), input_
assert (input_[:, :1] <= lm_model.text_card).all()
transformer_out, text_logits, gate_weight = self.lm_model.forward_text(
input_, cross_attention_src=ca_src
)
# Sample text tokens
# Shape of text_logits should be [B, K_text=1, T=1, Card_text]
text_token = sample_token(
text_logits.float(),
self.use_sampling,
self.temp_text,
self.top_k_text,
)
assert text_token.dim() == 3, text_token.shape
assert text_token.shape[2] == 1
assert text_token.shape[1] == 1, "Only one text stream supported."
text_token = text_token[:, 0, 0] # shape is [B]
# Generate and sample audio tokens
audio_tokens = self.depformer_step(text_token, transformer_out)
# Write generated tokens
current_offset += 1
position = current_offset % dcache_len
current_input_cache[:, 0, position] = text_token
current_input_cache[
:,
lm_model.audio_offset : lm_model.num_audio_codebooks_out
+ lm_model.audio_offset,
position,
] = audio_tokens
# if <= max_delay, we continue partial-generation
# until removing all ungenerated tokens
if current_offset <= self.max_delay:
self.add_streaming_attribute("cache", current_input_cache)
self.add_streaming_attribute("offset", current_offset)
return None, 0.0
# otherwise, retrieve tokens with the correct delay
gen_delays_cuda = self.delays_cuda[
: lm_model.num_audio_codebooks_out + lm_model.audio_offset
]
index = (
((current_offset - self.max_delay + gen_delays_cuda) % dcache_len)
.view(1, -1, 1)
.expand(current_input_cache.shape[0], -1, 1)
)
out = current_input_cache.gather(dim=2, index=index)
self.add_streaming_attribute("offset", current_offset)
self.add_streaming_attribute("cache", current_input_cache)
return out, gate_weight
def depformer_step(
self,
text_token: torch.Tensor,
transformer_out: torch.Tensor,
) -> torch.Tensor:
"""A step of the depformer"""
batch_size = text_token.shape[0]
depformer_tokens: list[torch.Tensor] = []
assert self.lm_model.depformer is not None
with self.lm_model.depformer.streaming():
next_token = text_token[:, None, None]
for cb_index in range(self.lm_model.num_audio_codebooks_out):
logits = self.lm_model.forward_depformer(
cb_index, next_token, transformer_out
)
next_token = sample_token(
logits.float(),
self.use_sampling,
self.temp,
self.top_k,
)
assert next_token.shape == (batch_size, 1, 1)
depformer_tokens.append(next_token[:, 0, 0])
out = torch.stack(depformer_tokens, dim=1)
return out
================================================
FILE: kyuteye_pt/kyuteye/modules/__init__.py
================================================
================================================
FILE: kyuteye_pt/kyuteye/modules/attention.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MultiHead Self-Attention module with optional KV Caching"""
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import torch
from einops import rearrange
from kyuteye.modules.streaming_utils import StreamingModule
from kyuteye.modules.utils import RotaryEmbedding, multi_linear
@dataclass
class KVCache:
"""Efficient streaming KVCache to avoid allocating new memory too many times.
:param batch_size: Batch size.
:param num_heads: Number of heads in the attention.
:param dim_per_head: Dimension per head.
:param context: Context size for the attention, if None, will grow exponentially,
otherwise will use a fixed allocation with a bit of overhead.
:param growth: Growth factor for the exponential growth, fraction of overhead
when context is not None.
:param initial_size: Initial size of the cache, used only when context is None.
:param device: Device on which to initialize the cache.
:param dtype: dtype to use for the cache.
:param cache: Initial cache, if provided.
:param current_end: Current end of the cache, used only when cache is provided.
"""
def __init__(
self,
batch_size: int,
num_heads: int,
dim_per_head: int,
context: Optional[int] = None,
growth: float = 1.2,
initial_size: int = 100,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
cache: Optional[torch.Tensor] = None,
current_end: int = 0,
) -> None:
if cache is None:
assert current_end == 0
assert growth > 1
self.growth = growth
if context is not None:
initial_size = 1 + int(growth * context)
self.capacity = initial_size
self.context = context
self.current_end = current_end
if cache is None:
self._cache = torch.full(
(2, batch_size, initial_size, num_heads, dim_per_head),
float("NaN"),
device=device,
dtype=dtype,
)
else:
self._cache = cache
def clone(self) -> "KVCache":
"""Return a separate memory copy of the KV cache"""
return KVCache(
self._cache.shape[1],
self._cache.shape[3],
self._cache.shape[4],
self.context,
self.growth,
self.capacity,
self._cache.device,
self._cache.dtype,
self._cache.clone(),
self.current_end,
)
@property
def current_start(self) -> int:
"""Current start of the KV cache (0 if no context size)"""
return 0 if self.context is None else max(self.current_end - self.context, 0)
def __maybe_increase_capacity__(self, required_capacity: int) -> None:
"""If needed, increase capacity to the `required_capacity`
using exponential growth strategy"""
if required_capacity > self.capacity:
if self.context is None:
# We take an exponential growth approach.
new_capacity = self.capacity
while required_capacity > new_capacity:
new_capacity = int(math.ceil(new_capacity * self.growth))
new_shape = list(self._cache.shape)
new_shape[2] = new_capacity
new_cache = torch.full(
tuple(new_shape),
float("NaN"),
device=self._cache.device,
dtype=self._cache.dtype,
)
new_cache[:, :, : self.current_end] = self._cache[
:, :, : self.current_end
]
self._cache = new_cache
self.capacity = new_capacity
else:
# With context, we just have to roll the predict to the left and
# use the new space on the right.
assert self.current_start > 0
self._cache[:] = self._cache.roll(-self.current_start, dims=2)
self.current_end -= self.current_start
def complete(
self, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Add keys `k` and values `v` to the current cache and returns
cache up to the context size"""
assert k.shape[1] == v.shape[1]
self.__maybe_increase_capacity__(self.current_end + k.shape[1])
assert self.current_end + k.shape[1] <= self.capacity, (
self.current_end,
k.shape[1],
self.capacity,
)
self._cache[0, :, self.current_end : self.current_end + k.shape[1]] = k
self._cache[1, :, self.current_end : self.current_end + v.shape[1]] = v
self.current_end += k.shape[1]
valid = self._cache[:, :, self.current_start : self.current_end]
return valid[0], valid[1]
class MultiheadAttention(StreamingModule):
"""Similar to `nn.MultiheadAttention` but with support for causal evaluation.
Args:
:param embed_dim: Dimension to project to.
:param num_heads: Number of heads.
:param causal: If true, applies causal mask automatically.
:param context: Number of time steps the attention can access to.
When causal, can access `context` time steps into the past, and when non causal,
can access `context // 2` steps in the past, and the same in the future.
:param rope: Rope embedding to use. If None, no rope embedding is applied
:param cross_attention: Should be true when used as a cross attention.
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
interpret the time steps in the keys relative to those in the queries).
:param use_kv_cache: If True, enables a KV cache with context size `context`.
:param weights_per_step: use different weights per depformer step. If non zero,
should correspond to the number of possible time steps.
:param device: Device on which to initialize the module.
:param dtype: dtype to use.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
causal: bool = False,
context: Optional[int] = None,
rope: Optional[RotaryEmbedding] = None,
cross_attention: bool = False,
use_kv_cache: bool = False,
weights_per_step: int = 0,
xa_dim: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.embed_dim = embed_dim
self.causal = causal
self.context = context
self.rope = rope
self.cross_attention = cross_attention
self.num_heads = num_heads
self.use_kv_cache = use_kv_cache
self.weights_per_step = weights_per_step
mult = max(1, weights_per_step)
if cross_attention:
assert not causal, "Cannot set causal mask when `cross attention` is True."
assert (
not context
), "Cannot set context size when `cross attention` is True."
# if cross-attention source have != num_dims than the speech tokens,
# we need to separate the KV and Q embeddings
if cross_attention and xa_dim is not None and xa_dim != embed_dim:
in_proj_q = torch.nn.Linear(
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
)
in_proj_kv = torch.nn.Linear(
xa_dim, mult * 2 * embed_dim, bias=False, **factory_kwargs
)
self.in_proj_weight_q = in_proj_q.weight
self.in_proj_bias_q = in_proj_q.bias
self.in_proj_weight_kv = in_proj_kv.weight
self.in_proj_bias_kv = in_proj_kv.bias
self.in_proj_weight = None
self.in_proj_bias = None
else:
in_proj = torch.nn.Linear(
embed_dim, mult * 3 * embed_dim, bias=False, **factory_kwargs
)
self.in_proj_weight = in_proj.weight
self.in_proj_bias = in_proj.bias
self.out_proj = torch.nn.Linear(
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
)
def _complete_kv(
self, k: torch.Tensor, v: torch.Tensor, initial_kv_cache_size: int = 256
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add key/values to the KV cache"""
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
# to the queries only.
if self._is_streaming and not self.cross_attention:
if "kv_cache" not in self._streaming_state:
self._streaming_state["kv_cache"] = KVCache( # type: ignore
k.shape[0],
k.shape[2],
k.shape[3],
self.context,
initial_size=self.weights_per_step or initial_kv_cache_size,
device=k.device,
dtype=k.dtype,
)
self.streaming_offset = torch.zeros(1) # type: ignore
kv_cache: KVCache = self._streaming_state["kv_cache"] # type: ignore
self.streaming_offset += k.shape[1]
return kv_cache.complete(k, v)
return k, v
def forward(
self,
query: torch.Tensor,
key: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""If self.cross attention is False, we only expects the first input. Otherwise,
when using cross attention, we need to explicitly give the source for the
respective query/key/value embeddings"""
# Get current streaming offset before it gets potentially modified by the KV cache update
current_streaming_offset = self.streaming_offset
if self.cross_attention:
assert key is not None, "Missing inputs in cross attention"
if isinstance(key, torch.Tensor):
value = value or key
assert value is not None, "Missing inputs in cross attention"
# Case 1: Inputs x and ca_src have the same number of dimension
# We have a single big weight for the QKV projections
if self.in_proj_weight is not None:
q = torch.nn.functional.linear( # pylint: disable=not-callable
query, self.in_proj_weight[: self.embed_dim]
)
if isinstance(key, torch.Tensor):
k = torch.nn.functional.linear( # pylint: disable=not-callable
key, self.in_proj_weight[self.embed_dim : 2 * self.embed_dim]
)
v = torch.nn.functional.linear( # pylint: disable=not-callable
value, self.in_proj_weight[2 * self.embed_dim :] # type: ignore
)
else:
k, v = key
# Case 2: Inputs x and ca_src have different number of dimension
# We have to separate the Q and KV proj
else:
q = torch.nn.functional.linear( # pylint: disable=not-callable
query, self.in_proj_weight_q[: self.embed_dim]
)
if isinstance(key, torch.Tensor):
k = torch.nn.functional.linear( # pylint: disable=not-callable
key, self.in_proj_weight_kv[: self.embed_dim]
)
v = torch.nn.functional.linear( # pylint: disable=not-callable
value, self.in_proj_weight_kv[self.embed_dim :] # type: ignore
)
else:
k, v = key
q, k, v = [
rearrange(x, "b t (h d) -> b t h d", h=self.num_heads)
for x in [q, k, v]
]
else:
assert self.in_proj_weight is not None
if self.weights_per_step > 0:
projected = multi_linear(
self.weights_per_step,
self.in_proj_weight,
query,
offset=current_streaming_offset,
)
else:
projected = torch.nn.functional.linear( # pylint: disable=not-callable
query, self.in_proj_weight
)
packed = rearrange(
projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads
)
q, k, v = torch.unbind(packed, dim=2)
if self.rope:
q, k = self.rope(q, k, offset=current_streaming_offset)
k, v = self._complete_kv(k, v)
# Attention
q, k, v = [x.transpose(1, 2) for x in [q, k, v]]
x = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
q, k, v, is_causal=False, attn_mask=attention_mask
)
x = x.transpose(1, 2)
# output projection
x = rearrange(x, "b t h d -> b t (h d)")
if self.weights_per_step > 0:
x = multi_linear(
self.weights_per_step,
self.out_proj.weight,
x,
offset=current_streaming_offset,
)
else:
x = self.out_proj(x)
return x
================================================
FILE: kyuteye_pt/kyuteye/modules/cross_attention.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Gated cross-attention module"""
from typing import Any, Callable, Literal, Optional, Tuple, Union
import torch
from kyuteye.modules.attention import MultiheadAttention
from kyuteye.modules.streaming_utils import StreamingModule
class SharedModuleType(type):
"""Wrapper to build shared Pytorch modules"""
_instances = {} # type: ignore
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
cls._instances[cls] = super(SharedModuleType, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class XAGate(torch.nn.Module):
"""Learned multiplicative gating per layer"""
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
activation: Literal["tanh", "sigmoid"] = "tanh",
conditional_gating: bool = False,
dims: Optional[int] = None,
hidden_dims_factor: float = 0.125,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.alpha: torch.nn.Parameter | torch.nn.Module
self.conditional_gating = conditional_gating
if self.conditional_gating:
assert dims is not None
hidden_dims = int(hidden_dims_factor * dims)
self.alpha = torch.nn.Sequential(
torch.nn.Linear(dims, hidden_dims, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dims, dims, bias=False),
)
else:
self.alpha = torch.nn.Parameter(
torch.full((1, 1, 1), 0.0, device=device, dtype=dtype)
)
self.act: Callable
if activation == "tanh":
self.act = torch.tanh
elif activation == "sigmoid":
# shift left to mimic initialization ~ close to 0
self.act = lambda x: torch.sigmoid(x - 4)
else:
raise NotImplementedError("Unknown activation function", activation)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gating (constant scaling or input-dependent)"""
if isinstance(self.alpha, torch.nn.Parameter):
gate_weight = self.act(self.alpha)
else:
gate_weight = self.act(self.alpha(x))
return x * gate_weight, gate_weight
class SharedXaGate(XAGate, metaclass=SharedModuleType):
"""Shared XaGate"""
pass # pylint: disable=unnecessary-pass
class CrossAttention(MultiheadAttention):
"""Cross attention module"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["cross_attention"] = True
super().__init__(*args, **kwargs)
class SharedCrossAttention(CrossAttention, metaclass=SharedModuleType):
"""Shared Cross Attention projection across all layers"""
pass # pylint: disable=unnecessary-pass
class GatedCrossAttention(StreamingModule):
"""Gated Cross Attention with or without sharing parameters across layers"""
def __init__(
self,
embed_dim: int,
xa_gating: Literal["tanh", "sigmoid", "none"] = "tanh",
xa_conditional_gating: bool = False,
xa_shared: bool = True,
xa_gate_shared: bool = False,
xa_delay: int = 0,
xa_start: Literal["start", "boi", "eoi"] = "start",
xa_end: Literal["end", "eoi"] | int = "end",
xa_step: int = 1,
xa_dim: Optional[int] = None,
**attn_kwargs: Any,
) -> None:
"""
Initializes a Gated CrossAttention module.
:param xa_gating: Whether (and which type of) to add multiplicative gating at
the output of the cross-attention layer
:param xa_shared: Whether to share the projection parameters of this
corss-attention module across all layers
tanh_gate (bool, optional): Whether to apply tanh activation to the gate.
Defaults to True.
share_tanh_gate (bool, optional): Whether to share the tanh gate parameters
across different attention heads. Defaults to True.
shared_parameters (bool, optional): Whether to share the attention parameters
across different attention heads. Defaults to True.
shift (int, optional): The shift value for the cross attention, i.e., whether
the update is to be based on past queries.
Defaults to 0, i.e., no shifting.
xa_scope (Literal["images", "all_after_first_boi"], optional): The scope of the
attention. Can be "images" or "all_after_first_boi". Defaults to "images".
**attn_kwargs: Additional keyword arguments to be passed to the CrossAttention
or SharedCrossAttention class.
"""
super().__init__()
device = attn_kwargs.get("device", None)
dtype = attn_kwargs.get("dtype", None)
# Attention module
self.mha = (SharedCrossAttention if xa_shared else CrossAttention)(
embed_dim=embed_dim, xa_dim=xa_dim, **attn_kwargs
)
# Output Gating
self.gate: Optional[torch.nn.Module] = None
if xa_gating != "none":
self.gate = (SharedXaGate if xa_gate_shared else XAGate)(
activation=xa_gating,
device=device,
dtype=dtype,
dims=embed_dim,
conditional_gating=xa_conditional_gating,
)
# If the XA module AND gates are shared, we add a per-layer
# coefficient to still have some modularity
self.per_layer_alpha: Optional[torch.nn.Parameter] = None
if xa_shared and xa_gate_shared:
self.per_layer_alpha = torch.nn.Parameter(
torch.full((1, 1, 1), 1.0, device=device, dtype=dtype)
)
# Determine the xa scope
self.xa_start = xa_start
self.xa_end = xa_end
self.xa_step = xa_step
self.xa_delay = xa_delay
self._active = True
def get_xa_scope(
self, x: torch.Tensor, image_tokens_mask: Optional[torch.Tensor]
) -> torch.Tensor:
"""Build the mask of which tokens should receive contribution from
the cross-attention image tokens"""
if self.is_streaming and isinstance(self.xa_start, int):
return int(self.streaming_offset >= self.xa_start)
if not (
self.xa_start in {"start", "boi", "eoi"} or isinstance(self.xa_start, int)
) or not (self.xa_end in {"end", "eoi"} or isinstance(self.xa_end, int)):
raise NotImplementedError(
f"Unsupported XA scope : {self.xa_start}, {self.xa_end}"
)
if self.xa_start == "start":
# scope = 'all'
if self.xa_end == "end":
mask = torch.ones_like(x[:, :, :1])
# scope = 'start + relative'abs
elif isinstance(self.xa_end, int):
mask = torch.zeros_like(x[:, :, :1])
mask[:, : self.xa_end, :] = 1
# everything else is kinda weird (0, BoI) ? (0, EoI) ?
else:
raise NotImplementedError(
f"Unsupported XA scope : {self.xa_start}, {self.xa_end}"
)
elif isinstance(self.xa_start, int):
if isinstance(self.xa_end, int):
mask = torch.zeros_like(x[:, :, :1])
# self.xa_end is relative to self.xa_start
mask[:, self.xa_start : self.xa_start + self.xa_end, :] = 1
elif self.xa_end == "end":
# If for some reason the start is further than the end, we do not attend
mask = torch.zeros_like(x[:, :, :1])
if not self.xa_start > mask.shape[1]:
mask[:, self.xa_start :, :] = 1
else:
raise NotImplementedError(
f"Unsupported XA scope : {self.xa_start}, {self.xa_end}"
)
else:
assert image_tokens_mask is not None
# another easy case is attention only to the image tokens
if self.xa_start == "boi" and self.xa_end == "eoi":
mask = image_tokens_mask
# Otherwise, we build the mask manually
# first, determine the start, either EoI or BoI
elif self.xa_start == "boi":
# everything that comes after boi
# e.g. (0 0 0 1 1 1 1 0 0 0 ) becomes
# (0 0 0 1 2 3 4 4 4 4 )
mask = torch.cumsum(image_tokens_mask, dim=1)
elif self.xa_start == "eoi":
# everything that comes after eoi
# e.g. (0 0 0 1 1 1 1 0 0 0 ) becomes
# (0 0 0 0 0 0 0 4 4 4 )
mask = (1 - image_tokens_mask) * torch.cumsum(image_tokens_mask, dim=1)
else:
raise NotImplementedError(
f"Unsupported XA scope starting at {self.xa_start}"
)
# then determine the end of the mask
mask = torch.greater(mask, 0).to(x.dtype)
if isinstance(self.xa_end, int):
mask = torch.cumsum(mask, dim=1)
mask = torch.lt(mask, self.xa_end + 1).to(x.dtype)
# then apply xa_step
if self.xa_step > 1:
if self.xa_start not in {"start", 0}:
raise NotImplementedError("xa_step")
step_mask = torch.eq(
torch.remainder(torch.arange(x.shape[1]), self.xa_step), 0
).to(mask)
mask *= step_mask[None, :, None].float()
return mask
def is_active(self, image_tokens_mask: Optional[torch.Tensor] = None) -> bool:
"""Whether this model is active during the forward pass"""
if self.is_streaming:
# case 1: never stop
if self.xa_end == "end":
return self.xa_start == "start" or (
isinstance(self.xa_start, int)
and self.streaming_offset >= self.xa_start
)
# case 2: XA applies to the image; we only apply cross-attention
# in the step where the image is inserted
if self.xa_end == "eoi" and image_tokens_mask is None:
return False
# Case 3: XA end is relative to XA start
if isinstance(self.xa_end, int):
if self.xa_start == "start":
offset = 0
elif isinstance(self.xa_start, int):
offset = self.xa_start
elif self.xa_start == "boi":
if self.has_streaming_attribute("image_insert_start"):
offset = self.get_streaming_info_as_int("image_insert_start")
else:
return False
elif self.xa_start == "eoi":
if self.has_streaming_attribute("image_insert_end"):
offset = self.get_streaming_info_as_int("image_insert_end")
else:
return False
else:
raise ValueError("Unsupported xa_start option", self.xa_start)
# if xa_step is active
if self.xa_step > 1:
return (
offset <= self.streaming_offset < offset + self.xa_end
) and ((self.streaming_offset - offset) % self.xa_step == 0)
# base case
return offset <= self.streaming_offset < offset + self.xa_end
# In training, we are always active and just build the xa scope mask
return True
def forward(
self,
x: torch.Tensor,
key: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
image_tokens_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Gated Cross attention
:param x: Input query tensor
:param key: Source tokens for the keys
:param value: Source tokens for the values (typically, equal to keys)
:param cross_attention_mask: Mask to apply to the cross-attention
:param image_tokens_mask: Mask indicating where the image tokens
are in the stream. This is used to determine which token should
cross-attend to the image
"""
gate_weight = None
if self.is_streaming:
if not self.has_streaming_attribute("offset"):
self.streaming_offset = 0
# Mark the last inserted image's position in streaming mode
if image_tokens_mask is not None:
self.add_streaming_attribute(
"image_insert_start", self.streaming_offset
)
image_lengths = torch.sum(image_tokens_mask[:, :, 0], dim=1)
assert torch.all(
torch.eq(image_lengths, int(image_lengths[0].item()))
), "All inserted images must have the same number of tokens"
self.add_streaming_attribute(
"image_insert_end",
self.streaming_offset + int(image_lengths[0].item()),
)
if not self.is_active(image_tokens_mask=image_tokens_mask):
x = torch.zeros_like(x)
else:
x = self.mha(
query=x, key=key, value=value, attention_mask=cross_attention_mask
)
if self.gate is not None:
x, gate_weight = self.gate(x)
if self.per_layer_alpha is not None:
x *= self.per_layer_alpha
# Mask out tokens that should not receive signal from the image
x = x * self.get_xa_scope(x, image_tokens_mask=image_tokens_mask)
# Update streaming offset
if self.is_streaming:
self.streaming_offset += x.shape[1]
return x, gate_weight
================================================
FILE: kyuteye_pt/kyuteye/modules/image_encoder.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Pretrained image encoders from timm and/or transformers"""
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import torch
from kyuteye.config.enums import ImageEncoder
from kyuteye.modules.image_transforms import (
Normalize,
PixtralNormalize,
SigLIPNormalize,
)
from transformers import (
AutoConfig,
AutoModelForImageTextToText,
LlavaForConditionalGeneration,
SiglipVisionConfig,
SiglipVisionModel,
)
class TrimmedFlexiViTWrapper(torch.nn.Module):
"""ViT module without the classification tower"""
def __init__(
self, model: torch.nn.Module, interpolate_pos_encoding: bool = False
) -> None:
super().__init__()
self.interpolate_pos_encoding = interpolate_pos_encoding
self.model = model
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Get last hidden states"""
return self.model(
x, interpolate_pos_encoding=self.interpolate_pos_encoding
).last_hidden_state
def load_paligemma_vision_encoder(
name: str, device: torch.device | str = "cpu", pretrained: bool = False
) -> torch.nn.Module:
"""Load Paligemma encoder from the shared HuggingFace cache"""
if pretrained:
model = AutoModelForImageTextToText.from_pretrained(
name
).vision_tower.vision_model
else:
image_size = int(name.rsplit("-", 1)[-1])
model = SiglipVisionModel(
SiglipVisionConfig(
**{
"hidden_size": 1152,
"image_size": image_size,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": (image_size // 14) ** 2,
"num_positions": 256,
"patch_size": 14,
"projection_dim": 2304,
"torch_dtype": "bfloat16",
"vision_use_head": False,
}
)
).vision_model
return TrimmedFlexiViTWrapper(model.to(device), interpolate_pos_encoding=True)
@dataclass
class PixtralOutput:
"""Pixtral Output"""
out: torch.Tensor
mask: torch.Tensor
class PixtralWrapper(torch.nn.Module):
"""Pixtral encoder returning penultimate features"""
def __init__(
self, device: Optional[torch.device | str] = None, pretrained: bool = False
) -> None:
super().__init__()
if pretrained:
self.model = AutoModelForImageTextToText.from_pretrained(
"mistral-community/pixtral-12b"
).vision_tower.to(device)
else:
config = AutoConfig.from_pretrained("mistral-community/pixtral-12b")
self.model = LlavaForConditionalGeneration(config).vision_tower.to(device)
self.patch_size = torch.prod(
torch.tensor(self.model.patch_conv.weight.shape[-2:]) # type: ignore
)
def __get_num_output_tokens__(self, x: List[torch.Tensor]) -> List[int]:
"""Get number of tokens for each image in the list"""
return [
(torch.prod(torch.tensor(img[1].shape[-2:])) // self.patch_size).item()
for img in x
]
@staticmethod
def split_and_pad_output(
x: torch.Tensor, split_points: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Split the output of the model into a list of tensors corresponding
to the input images and create a mask tensor"""
splits = list(torch.split(x, split_points, dim=1))
assert sum(split.shape[1] for split in splits) == x.shape[1]
# Pad the splits to have the same size along the sequence dimension
max_len = max(split.shape[1] for split in splits)
padded_splits = []
mask = torch.zeros((len(splits), max_len), dtype=torch.bool, device=x.device)
for i, split in enumerate(splits):
pad_len = max_len - split.shape[1]
# Right padding of the second to last (i.e., the sequence) dimension
padded_split = torch.nn.functional.pad(split, (0, 0, 0, pad_len))
padded_splits.append(padded_split)
mask[i, : split.shape[1]] = 1
return torch.cat(padded_splits, dim=0), mask
def forward(self, x: List[torch.Tensor] | torch.Tensor) -> PixtralOutput:
"""Forward to the last hidden states"""
if isinstance(x, torch.Tensor):
x = list(x)
assert isinstance(x, list), "Pixtral expects a list of tensors."
split_points = self.__get_num_output_tokens__(x)
# Pixtral expects list of images
model_out = self.model(x).last_hidden_state
split_output, mask = self.split_and_pad_output(model_out, split_points)
return PixtralOutput(out=split_output, mask=mask)
def get_img_normalize(
img_encoder: ImageEncoder,
) -> Callable[..., Normalize | PixtralNormalize]:
"""Return input normalization function"""
if img_encoder == ImageEncoder.PIXTRAL:
return PixtralNormalize
if img_encoder in {
ImageEncoder.SIGLIP_GEMMA2_224,
ImageEncoder.SIGLIP_GEMMA2_448,
ImageEncoder.SIGLIP_GEMMA2_896,
}:
return SigLIPNormalize
raise NotImplementedError("Unknown image encoder", img_encoder.name)
def load_image_encoder(
img_encoder: ImageEncoder,
device: torch.device | str = "cpu",
pretrained: bool = False,
) -> torch.nn.Module:
"""Load Image encoder as a torch module"""
if img_encoder == ImageEncoder.PIXTRAL:
return PixtralWrapper(device=device, pretrained=pretrained)
if img_encoder in {
ImageEncoder.SIGLIP_GEMMA2_224,
ImageEncoder.SIGLIP_GEMMA2_448,
ImageEncoder.SIGLIP_GEMMA2_896,
}:
size = int(img_encoder.name.rsplit("_", 1)[-1])
return load_paligemma_vision_encoder(
name=f"google/paligemma2-3b-pt-{size}",
device=device,
pretrained=pretrained,
)
raise NotImplementedError(f"image encoder {img_encoder.name} not recognized")
================================================
FILE: kyuteye_pt/kyuteye/modules/image_transforms.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Image transforms"""
from typing import List, Literal, Optional, Sequence, Tuple, Union
import torch
import torchvision.transforms.v2 as T
from PIL.Image import Image
try:
from transformers.models.pixtral.image_processing_pixtral import (
PixtralImageProcessor,
)
except ImportError:
print("Cannot find Pixtral encoder, you need to upgrade to transformers >= 0.46")
def get_minimal_transforms(
img_size: Union[Tuple[int], int] = 224,
interpolation: Literal[
"bicubic", "bilinear", "nearest", "nearest_exact"
] = "bicubic",
to_tensor: bool = False,
keep_aspect_ratio: bool = False,
max_img_size: Optional[int] = None,
) -> T.Transform:
"""Minimal transform from converting a PIL image to a Tensor.
This is used as default in most of our datasets when the img_transforms
is not provided. If keep_aspect_ratio is False, it resizes the image to (img_size, img_size),
without respecting aspect ratio. Otherwise, it only resizes the smaller side to img_size,
:param img_size: Target image (square) size
:param interpolation: Resizing interpolation
:param to_tensor: Whether to also convert to a Tensor type (or leave it to a later transform)
:param keep_aspect_ratio: Whether to keep the aspect ratio
:param max_img_size: Maximum size an image can be along the longer side
"""
if not isinstance(img_size, tuple):
img_size = img_size if keep_aspect_ratio else (img_size, img_size) # type: ignore
return T.Compose(
[
T.Resize(
img_size,
interpolation=getattr(T.InterpolationMode, interpolation.upper()),
max_size=max_img_size if keep_aspect_ratio else None,
),
T.PILToTensor() if to_tensor else T.Identity(),
]
)
class Normalize:
"""Normalization types for the different image encoders. These will be
set in image_projection.py"""
def __init__(self, mean: Sequence[float], std: Sequence[float]) -> None:
super().__init__()
self.std = std
self.mean = mean
self.transform = T.Compose(
[
T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)]),
T.Normalize(
mean,
std,
),
]
)
def __call__(
self, image: Union[Image, List[Image]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
if isinstance(image, list):
return [self.transform(img) for img in image]
return self.transform(image)
def to_pil_transform(self, mode: str = "RGB") -> T.Transform:
"""Returns the function that inverts this normalization"""
return T.Compose(
[
T.Normalize([0 for _ in self.mean], [1 / (x + 1e-6) for x in self.std]),
T.Normalize([-x for x in self.mean], [1 for _ in self.std]),
T.ToPILImage(mode=mode),
]
)
class UnitNormalize(Normalize):
"""Normalization for SigLIP encoder"""
def __init__(self) -> None:
super().__init__(
mean=(0.0, 0.0, 0.0),
std=(1.0, 1.0, 1.0),
)
class CLIPNormalize(Normalize):
"""Normalization for SigLIP encoder"""
def __init__(self) -> None:
super().__init__(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
class SigLIPNormalize(Normalize):
"""Normalization for SigLIP encoder"""
def __init__(self) -> None:
super().__init__(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
class PixtralNormalize:
"""Image preprocessing for Pixtral
https://github.com/huggingface/transformers/blob/fc1ae7f30f1d16c7652c28dd8d91c5d8a8ed2f15/src/transformers/models/pixtral/image_processing_pixtral.py#L369
"""
def __init__(self) -> None:
self.preprocess = PixtralImageProcessor.from_pretrained(
"mistral-community/pixtral-12b"
)
def __call__(
self, image: Union[Image, List[Image], torch.Tensor]
) -> Union[torch.Tensor, List[torch.Tensor]]:
# Image input to pixtral can be an individual image or a list of
# images or a list of lists of images (multiple images in a single text sequence)
# Case 1: Single image
if isinstance(image, Image) or isinstance(image, torch.Tensor):
# Pixtral converts a single image into [[image]] — a list of lists of images
return self.preprocess(image, return_tensors="pt")["pixel_values"][0][0]
# Case2: List of images
if isinstance(image, list) and isinstance(image[0], Image):
return [
self.preprocess(subimg, return_tensors="pt")["pixel_values"][0][0]
for subimg in image
]
# Case 3: List of lists of images
# This is not currently supported by us
raise ValueError(
"PixtralNormalize does not support list of lists of images currently."
)
================================================
FILE: kyuteye_pt/kyuteye/modules/streaming_utils.py
================================================
# pylint: disable=protected-access
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Common API for streaming modules during inference"""
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, Optional
import torch
State = Dict[str, float | int | torch.Tensor]
class StreamingModule(torch.nn.Module):
"""Common API for streaming components."""
def __init__(self) -> None:
super().__init__()
self._streaming_state: State = {}
self._is_streaming = False
@property
def empty_streaming_state(self) -> bool:
"""whether streaming state is empty"""
return len(self._streaming_state) == 0
def has_streaming_attribute(self, key: str) -> bool:
"""Whether `key` exists in the current streaming state"""
return self._is_streaming and key in self._streaming_state
def add_streaming_attribute(
self, key: str, value: float | int | torch.Tensor
) -> None:
"""Add `value` into streaming state's `key`"""
self._streaming_state[key] = value
def get_streaming_attribute(self, key: str, default: Any = None) -> Any:
"""Add `value` into streaming state's `key`"""
return self._streaming_state.get(key, default)
@property
def is_streaming(self) -> bool:
"""in streaming mode"""
return self._is_streaming
def get_streaming_info_as_int(self, attr_name: str, default: int = 0) -> int:
"""Tries to get attr_name as an integer"""
if self._is_streaming and attr_name in self._streaming_state:
if isinstance(self._streaming_state[attr_name], int):
return self._streaming_state[attr_name] # type: ignore
if isinstance(self._streaming_state[attr_name], torch.Tensor):
return int(self._streaming_state[attr_name].item()) # type: ignore
raise ValueError(
f"Unexpected type {type(self._streaming_state[attr_name])} in streaming state"
)
return default
@property
def streaming_offset(self) -> int:
"""Shortcut to get the current temporal offset in streaming mode"""
return self.get_streaming_info_as_int("offset", default=0)
@streaming_offset.setter
def streaming_offset(self, value: int | torch.Tensor) -> None:
if not self._is_streaming:
raise NotImplementedError(
"Updating streaming offset of a non-streaming module"
)
self._streaming_state["offset"] = value # type: ignore
def _apply_named_streaming(self, fn: Callable) -> None:
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def _set_streaming(self, streaming: bool) -> None:
def _set_streaming(_: str, module: StreamingModule) -> None:
module._is_streaming = streaming
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self) -> Iterator:
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(True)
try:
yield
finally:
self._set_streaming(False)
self.reset_streaming()
def streaming_forever(self, batch_size: Optional[int] = None) -> None:
"""Set in permanent streaming state"""
del batch_size
self._set_streaming(True)
def reset_streaming(self) -> None:
"""Reset the streaming state."""
def _reset(_: str, module: StreamingModule) -> None:
module._streaming_state.clear()
self._apply_named_streaming(_reset)
def get_streaming_state(self) -> State:
"""Return the streaming state, including that of sub-modules."""
state: State = {}
def _add(name: str, module: StreamingModule) -> None:
if name:
name += "."
for key, value in module._streaming_state.items():
state[name + key] = value
self._apply_named_streaming(_add)
return state
def set_streaming_state(self, state: State) -> None:
"""Set the streaming state, including that of sub-modules."""
state = dict(state)
def _set(name: str, module: StreamingModule) -> None:
if name:
name += "."
module._streaming_state.clear()
for key, value in list(state.items()):
# complexity is not ideal here, but probably fine.
if key.startswith(name):
local_key = key[len(name) :]
if "." not in local_key:
module._streaming_state[local_key] = value
del state[key]
self._apply_named_streaming(_set)
assert len(state) == 0, list(state.keys())
def flush(self, x: Optional[torch.Tensor] = None) -> Optional["StreamingModule"]:
"""Flush any remaining outputs that were waiting for completion.
Typically, for convolutions, this will add the final padding
and process the last buffer.
This should take an optional argument `x`, which will be provided
if a module before this one in the streaming pipeline has already
spitted out a flushed out buffer.
"""
if x is None:
return None
return self(x)
================================================
FILE: kyuteye_pt/kyuteye/modules/transformer.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Transformer base"""
from typing import Any, List, Literal, Optional, Tuple
import torch
from kyuteye.modules.attention import MultiheadAttention
from kyuteye.modules.cross_attention import GatedCrossAttention
from kyuteye.modules.streaming_utils import StreamingModule
from kyuteye.modules.utils import (
NormalizationLayer,
RotaryEmbedding,
create_sin_embedding,
get_activation,
make_ffn,
)
class TransformerLayer(StreamingModule):
"""Base TransformerLayer with causal support.
This also integrates cross_attention, when passing `cross_attention=True`,
rather than having two separate classes like in PyTorch.
:param d_model: Dimension of the data.
:param num_heads: Number of heads.
:param dim_feedforward: Intermediate dimension of FF module.
param causal: Causal mask applied automatically.
:param context: Receptive field for the causal mask, infinite if None.
:param custom: Use custom MHA implementation, for testing / benchmarking.
:param cross_attention: If True, expect to get secondary input for cross-attention.
Cross attention will use the default MHA, as it typically won't require
special treatment.
:param rope: Optional Rope embedding to use.
:param norm: Normalization to use. Currently, only 'layer_norm' is supported.
:param layer_scale: If not None, LayerScale will be used with the given value as initial scale.
:param weights_per_step: use different weights per depformer step. If non zero,
should correspond to the number of possible time steps.
:param gating: if True, uses SwiGLU like gating in the FFN
:param activation: Activation function to use in the FFN layer
:param device: Device on which to initialize the module.
:param dtype: Dtype to use.
"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int | List[int] = 2048,
causal: bool = True,
context: Optional[int] = None,
cross_attention: bool = False,
rope: Optional[RotaryEmbedding] = None,
norm: Literal[
"layer_norm",
"layer_norm_f32",
"rms_norm",
"rms_norm_f32",
"real_rms_norm",
"real_rms_norm_f32",
] = "layer_norm",
weights_per_step: int = 0,
gating: bool = True,
activation: Literal[
"none",
"identity",
"sigmoid",
"tanh",
"relu",
"leaky_relu",
"elu",
"gelu",
"silu",
"mish",
"softsign",
] = "silu",
# Cross attention to image tokens parameters
xa_gating: Literal["sigmoid", "tanh", "none"] = "tanh",
xa_conditional_gating: bool = False,
xa_shared: bool = True,
xa_gate_shared: bool = False,
xa_delay: int = 0,
xa_start: Optional[Literal["start", "boi", "eoi"]] = None,
xa_end: Optional[Literal["end", "eoi"] | int] = None,
xa_step: int = 1,
xa_dim: Optional[int] = None,
# Factory kwargs
device: Optional[str | torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
assert norm.upper() in [x.name for x in NormalizationLayer]
factory_kwargs = {"device": device, "dtype": dtype}
self.causal = causal
self.self_attn: MultiheadAttention = MultiheadAttention(
causal=causal,
context=context,
rope=rope,
weights_per_step=weights_per_step,
embed_dim=d_model,
num_heads=num_heads,
**factory_kwargs, # type: ignore
) # type: ignore
self.norm1 = getattr(NormalizationLayer, norm.upper()).create_norm_fn(
d_model, **factory_kwargs
)
# Cross attention (optional)
self.cross_attention: Optional[torch.nn.Module] = None
if cross_attention:
assert xa_start is not None and xa_end is not None
self.cross_attention = GatedCrossAttention(
xa_gating=xa_gating,
xa_conditional_gating=xa_conditional_gating,
xa_shared=xa_shared,
xa_gate_shared=xa_gate_shared,
xa_delay=xa_delay,
xa_start=xa_start,
xa_end=xa_end,
xa_step=xa_step,
embed_dim=d_model,
xa_dim=xa_dim,
num_heads=num_heads,
**factory_kwargs, # type: ignore
)
self.norm_cross = getattr(NormalizationLayer, norm.upper()).create_norm_fn(
d_model, **factory_kwargs
)
# gating = FFN/MLP
self.activation = get_activation(activation)
self.gating = make_ffn(
d_model,
dim_feedforward,
self.activation,
gating=gating,
weights_per_step=weights_per_step,
**factory_kwargs,
)
self.norm2 = getattr(NormalizationLayer, norm.upper()).create_norm_fn(
d_model, **factory_kwargs
)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
"""Feed forward block"""
x_orig = x
x = self.norm2(x)
if isinstance(self.gating, torch.nn.ModuleList):
# Inference
ys: List[torch.Tensor] = []
for t in range(x.shape[1]):
y = self.gating[self.streaming_offset + t](x[:, len(ys) : len(ys) + 1])
ys.append(y)
update = torch.cat(ys, dim=1)
else:
# Training: Apply all levels in parallel
update = self.gating(x)
return x_orig + update
def _maybe_cross_attend(
self,
x: torch.Tensor,
cross_attention_src: Optional[Tuple[torch.Tensor, torch.Tensor] | torch.Tensor],
cross_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Cross attention"""
if self.cross_attention is not None and cross_attention_src is not None:
x_orig = x
update, gate_weight = self.cross_attention(
self.norm_cross(x),
cross_attention_src,
None,
cross_attention_mask,
)
return x_orig + update, gate_weight
return x, None
def _self_attend(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Self-attention"""
x_orig = x
update = self.self_attn(
self.norm1(x),
attention_mask=attention_mask,
)
return x_orig + update
def forward(
self,
x: torch.Tensor,
cross_attention_src: Optional[
Tuple[torch.Tensor, torch.Tensor] | torch.Tensor
] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""(Optional) MHSA => (Optional) cross-attention => FFN
:param x: Input query tensor
:param cross_attention_src: Optional tensor used as source for the
keys and values in the cross-attention
:param cross_attention_mask: Mask for the cross-attention
:param attention_mask: Attention mask for the self-attention. Can be
used to mask out tokens that shouldn't be involved in the
self-attention computation
:param image_tokens_mask: Mask indicating where the image tokens
are in the stream with shape (B, seq, 1). This is used in two places:
* (i) in the cross-attention, to determine which token should
cross-attend to the image
* (ii) in the self-attention, to allow the attention mask to
be non causal inside the image tokens
"""
if self.is_streaming and not self.has_streaming_attribute("offset"):
self.streaming_offset = 0
x = self._self_attend(x, attention_mask=attention_mask)
x, gate_weight = self._maybe_cross_attend(
x,
cross_attention_src=cross_attention_src,
cross_attention_mask=cross_attention_mask,
)
x = self._ff_block(x)
# Update streaming offset for the multi linear FFNs in the depformer
if self.is_streaming:
self.streaming_offset += x.shape[1]
return x, gate_weight
class Transformer(StreamingModule):
"""Transformer with Streaming / Causal support.
:param d_model: Dimension of the data.
:param num_heads: Number of heads.
:param num_layers: Number of transformer layers
:param dim_feedforward: Intermediate dimension of FF module.
:param causal: If True, automatically applies a causal mask.
:param context: Size of the receptive field for the causal mask.
If None, assumes infinite context.
:param cross_attention: If True, `forward` will expect to get
secondary input for cross-attention.
:param xa_layers: If a non-empty tuple, specified which layers
to add cross-attention layers to. If None or an empty tuple,
and cross_attention is True, will apply cross attention
in every layer
:param positional_embedding: Positional embedding strategy
(sin, rope, sin_rope, or none).
:param max_period: Maximum period for the sin/cos in RoPE embedding.
:param positional_scale: Scale of positional embedding, set to 0 to deactivate.
:parma device: Device on which to initialize the model.
:param dtype: Device type to use.
**kwargs: Extra arguments fed to the `TransformerLayer` constructor (e.g. layer_scale)
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int | list[int] = 2048,
causal: bool = False,
context: Optional[int] = None,
cross_attention: bool = False,
xa_layers: Optional[Tuple[int, ...]] = None,
positional_embedding: Literal["none", "sin", "rope", "sin_rope"] = "sin",
max_period: float = 10000,
positional_scale: float = 1.0,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs: Any,
) -> None:
super().__init__()
assert d_model % num_heads == 0
self.positional_embedding = positional_embedding
self.max_period = max_period
self.positional_scale = positional_scale
assert positional_embedding in {"sin", "rope", "sin_rope", "none"}
self.rope: Optional[RotaryEmbedding] = None
if self.positional_embedding in {"rope", "sin_rope"}:
self.rope = RotaryEmbedding(max_period=max_period)
self.layers = torch.nn.ModuleList()
for layer_idx in range(num_layers):
cross_attend_layer = (
xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers
)
self.layers.append(
TransformerLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
causal=causal,
context=context,
cross_attention=cross_attention and cross_attend_layer,
rope=self.rope,
device=device,
dtype=dtype,
**kwargs,
)
)
def set_context(self, context: Optional[int] = None) -> None:
"""Update context size in all MHSA layers"""
for module in self.modules():
if isinstance(module, MultiheadAttention):
module.context = context
def forward(
self, x: torch.Tensor, *args: Any, **kwargs: Any
) -> Tuple[torch.Tensor, float]:
"""Forward pass"""
_, seq_len, channels = x.shape
if self.positional_embedding in {"sin", "sin_rope"}:
positions = torch.arange(seq_len, device=x.device).view(1, -1, 1)
pos_emb = create_sin_embedding(
positions, channels, max_period=self.max_period, dtype=x.dtype
)
x = x + self.positional_scale * pos_emb
alpha = 0.0
for layer_idx, layer in enumerate(self.layers):
x, gate_weight = layer(x, *args, **kwargs)
if gate_weight is not None and layer_idx >= len(self.layers) - 10:
alpha += torch.mean(gate_weight).cpu().item()
return x, alpha / min(10, len(self.layers))
================================================
FILE: kyuteye_pt/kyuteye/modules/utils.py
================================================
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Smaller building blocks:
* FFN layers with Swi-GLU like gating
* Normalization layers
* Input embedding layer
* Positional embeddings
"""
from enum import Enum, unique
from typing import Any, Callable, List, Literal, Optional, Tuple
import torch
def multi_linear(
num_linear: int, weight: torch.Tensor, x: torch.Tensor, offset: int = 0
) -> torch.Tensor:
"""Utility to apply a multi linear layer to the given input. A multi linear layer
applies a different set of weight for each time step.
Args:
num_linear (int): Number of possible time steps and so number of linears.
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
offset (int): offset for the current time step, in particular for decoding, with
time steps provided one by one.
"""
ys = []
# when calling the depformer, x.shape[1] is always 1, and the offset contains the
# codebook index we care about
for t in range(x.shape[1]):
y = torch.nn.functional.linear( # pylint: disable=not-callable
x[:, t], weight.chunk(num_linear)[offset + t]
)
ys.append(y)
out = torch.stack(ys, 1)
return out
def get_activation(
name: Literal[
"sigmoid",
"tanh",
"relu",
"leaky_relu",
"elu",
"gelu",
"silu",
"mish",
"softsign",
"identity",
"none",
],
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Return correct activation function from the given name"""
if name in {"sigmoid", "tanh", "relu"}:
return getattr(torch, name)
if name in {"leaky_relu", "elu", "gelu", "silu", "mish", "softsign"}:
return getattr(torch.nn.functional, name)
if name in {"identity", "none"}:
return torch.nn.Identity()
raise NotImplementedError(f"Unknown activation {name}")
def gating_forward_kernel(
weight_in: torch.Tensor,
weight_out: torch.Tensor,
activation: Callable[[torch.Tensor], torch.Tensor],
x: torch.Tensor,
) -> torch.Tensor:
"""Simple multiplicative gating strategy (SwiGLU like)"""
x = torch.nn.functional.linear(x, weight_in) # pylint: disable=not-callable
batch_size, seq_len, _ = x.shape
x = x.view(batch_size, seq_len, 2, -1)
x = activation(x[..., 0, :]) * x[..., 1, :]
x = torch.nn.functional.linear(x, weight_out) # pylint: disable=not-callable
return x
class ActivationGating(torch.nn.Module):
"""
FFN layer with multiplicative gating using the given activation.
:param dim: Dimensions of the tokens.
:param activation: Activation function to use.
:param factory_kwargs: Other kwargs passed to the linear layer, in particular device and dtype.
"""
def __init__(
self,
dim: int,
dim_feedforward: int,
activation: Callable[[torch.Tensor], torch.Tensor],
**factory_kwargs: Any,
):
super().__init__()
# We should have 8 d^2 param, instead we will have
# 2 * h * d + h * d = 3 h * d = 8 d^2
# so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
if dim_feedforward == 4 * dim:
hidden = (21 * dim) // 8
else:
hidden = (2 * dim_feedforward) // 3
self.linear_in = torch.nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
self.linear_out = torch.nn.Linear(hidden, dim, bias=False, **factory_kwargs)
self.activation = activation
max_params = 2 * dim * dim_feedforward
params = sum(p.numel() for p in self.parameters())
assert params <= max_params, f"Gating has {params} params, max is {max_params}"
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""FFN with Swi-GLU like gating but customizable activation"""
return gating_forward_kernel(
self.linear_in.weight, self.linear_out.weight, self.activation, x
)
class NoGating(torch.nn.Module):
"""
Simple 2 layer MLP FFN layer
"""
def __init__(
self,
dim: int,
dim_feedforward: int,
activation: Callable[[torch.Tensor], torch.Tensor],
**factory_kwargs: Any,
):
super().__init__()
self.linear1 = torch.nn.Linear(
dim, dim_feedforward, bias=False, **factory_kwargs
)
self.linear2 = torch.nn.Linear(
dim_feedforward, dim, bias=False, **factory_kwargs
)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Simple two layers MLP FFN"""
return self.linear2(self.activation(self.linear1(x)))
def make_ffn(
dim: int,
dim_feedforward: int | List[int],
activation_fn: Callable[[torch.Tensor], torch.Tensor],
gating: bool = False,
weights_per_step: int = 0,
**factory_kwargs: Any,
) -> torch.nn.Module:
"""Create a FNN module
:param dim: Number of input dimensions
:param dim_feedforward: Nubmer of inner dimensions
:param activation_fn: Activation function
:param gating: If True, uses FFN with multiplicative gating
:param factory_kwargs: Any extra argument fed to the Linear layer
constructors (e.g. device, dtype)
"""
ffn: torch.nn.Module
if gating:
if weights_per_step > 0:
if isinstance(dim_feedforward, int):
dim_feedforward = [dim_feedforward] * weights_per_step
assert isinstance(dim_feedforward, list), dim_feedforward
ffn = torch.nn.ModuleList(
[
ActivationGating(dim, dim_out, activation_fn, **factory_kwargs)
for dim_out in dim_feedforward
]
)
else:
assert isinstance(dim_feedforward, int)
ffn = ActivationGating(
dim, dim_feedforward, activation_fn, **factory_kwargs
)
else:
assert isinstance(dim_feedforward, int)
assert (
weights_per_step == 0
), f"weights per step {weights_per_step} > 0 is not supported without gated FFN"
ffn = NoGating(dim, dim_feedforward, activation_fn, **factory_kwargs)
return ffn
class LayerNormF32(torch.nn.LayerNorm):
"""Layer norm executed in Float32 for maximal precision"""
def forward(
self, input: torch.Tensor # pylint: disable=redefined-builtin
) -> torch.Tensor:
"""Applies the layer norm"""
x_f32 = input.float()
out_f32 = super().forward(x_f32)
return out_f32.to(input.dtype)
def _rms_norm(
x: torch.Tensor,
alpha: torch.Tensor,
dtype: Optional[torch.dtype],
eps: float,
use_var: bool,
return_factor: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}"
in_dtype = x.dtype
if dtype is not None:
x = x.to(dtype)
if use_var:
var = eps + x.var(dim=2, keepdim=True)
else:
var = eps + torch.mean(x**2, dim=2, keepdim=True)
if return_factor:
factor = alpha.to(var) * torch.rsqrt(var)
return (x * factor).to(in_dtype), factor.to(in_dtype)
return (x * (alpha.to(var) * torch.rsqrt(var))).to(in_dtype)
class RMSNorm(torch.nn.Module):
"""RMSNorm layer
:param dim: Input channels dimension
:param eps: Epsilon
"""
def __init__(
self,
dim: int,
eps: float = 1e-5,
use_var: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.eps = eps
self.dtype = dtype
self.use_var = use_var
self.alpha = torch.nn.Parameter(
torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""RMS Norm call"""
out = _rms_norm(x, self.alpha, self.dtype, self.eps, self.use_var)
return out # type: ignore[return-value]
@unique
class NormalizationLayer(Enum):
"""Select one of several normalization layers"""
LAYER_NORM = 0
LAYER_NORM_F32 = 1
RMS_NORM = 2
RMS_NORM_F32 = 3
REAL_RMS_NORM = 4
REAL_RMS_NORM_F32 = 5
def create_norm_fn(self, dim: int, **kwargs: Any) -> torch.nn.Module:
"""Return the proper normalization layer initializer"""
# Layer Norm
if self == NormalizationLayer.LAYER_NORM:
return torch.nn.LayerNorm(dim, eps=1e-5, **kwargs)
if self == NormalizationLayer.LAYER_NORM_F32:
return LayerNormF32(
dim, eps=1e-8, **{k: v for k, v in kwargs.items() if k != "dtype"}
)
# Real RMS Norm using |x**2| normalization
if self == NormalizationLayer.REAL_RMS_NORM:
return RMSNorm(dim, eps=1e-5, use_var=False, **kwargs)
if self == NormalizationLayer.REAL_RMS_NORM_F32:
return RMSNorm(
dim,
eps=1e-8,
dtype=torch.float32,
use_var=False,
**{k: v for k, v in kwargs.items() if k != "dtype"},
)
# RMS Norm using variance of the data
if self == NormalizationLayer.RMS_NORM:
return RMSNorm(dim, eps=1e-5, **kwargs)
if self == NormalizationLayer.RMS_NORM_F32:
return RMSNorm(
dim,
eps=1e-8,
dtype=torch.float32,
**{k: v for k, v in kwargs.items() if k != "dtype"},
)
raise NotImplementedError(f"Unknown norm type: {self.name}")
class ClampedEmbedding(torch.nn.Embedding):
"""An embedding layer such that all input IDs of the ID `zero_idx < 0`
are mapped to zero at the output of the module
Args:
lr (float or None): Learning rate for the embedding layer if provided.
norm (bool): if True, uses a layer norm after the embedding.
zero_idx (int): special value indicating that the output should be exactly 0.
"""
def __init__(
self, *args: Any, norm: bool = False, zero_idx: int = -1, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.norm = None
if norm:
self.norm = NormalizationLayer.LAYER_NORM.create_norm_fn(self.embedding_dim)
assert zero_idx < 0, "Please use negative values for the zero_idx."
self.zero_idx = zero_idx
def forward( # pylint: disable=arguments-renamed
self, inputs: torch.Tensor
) -> torch.Tensor:
"""Embed the input IDs"""
is_zero = inputs == self.zero_idx
zero = torch.zeros(1, dtype=inputs.dtype, device=inputs.device)
y = super().forward(inputs.clamp(min=0))
if self.norm is not None:
y = self.norm(y)
y = torch.where(is_zero[..., None], zero, y)
return y
def create_sin_embedding(
positions: torch.Tensor,
dim: int,
max_period: float = 10000,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create fixed sinusoidal positional embedding, with shape `[B, T, C]`.
Args:
positions (torch.Tensor): LongTensor of positions.
dim (int): Dimension of the embedding.
max_period (float): Maximum period of the cosine/sine functions.
dtype (torch.dtype or str): dtype to use to generate the embedding.
Returns:
torch.Tensor: Sinusoidal positional embedding.
"""
# Assumes BTC format
assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full(
[], max_period, device=positions.device, dtype=dtype
) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
max_period: float = 10_000,
offset: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE embedding to he input queries/keys
:param q: queries, shape `[B, T, H, D]`.
:param k: keys, shape `[B, T, H, D]`.
:param max_period: maximum period for the cos and sin (aka. `theta_rope`).
"""
batch, seq_length, num_heads, dim = q.shape
assert k.shape == q.shape
ds = torch.arange(dim // 2, device=q.device, dtype=torch.float32)
max_period_t = torch.full([1], max_period, device=q.device, dtype=torch.float32)
freqs = 1.0 / (max_period_t ** (2 * ds / dim))
ts = torch.arange(
offset, seq_length + offset, device=q.device, dtype=torch.float32
).view(-1, 1, 1)
q = q.view(batch, seq_length, num_heads, dim // 2, 2)
k = k.view(batch, seq_length, num_heads, dim // 2, 2)
# convention is `r` suffix is real part, `i` is imaginary.
qr = q[..., 0].float()
qi = q[..., 1].float()
kr = k[..., 0].float()
ki = k[..., 1].float()
rotr = torch.cos(freqs * ts)
roti = torch.sin(freqs * ts)
qor = qr * rotr - qi * roti
qoi = qr * roti + qi * rotr
kor = kr * rotr - ki * roti
koi = kr * roti + ki * rotr
dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
return qo.view(batch, seq_length, num_heads, dim), ko.view(
batch, seq_length, num_heads, dim
)
class RotaryEmbedding(torch.nn.Module):
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
:param max_period: Maximum period of the rotation frequencies (aka `theta_rope`).
"""
def __init__(self, max_period: float = 10000.0) -> None:
super().__init__()
self.max_period = max_period
def forward(
self, q: torch.Tensor, k: torch.Tensor, offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rope rotation to query or key tensor with an `offset` on temporal positions."""
return apply_rope(q, k, max_period=self.max_period, offset=offset)
================================================
FILE: kyuteye_pt/kyuteye/server.py
================================================
# pylint: disable=protected-access,no-member
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Start Pytorch backend"""
import asyncio
import os
import random
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple
import aiohttp
import fire
import numpy as np
import sentencepiece
import sphn
import torch
from aiohttp import web
from huggingface_hub import hf_hub_download
from kyuteye.config.enums import ImageEncoder
from kyuteye.config.kyuteye_config import KyuteyeConfig
from kyuteye.models.loaders import get_moshi_vis
from kyuteye.modules.image_transforms import get_minimal_transforms
from moshi.models.loaders import get_mimi
from torchvision.io import ImageReadMode, decode_image
if TYPE_CHECKING:
from kyuteye.models.image_projection import ImageProjection
from kyuteye.models.moshivis import MoshiVisGen
from moshi.models import MimiModel
def colorize(text: str, color: str) -> str:
"""Add colors to log"""
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def make_log(level: str, msg: str) -> str:
"""Create log"""
if level == "warning":
prefix = colorize("[Warn]", "1;31")
elif level == "info":
prefix = colorize("[Info]", "1;34")
elif level == "error":
prefix = colorize("[Err ]", "1;31")
else:
raise ValueError(f"Unknown level {level}")
return prefix + " " + msg
def log(level: str, msg: str) -> None:
"""Log with colors"""
print(make_log(level, msg))
def seed_all(seed: int) -> None:
"""Seed"""
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@dataclass
class ServerState:
"""Main server state"""
mimi: "MimiModel"
text_tokenizer: sentencepiece.SentencePieceProcessor
moshi_vis: "MoshiVisGen"
image_encoder_model: "ImageProjection"
image_size: int
xa_start: int
lock: asyncio.Lock
dtype: torch.dtype
display_gating: bool
def __init__(
self,
mimi: "MimiModel",
text_tokenizer: sentencepiece.SentencePieceProcessor,
moshi_vis: "MoshiVisGen",
image_encoder_model: "ImageProjection",
device: str | torch.device,
dtype: torch.dtype = torch.bfloat16,
max_msg_size: int = 0,
image_size: int = 448,
xa_start: int = 0,
):
self.mimi = mimi
self.text_tokenizer = text_tokenizer
self.moshi_vis = moshi_vis
self.image_encoder_model = image_encoder_model
self.embeddings: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | None = None
self.max_msg_size = max_msg_size
self.image_size = image_size
self.xa_start = xa_start
self.display_gating = True
self.device = device
self.dtype = dtype
self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
self.lock = asyncio.Lock()
self.mimi.streaming_forever(1)
self.moshi_vis.streaming_forever(1)
def warmup(self) -> None:
"""Warmup the models"""
for _ in range(4):
chunk = torch.zeros(
1, 1, self.frame_size, dtype=torch.float32, device=self.device
)
codes = self.mimi.encode(chunk)
ca_src = self.image_encoder_model(
torch.zeros(1, 3, 224, 224, device=self.device)
)["cross_attention_src"]
for c in range(codes.shape[-1]):
tokens, _ = self.moshi_vis.step(codes[:, :, c : c + 1], ca_src=ca_src)
if tokens is None:
continue
_ = self.mimi.decode(tokens[:, 1:])
torch.cuda.synchronize()
self.mimi.reset_streaming()
self.moshi_vis.reset_streaming()
async def handle_chat(self, request: Any) -> Any:
"""start conversation"""
ws = web.WebSocketResponse(max_msg_size=self.max_msg_size)
await ws.prepare(request)
close = False
async def recv_loop() -> None:
nonlocal close
try:
async for message in ws:
if message.type == aiohttp.WSMsgType.ERROR:
log("error", f"{ws.exception()}")
break
if message.type == aiohttp.WSMsgType.CLOSED:
log("info", "closed received")
break
if message.type != aiohttp.WSMsgType.BINARY:
log("error", f"unexpected message type {message.type}")
continue
message = message.data
if not isinstance(message, bytes):
log("error", f"unsupported message type {type(message)}")
continue
if len(message) == 0:
log("warning", "empty message")
continue
kind = message[0]
if kind == 1: # audio
payload = message[1:]
opus_reader.append_bytes(payload)
elif kind == 10:
log("info", f"received user rating {message[1]}")
else:
log("warning", f"unknown message kind {kind}")
except Exception as e:
print("Exception raised:", e)
finally:
close = True
log("info", "connection closed")
async def opus_loop() -> None:
all_pcm_data = None
while True:
if close:
return
await asyncio.sleep(0.001)
pcm = opus_reader.read_pcm()
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= self.frame_size:
be = time.time()
chunk = all_pcm_data[: self.frame_size]
all_pcm_data = all_pcm_data[self.frame_size :]
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens, gate_weight = self.moshi_vis.step(
codes[:, :, c : c + 1],
ca_src=(
self.embeddings
if self.moshi_vis.get_streaming_attribute("offset", 0)
>= self.xa_start
else None
),
)
if tokens is None:
continue
assert (
tokens.shape[1]
== self.moshi_vis.num_audio_codebooks_out + 1
)
main_pcm = self.mimi.decode(tokens[:, 1:])
main_pcm = main_pcm.cpu()
opus_writer.append_pcm(main_pcm[0, 0].numpy())
text_token = tokens[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
text_color = round(
max(min((gate_weight - 0.005) / 0.016, 1.0), 0.0) * 10
)
msg = (
b"\x07"
+ text_color.to_bytes(1, "big")
+ bytes(_text, encoding="utf8")
)
log("info", f"text token '{_text}'")
await ws.send_bytes(msg)
log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms")
async def send_loop() -> None:
while True:
if close:
return
await asyncio.sleep(0.001)
msg = opus_writer.read_bytes()
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)
log("info", "accepted connection")
close = False
query_parameters = request.rel_url.query
self.moshi_vis.update_gen_kwargs(
temp_text=(
float(aux)
if (aux := query_parameters.get("text_temperature", None)) is not None
else None
),
temp=(
float(aux)
if (aux := query_parameters.get("audio_temperature", None)) is not None
else None
),
top_k_text=(
int(aux)
if (aux := query_parameters.get("text_topk", None)) is not None
else None
),
top_k=(
int(aux)
if (aux := query_parameters.get("audio_topk", None)) is not None
else None
),
)
self.image_size = (
int(aux)
if (aux := query_parameters.get("image_resolution", None)) is not None
else self.image_size
)
if (aux := query_parameters.get("xa_start", None)) is not None:
self.xa_start = int(aux)
async with self.lock:
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) # type: ignore
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) # type: ignore
self.mimi.reset_streaming()
self.moshi_vis.reset_streaming()
await self.extract_image(ws)
# Send the handshake.
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop())
log("info", "done with connection")
return ws
async def extract_image(self, ws: web.WebSocketResponse) -> None:
"""Embed imageat the beginning of the stream"""
first_message = await ws.receive()
first_message = first_message.data
try:
kind = first_message[0]
except Exception as e:
raise RuntimeError(f"Error in message: {first_message}") from e
if kind != 8: # image
raise RuntimeError(f"unknown message kind {kind}")
payload = first_message[1:]
image_tensor = decode_image(
torch.frombuffer(payload, dtype=torch.uint8), mode=ImageReadMode.RGB
)
image_tensor = get_minimal_transforms(self.image_size)(image_tensor)
image_tensor = self.image_encoder_model.to_tensor_and_normalize(image_tensor)
log("info", f"Loaded image tensor with shape {image_tensor.shape}")
if self.image_encoder_model.encoder_type == ImageEncoder.PIXTRAL:
image_tensor = [image_tensor.to(self.device)]
else:
image_tensor = image_tensor[None, ...].to(self.device)
k, v = self.moshi_vis.precompte_ca_kv(
self.image_encoder_model(image_tensor)["cross_attention_src"]
)
self.embeddings = (k.to(self.dtype), v.to(self.dtype))
def start_server(
kyuteye_config_path: str,
host: str = "localhost",
port: int = 8998,
static: Optional[str] = None,
device: str = "cuda",
dtype: Literal["float32", "bfloat16"] = "bfloat16",
ssl: bool = True,
ssl_cert_dir: Optional[str] = None,
) -> None:
"""Start server
:param kyuteye_config: Config of the model to load
:param host: Host to start the server on
:param port: Port to start the server on
:param static: Path to the built client source. If None, defaults to the local client
:param device: Device of execution
:param dtype: Dtype of execution
:param ssl: Whether to launch on https or http protocol
:param max_img_size: Max image size (in MB) that can be
sent via aiohttp; If 0, no limit is set. Note that input images
are resized in any case before being sent to the encoder
"""
assert kyuteye_config_path is not None
root_dir = Path(__file__).parents[2]
static_path: None | str = None
if static is None:
static_path = str(root_dir / "client" / "dist")
else:
static_path = static
static_path = os.path.abspath(static_path)
assert static_path is not None and os.path.exists(static_path)
seed_all(42)
setup_tunnel = None
tunnel_token = ""
kyuteye_config = KyuteyeConfig.from_yml(kyuteye_config_path)
# Load main model components
log("info", "loading mimi")
if kyuteye_config.hf_repo is None:
assert os.path.exists(kyuteye_config.mimi_codec)
mimi_weight = kyuteye_config.mimi_codec
else:
mimi_weight = hf_hub_download(kyuteye_config.hf_repo, kyuteye_config.mimi_codec)
mimi = get_mimi(mimi_weight, device)
log("info", "mimi loaded")
if kyuteye_config.hf_repo is None:
assert os.path.exists(kyuteye_config.text_tokenizer)
text_tokenizer_path = kyuteye_config.text_tokenizer
else:
text_tokenizer_path = hf_hub_download(
kyuteye_config.hf_repo, kyuteye_config.text_tokenizer
)
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer_path) # type: ignore
log("info", "loading moshi + vision")
if kyuteye_config.hf_repo is None:
assert os.path.exists(kyuteye_config.model)
moshi_weight = kyuteye_config.model
if not moshi_weight.endswith("_pt.safetensors"):
assert moshi_weight.endswith(".safetensors")
moshi_weight = moshi_weight.replace(".safetensors", "_pt.safetensors")
print(f"Will load from {moshi_weight}")
else:
moshi_weight = hf_hub_download(kyuteye_config.hf_repo, kyuteye_config.model)
torch_dtype = getattr(torch, dtype)
moshi_vis, image_embedder = get_moshi_vis(
kyuteye_config, moshi_weight, device, torch_dtype
)
log("info", "moshi + vision loaded")
state = ServerState(
mimi=mimi,
text_tokenizer=text_tokenizer,
moshi_vis=moshi_vis,
image_encoder_model=image_embedder,
device=device,
dtype=torch_dtype,
xa_start=kyuteye_config.xa_start,
)
log("info", "warming up the model")
state.warmup()
app = web.Application()
app.router.add_get("/api/chat", state.handle_chat)
async def handle_root(_): # type: ignore
return web.FileResponse(os.path.join(static_path, "index.html"))
log("info", f"serving static content from {static_path}")
app.router.add_get("/", handle_root)
app.router.add_static("/", path=static_path, follow_symlinks=True, name="static")
protocol = "http"
ssl_context = None
if ssl:
import ssl as ssl_module
ssl_context = ssl_module.create_default_context(ssl_module.Purpose.CLIENT_AUTH)
ssl_cert_dir = ssl_cert_dir or str(root_dir)
ssl_context.load_cert_chain(
certfile=os.path.join(ssl_cert_dir, "cert.pem"),
keyfile=os.path.join(ssl_cert_dir, "key.pem"),
)
protocol = "https"
log("info", f"Access the Web UI directly at {protocol}://{host}:{port}")
if setup_tunnel is not None:
tunnel = setup_tunnel("localhost", port, tunnel_token, None)
log(
"info",
f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.",
)
log(
"info",
"Note that this tunnel goes through the US and you"
" might experience high latency in Europe.",
)
with torch.no_grad():
web.run_app(app, port=port, ssl_context=ssl_context)
def sanity_check() -> None:
pass
def main() -> None:
"""main"""
fire.Fire(start_server)
================================================
FILE: kyuteye_pt/kyuteye/utils/__init__.py
================================================
================================================
FILE: kyuteye_pt/kyuteye/utils/dist_utils.py
================================================
"""Some utils for distributed training"""
import os
from typing import Any
import torch.distributed as dist
from rich import print as rich_print
def is_main() -> bool:
"""Returns True iff the current process is the main one"""
# torch distributed
if dist.is_initialized():
return dist.get_rank() == 0
# procid
if "LOCAL_RANK" in os.environ:
return int(os.environ.get("LOCAL_RANK", 0)) == 0
return int(os.environ.get("SLURM_PROCID", 0)) == 0
def print_main(*args: Any, rich: bool = False, **kwargs: Any) -> None:
"""Print function that only activate for the main process"""
if is_main():
if rich:
rich_print(*args, **kwargs)
else:
print(*args, **kwargs)
================================================
FILE: kyuteye_pt/kyuteye/utils/logging_utils.py
================================================
"""Some utils for experiment tracking and logging"""
import json
import subprocess
from typing import Dict, Tuple
import rich
def flatten_nested_dict(d: Dict) -> Dict:
"""Flatten a nested config dictionary"""
flattened_dict = {}
for k, v in d.items():
if isinstance(v, dict):
flattened_dict.update(v)
else:
flattened_dict[k] = v
return flattened_dict
def get_git_revision_hash(verbose: bool = True) -> Tuple[str, str]:
"""Return current git branch and commit"""
git_branch = (
subprocess.check_output(["git", "branch", "--show-current"])
.decode("ascii")
.strip()
)
commit_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
)
if verbose:
rich.print(
f"[magenta]Git:[/] Commit [bold]{commit_hash}[/] on branch {git_branch}"
)
return git_branch, commit_hash
def pretty_json(config_dict: dict) -> str:
"""Pretty print the given dict as json"""
config_dict = {
k: (
v.name
if hasattr(v, "name")
else v if isinstance(v, (str, int, float)) else str(v)
)
for k, v in config_dict.items()
}
json_config_dict = json.dumps(config_dict, indent=4)
return "".join("\t" + line for line in json_config_dict.splitlines(True))
================================================
FILE: kyuteye_pt/kyuteye/utils/struct_utils.py
================================================
"""Useful structure and simple class definition
FrozenEnum are used to hold global configs shared across multiple files"""
from enum import Enum, EnumMeta
from typing import Any
class FrozenEnumMeta(EnumMeta):
"Enum metaclass that freezes an enum entirely"
def __new__(mcs, name: str, bases: Any, classdict: Any) -> type:
classdict["__frozenenummeta_creating_class__"] = True
enum = super().__new__(mcs, name, bases, classdict)
del enum.__frozenenummeta_creating_class__ # type: ignore[attr-defined]
return enum
def __setattr__(cls, name: str, value: Any) -> None:
members = cls.__dict__.get("_member_map_", {})
if hasattr(cls, "__frozenenummeta_creating_class__") or name in members:
return super().__setattr__(name, value)
if hasattr(cls, name):
msg = "{!r} object attribute {!r} is read-only"
else:
msg = "{!r} object has no attribute {!r}"
raise AttributeError(msg.format(cls.__name__, name))
def __delattr__(cls, name: str) -> None:
members = cls.__dict__.get("_member_map_", {})
if hasattr(cls, "__frozenenummeta_creating_class__") or name in members:
return super().__delattr__(name)
if hasattr(cls, name):
msg = "{!r} object attribute {!r} is read-only"
else:
msg = "{!r} object has no attribute {!r}"
raise AttributeError(msg.format(cls.__name__, name))
class FrozenEnum(Enum, metaclass=FrozenEnumMeta):
"""Frozen Enum type used for immutable configurations"""
pass
================================================
FILE: kyuteye_pt/pyproject.toml
================================================
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "kyuteye"
version = "0.0.0"
description = "Kyutai with an 'eye'"
authors = [
{ name = "Amelie Royer", email = "amelie@kyutai.org" },
{ name = "Moritz Boehle", email = "moritz@kyutai.org" }
]
maintainers = [{ name = "Amelie Royer", email = "amelie@kyutai.org" }]
keywords = []
readme = "README.md"
license = { text = "MIT" }
requires-python = ">=3.10"
dependencies = [
"einops",
"huggingface_hub",
"moshi==0.1.0",
"numpy<2",
"sentencepiece",
"torch==2.2.0",
"torchvision==0.17.0",
"tqdm",
"transformers==4.47.0",
"triton",
"fire",
"rich",
"pyyaml",
"black",
"setuptools",
"sphn >= 0.1.4",
]
[project.scripts]
server = "kyuteye.server:main"
sanity-check = "kyuteye.server:sanity_check"
[tool.mypy]
python_version = "3.10"
[[tool.mypy.overrides]]
module = ["fire", "yaml", "setuptools", "transformers.*",
"torchvision.*", "timm.*", "sentencepiece", "moshi.*",
"huggingface_hub", "sphn"]
ignore_missing_imports = true
[tool.setuptools.packages.find]
where = ["."]
[dependency-groups]
dev = [
"mypy==1.11.2",
"pylint>=3.3.4",
]
================================================
FILE: kyuteye_pt/tests/hello.py
================================================
from transformers import AutoProcessor, AutoModelForImageTextToText
import numpy as np
import torch
import torch
from moshi.models import loaders, LMGen
import random
from pathlib import Path
from huggingface_hub import hf_hub_download
def write_weights_for_analysis(model: torch.nn.Module):
file_content = ""
if isinstance(model, torch.nn.Module):
framework_name = "torch"
params = model.state_dict().items()
else:
print("Unsupported model type")
for i, (key, value) in enumerate(params):
file_content += f"{i} {key} {value.shape}\n"
dest = Path(f"/tmp/weights_{random.randint(0, 2**16)}_{framework_name}.txt")
dest.write_text(file_content)
print("Wrote layers description into " + dest)
@torch.no_grad()
def test_weights_conversion_moshi():
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight)
lm_gen_torch = LMGen(moshi, temp=0.8, temp_text=0.7)
write_weights_for_analysis(lm_gen_torch)
if __name__ == "__main__":
test_weights_conversion_moshi()
================================================
FILE: kyuteye_rs/Cargo.toml
================================================
[workspace]
members = [
"moshi-core",
"moshi-backend",
]
resolver = "2"
[workspace.dependencies]
anyhow = "1"
axum = { version = "0.8.1", features = ["ws"] }
axum-server = { version = "0.7.1", features = ["tls-rustls"] }
base64 = "0.21.7"
bincode = "1.3.3"
byteorder = "1.5.0"
candle = { version = "0.8.3", package = "candle-core" }
candle-flash-attn = "0.8.3"
candle-nn = "0.8.3"
candle-transformers = "0.8.3"
clap = { version = "4.4.12", features = ["derive"] }
color-eyre = "0.6.2"
console_error_panic_hook = "0.1.7"
cpal = "0.15.3"
crossterm = { version = "0.27.0", features = ["event-stream"] }
cudarc = { version = "=0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
futures = "0.3.28"
futures-util = "0.3.30"
hf-hub = { version = "0.3.2", features = ["tokio"] }
http = "1.1.0"
hyper = "1.3.1"
image = "0.25.2"
js-sys = "0.3.66"
lazy_static = "1.5.0"
log = "0.4.20"
moshi = { path = "./moshi-core" }
native-tls = "0.2.11"
numpy = "0.23.0"
ogg = { version = "0.9.1", features = ["async"] }
opus = "0.3.0"
prometheus = "0.13.4"
prost = "0.12"
pyo3 = "0.23.0"
rand = { version = "0.8.5", features = ["getrandom"] }
rand_chacha = "0.3.1"
ratatui = "=0.26.0"
rayon = "1.8.1"
rcgen = "0.13.1"
regex = "1.10.3"
reqwest = { version = "0.12", features = ["stream", "json"] }
rubato = "0.15.0"
rustls = { version = "0.23.20", features = ["ring"] }
sentencepiece = "0.11.2"
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.115"
symphonia = { version = "0.5.3", features = ["all"] }
timens = "0.1.9"
tokio = { version = "1.35.1", features = ["full"] }
tokio-rustls = "0.24.1"
tokio-stream = "0.1"
tokio-tungstenite = { version = "0.21.0", features = ["rustls", "native-tls"] }
tonic = "0.11"
tonic-build = "0.11"
tower = "0.4.13"
tower-http = { version = "0.5", features = ["full"] }
tracing = "0.1.40"
tracing-appender = "0.2.3"
tracing-subscriber = "0.3.18"
tui-logger = "=0.11.1"
vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] }
webrtc = "0.10.1"
================================================
FILE: kyuteye_rs/configs/config-moshika-vis-q8.json
================================================
{
"instance_name": "foo",
"hf_repo": "kyutai/moshika-vis-candle-q8",
"lm_model_file": "$HOME/tmp/model.q8_0.gguf",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"log_dir": "$HOME/tmp/moshi-logs",
"mimi_num_codebooks": 8,
"static_dir": "../client/dist",
"addr": "0.0.0.0",
"port": 8008,
"cert_dir": "../",
"lm_config": {
"acoustic_delay": 1,
"generated_audio_codebooks": 8,
"input_audio_codebooks": 8,
"audio_vocab_size": 2049,
"text_eop_token": 0,
"text_pad_token": 3,
"text_start_token": 32000
},
"image_prefix_backbone": "Siglip448",
"image_prefix_use_rms_norm": true,
"cross_attention_gating": "ConditionalGatedSigmoid",
"cross_attention_in_dims": null
}
================================================
FILE: kyuteye_rs/configs/config-moshika-vis.json
================================================
{
"instance_name": "foo",
"hf_repo": "kyutai/moshika-vis-candle-bf16",
"lm_model_file": "$HOME/tmp/model.safetensors",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"log_dir": "$HOME/tmp/moshi-logs",
"mimi_num_codebooks": 8,
"static_dir": "../client/dist",
"addr": "0.0.0.0",
"port": 8008,
"cert_dir": "../",
"lm_config": {
"acoustic_delay": 1,
"generated_audio_codebooks": 8,
"input_audio_codebooks": 8,
"audio_vocab_size": 2049,
"text_eop_token": 0,
"text_pad_token": 3,
"text_start_token": 32000
},
"image_prefix_backbone": "Siglip448",
"image_prefix_use_rms_norm": true,
"cross_attention_gating": "ConditionalGatedSigmoid",
"cross_attention_in_dims": null
}
================================================
FILE: kyuteye_rs/moshi-backend/Cargo.toml
================================================
[package]
name = "moshi-backend"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
axum-server = { workspace = true }
bincode = { workspace = true }
byteorder = { workspace = true }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
clap = { workspace = true }
futures-util = { workspace = true }
hf-hub = { workspace = true }
http = { workspace = true }
image = { workspace = true }
lazy_static = { workspace = true }
log = { workspace = true }
moshi = { workspace = true }
ogg = { workspace = true }
opus = { workspace = true }
prometheus = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
rcgen = { workspace = true }
regex = { workspace = true }
reqwest = { workspace = true }
rubato = { workspace = true }
rustls = { version = "0.23.20", features = ["ring"] }
sentencepiece = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
symphonia = { workspace = true }
tokio = { workspace = true }
tokio-rustls = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
tracing-appender = { workspace = true }
tracing-subscriber = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
vergen = { workspace = true }
[features]
default = []
cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"]
[profile.release]
debug = true
[profile.release-no-debug]
inherits = "release"
debug = false
================================================
FILE: kyuteye_rs/moshi-backend/build.rs
================================================
use anyhow::Result;
use vergen::EmitBuilder;
pub fn main() -> Result<()> {
// NOTE: This will output everything, and requires all features enabled.
// NOTE: See the EmitBuilder documentation for configuration options.
EmitBuilder::builder()
.all_build()
.all_cargo()
.all_git()
.all_rustc()
.all_sysinfo()
.emit()?;
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/audio.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#![allow(unused)]
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}
fn conv(samples: &mut Vec, data: std::borrow::Cow>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode>(path: P) -> anyhow::Result<(Vec, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> anyhow::Result> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}
pub(crate) fn write_opus_header(w: &mut W) -> std::io::Result<()> {
use byteorder::WriteBytesExt;
// https://wiki.xiph.org/OggOpus#ID_Header
w.write_all(b"OpusHead")?;
w.write_u8(1)?; // version
w.write_u8(1)?; // channel count
w.write_u16::(3840)?; // pre-skip
w.write_u32::(48000)?; // sample-rate in Hz
w.write_i16::(0)?; // output gain Q7.8 in dB
w.write_u8(0)?; // channel map
Ok(())
}
pub(crate) fn write_opus_tags(w: &mut W) -> std::io::Result<()> {
use byteorder::WriteBytesExt;
// https://wiki.xiph.org/OggOpus#Comment_Header
let vendor = "KyutaiMoshi";
w.write_all(b"OpusTags")?;
w.write_u32::(vendor.len() as u32)?; // vendor string length
w.write_all(vendor.as_bytes())?; // vendor string, UTF8 encoded
w.write_u32::(0u32)?; // number of tags
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/build.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use anyhow::Result;
use vergen::EmitBuilder;
pub fn main() -> Result<()> {
// NOTE: This will output everything, and requires all features enabled.
// NOTE: See the EmitBuilder documentation for configuration options.
EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?;
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/image_embedder.rs
================================================
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{linear, Linear, VarBuilder};
use candle_transformers::models::fastvit as MobileClip;
use candle_transformers::models::pixtral::vision_model as Pixtral;
use candle_transformers::models::siglip as Siglip;
use moshi::transformer::{CaSrc, Norm};
use moshi::NormType;
fn load_image(
bytes: &[u8],
max_size: usize,
mean: &[f32; 3],
std: &[f32; 3],
center_crop: bool,
preserve_aspect_ratio: bool,
) -> candle::Result {
// Load RGB image and resize such that longest side is equal to `max_size`
// if crop_to_square is True, the resize also crops the image to square
let mut img = image::ImageReader::new(std::io::Cursor::new(bytes))
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
// if not center crop, we just resize to the max size
let (img, width, height) = if !center_crop {
if preserve_aspect_ratio {
let (width, height) = (img.width(), img.height());
let (new_width, new_height) = if width < height {
(((width * max_size as u32) / height) as usize, max_size)
} else {
(max_size, ((height * max_size as u32) / width) as usize)
};
let img = img.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
(img, new_width, new_height)
} else {
(
img.resize_exact(
max_size as u32,
max_size as u32,
image::imageops::FilterType::CatmullRom,
),
max_size,
max_size,
)
}
}
// otherwise, we first center crop to a square of (max_size, maz_size) then resize
else {
let (width, height) = (img.width(), img.height());
let min_dim = if width > height { height } else { width };
let x = (width - min_dim) / 2;
let y = (height - min_dim) / 2;
//print!("center crop: {} {} {} {} {}\n", x, y, min_dim, width, height);
let img = img.crop(x, y, min_dim, min_dim);
let img = img.resize_exact(
max_size as u32,
max_size as u32,
image::imageops::FilterType::CatmullRom,
);
(img, max_size, max_size)
};
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
tracing::info!(data = ?data.shape(), "image loaded");
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
// Image Encoder for vision-conditioned models
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum ImageEncoder {
Siglip224,
Siglip448,
Siglip896,
MobileclipS1,
MobileclipS2,
Pixtral,
}
#[derive(Debug, Clone)]
pub enum ImageEncoderModel {
Siglip(Siglip::VisionModel),
Mobileclip(candle_nn::Func<'static>),
Pixtral(Pixtral::Model),
}
fn init_output_proj(in_dims: usize, out_dims: usize, vb: VarBuilder) -> Result> {
let proj = if vb.contains_tensor("proj_xa.weight") {
Some(linear(in_dims, out_dims, vb.pp("proj_xa"))?)
} else {
None
};
Ok(proj)
}
#[derive(Debug, Clone)]
pub struct ImageEmbedder {
model: ImageEncoderModel,
// optional output proj
proj: Option,
// output norm
norm: Norm,
//image loading param
mean: [f32; 3],
std: [f32; 3],
// max image size
// corresponds to maximum allowed image dimension for the model
// For Pixtral, this is the maximum side
// For Siglip and Mobileclip, images are fixed to a square size for now
max_image_size: usize,
patch_size: usize,
// whether the lm model will be quantized (need F32 inputs)
quantized_lm_model: bool,
}
impl ImageEmbedder {
pub fn new(
model_file: &str,
image_prefix_backbone: Option,
image_prefix_rmsnorm: bool,
cross_attention_in_dims: Option,
dev: &Device,
) -> Result {
let out_dims = cross_attention_in_dims.unwrap_or(4096);
let quantized_lm_model = model_file.ends_with(".gguf");
let vb = if quantized_lm_model {
// model_file should have format XXX.quant_format.gguf
// and the unquantized vision tower weights should be in XXX_vision_tower_unquant.safetensors
let base_path = model_file.rsplitn(3, '.').nth(2);
let out_path = match base_path {
None => anyhow::bail!(".gguf file does not have a corresponding vision tower"),
Some(p) => format!("{}_vision_tower_unquant.safetensors", p),
};
tracing::info!(?out_path, "Loading unquantized vision encoder from");
unsafe {
VarBuilder::from_mmaped_safetensors(&[out_path.as_str()], candle::DType::F32, dev)?
}
} else {
// if safetensors, we assume the file contains *all* model's tensors
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, dev)? }
};
let vb = vb.pp("image_prefix");
let norm_typ = match image_prefix_rmsnorm {
true => NormType::RmsNorm,
false => NormType::LayerNorm,
};
// output norm
let norm = Norm::new_shortcut(
out_dims,
norm_typ,
moshi::nn::MaybeQuantizedVarBuilder::Real(vb.pp("norm_xa")),
)?;
let max_image_size = match image_prefix_backbone {
Some(
ImageEncoder::Siglip224 | ImageEncoder::MobileclipS1 | ImageEncoder::MobileclipS2,
) => 224,
Some(ImageEncoder::Siglip448) => 448,
Some(ImageEncoder::Siglip896) => 896,
Some(ImageEncoder::Pixtral) => 1024,
None => anyhow::bail!("Image encoder type not specified in config"),
};
// Main backbone
match image_prefix_backbone {
Some(ImageEncoder::Siglip224 | ImageEncoder::Siglip448 | ImageEncoder::Siglip896) => {
// Siglip
// https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
let siglip_cfg = match image_prefix_backbone {
Some(ImageEncoder::Siglip224) => Siglip::VisionConfig::paligemma_3b_224(),
Some(ImageEncoder::Siglip448) => Siglip::VisionConfig::paligemma_3b_448(),
Some(ImageEncoder::Siglip896) => Siglip::VisionConfig::paligemma_3b_896(),
_ => anyhow::bail!("Impossible match arm. Congrats on reaching there"),
};
let model = Siglip::VisionModel::new(&siglip_cfg, false, vb.pp("enc.model"))?;
let proj = init_output_proj(siglip_cfg.hidden_size, out_dims, vb)?;
Ok(Self {
model: ImageEncoderModel::Siglip(model),
proj,
norm,
mean: [0.5, 0.5, 0.5],
std: [0.5, 0.5, 0.5],
max_image_size,
patch_size: siglip_cfg.patch_size,
quantized_lm_model,
})
}
Some(ImageEncoder::MobileclipS1) | Some(ImageEncoder::MobileclipS2) => {
// mobileclip (mci1/2 from candle)
let vit_cfg = match image_prefix_backbone {
Some(ImageEncoder::MobileclipS1) => MobileClip::Config::mci1(),
Some(ImageEncoder::MobileclipS2) => MobileClip::Config::mci2(),
_ => anyhow::bail!("No image backbone specified for cross-attention layers"),
};
let model = MobileClip::fastvit_no_final_layer(&vit_cfg, vb.pp("enc.model"))?;
let proj = init_output_proj(vit_cfg.in_channels * 16, out_dims, vb)?;
Ok(Self {
model: ImageEncoderModel::Mobileclip(model),
proj,
norm,
mean: [0., 0., 0.],
std: [1., 1., 1.],
max_image_size,
patch_size: 1,
quantized_lm_model,
})
}
Some(ImageEncoder::Pixtral) => {
let pixtral_cfg = Pixtral::Config::pixtral_12b_2409();
let model = Pixtral::Model::new(&pixtral_cfg, vb.pp("enc.model"))?;
let proj = init_output_proj(pixtral_cfg.hidden_size, out_dims, vb)?;
Ok(Self {
model: ImageEncoderModel::Pixtral(model),
proj,
norm,
mean: [0.481_454_66, 0.457_827_5, 0.408_210_73],
std: [0.268_629_54, 0.261_302_6, 0.275_777_1],
max_image_size,
patch_size: pixtral_cfg.patch_size,
quantized_lm_model,
})
}
None => {
anyhow::bail!("No image backbone specified for cross-attention layers")
}
}
}
pub fn output_proj(&self, img_features: Tensor, dev: &Device) -> Result {
// Output linear + normalization
let img_features = match &self.proj {
None => img_features.apply(&self.norm)?,
Some(module) => img_features.apply(module)?.apply(&self.norm)?,
};
tracing::info!(feat = ?img_features.shape(), "image features generated");
// Type conversion
let dtype = if dev.is_cuda() && !self.quantized_lm_model {
candle::DType::BF16
} else {
candle::DType::F32
};
Ok(img_features.to_dtype(dtype)?)
}
pub fn embed(
&self,
img_bytes: &[u8],
image_size: usize,
center_crop: bool,
dev: &Device,
) -> Result {
// load Uint image as tensor then embed
// to avoid any issue with the position embeddings interpolations, resize
// images to the closest multiple of the model's patch size
let image_size = if image_size % self.patch_size > self.patch_size / 2 {
image_size - image_size % self.patch_size + self.patch_size
} else {
image_size - image_size % self.patch_size
};
// too small image sizes are very out of distributions so we clamp
// the input image size to something reasonable
let min_image_size = self.patch_size * 10;
let image_size = if !(min_image_size..=self.max_image_size).contains(&image_size) {
tracing::info!(
"Limiting image size up to be in [{}, {}]",
min_image_size,
self.max_image_size
);
image_size.clamp(min_image_size, self.max_image_size)
} else {
image_size
};
let img_features = match &self.model {
// Pixtral handles dynamic image size + non-square ratios by nature
ImageEncoderModel::Pixtral(m) => load_image(
img_bytes,
image_size,
&self.mean,
&self.std,
center_crop,
true,
)?
.to_device(dev)?
.unsqueeze(0)?
.apply(m)?,
// Siglip now also handles dynamic size/ratios with positional embedding interpolation
ImageEncoderModel::Siglip(m) => load_image(
img_bytes,
image_size,
&self.mean,
&self.std,
center_crop,
false,
)?
.to_device(dev)?
.unsqueeze(0)?
.apply(m)?,
// But MobileClip always resize to its own fixed (and square) image size
ImageEncoderModel::Mobileclip(m) => load_image(
img_bytes,
self.max_image_size,
&self.mean,
&self.std,
center_crop,
false,
)?
.to_device(dev)?
.unsqueeze(0)?
.apply(m)?
.flatten_from(2)?
.transpose(1, 2)?,
};
Ok(CaSrc::Tokens(self.output_proj(img_features, dev)?))
}
pub fn embed_from_tensor(&self, img: Tensor, dev: &Device) -> Result {
// embed image preloaded as safetensors
let img_features = match &self.model {
// Currently siglip and MobileClip only supports fixed size (from config)
ImageEncoderModel::Siglip(m) => img.apply(m)?,
ImageEncoderModel::Mobileclip(m) => img.apply(m)?.flatten_from(2)?.transpose(1, 2)?,
// Pixtral supports any size, but we may need to update the positional embeddings
ImageEncoderModel::Pixtral(m) => img.apply(m)?,
};
Ok(CaSrc::Tokens(self.output_proj(img_features, dev)?))
}
}
================================================
FILE: kyuteye_rs/moshi-backend/src/main.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use anyhow::Result;
use clap::Parser;
use std::str::FromStr;
mod audio;
mod image_embedder;
mod metrics;
mod standalone;
mod stream_both;
mod utils;
#[derive(Parser, Debug)]
#[clap(name = "server", about = "moshi web server")]
struct Args {
#[clap(short = 'l', long = "log", default_value = "info")]
log_level: String,
#[clap(long, default_value = "configs/config-moshika-vis.json")]
config: String,
#[clap(long)]
silent: bool,
#[command(subcommand)]
command: Command,
}
#[derive(Parser, Debug)]
struct StandaloneArgs {
#[clap(long)]
cpu: bool,
#[clap(short = 's', long = "sig", default_value = None)]
sig: Option,
#[clap(short = 'e', long = "epoch", default_value = None)]
epoch: Option,
#[clap(short = 'u', long = "user", default_value = None)]
user: Option,
#[clap(long)]
img: Option,
#[clap(long, default_value_t = 224)]
img_size: usize,
#[clap(long, conflicts_with = "img")]
vis: bool,
#[clap(long)]
s2s: bool,
#[clap(long)]
asr: bool,
}
#[derive(Debug, clap::Subcommand)]
enum Command {
Standalone(StandaloneArgs),
}
/// A TLS acceptor that sets `TCP_NODELAY` on accepted streams.
#[derive(Clone, Debug)]
pub struct NoDelayAcceptor;
impl axum_server::accept::Accept for NoDelayAcceptor {
type Stream = tokio::net::TcpStream;
type Service = S;
type Future =
futures_util::future::BoxFuture<'static, std::io::Result<(Self::Stream, Self::Service)>>;
fn accept(&self, stream: tokio::net::TcpStream, service: S) -> Self::Future {
Box::pin(async move {
// Disable Nagle's algorithm.
stream.set_nodelay(true)?;
Ok::<_, std::io::Error>((stream, service))
})
}
}
fn tracing_init(
log_dir: &str,
instance_name: &str,
log_level: &str,
silent: bool,
) -> Result {
use tracing_subscriber::prelude::*;
let build_info = utils::BuildInfo::new();
let file_appender = tracing_appender::rolling::daily(log_dir, format!("log.{}", instance_name));
let (non_blocking, guard) = tracing_appender::non_blocking(file_appender);
let filter = tracing_subscriber::filter::LevelFilter::from_str(log_level)?;
let mut layers = vec![tracing_subscriber::fmt::layer()
.with_writer(non_blocking)
.with_filter(filter)
.boxed()];
if !silent {
layers.push(Box::new(
tracing_subscriber::fmt::layer()
.with_writer(std::io::stdout)
.with_filter(filter),
))
};
tracing_subscriber::registry().with(layers).init();
tracing::info!(?build_info);
Ok(guard)
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<()> {
rustls::crypto::ring::default_provider()
.install_default()
.map_err(|_| anyhow::Error::msg("unable to install default crypto"))?;
let args = Args::parse();
match args.command {
Command::Standalone(standalone_args) => {
let mut config = standalone::Config::load(&args.config)?;
let _guard = tracing_init(
&config.stream.log_dir,
&config.stream.instance_name,
&args.log_level,
args.silent,
)?;
tracing::info!("starting process with pid {}", std::process::id());
if config.stream.requires_model_download() {
standalone::download_from_hub(&mut config.stream).await?;
}
if !std::path::PathBuf::from(&config.static_dir).exists() {
use hf_hub::api::tokio::Api;
let api = Api::new()?;
let repo = api.model("kyutai/moshi-artifacts".to_string());
let dist_tgz = repo.get("vis_dist.tgz").await?;
if let Some(parent) = dist_tgz.parent() {
let dist = parent.join("dist");
if !dist.exists() {
let output = std::process::Command::new("tar")
.arg("-xzf")
.arg(&dist_tgz)
.arg("-C")
.arg(parent)
.output()?;
if !output.status.success() {
anyhow::bail!(
"error extract {dist_tgz:?}: {}",
String::from_utf8_lossy(&output.stderr)
);
}
}
config.static_dir = dist.to_string_lossy().to_string()
}
}
standalone::run(&standalone_args, &config).await?;
}
}
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/metrics.rs
================================================
use lazy_static::lazy_static;
use prometheus::Histogram;
use prometheus::{histogram_opts, register_histogram};
pub mod worker {
use super::*;
lazy_static! {
pub static ref MODEL_STEP_DURATION: Histogram = register_histogram!(histogram_opts!(
"worker_model_step_duration",
"Model step duration distribution.",
vec![40e-3, 50e-3, 60e-3, 75e-3, 80e-3, 0.1],
))
.unwrap();
}
}
================================================
FILE: kyuteye_rs/moshi-backend/src/standalone.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use anyhow::{Context, Result};
use axum::extract::ws;
use std::process;
use std::sync::Arc;
use std::{path::Path, str::FromStr};
use crate::{image_embedder, stream_both, utils, StandaloneArgs};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
cert_dir: String,
#[serde(default = "utils::default_static_dir")]
pub static_dir: String,
addr: String,
port: u16,
#[serde(flatten)]
pub stream: stream_both::Config,
}
impl Config {
pub fn load>(p: P) -> Result {
let config = std::fs::read_to_string(p)?;
let mut config: Self = serde_json::from_str(&config)?;
config.static_dir = crate::utils::replace_env_vars(&config.static_dir);
config.cert_dir = crate::utils::replace_env_vars(&config.cert_dir);
// location of the images in the static dir
config.stream.images_dir =
Some(std::path::PathBuf::from(&config.static_dir).join("assets/images/demo"));
config.stream.log_dir = crate::utils::replace_env_vars(&config.stream.log_dir);
config.stream.text_tokenizer_file =
crate::utils::replace_env_vars(&config.stream.text_tokenizer_file);
config.stream.mimi_model_file =
crate::utils::replace_env_vars(&config.stream.mimi_model_file);
config.stream.lm_model_file = crate::utils::replace_env_vars(&config.stream.lm_model_file);
Ok(config)
}
pub fn cert_file(&self, name: &str) -> Result {
let cert_dir = std::path::PathBuf::from(&self.cert_dir);
let cert_file = cert_dir.join(name);
if !cert_file.is_file() {
anyhow::bail!("missing file {cert_file:?}");
}
Ok(cert_file)
}
}
pub(crate) fn device(cpu: bool) -> Result {
use candle::Device;
if cpu {
Ok(Device::Cpu)
} else if candle::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if candle::utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
impl stream_both::AppStateInner {
pub fn new(args: &StandaloneArgs, config: &stream_both::Config) -> Result {
let device = device(args.cpu)?;
let mut config = config.clone();
if let Some(sig) = args.sig.as_ref() {
tracing::info!(sig, "Loading checkpoint from sig");
let mut cmd = process::Command::new("python");
cmd.arg("-m").arg("scripts.mimi_import").arg(sig).arg("-s");
if let Some(epoch) = args.epoch {
tracing::info!(epoch, "using epoch");
cmd.arg("-e").arg(epoch.to_string());
}
if let Some(user) = args.user.as_ref() {
tracing::info!(user, "taking checkpoint from user");
cmd.arg("-u").arg(user);
}
let output = cmd.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr[..]);
tracing::error!("Error while trying to get checkpoint: {stderr}");
anyhow::bail!("Couldn't convert the checkpoint.");
}
let lm_model_file = String::from_utf8_lossy(&output.stdout[..])
.trim_end_matches('\n')
.to_string();
tracing::info!(lm_model_file, "model path was overriden");
config.lm_model_file = lm_model_file;
}
let dtype = if device.is_cuda() {
candle::DType::BF16
} else {
candle::DType::F32
};
let image_embedder = match args.vis {
// Standard Moshi: no vision features
false => None,
// --vis: Load the image encoder and the image will be embedded
// on-the-fly in stream_handler
true => Some(image_embedder::ImageEmbedder::new(
&config.lm_model_file,
config.image_prefix_backbone.clone(),
config.image_prefix_use_rms_norm,
config.cross_attention_in_dims,
&device,
)?),
};
// update cross-attention gating based on user-provided JSON config
let lm_model = if args.vis || args.img.is_some() {
moshi::lm::load_vision(
&config.lm_model_file,
config.cross_attention_gating,
config.cross_attention_in_dims,
dtype,
&device,
)?
} else {
moshi::lm::load_streaming(&config.lm_model_file, dtype, &device)?
};
let mimi_device = if config.use_cpu_for_mimi {
&candle::Device::Cpu
} else {
&device
};
let mimi_model = moshi::mimi::load(
&config.mimi_model_file,
Some(config.mimi_num_codebooks),
mimi_device,
)?;
let text_tokenizer =
sentencepiece::SentencePieceProcessor::open(&config.text_tokenizer_file)?;
// Warm-up code.
{
tracing::info!(?dtype, ?device, "warming up the model");
// Warmup the image encoder if any
let fake_image = candle::Tensor::zeros((1, 3, 224, 224), candle::DType::F32, &device)?;
let ca_src = match &image_embedder {
None => None,
Some(m) => Some(m.embed_from_tensor(fake_image, &device)?),
};
// Warm up the LM model w/ cross-attention as needed
let mut lm_model = lm_model.clone();
let (_v, ys) = match &ca_src {
None => lm_model.forward(None, vec![None; config.mimi_num_codebooks])?,
Some(x) => lm_model.forward_ca(None, vec![None; config.mimi_num_codebooks], x)?,
};
let mut lp = candle_transformers::generation::LogitsProcessor::new(123, None, None);
let _ = lm_model.depformer_sample(&ys, None, &[], &mut lp)?;
let mut mimi_model = mimi_model.clone();
let config = mimi_model.config();
let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize;
let fake_pcm =
candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?;
let codes = mimi_model.encode_step(&fake_pcm.into())?;
let ys = mimi_model.decode_step(&codes)?;
if ys.as_option().is_none() {
anyhow::bail!("Expected Mimi to output some stuff, but nothing came out.");
}
device.synchronize()?;
tracing::info!("model is ready to roll!");
}
Ok(Self {
lm_model,
mimi_model,
device,
config: config.clone(),
text_tokenizer,
image_embedder,
})
}
}
async fn handle_socket(socket: ws::WebSocket, sm: stream_both::StreamingModel) {
if let Err(err) = stream_both::handle_socket(socket, sm, None).await {
tracing::error!(err = err.to_string(), "handle_socket")
}
}
pub async fn stream_handler(
ws: ws::WebSocketUpgrade,
axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo,
state: axum::extract::State,
req: axum::extract::Query,
) -> crate::utils::AxumResult {
tracing::info!(?addr, "received connection");
let sm = stream_both::StreamingModel::new(&state.0, req.0);
Ok(ws.on_upgrade(move |v| handle_socket(v, sm)))
}
pub async fn download_from_hub(config: &mut stream_both::Config) -> Result<()> {
use hf_hub::api::tokio::Api;
let api = Api::new()?;
let repo = api.model(config.hf_repo.clone());
let extract_filename = |path: &str| -> Result {
Path::new(path)
.file_name()
.and_then(|f| f.to_str())
.map(String::from)
.ok_or_else(|| anyhow::anyhow!("'{path}' has no file name"))
};
for file_path in [
&mut config.lm_model_file,
&mut config.mimi_model_file,
&mut config.text_tokenizer_file,
]
.iter_mut()
{
let filename = extract_filename(file_path)
.with_context(|| format!("Failed to extract filename for '{file_path}'"))?;
let downloaded_path = repo
.get(&filename)
.await
.with_context(|| format!("Failed to download '{file_path}' file"))?;
**file_path = downloaded_path
.into_os_string()
.into_string()
.map_err(|_| anyhow::anyhow!("'{file_path}' path is not a valid string"))?;
}
// Download vision tower unquantized
if config.lm_model_file.ends_with(".gguf") {
// model_file should have format XXX.quant_format.gguf
// and the unquantized vision tower weights should be in XXX_vision_tower_unquant.safetensors
let base_path = config.lm_model_file.rsplitn(3, '.').nth(2);
let vision_tower_path = match base_path {
None => anyhow::bail!(".gguf file does not have a corresponding vision tower"),
Some(p) => format!("{p}_vision_tower_unquant.safetensors"),
};
let filename = extract_filename(&vision_tower_path)
.with_context(|| format!("Failed to extract filename for '{vision_tower_path}'"))?;
repo.get(&filename)
.await
.with_context(|| format!("Failed to download '{vision_tower_path}' file"))?;
};
Ok(())
}
pub async fn run(args: &StandaloneArgs, config: &Config) -> Result<()> {
let cert_pem = config.cert_file("cert.pem")?;
let key_pem = config.cert_file("key.pem")?;
if !cert_pem.exists() || !key_pem.exists() {
let rcgen::CertifiedKey { cert, key_pair } =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
std::fs::write(&cert_pem, cert.pem())?;
std::fs::write(&key_pem, key_pair.serialize_pem())?;
}
let tls_config =
axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_pem, key_pem).await?;
let sock_addr = std::net::SocketAddr::from((
std::net::IpAddr::from_str(config.addr.as_str())
.unwrap_or(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
config.port,
));
let state = Arc::new(stream_both::AppStateInner::new(args, &config.stream)?);
tracing::info!("serving static dir {}", config.static_dir);
let app = axum::Router::new()
.route("/api/chat", axum::routing::get(stream_handler))
.fallback_service(
tower_http::services::ServeDir::new(&config.static_dir)
.append_index_html_on_directories(true),
)
.layer(tower::ServiceBuilder::new().layer(tower_http::trace::TraceLayer::new_for_http()))
.with_state(state);
tracing::info!(
"standalone worker listening on https://{}?worker_addr={}",
sock_addr,
sock_addr
);
axum_server::bind_rustls(sock_addr, tls_config)
.serve(app.into_make_service_with_connect_info::())
.await?;
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/stream_both.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use crate::image_embedder::{ImageEmbedder, ImageEncoder};
use anyhow::Result;
use axum::extract::ws;
use futures_util::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use moshi::{dynamic_logits_processor::GateInfluencedLogitsProcessor, transformer::CaSrc};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct ForceSessionConfig {
pub text_temperature: f64,
pub text_topk: usize,
pub audio_temperature: f64,
pub audio_topk: usize,
pub pad_mult: Option,
pub repetition_penalty: Option<(usize, f32)>,
pub xa_start: Option,
pub text_temperature_gating_influence: Option,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Config {
pub instance_name: String,
#[serde(default)]
pub hf_repo: String,
pub lm_model_file: String,
pub log_dir: String,
pub text_tokenizer_file: String,
pub mimi_model_file: String,
pub mimi_num_codebooks: usize,
pub lm_config: Option,
#[serde(default = "default_false")]
pub use_cpu_for_mimi: bool,
pub asr_delay_in_tokens: Option,
// optional config options for image conditioning
pub image_prefix_backbone: Option,
#[serde(default = "default_false")]
pub image_prefix_use_rms_norm: bool,
pub cross_attention_gating: Option,
pub cross_attention_in_dims: Option,
pub images_dir: Option,
pub force_session_config: Option,
}
fn default_false() -> bool {
false
}
impl Config {
/// Check if all modelling files are available on machine.
pub fn requires_model_download(&self) -> bool {
[
&self.lm_model_file,
&self.mimi_model_file,
&self.text_tokenizer_file,
]
.iter()
.any(|file| !std::path::Path::new(file).exists())
}
}
pub type AppState = Arc;
pub struct AppStateInner {
pub lm_model: moshi::lm::LmModel,
pub mimi_model: moshi::mimi::Mimi,
pub text_tokenizer: sentencepiece::SentencePieceProcessor,
pub device: candle::Device,
pub config: Config,
pub image_embedder: Option,
}
impl AppStateInner {
fn text(
&self,
prev_text_token: u32,
text_token: u32,
config: &moshi::lm_generate_multistream::Config,
) -> Option {
if text_token != config.text_start_token
&& text_token != config.text_pad_token
&& text_token != config.text_eop_token
{
if prev_text_token == config.text_start_token {
self.text_tokenizer.decode_piece_ids(&[text_token]).ok()
} else {
let prev_ids = self
.text_tokenizer
.decode_piece_ids(&[prev_text_token])
.ok();
let ids = self
.text_tokenizer
.decode_piece_ids(&[prev_text_token, text_token])
.ok();
prev_ids.and_then(|prev_ids| {
ids.map(|ids| {
if ids.len() > prev_ids.len() {
ids[prev_ids.len()..].to_string()
} else {
String::new()
}
})
})
}
} else {
None
}
}
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct SessionConfigReq {
pub text_temperature: Option,
pub text_topk: Option,
pub audio_temperature: Option,
pub audio_topk: Option,
pub max_steps: Option,
pub audio_seed: Option,
pub text_seed: Option,
pub email: Option,
pub pad_mult: Option,
pub repetition_penalty_context: Option,
pub repetition_penalty: Option,
pub image_resolution: Option,
pub center_crop: Option,
pub xa_start: Option,
pub text_temperature_gating_influence: Option,
}
#[derive(serde::Serialize, Debug, Clone)]
pub struct SessionConfig {
pub text_temperature: f64,
pub text_topk: usize,
pub audio_temperature: f64,
pub audio_topk: usize,
pub max_steps: usize,
pub audio_seed: u64,
pub text_seed: u64,
pub pad_mult: Option,
pub repetition_penalty: Option<(usize, f32)>,
pub image_resolution: Option,
pub center_crop: Option,
pub xa_start: usize,
pub text_temperature_gating_influence: f32,
pub email: Option,
pub user_feedback: Option,
}
#[derive(serde::Serialize, Debug, Clone)]
struct SessionSummary<'a> {
#[serde(flatten)]
session_config: &'a SessionConfig,
last_step_idx: usize,
transcript: String,
addr: Option,
lm_model_file: &'a str,
mimi_model_file: &'a str,
#[serde(flatten)]
lm_config: &'a Option,
}
impl SessionConfigReq {
fn into_session_config(self, force_cfg: Option<&ForceSessionConfig>) -> SessionConfig {
use rand::Rng;
let text_temperature = match force_cfg {
None => self.text_temperature.unwrap_or(0.8),
Some(v) => v.text_temperature,
};
let audio_temperature = match force_cfg {
None => self.audio_temperature.unwrap_or(0.8),
Some(v) => v.audio_temperature,
};
let text_topk = match force_cfg {
None => self.text_topk.unwrap_or(250),
Some(v) => v.text_topk,
};
let audio_topk = match force_cfg {
None => self.audio_topk.unwrap_or(250),
Some(v) => v.audio_topk,
};
let pad_mult = match force_cfg {
None => self.pad_mult,
Some(v) => v.pad_mult,
};
let repetition_penalty = match force_cfg {
None => self.repetition_penalty_context.zip(self.repetition_penalty),
Some(v) => v.repetition_penalty,
};
let xa_start = match force_cfg {
None => self.xa_start.unwrap_or(0),
Some(v) => v.xa_start.unwrap_or(0),
};
let text_temperature_gating_influence = match force_cfg {
None => self.text_temperature_gating_influence.unwrap_or(0.0),
Some(v) => v.text_temperature_gating_influence.unwrap_or(0.0),
};
SessionConfig {
text_temperature,
text_topk,
text_seed: self.text_seed.unwrap_or_else(|| rand::thread_rng().gen()),
audio_temperature,
audio_topk,
audio_seed: self.audio_seed.unwrap_or_else(|| rand::thread_rng().gen()),
email: self.email,
user_feedback: None,
max_steps: self.max_steps.unwrap_or(4500).min(4500),
pad_mult,
repetition_penalty,
image_resolution: self.image_resolution,
center_crop: self.center_crop,
xa_start,
text_temperature_gating_influence,
}
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct MetaData {
text_temperature: f64,
text_topk: usize,
text_temperature_gating_influence: f64,
audio_temperature: f64,
audio_topk: usize,
pad_mult: f32,
repetition_penalty_context: usize,
repetition_penalty: f32,
image_resolution: usize,
lm_model_file: String,
mimi_model_file: String,
build_info: crate::utils::BuildInfo,
instance_name: String,
base_filename: String,
}
#[derive(Debug, Clone)]
pub enum StreamOut {
Ready,
MetaData { metadata: Box },
Pcm { pcm: Vec },
ColoredText { text: String, intensity: f32 },
}
// This must be an allowed value among 120, 240, 480, 960, 1920, and 2880.
// Using a different value would result in a BadArg "invalid argument" error when calling encode.
// https://opus-codec.org/docs/opus_api-1.2/group__opus__encoder.html#ga4ae9905859cd241ef4bb5c59cd5e5309
const OPUS_ENCODER_FRAME_SIZE: usize = 960;
#[derive(Debug, Clone, Copy)]
pub enum MsgType {
Handshake,
Audio,
Text,
Control,
Metadata,
Error,
Ping,
ColoredText,
Image,
UserRating,
}
impl MsgType {
pub fn from_u8(v: u8) -> Result {
let s = match v {
0 => MsgType::Handshake,
1 => MsgType::Audio,
2 => MsgType::Text,
3 => MsgType::Control,
4 => MsgType::Metadata,
5 => MsgType::Error,
6 => MsgType::Ping,
7 => MsgType::ColoredText,
8 => MsgType::Image,
10 => MsgType::UserRating,
_ => anyhow::bail!("unexpected msg type {v}"),
};
Ok(s)
}
pub fn to_u8(self) -> u8 {
match self {
MsgType::Handshake => 0,
MsgType::Audio => 1,
MsgType::Text => 2,
MsgType::Control => 3,
MsgType::Metadata => 4,
MsgType::Error => 5,
MsgType::Ping => 6,
MsgType::ColoredText => 7,
MsgType::Image => 8,
MsgType::UserRating => 10,
}
}
}
pub enum ModelInput {
Image(Vec),
Audio(Vec),
Rating(u32),
}
pub struct MsgSender {
pw: ogg::PacketWriter<'static, Vec>,
encoder: opus::Encoder,
out_pcm: std::collections::VecDeque,
out_pcm_buf: Vec,
total_data: usize,
sender: SplitSink,
last_input_pcm: Option,
}
impl MsgSender {
fn new(sender: SplitSink) -> Result {
let encoder = opus::Encoder::new(24000, opus::Channels::Mono, opus::Application::Voip)?;
// Not sure what the appropriate buffer size would be here.
let out_pcm_buf = vec![0u8; 50_000];
let out_pcm = std::collections::VecDeque::with_capacity(2 * OPUS_ENCODER_FRAME_SIZE);
let all_data = Vec::new();
let mut pw = ogg::PacketWriter::new(all_data);
let mut head = Vec::new();
crate::audio::write_opus_header(&mut head)?;
pw.write_packet(head, 42, ogg::PacketWriteEndInfo::EndPage, 0)?;
let mut tags = Vec::new();
crate::audio::write_opus_tags(&mut tags)?;
pw.write_packet(tags, 42, ogg::PacketWriteEndInfo::EndPage, 0)?;
Ok(Self {
pw,
encoder,
out_pcm,
out_pcm_buf,
total_data: 0,
sender,
last_input_pcm: None,
})
}
async fn send_colored_text(&mut self, text: String, intensity: f32) -> Result<()> {
let int_intensity = (((intensity - 0.005) / 0.016).clamp(0., 1.) * 10.).round() as u8;
let msg: Vec = [
&[MsgType::ColoredText.to_u8()],
&[int_intensity],
text.as_bytes(),
]
.concat();
let msg = ws::Message::Binary(msg.into());
self.sender.send(msg).await?;
Ok(())
}
async fn send_ready(&mut self) -> Result<()> {
// The payload is made of two fields.
// 1. Protocol version (`u32`) - always 0 for now.
// 2. Model version (`u32`).
let msg: Vec = [&[MsgType::Handshake.to_u8()], [0u8; 8].as_slice()].concat();
let msg = ws::Message::Binary(msg.into());
self.sender.send(msg).await?;
Ok(())
}
async fn send_metadata(&mut self, md: Box) -> Result<()> {
let bytes = serde_json::to_vec(&md)?;
let msg: Vec = [&[MsgType::Metadata.to_u8()], bytes.as_slice()].concat();
let msg = ws::Message::Binary(msg.into());
self.sender.send(msg).await?;
Ok(())
}
async fn send_pcm(&mut self, pcm: Vec) -> Result<()> {
self.out_pcm.extend(pcm.iter());
self.total_data += pcm.len();
let nchunks = self.out_pcm.len() / OPUS_ENCODER_FRAME_SIZE;
for _chunk_id in 0..nchunks {
let mut chunk = Vec::with_capacity(OPUS_ENCODER_FRAME_SIZE);
for _i in 0..OPUS_ENCODER_FRAME_SIZE {
let v = match self.out_pcm.pop_front() {
None => anyhow::bail!("unexpected err popping from pcms"),
Some(v) => v,
};
chunk.push(v)
}
let size = self.encoder.encode_float(&chunk, &mut self.out_pcm_buf)?;
if size > 0 {
let msg = self.out_pcm_buf[..size].to_vec();
self.pw.write_packet(
msg,
42,
ogg::PacketWriteEndInfo::EndPage,
self.total_data as u64,
)?
} else {
tracing::error!("OPUS SIZE 0")
}
let data = self.pw.inner_mut();
if !data.is_empty() {
let msg: Vec = [&[MsgType::Audio.to_u8()], data.as_slice()].concat();
let msg = ws::Message::Binary(msg.into());
self.sender.send(msg).await?;
self.sender.flush().await?;
data.clear();
} else {
tracing::error!("OGG SIZE 0")
}
}
Ok(())
}
}
pub struct StreamingModel {
state: AppState,
device: candle::Device,
config: moshi::lm_generate_multistream::Config,
session_config: SessionConfig,
}
impl StreamingModel {
fn run_with_state(
&self,
state: &mut moshi::lm_generate_multistream::State,
receiver: std::sync::mpsc::Receiver,
sender: tokio::sync::mpsc::UnboundedSender,
) -> Result<()> {
use candle::IndexOp;
let app_state = &self.state;
let mut mimi = app_state.mimi_model.clone();
let config = state.config().clone();
mimi.reset_state();
tracing::info!("processing loop");
let mut prev_text_token = config.text_start_token;
let mut tensor_tokens = vec![];
let mimi_device = if self.state.config.use_cpu_for_mimi {
&candle::Device::Cpu
} else {
&self.device
};
let ca_src = match &app_state.image_embedder {
Some(image_embedder) => {
// The first message is the image to use as input for the model.
let image_bytes = match receiver.recv() {
Ok(ModelInput::Image(image)) => image,
_ => anyhow::bail!("Expected image as first message, we are in vision mode."),
};
tracing::info!("image received");
let resolution = self.session_config.image_resolution.unwrap_or(224);
let center_crop = self.session_config.center_crop.unwrap_or(false);
let ca_src =
image_embedder.embed(&image_bytes, resolution, center_crop, mimi_device)?;
let ca_src = self.state.lm_model.maybe_precompute_ca_kv(Some(ca_src))?;
match &ca_src {
Some(ca_src) => match ca_src {
CaSrc::Tokens(tensor) => {
tracing::info!(shape=?tensor.shape(), "image processed")
}
CaSrc::KeysValues((keys, values)) => {
tracing::info!(keys_shape=?keys.shape(), values_shape=?values.shape(), "image processed")
}
},
None => anyhow::bail!("ca_src should never be None here."),
};
ca_src
}
None => None,
};
mimi_device.synchronize()?;
sender.send(StreamOut::Ready)?;
let mut counter = 0;
while let Ok(in_pcm) = receiver.recv() {
match in_pcm {
ModelInput::Audio(in_pcm) => {
if in_pcm.is_empty() {
continue;
}
let pcm_len = in_pcm.len();
let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?;
let audio_tokens = mimi.encode_step(&pcms.into())?;
let audio_tokens = match audio_tokens.as_option() {
None => continue,
Some(audio_tokens) => audio_tokens,
};
let (_one, _codebooks, steps) = audio_tokens.dims3()?;
for step in 0..steps {
let codes = audio_tokens.i((0, .., step))?.to_vec1::()?;
//let text_token = state.step(prev_text_token, &codes, None, ca_src.as_ref())?;
let (text_token, gate_weight) = state.step_with_gate_weight(
prev_text_token,
&codes,
None,
if counter >= self.session_config.xa_start {
ca_src.as_ref()
} else {
None
},
)?;
if let Some(audio_tokens) = state.last_audio_tokens() {
let audio_tokens = {
let cb = app_state.config.mimi_num_codebooks;
candle::Tensor::from_slice(
&audio_tokens[..cb],
(1, cb, 1),
mimi_device,
)?
};
tensor_tokens.push(audio_tokens.clone());
let pcm = mimi.decode_step(&audio_tokens.into())?;
if let Some(pcm) = pcm.as_option() {
let pcm = pcm.i((0, 0))?.to_vec1::()?;
sender.send(StreamOut::Pcm { pcm })?;
}
}
if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
//sender.send(StreamOut::Text { text })?;
sender.send(StreamOut::ColoredText {
text,
intensity: gate_weight,
})?;
}
prev_text_token = text_token;
}
counter += 1;
}
ModelInput::Image(_) => {
// See PR#105 https://github.com/0x53504852/moshi-rs/pull/105
// Due to the front-end async mode, we may have duplicate identical
// image messages which we should ignore while waiting for the first
// audio message to arrive
tracing::error!("Cannot handle new images during a conversation yet; skipping");
}
ModelInput::Rating(grade) => state.set_user_rating(grade),
}
}
tracing::info!("finished the processing loop");
Ok(())
}
pub fn new(state: &AppState, session_config: SessionConfigReq) -> Self {
let config = match state.config.lm_config.as_ref() {
None => moshi::lm_generate_multistream::Config::v0_1(),
Some(config) => config.clone(),
};
let session_config =
session_config.into_session_config(state.config.force_session_config.as_ref());
Self {
state: state.clone(),
device: state.device.clone(),
config,
session_config,
}
}
pub fn run(
&self,
receiver: std::sync::mpsc::Receiver,
sender: tokio::sync::mpsc::UnboundedSender,
addr: Option,
) -> Result<()> {
let app_state = &self.state;
let (repetition_penalty_context, repetition_penalty) =
self.session_config.repetition_penalty.unwrap_or((32, 1.));
// base path for logging
let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?;
let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros());
let base_filename = format!("{}-{secs}-{us}", app_state.config.instance_name);
// metadata that will be stored over
let metadata = MetaData {
text_temperature: self.session_config.text_temperature,
text_topk: self.session_config.text_topk,
text_temperature_gating_influence: self.session_config.text_temperature_gating_influence
as f64,
audio_temperature: self.session_config.audio_temperature,
audio_topk: self.session_config.audio_topk,
pad_mult: self.session_config.pad_mult.unwrap_or(0.),
repetition_penalty,
image_resolution: self.session_config.image_resolution.unwrap_or(0),
repetition_penalty_context,
lm_model_file: self.state.config.lm_model_file.to_string(),
mimi_model_file: self.state.config.mimi_model_file.to_string(),
build_info: crate::utils::BuildInfo::new(),
instance_name: self.state.config.instance_name.to_string(),
base_filename: base_filename.clone(),
};
sender.send(StreamOut::MetaData {
metadata: Box::new(metadata),
})?;
let lm_model = app_state.lm_model.clone();
// Load ca_src either from the --img command line (static image), or
// from the vision-UI's websocket url (contained in the StreamingModel)
let audio_lp = candle_transformers::generation::LogitsProcessor::from_sampling(
self.session_config.audio_seed,
candle_transformers::generation::Sampling::TopK {
k: self.session_config.audio_topk,
temperature: self.session_config.audio_temperature,
},
);
let text_lp = GateInfluencedLogitsProcessor::from_sampling_with_scale(
self.session_config.text_seed,
candle_transformers::generation::Sampling::TopK {
k: self.session_config.text_topk,
temperature: self.session_config.text_temperature,
},
self.session_config.text_temperature_gating_influence as f64,
);
let mut state = moshi::lm_generate_multistream::State::new(
lm_model,
self.session_config.max_steps,
audio_lp,
text_lp,
self.session_config.pad_mult,
self.session_config.repetition_penalty,
self.config.clone(),
);
// We want to log the output even if the run function returns an error.
let run_result = self.run_with_state(&mut state, receiver, sender);
let log_result = (|| {
let text_tokens = state.text_tokens(false);
let transcript = {
let text_tokens = text_tokens
.iter()
.filter_map(|v| {
let v = *v;
if v != moshi::lm_generate_multistream::UNGENERATED
&& v != self.config.text_pad_token
&& v != self.config.text_eop_token
&& v != self.config.text_start_token
{
Some(v)
} else {
None
}
})
.collect::>();
self.state
.text_tokenizer
.decode_piece_ids(&text_tokens)
.unwrap_or_else(|_| String::new())
};
let gate_weights = state.gate_weights(false);
let audio_tokens = state.audio_tokens(false);
let audio_tokens = audio_tokens
.iter()
.map(|v| {
v.iter()
.map(|v| {
if *v == moshi::lm_generate_multistream::UNGENERATED {
-1
} else {
*v as i64
}
})
.collect::>()
})
.collect::>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?
.to_dtype(candle::DType::I64)?;
let gate_weights = candle::Tensor::new(gate_weights, &candle::Device::Cpu)?
.to_dtype(candle::DType::F32)?;
let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?;
let user_rating =
candle::Tensor::full(state.user_rating(), (1,), &candle::Device::Cpu)?;
let log_dir = &app_state.config.log_dir;
let base_path = format!("{}/{}", log_dir, base_filename);
let json_filename = format!("{base_path}.json");
let json_content = serde_json::to_string_pretty(&SessionSummary {
session_config: &self.session_config,
last_step_idx: state.step_idx(),
transcript,
addr,
mimi_model_file: &self.state.config.mimi_model_file,
lm_model_file: &self.state.config.lm_model_file,
lm_config: &self.state.config.lm_config,
})?;
std::fs::write(json_filename, json_content)?;
let st_filename = format!("{base_path}.safetensors");
let st_content = std::collections::HashMap::from([
("text", text_tokens),
("gate_weights", gate_weights),
("audio", audio_tokens),
("user_rating", user_rating),
]);
candle::safetensors::save(&st_content, st_filename)?;
Ok(())
})();
run_result.and(log_result)
}
}
type Handle = tokio::task::JoinHandle>;
fn spawn_recv_loops(
mut receiver: SplitStream,
sender: std::sync::mpsc::Sender,
) -> Result<(Handle, Handle)> {
use tokio::io::AsyncWriteExt;
let (mut tx, rx) = tokio::io::duplex(100_000);
let mut pr = ogg::reading::async_api::PacketReader::new(rx);
let mut decoder = opus::Decoder::new(24000, opus::Channels::Mono)?;
let handle1_sender = sender.clone();
let handle1 = tokio::spawn({
async move {
loop {
match receiver.next().await {
None => {
// The close logic is that if this loop exits, then tx gets dropped so pr
// gets closed and the second thread gets dropped resulting in sender
// getting dropped.
break;
}
Some(v) => {
let v = v?.into_data();
if v.is_empty() {
continue;
}
let msg_type = MsgType::from_u8(v[0])?;
match msg_type {
MsgType::Metadata => {}
MsgType::Handshake => {}
MsgType::Control => {}
MsgType::Text => {}
MsgType::Error => {}
MsgType::Ping => {}
MsgType::Audio => tx.write_all(&v[1..]).await?,
MsgType::ColoredText => {}
MsgType::Image => {
handle1_sender.send(ModelInput::Image(v[1..].to_vec()))?
}
MsgType::UserRating => {
handle1_sender.send(ModelInput::Rating(v[1] as u32))?
}
}
}
}
}
tracing::info!("socket closed");
Ok::<_, anyhow::Error>(())
}
});
let handle2 = tokio::spawn(async move {
// TODO: dynamic sizing?
let mut pcm_buf = vec![0f32; 24_000 * 10];
let mut size_in_buf = 0;
loop {
match pr.next().await {
None => {
break;
}
Some(packet) => {
let packet = packet?;
if packet.data.starts_with(b"OpusHead") || packet.data.starts_with(b"OpusTags")
{
continue;
}
let read_size = decoder.decode_float(
&packet.data,
&mut pcm_buf[size_in_buf..],
/* Forward Error Correction */ false,
)?;
size_in_buf += read_size;
// flush the data every half timestep
if size_in_buf >= 24_000 / 25 {
if sender
.send(ModelInput::Audio(pcm_buf[..size_in_buf].to_vec()))
.is_err()
{
break;
}
size_in_buf = 0;
}
}
}
}
tracing::info!("decoder closed");
Ok::<_, anyhow::Error>(())
});
Ok((handle1, handle2))
}
async fn sender_loop(
mut stream_out_rx: tokio::sync::mpsc::UnboundedReceiver,
mut sender: MsgSender,
) -> Result<()> {
// It is important for the recv here to be an async enabled one. Otherwise this could lead
// to some weird deadlocks.
while let Some(v) = stream_out_rx.recv().await {
match v {
StreamOut::Pcm { pcm } => {
if let Some(last_input_pcm) = sender.last_input_pcm {
let model_step_duration = last_input_pcm.elapsed().as_secs_f64();
crate::metrics::worker::MODEL_STEP_DURATION.observe(model_step_duration);
sender.last_input_pcm = None;
}
sender.send_pcm(pcm).await?
}
StreamOut::Ready => sender.send_ready().await?,
StreamOut::MetaData { metadata } => sender.send_metadata(metadata).await?,
StreamOut::ColoredText { text, intensity } => {
sender.send_colored_text(text, intensity).await?
}
}
}
Ok::<_, anyhow::Error>(())
}
pub async fn handle_socket(
socket: ws::WebSocket,
sm: StreamingModel,
addr: Option,
) -> Result<()> {
tracing::info!("accepted websocket connection");
let (sender, receiver) = socket.split();
let sender = MsgSender::new(sender)?;
tracing::info!("starting streaming");
let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel();
let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel();
let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?;
std::thread::spawn(move || {
if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) {
tracing::error!("{err}")
}
});
let sender_loop = tokio::spawn(async move {
match sender_loop(stream_out_rx, sender).await {
Ok(()) => tracing::info!("sender closed"),
Err(err) => {
// Using the Display trait rather than the Debug one so as not to include the backtrace.
let err = format!("{err}");
tracing::info!(err, "sender err")
}
}
});
let sleep = tokio::time::sleep(std::time::Duration::from_secs(360));
tokio::pin!(sleep);
// select should ensure that all the threads get aborted on timeout.
tokio::select! {
_ = &mut sleep => {
tracing::error!("reached timeout");
}
r = loop1 => {
tracing::error!(?r, "loop1 ended")
}
r = loop2 => {
tracing::error!(?r, "loop2 ended")
}
r = sender_loop => {
tracing::error!(?r, "sender loop ended")
}
}
Ok(())
}
================================================
FILE: kyuteye_rs/moshi-backend/src/utils.rs
================================================
#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
pub struct BuildInfo {
build_timestamp: String,
build_date: String,
git_branch: String,
git_timestamp: String,
git_date: String,
git_hash: String,
git_describe: String,
rustc_host_triple: String,
rustc_version: String,
cargo_target_triple: String,
}
impl BuildInfo {
pub fn new() -> BuildInfo {
BuildInfo {
build_timestamp: String::from(env!("VERGEN_BUILD_TIMESTAMP")),
build_date: String::from(env!("VERGEN_BUILD_DATE")),
git_branch: String::from(env!("VERGEN_GIT_BRANCH")),
git_timestamp: String::from(env!("VERGEN_GIT_COMMIT_TIMESTAMP")),
git_date: String::from(env!("VERGEN_GIT_COMMIT_DATE")),
git_hash: String::from(env!("VERGEN_GIT_SHA")),
git_describe: String::from(env!("VERGEN_GIT_DESCRIBE")),
rustc_host_triple: String::from(env!("VERGEN_RUSTC_HOST_TRIPLE")),
rustc_version: String::from(env!("VERGEN_RUSTC_SEMVER")),
cargo_target_triple: String::from(env!("VERGEN_CARGO_TARGET_TRIPLE")),
}
}
}
pub struct WrapJson(pub anyhow::Result);
impl axum::response::IntoResponse for WrapJson {
fn into_response(self) -> axum::response::Response {
match self.0 {
Ok(v) => axum::Json(v).into_response(),
Err(err) => {
tracing::error!(?err, "returning internal server error 500");
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("{err}"),
)
.into_response()
}
}
}
}
pub fn replace_env_vars(input: &str) -> String {
let re = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").unwrap();
re.replace_all(input, |caps: ®ex::Captures| {
let var_name = &caps[1];
std::env::var(var_name).unwrap_or_else(|_| "".to_string())
})
.to_string()
}
pub struct WrapBincode(pub anyhow::Result);
impl axum::response::IntoResponse for WrapBincode {
fn into_response(self) -> axum::response::Response {
match self.0.and_then(|v| Ok(bincode::serialize(&v)?)) {
Ok(v) => (axum::http::StatusCode::OK, v).into_response(),
Err(err) => {
tracing::error!(?err, "returning internal server error 500");
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("{err}"),
)
.into_response()
}
}
}
}
pub fn default_static_dir() -> String {
"./client/dist".to_string()
}
pub struct AxumError(anyhow::Error);
impl axum::response::IntoResponse for AxumError {
fn into_response(self) -> axum::response::Response {
let err = self.0;
tracing::error!(?err);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("{err:?}"),
)
.into_response()
}
}
impl> From for AxumError {
fn from(value: E) -> Self {
Self(value.into())
}
}
pub type AxumResult = std::result::Result;
================================================
FILE: kyuteye_rs/moshi-core/Cargo.toml
================================================
[package]
name = "moshi"
version = "0.1.0"
edition = "2021"
[dependencies]
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
cudarc = { version = "=0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
rand = { workspace = true }
rayon = "1.8.1"
serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.40"
[features]
default = []
cuda = ["candle/cuda", "candle-nn/cuda", "cudarc"]
metal = ["candle/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
[profile.release]
debug = true
[profile.release-no-debug]
inherits = "release"
debug = false
================================================
FILE: kyuteye_rs/moshi-core/src/conv.rs
================================================
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use crate::streaming::{StreamTensor, StreamingModule};
use candle::{Module, Result, Tensor, D};
use candle_nn::{Conv1d, VarBuilder};
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Norm {
WeightNorm,
SpectralNorm,
TimeGroupNorm,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PadMode {
Constant,
Reflect,
Replicate,
}
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result {
let weight = if vb.contains_tensor("weight") {
vb.get((out_c, in_c, kernel_size), "weight")?
} else {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
};
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(Conv1d::new(weight, bias, config))
}
#[derive(Debug, Clone)]
pub struct NormConv1d {
conv: Conv1d,
norm: Option,
span: tracing::Span,
}
impl NormConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option,
bias: bool,
cfg: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result {
let conv = match norm {
None | Some(Norm::TimeGroupNorm) => {
if bias {
candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
} else {
candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
}
}
Some(Norm::WeightNorm) => {
conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
conv,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
})
}
}
impl Module for NormConv1d {
fn forward(&self, xs: &Tensor) -> Result {
let _enter = self.span.enter();
let xs = xs.apply(&self.conv)?;
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug, Clone)]
pub struct NormConvTranspose1d {
ws: Tensor,
bs: Option,
k_size: usize,
stride: usize,
groups: usize,
norm: Option,
span: tracing::Span,
}
impl NormConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option,
bias: bool,
stride: usize,
groups: usize,
vb: VarBuilder,
) -> Result {
let vb = vb.pp("convtr");
let bs = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
let ws = match norm {
None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
Some(Norm::WeightNorm) => {
if vb.contains_tensor("weight") {
vb.get((in_c, out_c, k_size), "weight")?
} else {
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
}
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let (ws, groups) = if groups == out_c && in_c == out_c {
let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
let ws = ws
.repeat((1, out_c, 1))?
.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
(ws, 1)
} else {
(ws, groups)
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
ws,
bs,
k_size,
stride,
groups,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
})
}
}
impl Module for NormConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result {
let _enter = self.span.enter();
// conv-transpose1d seems to be broken on metal after enough iterations. Causing
// the following error:
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
// This is now fixed in candle.
let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
let xs = match &self.bs {
None => xs,
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
xs.broadcast_add(&bias)?
}
};
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
fn get_extra_padding_for_conv1d(
xs: &Tensor,
k_size: usize,
stride: usize,
padding_total: usize,
) -> Result {
let len = xs.dim(D::Minus1)?;
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
let ideal_len =
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
Ok(ideal_len.saturating_sub(len))
}
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result {
match mode {
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
}
}
fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result {
let len = xs.dim(D::Minus1)?;
if len < unpad_l + unpad_r {
candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
}
xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
}
#[derive(Debug, Clone)]
pub struct StreamableConv1d {
conv: NormConv1d,
causal: bool,
pad_mode: PadMode,
state_prev_xs: StreamTensor,
left_pad_applied: bool,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
dilation: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option,
pad_mode: PadMode,
vb: VarBuilder,
) -> Result {
let cfg = candle_nn::Conv1dConfig {
padding: 0,
stride,
dilation,
groups,
};
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb.pp("conv"))?;
if k_size < stride {
candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
}
Ok(Self {
conv,
causal,
pad_mode,
state_prev_xs: StreamTensor::empty(),
left_pad_applied: false,
kernel_size: k_size,
span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
})
}
}
impl Module for StreamableConv1d {
fn forward(&self, xs: &Tensor) -> Result {
let _enter = self.span.enter();
let (_b, _t, _c) = xs.dims3()?;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
// Effective kernel size with dilations.
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
let extra_padding =
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
let xs = if self.causal {
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
pad1d(
xs,
padding_left,
padding_right + extra_padding,
self.pad_mode,
)?
};
xs.apply(&self.conv)
}
}
impl StreamingModule for StreamableConv1d {
fn reset_state(&mut self) {
self.state_prev_xs.reset();
self.left_pad_applied = false;
}
fn step(&mut self, xs: &StreamTensor) -> Result {
let _enter = self.span.enter();
let xs = match xs.as_option() {
None => return Ok(().into()),
Some(xs) => xs.clone(),
};
let xs = if self.left_pad_applied {
xs
} else {
self.left_pad_applied = true;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
pad1d(&xs, padding_total, 0, self.pad_mode)?
};
let cfg = self.conv.conv.config();
let stride = cfg.stride;
let dilation = cfg.dilation;
let kernel = (self.kernel_size - 1) * dilation + 1;
let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
let seq_len = xs.seq_len(D::Minus1)?;
let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
if num_frames > 0 {
let offset = num_frames * stride;
self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
let in_l = (num_frames - 1) * stride + kernel;
let xs = xs.narrow(D::Minus1, 0, in_l)?;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
xs.apply(&self.conv.conv)
} else {
self.state_prev_xs = xs;
Ok(StreamTensor::empty())
}
}
}
#[derive(Debug, Clone)]
pub struct StreamableConvTranspose1d {
convtr: NormConvTranspose1d,
causal: bool,
state_prev_ys: StreamTensor,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option,
vb: VarBuilder,
) -> Result {
let convtr = NormConvTranspose1d::new(
in_c,
out_c,
k_size,
causal,
norm,
bias,
stride,
groups,
vb.pp("convtr"),
)?;
Ok(Self {
convtr,
causal,
kernel_size: k_size,
state_prev_ys: StreamTensor::empty(),
span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
})
}
}
impl Module for StreamableConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result {
let _enter = self.span.enter();
let k_size = self.convtr.k_size;
let stride = self.convtr.stride;
let padding_total = k_size.saturating_sub(stride);
let xs = xs.apply(&self.convtr)?;
if self.causal {
// This corresponds to trim_right_ratio = 1.
unpad1d(&xs, 0, padding_total)
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
unpad1d(&xs, padding_left, padding_right)
}
}
}
impl StreamingModule for StreamableConvTranspose1d {
fn reset_state(&mut self) {
self.state_prev_ys.reset()
}
fn step(&mut self, xs: &StreamTensor) -> Result {
let _enter = self.span.enter();
let xs = match xs.as_option() {
Some(xs) => xs,
None => return Ok(StreamTensor::empty()),
};
let stride = self.convtr.stride;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
let ys = self.convtr.forward(xs)?;
let ot = ys.dim(D::Minus1)?;
let ys = match self.state_prev_ys.as_option() {
None => ys,
Some(prev_ys) => {
let pt = prev_ys.dim(D::Minus1)?;
// Remove the bias as it will be applied multiple times.
let prev_ys = match &self.convtr.bs {
None => prev_ys.clone(),
Some(bias) => {
let bias = bias.reshape((1, (), 1))?;
prev_ys.broadcast_sub(&bias)?
}
};
let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
Tensor::cat(&[ys1, ys2], D::Minus1)?
}
};
let invalid_steps = self.kernel_size - stride;
let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
self.state_prev_ys = prev_ys;
Ok(ys)
}
}
#[derive(Debug, Clone)]
pub struct ConvDownsample1d {
conv: StreamableConv1d,
}
impl ConvDownsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let conv = StreamableConv1d::new(
/* in_c */ dim,
/* out_c */ dim,
/* k_size_c */ 2 * stride,
/* stride */ stride,
/* dilation */ 1,
/* groups */ 1, // channel_wise = false
/* bias */ false,
/* causal */ causal,
/* norm */ None,
/* pad_mode */ PadMode::Replicate,
vb.pp("conv"),
)?;
Ok(Self { conv })
}
}
impl Module for ConvDownsample1d {
fn forward(&self, xs: &Tensor) -> Result {
xs.apply(&self.conv)
}
}
impl StreamingModule for ConvDownsample1d {
fn reset_state(&mut self) {
self.conv.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result {
self.conv.step(xs)
}
}
#[derive(Debug, Clone)]
pub struct ConvTrUpsample1d {
convtr: StreamableConvTranspose1d,
}
impl ConvTrUpsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let convtr = StreamableConvTranspose1d::new(
dim,
dim,
/* k_size */ 2 * stride,
/* stride */ stride,
/* groups */ dim,
/* bias */ false,
/* causal */ causal,
/* norm */ None,
vb.pp("convtr"),
)?;
Ok(Self { convtr })
}
}
impl Module for ConvTrUpsample1d {
fn forward(&self, xs: &Tensor) -> Result {
xs.apply(&self.convtr)
}
}
impl StreamingModule for ConvTrUpsample1d {
fn reset_state(&mut self) {
self.convtr.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result {
self.convtr.step(xs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
fn run_conv1d(
k_size: usize,
stride: usize,
dilation: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConv1d::new(
/* in_c */ 2,
/* out_c */ 3,
/* k_size */ k_size,
/* stride */ stride,
/* dilation */ dilation,
/* groups */ 1,
/* bias */ bias,
/* causal */ true,
/* norm */ None,
/* pad_mode */ PadMode::Constant,
vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
fn run_conv_tr1d(
k_size: usize,
stride: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConvTranspose1d::new(
/* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
/* stride */ stride, /* groups */ 1, /* bias */ bias,
/* causal */ true, /* norm */ None, vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
#[test]
fn conv1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv1d(1, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 2, 1, step_size, 6, bias)?;
run_conv1d(3, 2, 1, step_size, 8, bias)?;
run_conv1d(3, 2, 2, step_size, 8, bias)?;
}
}
Ok(())
}
#[test]
fn conv_tr1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv_tr1d(1, 1, step_size, 5, bias)?;
run_conv_tr1d(2, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 2, step_size, 5, bias)?;
}
}
Ok(())
}
}
================================================
FILE: kyuteye_rs/moshi-core/src/dynamic_logits_processor.rs
================================================
use candle::{Context, DType, Error, Result, Tensor};
use candle_transformers::generation::Sampling;
use rand::{distributions::Distribution, SeedableRng};
pub struct GateInfluencedLogitsProcessor {
rng: rand::rngs::StdRng,
sampling: Sampling,
text_temperature_gating_influence: f64,
}
impl GateInfluencedLogitsProcessor {
pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
let rng = rand::rngs::StdRng::seed_from_u64(seed);
Self {
rng,
sampling,
text_temperature_gating_influence: 0.0,
}
}
pub fn from_sampling_with_scale(
seed: u64,
sampling: Sampling,
text_temperature_gating_influence: f64,
) -> Self {
let rng = rand::rngs::StdRng::seed_from_u64(seed);
Self {
rng,
sampling,
text_temperature_gating_influence,
}
}
pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self {
let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
let sampling = match temperature {
None => Sampling::ArgMax,
Some(temperature) => match top_p {
None => Sampling::All { temperature },
Some(p) => Sampling::TopP { p, temperature },
},
};
Self::from_sampling(seed, sampling)
}
fn sample_argmax(&mut self, logits: Tensor) -> Result