Showing preview only (489K chars total). Download the full file or copy to clipboard to get everything.
Repository: boson-ai/higgs-audio
Branch: main
Commit: 8b1539a02d57
Files: 79
Total size: 463.3 KB
Directory structure:
gitextract_40a6xiei/
├── .github/
│ └── workflows/
│ └── test.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── SUPPORT_GUIDELINES.md
├── boson_multimodal/
│ ├── __init__.py
│ ├── audio_processing/
│ │ ├── LICENSE
│ │ ├── descriptaudiocodec/
│ │ │ ├── __init__.py
│ │ │ └── dac/
│ │ │ ├── model/
│ │ │ │ ├── base.py
│ │ │ │ └── dac.py
│ │ │ └── nn/
│ │ │ ├── layers.py
│ │ │ └── quantize.py
│ │ ├── higgs_audio_tokenizer.py
│ │ ├── quantization/
│ │ │ ├── __init__.py
│ │ │ ├── ac.py
│ │ │ ├── core_vq.py
│ │ │ ├── core_vq_lsx_version.py
│ │ │ ├── ddp_utils.py
│ │ │ ├── distrib.py
│ │ │ └── vq.py
│ │ └── semantic_module.py
│ ├── constants.py
│ ├── data_collator/
│ │ ├── __init__.py
│ │ └── higgs_audio_collator.py
│ ├── data_types.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ └── chatml_dataset.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── higgs_audio/
│ │ ├── __init__.py
│ │ ├── audio_head.py
│ │ ├── common.py
│ │ ├── configuration_higgs_audio.py
│ │ ├── cuda_graph_runner.py
│ │ ├── custom_modules.py
│ │ ├── modeling_higgs_audio.py
│ │ └── utils.py
│ └── serve/
│ ├── serve_engine.py
│ └── utils.py
├── examples/
│ ├── README.md
│ ├── generation.py
│ ├── scene_prompts/
│ │ ├── quiet_indoor.txt
│ │ └── reading_blog.txt
│ ├── serve_engine/
│ │ ├── README.md
│ │ ├── input_samples.py
│ │ └── run_hf_example.py
│ ├── transcript/
│ │ ├── multi_speaker/
│ │ │ ├── en_argument.txt
│ │ │ └── en_higgs.txt
│ │ └── single_speaker/
│ │ ├── en_basic.txt
│ │ ├── en_dl.txt
│ │ ├── en_higgs_audio_blog.md
│ │ ├── experimental/
│ │ │ ├── en_bgm.txt
│ │ │ └── en_humming.txt
│ │ └── zh_ai.txt
│ ├── vllm/
│ │ ├── README.md
│ │ └── run_chat_completion.py
│ └── voice_prompts/
│ ├── belinda.txt
│ ├── bigbang_amy.txt
│ ├── bigbang_sheldon.txt
│ ├── broom_salesman.txt
│ ├── chadwick.txt
│ ├── en_man.txt
│ ├── en_woman.txt
│ ├── fiftyshades_anna.txt
│ ├── mabaoguo.txt
│ ├── mabel.txt
│ ├── profile.yaml
│ ├── shrek_donkey.txt
│ ├── shrek_donkey_es.txt
│ ├── shrek_fiona.txt
│ ├── shrek_shrek.txt
│ ├── vex.txt
│ └── zh_man_sichuan.txt
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
└── tech_blogs/
├── ARCHITECTURE_BLOG.md
└── TOKENIZER_BLOG.md
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/test.yml
================================================
name: Unit Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
lint:
name: Lint
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check Code Formatting with Ruff
run: |
echo "python version: $(python --version)"
pip install ruff==0.12.2 # Ensure ruff is installed
ruff format --check .
================================================
FILE: .gitignore
================================================
# Temporary files generated in training
dpo_samples*
scoring_results
results/
hf_slurm_logs/
slurm_results/
enroot_images/
slurm*.out
cache_*
mlruns/
local_download_dir/
audioverse/data
# the folder pattern is sft_{year}.
sft_20*
data/
audioverse/cache
# vim ipython plugin generated files
.jukit
# node
node_modules
package.json
package-lock.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
!tests/*
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
/.conda_env*
/.env*
/.higgs_audio_env*
/.venv*
/conda_env*
/env*
/ENV*
/higgs_audio_env*
/venv*
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.jsonl
download
.DS_Store
*entry.py
# Pytorch
torch_compile_debug/
# Out Dir
result/
# Ruff
.ruff_cache/
================================================
FILE: .gitmodules
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
================================================
FILE: README.md
================================================
<h1 align="center">Higgs Audio: Redefining Expressiveness in Audio Generation</h1>
<div align="center" style="display: flex; justify-content: center; margin-top: 10px;">
<a href="https://boson.ai/blog/higgs-audio-v2"><img src='https://img.shields.io/badge/🚀-V2 Blogpost-228B22' style="margin-right: 5px;"></a>
<a href="https://www.boson.ai/blog/higgs-audio-v2.5"><img src='https://img.shields.io/badge/🚀-V2.5 Blogpost-228B22' style="margin-right: 5px;"></a>
<a href="https://boson.ai/demo/tts"><img src="https://img.shields.io/badge/🕹️-Boson%20AI%20Playground-9C276A" style="margin-right: 5px;"></a>
<a href="https://huggingface.co/spaces/smola/higgs_audio_v2"><img src="https://img.shields.io/badge/🎮-HF%20Space%20Playground-8A2BE2" style="margin-right: 5px;"></a>
<a href="https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base"><img src="https://img.shields.io/badge/🤗-Checkpoints (3.6B LLM + 2.2B audio adapter)-ED5A22.svg" style="margin-right: 5px;"></a>
</div>
## NEWS!
We are proud to launch **Higgs-Audio V2.5**, the latest iteration of Boson AI’s Audio model, designed to bring high-fidelity generation into production environments. Building on Higgs-Audio V2, this release combines improved efficiency with the stability required for real-world deployment.
With V2.5, we condensed the model architecture to 1B parameters while surpassing speed and accuracy of the prior 3B model. The result is achieved through a new alignment strategy using Group Relative Policy Optimization (GRPO) on our curated Voice Bank dataset, combined with improved voice cloning and finer-grained style control.
For detailed model performance, key improvements, and usage, please check our [blog](https://www.boson.ai/blog/higgs-audio-v2.5).
## Higgs Audio V2
We are open-sourcing Higgs Audio v2, a powerful audio foundation model pretrained on over 10 million hours of audio data and a diverse set of text data. Despite having no post-training or fine-tuning, Higgs Audio v2 excels in expressive audio generation, thanks to its deep language and acoustic understanding.
On [EmergentTTS-Eval](https://github.com/boson-ai/emergenttts-eval-public), it achieves win rates of **75.7%** and **55.7%** over "gpt-4o-mini-tts" on the "Emotions" and "Questions" categories, respectively. It also obtains state-of-the-art performance on traditional TTS benchmarks like Seed-TTS Eval and Emotional Speech Dataset (ESD). Moreover, the model demonstrates capabilities rarely seen in previous systems, including generating natural multi-speaker dialogues in multiple languages, automatic prosody adaptation during narration, melodic humming with the cloned voice, and simultaneous generation of speech and background music.
<p align="center">
<img src="figures/emergent-tts-emotions-win-rate.png" width=900>
</p>
Here's the demo video that shows some of its emergent capabilities (remember to unmute):
<video src="https://github.com/user-attachments/assets/0fd73fad-097f-48a9-9f3f-bc2a63b3818d" type="video/mp4" width="80%" controls>
</video>
Here's another demo video that show-cases the model's multilingual capability and how it enabled live translation (remember to unmute):
<video src="https://github.com/user-attachments/assets/2b9b01ff-67fc-4bd9-9714-7c7df09e38d6" type="video/mp4" width="80%" controls>
</video>
## Installation
We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment. Following are two docker images that we have verified:
- nvcr.io/nvidia/pytorch:25.02-py3
- nvcr.io/nvidia/pytorch:25.01-py3
Here's an example command for launching a docker container environment. Please also check the [official NVIDIA documentations](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch).
```bash
docker run --gpus all --ipc=host --net=host --ulimit memlock=-1 --ulimit stack=67108864 -it --rm nvcr.io/nvidia/pytorch:25.02-py3 bash
```
### Option 1: Direct installation
```bash
git clone https://github.com/boson-ai/higgs-audio.git
cd higgs-audio
pip install -r requirements.txt
pip install -e .
```
### Option 2: Using venv
```bash
git clone https://github.com/boson-ai/higgs-audio.git
cd higgs-audio
python3 -m venv higgs_audio_env
source higgs_audio_env/bin/activate
pip install -r requirements.txt
pip install -e .
```
### Option 3: Using conda
```bash
git clone https://github.com/boson-ai/higgs-audio.git
cd higgs-audio
conda create -y --prefix ./conda_env --override-channels --strict-channel-priority --channel "conda-forge" "python==3.10.*"
conda activate ./conda_env
pip install -r requirements.txt
pip install -e .
# Uninstalling environment:
conda deactivate
conda remove -y --prefix ./conda_env --all
```
### Option 4: Using uv
```bash
git clone https://github.com/boson-ai/higgs-audio.git
cd higgs-audio
uv venv --python 3.10
source .venv/bin/activate
uv pip install -r requirements.txt
uv pip install -e .
```
### Option 5: Using vllm
For advanced usage with higher throughput, we also built OpenAI compatible API server backed by vLLM engine for you to use.
Please refer to [examples/vllm](./examples/vllm) for more details.
## Usage
> [!TIP]
> For optimal performance, run the generation examples on a machine equipped with GPU with at least 24GB memory!
### Get Started
Here's a basic python snippet to help you get started.
```python
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
import torch
import torchaudio
import time
import click
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
system_prompt = (
"Generate audio following instruction.\n\n<|scene_desc_start|>\nAudio is recorded from a quiet room.\n<|scene_desc_end|>"
)
messages = [
Message(
role="system",
content=system_prompt,
),
Message(
role="user",
content="The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.",
),
]
device = "cuda" if torch.cuda.is_available() else "cpu"
serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device)
output: HiggsAudioResponse = serve_engine.generate(
chat_ml_sample=ChatMLSample(messages=messages),
max_new_tokens=1024,
temperature=0.3,
top_p=0.95,
top_k=50,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
)
torchaudio.save(f"output.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
```
We also provide a list of examples under [examples](./examples). In the following we highlight a few examples to help you use Higgs Audio v2.
### Zero-Shot Voice Cloning
Generate audio that sounds similar as the provided [reference audio](./examples/voice_prompts/belinda.wav).
```bash
python3 examples/generation.py \
--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
--ref_audio belinda \
--temperature 0.3 \
--out_path generation.wav
```
The generation script will automatically use `cuda:0` if it founds cuda is available. To change the device id, specify `--device_id`:
```bash
python3 examples/generation.py \
--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
--ref_audio belinda \
--temperature 0.3 \
--device_id 0 \
--out_path generation.wav
```
You can also try other voices. Check more example voices in [examples/voice_prompts](./examples/voice_prompts). You can also add your own voice to the folder.
```bash
python3 examples/generation.py \
--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
--ref_audio broom_salesman \
--temperature 0.3 \
--out_path generation.wav
```
### Single-speaker Generation with Smart Voice
If you do not specify reference voice, the model will decide the voice based on the transcript it sees.
```bash
python3 examples/generation.py \
--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
--temperature 0.3 \
--out_path generation.wav
```
### Multi-speaker Dialog with Smart Voice
Generate multi-speaker dialog. The model will decide the voices based on the transcript it sees.
```bash
python3 examples/generation.py \
--transcript examples/transcript/multi_speaker/en_argument.txt \
--seed 12345 \
--out_path generation.wav
```
### Multi-speaker Dialog with Voice Clone
Generate multi-speaker dialog with the voices you picked.
```bash
python3 examples/generation.py \
--transcript examples/transcript/multi_speaker/en_argument.txt \
--ref_audio belinda,broom_salesman \
--ref_audio_in_system_message \
--chunk_method speaker \
--seed 12345 \
--out_path generation.wav
```
## Technical Details
<img src="figures/higgs_audio_v2_architecture_combined.png" width=900>
Higgs Audio v2 adopts the "generation variant" depicted in the architecture figure above. Its strong performance is driven by three key technical innovations:
- We developed an automated annotation pipeline that leverages multiple ASR models, sound event classification models, and our in-house audio understanding model. Using this pipeline, we cleaned and annotated 10 million hours audio data, which we refer to as **AudioVerse**. The in-house understanding model is finetuned on top of [Higgs Audio v1 Understanding](https://www.boson.ai/blog/higgs-audio), which adopts the "understanding variant" shown in the architecture figure.
- We trained a unified audio tokenizer from scratch that captures both semantic and acoustic features. We also open-sourced our evaluation set on [HuggingFace](https://huggingface.co/datasets/bosonai/AudioTokenBench). Learn more in the [tokenizer blog](./tech_blogs/TOKENIZER_BLOG.md).
- We proposed the DualFFN architecture, which enhances the LLM’s ability to model acoustics tokens with minimal computational overhead. See the [architecture blog](./tech_blogs/ARCHITECTURE_BLOG.md).
## Evaluation
Here's the performance of Higgs Audio v2 on four benchmarks, [Seed-TTS Eval](https://github.com/BytedanceSpeech/seed-tts-eval), [Emotional Speech Dataset (ESD)](https://paperswithcode.com/dataset/esd), [EmergentTTS-Eval](https://arxiv.org/abs/2505.23009), and Multi-speaker Eval:
#### Seed-TTS Eval & ESD
We prompt Higgs Audio v2 with the reference text, reference audio, and target text for zero-shot TTS. We use the standard evaluation metrics from Seed-TTS Eval and ESD.
| | SeedTTS-Eval| | ESD | |
|------------------------------|--------|--------|---------|-------------------|
| | WER ↓ | SIM ↑ | WER ↓ | SIM (emo2vec) ↑ |
| Cosyvoice2 | 2.28 | 65.49 | 2.71 | 80.48 |
| Qwen2.5-omni† | 2.33 | 64.10 | - | - |
| ElevenLabs Multilingual V2 | **1.43** | 50.00 | 1.66 | 65.87 |
| Higgs Audio v1 | 2.18 | 66.27 | **1.49** | 82.84 |
| Higgs Audio v2 (base) | 2.44 | **67.70** | 1.78 | **86.13** |
#### EmergentTTS-Eval ("Emotions" and "Questions")
Following the [EmergentTTS-Eval Paper](https://arxiv.org/abs/2505.23009), we report the win-rate over "gpt-4o-mini-tts" with the "alloy" voice. The judge model is Gemini 2.5 Pro.
| Model | Emotions (%) ↑ | Questions (%) ↑ |
|------------------------------------|--------------|----------------|
| Higgs Audio v2 (base) | **75.71%** | **55.71%** |
| [gpt-4o-audio-preview†](https://platform.openai.com/docs/models/gpt-4o-audio-preview) | 61.64% | 47.85% |
| [Hume.AI](https://www.hume.ai/research) | 61.60% | 43.21% |
| **BASELINE:** [gpt-4o-mini-tts](https://platform.openai.com/docs/models/gpt-4o-mini-tts) | 50.00% | 50.00% |
| [Qwen 2.5 Omni†](https://github.com/QwenLM/Qwen2.5-Omni) | 41.60% | 51.78% |
| [minimax/speech-02-hd](https://replicate.com/minimax/speech-02-hd) | 40.86% | 47.32% |
| [ElevenLabs Multilingual v2](https://elevenlabs.io/blog/eleven-multilingual-v2) | 30.35% | 39.46% |
| [DeepGram Aura-2](https://deepgram.com/learn/introducing-aura-2-enterprise-text-to-speech) | 29.28% | 48.21% |
| [Sesame csm-1B](https://github.com/SesameAILabs/csm) | 15.96% | 31.78% |
<sup><sub>'†' means using the strong-prompting method described in the paper.</sub></sup>
#### Multi-speaker Eval
We also designed a multi-speaker evaluation benchmark to evaluate the capability of Higgs Audio v2 for multi-speaker dialog generation. The benchmark contains three subsets
- `two-speaker-conversation`: 1000 synthetic dialogues involving two speakers. We fix two reference audio clips to evaluate the model's ability in double voice cloning for utterances ranging from 4 to 10 dialogues between two randomly chosen persona.
- `small talk (no ref)`: 250 synthetic dialogues curated in the same way as above, but are characterized by short utterances and a limited number of turns (4–6), we do not fix reference audios in this case and this set is designed to evaluate the model's ability to automatically assign appropriate voices to speakers.
- `small talk (ref)`: 250 synthetic dialogues similar to above, but contains even shorter utterances as this set is meant to include reference clips in it's context, similar to `two-speaker-conversation`.
We report the word-error-rate (WER) and the geometric mean between intra-speaker similarity and inter-speaker dis-similarity on these three subsets. Other than Higgs Audio v2, we also evaluated [MoonCast](https://github.com/jzq2000/MoonCast) and [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626), two of the most popular open-source models capable of multi-speaker dialog generation. Results are summarized in the following table. We are not able to run [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) on our "two-speaker-conversation" subset due to its strict limitation on the length of the utterances and output audio.
| | two-speaker-conversation | |small talk | | small talk (no ref) | |
| ---------------------------------------------- | -------------- | ------------------ | ---------- | -------------- | ------------------- | -------------- |
| | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ |
| [MoonCast](https://github.com/jzq2000/MoonCast) | 38.77 | 46.02 | **8.33** | 63.68 | 24.65 | 53.94 |
| [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) | \- | \- | 17.62 | 63.15 | 19.46 | **61.14** |
| Higgs Audio v2 (base) | **18.88** | **51.95** | 11.89 | **67.92** | **14.65** | 55.28 |
## Contribution and Support
For contribution and support guidelines, please see the support guidelines at [SUPPORT_GUIDELINES.md](SUPPORT_GUIDELINES.md).
## Citation
If you feel the repository is helpful, please kindly cite as:
```
@misc{higgsaudio2025,
author = {{Boson AI}},
title = {{Higgs Audio V2: Redefining Expressiveness in Audio Generation}},
year = {2025},
howpublished = {\url{https://github.com/boson-ai/higgs-audio}},
note = {GitHub repository. Release blog available at \url{https://www.boson.ai/blog/higgs-audio-v2}},
}
```
## Third-Party Licenses
The `boson_multimodal/audio_processing/` directory contains code derived from third-party repositories, primarily from [xcodec](https://github.com/zhenye234/xcodec). Please see the [`LICENSE`](boson_multimodal/audio_processing/LICENSE) in that directory for complete attribution and licensing information.
## We Are Hiring!
If you are passionate about multimodal AI, speech/audio models, or large-scale systems,
check out our open positions at [Boson AI Careers](https://jobs.lever.co/bosonai).
================================================
FILE: SUPPORT_GUIDELINES.md
================================================
# Contribution & Support Guidelines
Thank you for your interest in this project! Before opening an issue, please take a moment to read the following guidelines:
## Self-Check First
- Write your question in **English** or include an English translation so the community can understand and assist you better.
- Verify that you have **installed the correct version** of the package.
- Check the GitHub [README](README.md), [Hugging Face Space](https://huggingface.co/spaces/smola/higgs_audio_v2), [Model Card](https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base) and existing issues — many questions already have answers.
- Ensure your problem can be **reproduced** and is directly related to this project.
## Asking Properly
- Provide **clear reproduction steps / minimal code examples / error logs**.
- Keep the issue title **concise and descriptive**, and include enough context in the body.
- Avoid vague questions like *“It doesn’t work, what should I do?”* or *“Can you debug this for me?”*.
## About Support
- This is a **community-driven open source project**. Maintainers will respond when time allows.
- There is **no obligation** to answer every request — please be patient and understanding.
- For more reliable or timely support, consider:
- Submitting a **Pull Request** to improve code or documentation.
- Providing detailed context so that the community can help.
## Code of Conduct
- Be **respectful and polite**.
- Do not spam or repeatedly demand responses.
- Off-topic, vague, or inappropriate questions may be closed.
================================================
FILE: boson_multimodal/__init__.py
================================================
================================================
FILE: boson_multimodal/audio_processing/LICENSE
================================================
Third-Party License Attribution for Audio Processing Module
===========================================================
This directory contains code derived from multiple open-source projects.
The following sections detail the licenses and attributions for third-party code.
## XCodec Repository
The code in this directory is derived from:
https://github.com/zhenye234/xcodec
## Individual File Attributions
### Quantization Module (quantization/)
- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
- Individual files contain their own license headers where applicable
- The vector-quantize-pytorch portions are licensed under the MIT License
## License Terms
### MIT License (for applicable portions)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Attribution Requirements
When using this code, please ensure proper attribution to:
1. The original xcodec repository: https://github.com/zhenye234/xcodec
2. Any other repositories mentioned in individual file headers
3. This derivative work and its modifications
## Disclaimer
This directory contains modified versions of the original code. Please refer to
the original repositories for the canonical implementations and their specific
license terms.
For any questions about licensing or attribution, please check the individual
file headers and the original source repositories.
================================================
FILE: boson_multimodal/audio_processing/descriptaudiocodec/__init__.py
================================================
================================================
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py
================================================
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import numpy as np
import torch
import tqdm
from audiotools import AudioSignal
from torch import nn
SUPPORTED_VERSIONS = ["1.0.0"]
@dataclass
class DACFile:
codes: torch.Tensor
# Metadata
chunk_length: int
original_length: int
input_db: float
channels: int
sample_rate: int
padding: bool
dac_version: str
def save(self, path):
artifacts = {
"codes": self.codes.numpy().astype(np.uint16),
"metadata": {
"input_db": self.input_db.numpy().astype(np.float32),
"original_length": self.original_length,
"sample_rate": self.sample_rate,
"chunk_length": self.chunk_length,
"channels": self.channels,
"padding": self.padding,
"dac_version": SUPPORTED_VERSIONS[-1],
},
}
path = Path(path).with_suffix(".dac")
with open(path, "wb") as f:
np.save(f, artifacts)
return path
@classmethod
def load(cls, path):
artifacts = np.load(path, allow_pickle=True)[()]
codes = torch.from_numpy(artifacts["codes"].astype(int))
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
return cls(codes=codes, **artifacts["metadata"])
class CodecMixin:
@property
def padding(self):
if not hasattr(self, "_padding"):
self._padding = True
return self._padding
@padding.setter
def padding(self, value):
assert isinstance(value, bool)
layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
for layer in layers:
if value:
if hasattr(layer, "original_padding"):
layer.padding = layer.original_padding
else:
layer.original_padding = layer.padding
layer.padding = tuple(0 for _ in range(len(layer.padding)))
self._padding = value
def get_delay(self):
# Any number works here, delay is invariant to input length
l_out = self.get_output_length(0)
L = l_out
layers = []
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
layers.append(layer)
for layer in reversed(layers):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.ConvTranspose1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.Conv1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.ceil(L)
l_in = L
return (l_in - l_out) // 2
def get_output_length(self, input_length):
L = input_length
# Calculate output length
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.Conv1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.ConvTranspose1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.floor(L)
return L
@torch.no_grad()
def compress(
self,
audio_path_or_signal: Union[str, Path, AudioSignal],
win_duration: float = 1.0,
verbose: bool = False,
normalize_db: float = -16,
n_quantizers: int = None,
) -> DACFile:
"""Processes an audio signal from a file or AudioSignal object into
discrete codes. This function processes the signal in short windows,
using constant GPU memory.
Parameters
----------
audio_path_or_signal : Union[str, Path, AudioSignal]
audio signal to reconstruct
win_duration : float, optional
window duration in seconds, by default 5.0
verbose : bool, optional
by default False
normalize_db : float, optional
normalize db, by default -16
Returns
-------
DACFile
Object containing compressed codes and metadata
required for decompression
"""
audio_signal = audio_path_or_signal
if isinstance(audio_signal, (str, Path)):
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
self.eval()
original_padding = self.padding
original_device = audio_signal.device
audio_signal = audio_signal.clone()
original_sr = audio_signal.sample_rate
resample_fn = audio_signal.resample
loudness_fn = audio_signal.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if audio_signal.signal_duration >= 10 * 60 * 60:
resample_fn = audio_signal.ffmpeg_resample
loudness_fn = audio_signal.ffmpeg_loudness
original_length = audio_signal.signal_length
resample_fn(self.sample_rate)
input_db = loudness_fn()
if normalize_db is not None:
audio_signal.normalize(normalize_db)
audio_signal.ensure_max_of_audio()
nb, nac, nt = audio_signal.audio_data.shape
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
win_duration = audio_signal.signal_duration if win_duration is None else win_duration
if audio_signal.signal_duration <= win_duration:
# Unchunked compression (used if signal length < win duration)
self.padding = True
n_samples = nt
hop = nt
else:
# Chunked inference
self.padding = False
# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
hop = self.get_output_length(n_samples)
codes = []
range_fn = range if not verbose else tqdm.trange
for i in range_fn(0, nt, hop):
x = audio_signal[..., i : i + n_samples]
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
audio_data = x.audio_data.to(self.device)
audio_data = self.preprocess(audio_data, self.sample_rate)
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
codes.append(c.to(original_device))
chunk_length = c.shape[-1]
codes = torch.cat(codes, dim=-1)
dac_file = DACFile(
codes=codes,
chunk_length=chunk_length,
original_length=original_length,
input_db=input_db,
channels=nac,
sample_rate=original_sr,
padding=self.padding,
dac_version=SUPPORTED_VERSIONS[-1],
)
if n_quantizers is not None:
codes = codes[:, :n_quantizers, :]
self.padding = original_padding
return dac_file
@torch.no_grad()
def decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
"""Reconstruct audio from a given .dac file
Parameters
----------
obj : Union[str, Path, DACFile]
.dac file location or corresponding DACFile object.
verbose : bool, optional
Prints progress if True, by default False
Returns
-------
AudioSignal
Object with the reconstructed audio
"""
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
self.padding = original_padding
return recons
================================================
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py
================================================
import math
from typing import List
from typing import Union
import numpy as np
import torch
from audiotools import AudioSignal
from audiotools.ml import BaseModel
from torch import nn
from .base import CodecMixin
from dac.nn.layers import Snake1d
from dac.nn.layers import WNConv1d
from dac.nn.layers import WNConvTranspose1d
from dac.nn.quantize import ResidualVectorQuantize
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 256,
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2, # out_pad,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
if i == 1:
out_pad = 1
else:
out_pad = 0
layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
# nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(BaseModel, CodecMixin):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: bool = False,
sample_rate: int = 44100,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
n_quantizers: int = None,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z = self.encoder(audio_data)
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
return z, codes, latents, commitment_loss, codebook_loss
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return self.decoder(z)
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = None,
n_quantizers: int = None,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
x = self.decode(z)
return {
"audio": x[..., :length],
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
if __name__ == "__main__":
import numpy as np
from functools import partial
model = DAC().to("cpu")
for n, m in model.named_modules():
o = m.extra_repr()
p = sum([np.prod(p.size()) for p in m.parameters()])
fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
setattr(m, "extra_repr", partial(fn, o=o, p=p))
print(model)
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
length = 88200 * 2
x = torch.randn(1, 1, length).to(model.device)
x.requires_grad_(True)
x.retain_grad()
# Make a forward pass
out = model(x)["audio"]
print("Input shape:", x.shape)
print("Output shape:", out.shape)
# Create gradient variable
grad = torch.zeros_like(out)
grad[:, :, grad.shape[-1] // 2] = 1
# Make a backward pass
out.backward(grad)
# Check non-zero values
gradmap = x.grad.squeeze(0)
gradmap = (gradmap != 0).sum(0) # sum across features
rf = (gradmap != 0).sum()
print(f"Receptive field: {rf.item()}")
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
model.decompress(model.compress(x, verbose=True), verbose=True)
================================================
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
================================================
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py
================================================
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
from dac.nn.layers import WNConv1d
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
# Create mask to apply quantizer dropout
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
if __name__ == "__main__":
rvq = ResidualVectorQuantize(quantizer_dropout=True)
x = torch.randn(16, 512, 80)
y = rvq(x)
print(y["latents"].shape)
================================================
FILE: boson_multimodal/audio_processing/higgs_audio_tokenizer.py
================================================
# Based on code from: https://github.com/zhenye234/xcodec
# Licensed under MIT License
# Modifications by BosonAI
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Sequence
import numpy as np
from transformers import AutoModel
import torchaudio
import json
import librosa
from huggingface_hub import snapshot_download
from vector_quantize_pytorch import ResidualFSQ
from .descriptaudiocodec.dac.model import dac as dac2
from .quantization.vq import ResidualVectorQuantizer
from .semantic_module import Encoder, Decoder
class EncodedResult:
def __init__(self, audio_codes):
self.audio_codes = audio_codes
class HiggsAudioFeatureExtractor(nn.Module):
def __init__(self, sampling_rate=16000):
super().__init__()
self.sampling_rate = sampling_rate
def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
# Convert from librosa to torch
audio_signal = torch.tensor(raw_audio)
audio_signal = audio_signal.unsqueeze(0)
if len(audio_signal.shape) < 3:
audio_signal = audio_signal.unsqueeze(0)
return {"input_values": audio_signal}
class HiggsAudioTokenizer(nn.Module):
def __init__(
self,
n_filters: int = 32,
D: int = 128,
target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
sample_rate: int = 16000,
bins: int = 1024,
n_q: int = 8,
codebook_dim: int = None,
normalize: bool = False,
causal: bool = False,
semantic_techer: str = "hubert_base_general",
last_layer_semantic: bool = True,
merge_mode: str = "concat",
downsample_mode: str = "step_down",
semantic_mode: str = "classic",
vq_scale: int = 1,
semantic_sample_rate: int = None,
device: str = "cuda",
):
super().__init__()
self.hop_length = np.prod(ratios)
self.semantic_techer = semantic_techer
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
self.target_bandwidths = target_bandwidths
self.n_q = n_q
self.sample_rate = sample_rate
self.encoder = dac2.Encoder(64, ratios, D)
self.decoder_2 = dac2.Decoder(D, 1024, ratios)
self.last_layer_semantic = last_layer_semantic
self.device = device
if semantic_techer == "hubert_base":
self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
self.semantic_sample_rate = 16000
self.semantic_dim = 768
self.encoder_semantic_dim = 768
elif semantic_techer == "wavlm_base_plus":
self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
self.semantic_sample_rate = 16000
self.semantic_dim = 768
self.encoder_semantic_dim = 768
elif semantic_techer == "hubert_base_general":
self.semantic_model = AutoModel.from_pretrained("bosonai/hubert_base", trust_remote_code=True)
self.semantic_sample_rate = 16000
self.semantic_dim = 768
self.encoder_semantic_dim = 768
# Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
if semantic_sample_rate is not None:
self.semantic_sample_rate = semantic_sample_rate
self.semantic_model.eval()
# make the semantic model parameters do not need gradient
for param in self.semantic_model.parameters():
param.requires_grad = False
self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
self.decoder_semantic = Decoder(
code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
)
# out_D=D+768
if isinstance(bins, int): # RVQ
self.quantizer = ResidualVectorQuantizer(
dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
)
self.quantizer_type = "RVQ"
else: # RFSQ
self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
self.quantizer_type = "RFSQ"
self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
self.fc_post2 = nn.Linear(self.quantizer_dim, D)
self.downsample_mode = downsample_mode
if downsample_mode == "avg":
self.semantic_pooling = nn.AvgPool1d(
kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor
)
self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
@property
def tps(self):
return self.frame_rate
@property
def sampling_rate(self):
return self.sample_rate
@property
def num_codebooks(self):
return self.n_q
@property
def codebook_size(self):
return self.quantizer_dim
def get_last_layer(self):
return self.decoder.layers[-1].weight
def calculate_rec_loss(self, rec, target):
target = target / target.norm(dim=-1, keepdim=True)
rec = rec / rec.norm(dim=-1, keepdim=True)
rec_loss = (1 - (target * rec).sum(-1)).mean()
return rec_loss
@torch.no_grad()
def get_regress_target(self, x):
x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
if (
self.semantic_techer == "hubert_base"
or self.semantic_techer == "hubert_base_general"
or self.semantic_techer == "wavlm_base_plus"
):
x = x[:, 0, :]
x = F.pad(x, (160, 160))
target = self.semantic_model(x, output_hidden_states=True).hidden_states
target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
# average for all layers
target = target.mean(1)
# target = target[9]
# if self.hop_length > 320:
# target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
elif self.semantic_techer == "w2v_bert2":
target = self.semantic_model(x)
elif self.semantic_techer.startswith("whisper"):
if self.last_layer_semantic:
target = self.semantic_model(x, avg_layers=False)
else:
target = self.semantic_model(x, avg_layers=True)
elif self.semantic_techer.startswith("mert_music"):
if self.last_layer_semantic:
target = self.semantic_model(x, avg_layers=False)
else:
target = self.semantic_model(x, avg_layers=True)
elif self.semantic_techer.startswith("qwen_audio_omni"):
target = self.semantic_model(x)
if self.downsample_mode == "step_down":
if self.semantic_downsample_factor > 1:
target = target[:, :: self.semantic_downsample_factor, :]
elif self.downsample_mode == "avg":
target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
return target
def forward(self, x: torch.Tensor, bw: int):
e_semantic_input = self.get_regress_target(x).detach()
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
e_acoustic = self.encoder(x)
e = torch.cat([e_acoustic, e_semantic], dim=1)
e = self.fc_prior(e.transpose(1, 2))
if self.quantizer_type == "RVQ":
e = e.transpose(1, 2)
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
quantized = quantized.transpose(1, 2)
else:
quantized, codes = self.quantizer(e)
commit_loss = torch.tensor(0.0)
quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
o = self.decoder_2(quantized_acoustic)
o_semantic = self.decoder_semantic(quantized_semantic)
semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
return o, commit_loss, semantic_recon_loss, None
def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0):
if isinstance(audio_path_or_wv, str):
wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
else:
wv = audio_path_or_wv
assert sr is not None
if loudness_normalize:
import pyloudnorm as pyln
meter = pyln.Meter(sr)
l = meter.integrated_loudness(wv)
wv = pyln.normalize.loudness(wv, l, loudness_threshold)
if sr != self.sampling_rate:
wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
if self.audio_tokenizer_feature_extractor is not None:
inputs = self.audio_tokenizer_feature_extractor(
raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt"
)
input_values = inputs["input_values"].to(self.device)
else:
input_values = torch.from_numpy(wv).float().unsqueeze(0)
with torch.no_grad():
encoder_outputs = self._xcodec_encode(input_values)
vq_code = encoder_outputs.audio_codes[0]
return vq_code
def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
bw = target_bw
e_semantic_input = self.get_regress_target(x).detach()
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
e_acoustic = self.encoder(x)
if e_acoustic.shape[2] != e_semantic.shape[2]:
pad_size = 160 * self.semantic_downsample_factor
e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
if e_acoustic.shape[2] != e_semantic.shape[2]:
if e_acoustic.shape[2] > e_semantic.shape[2]:
e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
else:
e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
e = torch.cat([e_acoustic, e_semantic], dim=1)
e = self.fc_prior(e.transpose(1, 2))
if self.quantizer_type == "RVQ":
e = e.transpose(1, 2)
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
codes = codes.permute(1, 0, 2)
else:
quantized, codes = self.quantizer(e)
codes = codes.permute(0, 2, 1)
# return codes
return EncodedResult(codes)
def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
vq_code = vq_code.to(self.device)
if self.quantizer_type == "RVQ":
vq_code = vq_code.permute(1, 0, 2)
quantized = self.quantizer.decode(vq_code)
quantized = quantized.transpose(1, 2)
else:
vq_code = vq_code.permute(0, 2, 1)
quantized = self.quantizer.get_output_from_indices(vq_code)
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
o = self.decoder_2(quantized_acoustic)
return o.detach().cpu().numpy()
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
is_local = os.path.exists(tokenizer_name_or_path)
if not is_local:
tokenizer_path = snapshot_download(tokenizer_name_or_path)
else:
tokenizer_path = tokenizer_name_or_path
config_path = os.path.join(tokenizer_path, "config.json")
model_path = os.path.join(tokenizer_path, "model.pth")
config = json.load(open(config_path))
model = HiggsAudioTokenizer(
**config,
device=device,
)
parameter_dict = torch.load(model_path, map_location=device)
model.load_state_dict(parameter_dict, strict=False)
model.to(device)
model.eval()
return model
================================================
FILE: boson_multimodal/audio_processing/quantization/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer
================================================
FILE: boson_multimodal/audio_processing/quantization/ac.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp
import torch
from ..binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(
pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True
) -> torch.Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2**total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: torch.Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2**self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
return outs
def flush(self):
"""Flush the remaining information to the stream."""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
self._last: tp.Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
self.current -= b1 << self.max_bit
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2**self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return (mid, low, high, self.current)
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()
================================================
FILE: boson_multimodal/audio_processing/quantization/core_vq.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from xcodec.quantization.distrib import broadcast_tensors, rank
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x) # get index based on Euclidean distance
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
================================================
FILE: boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py
================================================
# Copyright (c)
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# This implementation is inspired from
# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from .distrib import broadcast_tensors, is_distributed
from .ddp_utils import SyncFunction
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
"""
Memory-efficient K-means clustering.
Args:
samples (tensor): shape [N, D]
num_clusters (int): number of centroids.
num_iters (int): number of iterations.
frames_to_use (int): subsample size from total samples.
batch_size (int): batch size used in distance computation.
Returns:
means: [num_clusters, D]
bins: [num_clusters] (number of points per cluster)
"""
N, D = samples.shape
dtype, device = samples.dtype, samples.device
if frames_to_use < N:
indices = torch.randperm(N, device=device)[:frames_to_use]
samples = samples[indices]
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
# Store cluster assignments
all_assignments = []
for i in range(0, samples.shape[0], batch_size):
batch = samples[i : i + batch_size] # [B, D]
dists = torch.cdist(batch, means, p=2) # [B, C]
assignments = dists.argmin(dim=1) # [B]
all_assignments.append(assignments)
buckets = torch.cat(all_assignments, dim=0) # [N]
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
# Compute new means
new_means = torch.zeros_like(means)
for i in range(num_clusters):
mask = buckets == i
if mask.any():
new_means[i] = samples[mask].mean(dim=0)
means = torch.where(zero_mask[:, None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
# Flag variable to indicate whether the codebook is initialized
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
# Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
self.register_buffer("cluster_size", torch.zeros(codebook_size))
# Codebook
self.register_buffer("embed", embed)
# EMA codebook: eq. (7) in vqvae paper
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
"""Initialize codebook.
Args:
data (tensor): [B * T, D].
"""
if self.inited:
return
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
if dist.is_available() and dist.is_initialized():
# [B * T * world_size, D]
data = SyncFunction.apply(data)
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
if is_distributed():
# [B * T * world_size, D]
batch_samples = SyncFunction.apply(batch_samples)
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
# shape: [B, T, D]
shape, dtype = x.shape, x.dtype
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
# Initialize codebook
self.init_embed_(x)
embed_ind = self.quantize(x) # [B*T,]
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
quantize = self.dequantize(embed_ind) # [B, T, D]
if self.training:
### Update codebook by EMA
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
embed_sum = x.t() @ embed_onehot # [D, cb-size]
if is_distributed():
dist.all_reduce(embed_onehot_sum)
dist.all_reduce(embed_sum)
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
# Update ema embed: eq. (7) in vqvae paper
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
# apply laplace smoothing
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
# Update ema embed: eq. (8) in vqvae paper
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
================================================
FILE: boson_multimodal/audio_processing/quantization/ddp_utils.py
================================================
import logging
import random
import subprocess
from datetime import datetime
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
import torch.optim
import torch.utils.data
from packaging import version
from omegaconf import OmegaConf
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def is_logging_process():
return not dist.is_initialized() or dist.get_rank() == 0
def get_logger(cfg, name=None):
# log_file_path is used when unit testing
if is_logging_process():
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
return logging.getLogger(name)
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
class SyncFunction(torch.autograd.Function):
@staticmethod
# @torch.no_grad()
def forward(ctx, tensor):
ctx.batch_size = tensor.shape[0]
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")
def get_commit_hash():
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
return message.strip().decode("utf-8")
class DDP(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs): # pragma: no cover
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
self._sync_params()
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
assert len(self.device_ids) == 1
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
from torch.nn.parallel.distributed import (
logging,
Join,
_DDPSink,
_tree_flatten_with_rref,
_tree_unflatten_with_rref,
)
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.reducer.prepare_for_forward()
# Notify the join context that this process has not joined, if
# needed
work = Join.notify_join_context(self)
if work:
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
# Calling _rebuild_buckets before forward compuation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
logging.info("Reducer buckets have been rebuilt in this iteration.")
self._has_rebuilt_buckets = True
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
buffer_hook_registered = hasattr(self, "buffer_hook")
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
):
state_dict = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
}
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
output_placeholders = [None for _ in range(len(output_tensor_list))]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as https://github.com/pytorch/pytorch/issues/60733
for i, output in enumerate(output_tensor_list):
if torch.is_tensor(output) and output.grad_fn is None:
output_placeholders[i] = output
# When find_unused_parameters=True, makes tensors which require grad
# run through the DDPSink backward pass. When not all outputs are
# used in loss, this makes those corresponding tensors receive
# undefined gradient which the reducer then handles to ensure
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
self.reducer,
state_dict,
*output_tensor_list,
)
for i in range(len(output_placeholders)):
if output_placeholders[i] is None:
output_placeholders[i] = passthrough_tensor_list[i]
# Reconstruct output data structure.
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
return output
================================================
FILE: boson_multimodal/audio_processing/quantization/distrib.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
# print('params[0].device ', params[0].device)
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
)
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))
================================================
FILE: boson_multimodal/audio_processing/quantization/vq.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import math
import typing as tp
import torch
from torch import nn
# from .core_vq import ResidualVectorQuantization
from .core_vq_lsx_version import ResidualVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
codebook_dim: int = None,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.codebook_dim = codebook_dim
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_dim=self.codebook_dim,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
sample_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return quantized, codes, bw, torch.mean(commit_loss)
# return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
"""Return n_q based on specified target bandwidth."""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.0:
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
return n_q
def get_bandwidth_per_quantizer(self, sample_rate: int):
"""Return bandwidth per quantizer for a given input sample rate."""
return math.log2(self.bins) * sample_rate / 1000
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
codes = self.vq.encode(x, n_q=n_q)
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation."""
quantized = self.vq.decode(codes)
return quantized
================================================
FILE: boson_multimodal/audio_processing/semantic_module.py
================================================
# Based on code from: https://github.com/zhenye234/xcodec
# Licensed under MIT License
# Modifications by BosonAI
import torch
import torch.nn as nn
class Conv1d1x1(nn.Conv1d):
"""1x1 Conv1d."""
def __init__(self, in_channels, out_channels, bias=True):
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
class Conv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = -1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
if padding < 0:
padding = (kernel_size - 1) // 2 * dilation
self.dilation = dilation
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C, T).
"""
x = self.conv(x)
return x
class ResidualUnit(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
dilation=1,
bias=False,
nonlinear_activation="ELU",
nonlinear_activation_params={},
):
super().__init__()
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
self.conv1 = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
bias=bias,
)
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
def forward(self, x):
y = self.conv1(self.activation(x))
y = self.conv2(self.activation(y))
return x + y
class ConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding=-1,
output_padding=-1,
groups=1,
bias=True,
):
super().__init__()
if padding < 0:
padding = (stride + 1) // 2
if output_padding < 0:
output_padding = 1 if stride % 2 else 0
self.deconv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C', T').
"""
x = self.deconv(x)
return x
class EncoderBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
):
super().__init__()
self.res_units = torch.nn.ModuleList()
for dilation in dilations:
self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation)]
self.num_res = len(self.res_units)
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
stride=stride,
bias=bias,
)
def forward(self, x):
for idx in range(self.num_res):
x = self.res_units[idx](x)
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
input_channels: int,
encode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv = Conv1d(
in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False
)
self.conv_blocks = torch.nn.ModuleList()
in_channels = encode_channels
for idx, stride in enumerate(strides):
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
self.conv_blocks += [
EncoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
in_channels = out_channels
self.num_blocks = len(self.conv_blocks)
self.out_channels = out_channels
def forward(self, x):
x = self.conv(x)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
return x
class DecoderBlock(nn.Module):
"""Decoder block (no up-sampling)"""
def __init__(
self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
):
super().__init__()
if stride == 1:
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
stride=stride,
bias=bias,
)
else:
self.conv = ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(2 * stride),
stride=stride,
bias=bias,
)
self.res_units = torch.nn.ModuleList()
for idx, dilation in enumerate(dilations):
self.res_units += [
ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation)
]
self.num_res = len(self.res_units)
def forward(self, x):
x = self.conv(x)
for idx in range(self.num_res):
x = self.res_units[idx](x)
return x
class Decoder(nn.Module):
def __init__(
self,
code_dim: int,
output_channels: int,
decode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv1 = Conv1d(
in_channels=code_dim,
out_channels=int(decode_channels * channel_ratios[0]),
kernel_size=kernel_size,
stride=1,
bias=False,
)
self.conv_blocks = torch.nn.ModuleList()
for idx, stride in enumerate(strides):
in_channels = int(decode_channels * channel_ratios[idx])
if idx < (len(channel_ratios) - 1):
out_channels = int(decode_channels * channel_ratios[idx + 1])
else:
out_channels = decode_channels
self.conv_blocks += [
DecoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
self.num_blocks = len(self.conv_blocks)
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
def forward(self, z):
x = self.conv1(z)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
x = self.conv2(x)
return x
================================================
FILE: boson_multimodal/constants.py
================================================
AUDIO_IN_TOKEN = "<|AUDIO|>"
AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
EOS_TOKEN = "<|end_of_text|>"
================================================
FILE: boson_multimodal/data_collator/__init__.py
================================================
================================================
FILE: boson_multimodal/data_collator/higgs_audio_collator.py
================================================
import librosa
import torch
import torch.nn.functional as F
import math
from typing import List, Tuple
from dataclasses import dataclass
from typing import List, Optional
from transformers.models.whisper.processing_whisper import WhisperProcessor
from ..dataset.chatml_dataset import ChatMLDatasetSample
from ..model.higgs_audio.utils import build_delay_pattern_mask
def _ceil_to_nearest(n, round_to):
return (n + round_to - 1) // round_to * round_to
def _ceil_to_next_power_of_two(self, x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
@dataclass
class HiggsAudioBatchInput:
input_ids: torch.LongTensor # shape (bsz, seq_len).
attention_mask: torch.Tensor # shape (bsz, seq_len).
audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
# The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
# Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
# For example,
# audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
# This is a batch of 3 samples, then we will have the group location as:
# audio_out_ids_start_group_loc = [0, 0, 1, 2]
audio_out_ids_start_group_loc: Optional[
torch.LongTensor
] # shape (num_audio_out,), specify which a sample's group location in the batch
audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
reward: Optional[float] = None
class HiggsAudioSampleCollator:
"""Sample collator for Higgs-Audio model.
Args:
whisper_processor (WhisperProcessor): The whisper processor.
audio_in_token_id (int): The token id for audio-in.
audio_out_token_id (int): The token id for audio-out.
pad_token_id (int): The token id for padding.
audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
audio_stream_eos_id (int): The token id for audio-stream end of sentence.
round_to (int): The round-to value.
pad_left (bool): Whether to pad left.
return_audio_in_tokens (bool): Whether to return audio-in tokens.
use_delay_pattern (bool): Whether to use delay pattern.
disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
chunk_size_seconds (int): The chunk size in seconds.
add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
"""
def __init__(
self,
whisper_processor: WhisperProcessor,
audio_in_token_id,
audio_out_token_id,
pad_token_id,
audio_stream_bos_id,
audio_stream_eos_id,
round_to=8,
pad_left=False,
encode_whisper_embed=True,
return_audio_in_tokens=True,
audio_num_codebooks=None,
use_delay_pattern=False,
disable_audio_codes_transform=False,
chunk_size_seconds=30, # Maximum duration for each chunk
add_new_bos_eos_for_long_chunk=True,
mask_audio_out_token_label=True,
):
self.whisper_processor = whisper_processor
self.round_to = round_to
self.pad_left = pad_left
self.audio_in_token_id = audio_in_token_id
self.audio_out_token_id = audio_out_token_id
self.audio_stream_bos_id = audio_stream_bos_id
self.audio_stream_eos_id = audio_stream_eos_id
self.pad_token_id = pad_token_id
self.encode_whisper_embed = encode_whisper_embed
self.return_audio_in_tokens = return_audio_in_tokens
self.audio_num_codebooks = audio_num_codebooks
self.use_delay_pattern = use_delay_pattern
if encode_whisper_embed:
self.chunk_size_seconds = chunk_size_seconds
self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
else:
self.chunk_size_seconds = None
self.chunk_size_samples = None
self.disable_audio_codes_transform = disable_audio_codes_transform
self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
self.mask_audio_out_token_label = mask_audio_out_token_label
def _process_and_duplicate_audio_tokens(
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, sr: int, labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Process long audio and duplicate corresponding audio tokens.
Args:
input_ids: Input token ids
audio_idx: Index of the audio token in the sequence
wv: Audio waveform
sr: Sample rate
labels: Optional label ids to be duplicated alongside input ids
Returns:
Tuple of:
- New input ids with duplicated audio tokens
- New label ids (if labels were provided) or None
- Number of chunks created
"""
# Calculate number of chunks needed
total_samples = len(wv)
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
if num_chunks <= 1:
return input_ids, labels, 1
# Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
# Duplicate sequence for each chunk
duplicated_sequence = audio_token_seq.repeat(num_chunks)
# Create new input_ids with duplicated tokens
new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]])
# If labels are provided, duplicate them as well
new_labels = None
if labels is not None:
label_seq = labels[audio_idx - 1 : audio_idx + 2]
duplicated_labels = label_seq.repeat(num_chunks)
new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
return new_input_ids, new_labels, num_chunks
def __call__(self, batch: List[ChatMLDatasetSample]):
"""Collate the input data with support for long audio processing."""
label_ids = None
label_audio_ids = None
if all([ele.label_ids is None for ele in batch]):
return_labels = False
else:
return_labels = True
if self.encode_whisper_embed:
# Process each sample in the batch to handle long audio
# TODO(?) The implementation here can be optimized.
processed_batch = []
for i in range(len(batch)):
sample = batch[i]
audio_in_mask = sample.input_ids == self.audio_in_token_id
audio_in_indices = torch.where(audio_in_mask)[0]
audio_out_mask = sample.input_ids == self.audio_out_token_id
# Process each audio token and duplicate if needed
modified_input_ids = sample.input_ids
modified_labels = sample.label_ids if return_labels else None
modified_waveforms_concat = []
modified_waveforms_start = []
modified_sample_rate = []
offset = 0 # Track position changes from duplicating tokens
curr_wv_offset = 0
# Process input audio tokens
for idx, audio_idx in enumerate(audio_in_indices):
# Get the audio for this token
wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
if sr != self.whisper_processor.feature_extractor.sampling_rate:
resampled_wv = librosa.resample(
wv.cpu().numpy(),
orig_sr=sr,
target_sr=self.whisper_processor.feature_extractor.sampling_rate,
)
else:
resampled_wv = wv.cpu().numpy()
wv = torch.tensor(resampled_wv, device=wv.device)
sr = self.whisper_processor.feature_extractor.sampling_rate
# Process and duplicate tokens if necessary
token_pos = audio_idx + offset
modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
modified_input_ids, token_pos, wv, sr, modified_labels
)
# Update audio data
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * self.chunk_size_samples
chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
chunk_wv = wv[chunk_start:chunk_end]
modified_waveforms_concat.append(chunk_wv)
modified_waveforms_start.append(curr_wv_offset)
curr_wv_offset += len(chunk_wv)
modified_sample_rate.append(sr)
# Update offset for next iteration
offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
# Create new sample with modified tokens and audio data
processed_sample = ChatMLDatasetSample(
input_ids=modified_input_ids,
label_ids=modified_labels if return_labels else sample.label_ids,
audio_ids_concat=sample.audio_ids_concat,
audio_ids_start=sample.audio_ids_start,
audio_waveforms_concat=torch.cat(modified_waveforms_concat)
if modified_waveforms_concat
else sample.audio_waveforms_concat,
audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
if modified_waveforms_start
else sample.audio_waveforms_start,
audio_sample_rate=torch.tensor(modified_sample_rate)
if modified_sample_rate
else sample.audio_sample_rate,
audio_speaker_indices=torch.tensor([]),
# FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
audio_label_ids_concat=sample.audio_label_ids_concat,
)
# audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
# assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
processed_batch.append(processed_sample)
else:
processed_batch = batch
# Get the max sequence length based on processed batch
max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
# Get the ids for audio-in and audio-out for each batch
audio_in_wv_l = []
audio_in_ids_l = []
audio_out_ids_l = []
audio_out_ids_group_loc_l = []
audio_in_label_ids_l = None
audio_out_label_ids_l = None
reward_l = []
if return_labels:
audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
# Process the audio inputs and outputs
for i in range(len(processed_batch)):
audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
audio_ids = torch.ones_like(processed_batch[i].input_ids)
audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
audio_in_ids = audio_ids[audio_in_mask]
audio_out_ids = audio_ids[audio_out_mask]
if return_labels:
audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
if self.mask_audio_out_token_label:
processed_batch[i].label_ids[audio_out_mask] = -100
# Process audio inputs
if self.return_audio_in_tokens:
audio_in_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_in_label_ids_l is None:
audio_in_label_ids_l = []
audio_in_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_in_ids
]
)
audio_out_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
)
audio_out_ids_group_loc_l.append(i)
if processed_batch[i].reward is not None:
reward_l.append(processed_batch[i].reward)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_out_label_ids_l is None:
audio_out_label_ids_l = []
audio_out_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_out_ids
]
)
if self.encode_whisper_embed:
for idx in audio_in_ids:
wv, sr = processed_batch[i].get_wv(idx)
resampled_wv = wv.cpu().numpy()
# Split long audio into chunks
total_samples = len(resampled_wv)
for chunk_start in range(0, total_samples, self.chunk_size_samples):
chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
chunk = resampled_wv[chunk_start:chunk_end]
audio_in_wv_l.append(chunk)
# assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
# f"Assertion failed: Mismatch in number of audios. " \
# f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
if return_labels:
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
# Process all audio features
if len(audio_in_wv_l) > 0:
feature_ret = self.whisper_processor.feature_extractor(
audio_in_wv_l,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
audio_features = torch.from_numpy(feature_ret["input_features"])
audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
else:
if self.encode_whisper_embed:
audio_features = torch.zeros(
(
0,
self.whisper_processor.feature_extractor.feature_size,
self.whisper_processor.feature_extractor.nb_max_frames,
),
dtype=torch.float32,
)
audio_feature_attention_mask = torch.zeros(
(0, self.whisper_processor.feature_extractor.nb_max_frames), dtype=torch.int32
)
else:
audio_features = None
audio_feature_attention_mask = None
# Process audio input tokens
if len(audio_in_ids_l) > 0:
# Append audio-stream-bos and eos tokens
new_audio_in_ids_l = []
for ele in audio_in_ids_l:
if self.disable_audio_codes_transform:
# Do not add audio-stream-bos or eos tokens.
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
audio_codes = ele
else:
audio_codes = torch.cat(
[
torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
ele,
torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
],
dim=1,
)
if self.use_delay_pattern:
audio_codes = build_delay_pattern_mask(
audio_codes.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id,
)[0].squeeze(0)
new_audio_in_ids_l.append(audio_codes)
audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
audio_in_ids_start = torch.cumsum(
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), dim=0
)
else:
audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
audio_in_ids_start = torch.zeros(0, dtype=torch.long)
# Process audio output tokens
audio_out_ids_start_group_loc = None
if len(audio_out_ids_l) > 0:
new_audio_out_ids_l = []
label_audio_ids_l = []
for idx, ele in enumerate(audio_out_ids_l):
if self.disable_audio_codes_transform:
# Do not add audio-stream-bos or eos tokens.
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
audio_codes = ele
if return_labels:
label_audio_ids = audio_out_label_ids_l[idx]
else:
audio_codes = torch.cat(
[
torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
ele,
torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
],
dim=1,
)
if return_labels:
label_audio_ids = torch.cat(
[
torch.full((ele.shape[0], 1), -100, dtype=torch.long),
ele,
torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
],
dim=1,
)
if self.use_delay_pattern:
audio_codes = build_delay_pattern_mask(
audio_codes.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id,
)[0].squeeze(0)
if return_labels:
label_audio_ids = build_delay_pattern_mask(
label_audio_ids.unsqueeze(0),
bos_token_id=-100,
pad_token_id=-100,
)[0].squeeze(0)
new_audio_out_ids_l.append(audio_codes)
if return_labels:
if audio_out_no_train_flag[idx]:
label_audio_ids[:] = -100
label_audio_ids_l.append(label_audio_ids)
audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
if return_labels:
label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
audio_out_ids_start = torch.cumsum(
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0
)
audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
else:
audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
audio_out_ids_start = torch.zeros(0, dtype=torch.long)
if return_labels:
label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
reward = torch.tensor(reward_l, dtype=torch.float32)
# Handle padding for input ids and attention mask
if self.pad_left:
input_ids = torch.stack(
[
F.pad(ele.input_ids, (max_seq_length - len(ele.input_ids), 0), value=self.pad_token_id)
for ele in processed_batch
]
)
if return_labels:
label_ids = torch.stack(
[
F.pad(ele.label_ids, (max_seq_length - len(ele.label_ids), 0), value=-100)
for ele in processed_batch
]
)
attention_mask = torch.stack(
[
F.pad(torch.ones_like(ele.input_ids), (max_seq_length - len(ele.input_ids), 0), value=0)
for ele in processed_batch
]
)
else:
input_ids = torch.stack(
[
F.pad(ele.input_ids, (0, max_seq_length - len(ele.input_ids)), value=self.pad_token_id)
for ele in processed_batch
]
)
if return_labels:
label_ids = torch.stack(
[
F.pad(ele.label_ids, (0, max_seq_length - len(ele.label_ids)), value=-100)
for ele in processed_batch
]
)
attention_mask = torch.stack(
[
F.pad(torch.ones_like(ele.input_ids), (0, max_seq_length - len(ele.input_ids)), value=0)
for ele in processed_batch
]
)
if not self.return_audio_in_tokens:
audio_in_ids = None
audio_in_ids_start = None
# Apply audio_num_codebooks limit if specified
if self.audio_num_codebooks is not None:
if audio_in_ids is not None:
audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
if audio_out_ids is not None:
audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
if label_audio_ids is not None:
label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
return HiggsAudioBatchInput(
input_ids=input_ids,
attention_mask=attention_mask,
audio_features=audio_features,
audio_feature_attention_mask=audio_feature_attention_mask,
audio_out_ids=audio_out_ids,
audio_out_ids_start=audio_out_ids_start,
audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
audio_in_ids=audio_in_ids,
audio_in_ids_start=audio_in_ids_start,
label_ids=label_ids,
label_audio_ids=label_audio_ids,
reward=reward,
)
================================================
FILE: boson_multimodal/data_types.py
================================================
"""Basic data types for multimodal ChatML format."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
@dataclass
class AudioContent:
audio_url: str
# Base64 encoded audio bytes
raw_audio: Optional[str] = None
offset: Optional[float] = None
duration: Optional[float] = None
row_id: Optional[int] = None
type: str = "audio"
@dataclass
class TextContent:
text: str
type: str = "text"
@dataclass
class Message:
role: str
content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
recipient: Optional[str] = None
@dataclass
class ChatMLSample:
"""Dataclass to hold multimodal ChatML data."""
messages: List[Message]
start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
misc: Optional[Dict] = None
speaker: Optional[str] = None
================================================
FILE: boson_multimodal/dataset/__init__.py
================================================
================================================
FILE: boson_multimodal/dataset/chatml_dataset.py
================================================
import dacite
import pandas as pd
import torch
import json
import numpy as np
import multiprocessing as mp
from dataclasses import dataclass, fields
from abc import ABC, abstractmethod
from typing import Union, List, Dict, Optional
from ..data_types import ChatMLSample, TextContent, AudioContent
from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
from loguru import logger
# Whisper processor, 30 sec -> 3000 features
# Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
@dataclass
class ChatMLDatasetSample:
input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
audio_ids_start: (
torch.LongTensor
) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
audio_waveforms_concat: (
torch.Tensor
) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
audio_waveforms_start: (
torch.LongTensor
) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
audio_speaker_indices: (
torch.LongTensor
) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
audio_label_ids_concat: Optional[torch.LongTensor] = (
None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
)
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
reward: Optional[float] = None
def num_audios(self):
return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
def get_audio_codes(self, idx):
code_start = self.audio_ids_start[idx]
if idx < len(self.audio_ids_start) - 1:
code_end = self.audio_ids_start[idx + 1]
else:
code_end = self.audio_ids_concat.shape[-1]
return self.audio_ids_concat[:, code_start:code_end]
def get_audio_codes_labels(self, idx):
if self.audio_label_ids_concat is None:
return None
code_start = self.audio_ids_start[idx]
if idx < len(self.audio_ids_start) - 1:
code_end = self.audio_ids_start[idx + 1]
else:
code_end = self.audio_ids_concat.shape[-1]
return self.audio_label_ids_concat[:, code_start:code_end]
def get_wv(self, idx):
wv_start = self.audio_waveforms_start[idx]
sr = self.audio_sample_rate[idx]
if idx < len(self.audio_waveforms_start) - 1:
wv_end = self.audio_waveforms_start[idx + 1]
else:
wv_end = self.audio_waveforms_concat.shape[-1]
return self.audio_waveforms_concat[wv_start:wv_end], sr
def cal_num_tokens(
self,
encode_whisper_embed: bool = True,
encode_audio_in_tokens: bool = False,
encode_audio_out_tokens: bool = True,
audio_in_token_id: int = 128015,
audio_out_token_id: int = 128016,
) -> int:
# we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
# It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
num_tokens = len(self.input_ids) - len(self.audio_ids_start)
if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
audio_lengths = torch.diff(self.audio_waveforms_start)
if len(audio_lengths):
# Sum before calling .item()
num_tokens += (
(
np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
).sum()
).item()
# add the last audio's token estimation
num_tokens += (
np.ceil(
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
* (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
/ self.audio_sample_rate[-1]
)
).item()
if self.audio_ids_concat.size(1) > 0:
audio_io_ids = self.input_ids[
(self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
]
audio_io_id_lengths = torch.concat(
[
torch.diff(self.audio_ids_start),
torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
]
)
if encode_audio_in_tokens:
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
if encode_audio_out_tokens:
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
return int(num_tokens)
@classmethod
def merge(
cls,
samples: List["ChatMLDatasetSample"],
eos_token_id: int,
ignore_index: int,
padding_size: Optional[int] = None,
) -> "ChatMLDatasetSample":
"""Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
Args:
samples (List[ChatMLDatasetSample]): List of samples to merge.
eos_token_id (int): Tokens to be inserted into input_ids between samples.
ignore_index (int): Default label for padding.
padding_size (Optional[int]): If provided, pad the sequence to with this length.
Returns:
ChatMLDatasetSample: Merged and potentially padded sample.
"""
if not samples:
logger.fatal("The samples list is empty and cannot be merged.")
raise ValueError("The samples list is empty and cannot be merged.")
# Initialize empty lists for concatenation
input_ids_list = []
label_ids_list = []
audio_ids_concat_list = []
audio_ids_start_list = []
audio_waveforms_concat_list = []
audio_waveforms_start_list = []
audio_sample_rate_list = []
audio_speaker_indices_list = []
# Track offsets
audio_ids_offset = 0
audio_waveforms_offset = 0
for sample in samples:
# Add input_ids and label_ids with padding
if input_ids_list:
input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
input_ids_list.append(sample.input_ids)
label_ids_list.append(sample.label_ids)
# Add audio_ids_concat and handle empty audio ids
if sample.audio_ids_concat.size(1) > 0:
audio_ids_concat_list.append(sample.audio_ids_concat)
# Offset and add audio_ids_start
audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
audio_ids_offset += sample.audio_ids_concat.size(
1
) # (num_codebooks, seq_len): Update offset by audio_seq_len
# Add audio_waveforms_concat
if sample.audio_waveforms_concat.size(0) > 0:
# Check dimensions of the audio waveform to ensure consistency
if (
audio_waveforms_concat_list
and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
):
logger.warning(
f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
)
continue
audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
# Add audio_sample_rate and audio_speaker_indices
audio_sample_rate_list.append(sample.audio_sample_rate)
audio_speaker_indices_list.append(sample.audio_speaker_indices)
# Concatenate all tensors
input_ids = torch.cat(input_ids_list, dim=0)
label_ids = torch.cat(label_ids_list, dim=0)
# Apply padding if padding_size is specified
if padding_size is not None and padding_size > 0:
input_ids = torch.cat([input_ids, torch.full((padding_size,), eos_token_id, dtype=torch.long)], dim=0)
label_ids = torch.cat([label_ids, torch.full((padding_size,), ignore_index, dtype=torch.long)], dim=0)
# Safely concatenate audio tensors with proper error handling
try:
audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
# Check for dimensional consistency in audio waveforms
if audio_waveforms_concat_list:
dims = [t.dim() for t in audio_waveforms_concat_list]
if not all(d == dims[0] for d in dims):
# If dimensions don't match, log warning and filter out the problematic tensors
logger.warning(
f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
)
expected_dim = max(set(dims), key=dims.count) # Most common dimension
audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
# Recalculate audio_waveforms_start with the filtered list
if audio_waveforms_concat_list:
audio_waveforms_offset = 0
audio_waveforms_start_list = []
for waveform in audio_waveforms_concat_list:
audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
audio_waveforms_offset += waveform.size(0)
audio_waveforms_concat = (
torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
)
audio_waveforms_start = (
torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
)
audio_sample_rate = (
torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
)
audio_speaker_indices = (
torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
)
except RuntimeError as e:
logger.error(f"Error during tensor concatenation: {str(e)}")
logger.warning("Falling back to empty audio tensors")
# Fall back to empty tensors
audio_ids_concat = torch.tensor([[]])
audio_ids_start = torch.tensor([])
audio_waveforms_concat = torch.tensor([])
audio_waveforms_start = torch.tensor([])
audio_sample_rate = torch.tensor([])
audio_speaker_indices = torch.tensor([])
# Create the merged sample
merged_sample = cls(
input_ids=input_ids,
label_ids=label_ids,
audio_ids_concat=audio_ids_concat,
audio_ids_start=audio_ids_start,
audio_waveforms_concat=audio_waveforms_concat,
audio_waveforms_start=audio_waveforms_start,
audio_sample_rate=audio_sample_rate,
audio_speaker_indices=audio_speaker_indices,
)
return merged_sample
@dataclass
class RankedChatMLDatasetSampleTuple:
samples: List[ChatMLDatasetSample]
scores: List[float]
def max_score_sample(self) -> ChatMLDatasetSample:
idx = self.scores.index(max(self.scores))
self.samples[idx].reward = self.scores[idx]
return self.samples[idx]
def min_score_sample(self) -> ChatMLDatasetSample:
idx = self.scores.index(min(self.scores))
self.samples[idx].reward = self.scores[idx]
return self.samples[idx]
@dataclass
class ChatMLDatasetStorageSample:
input_tokens: torch.LongTensor
label_tokens: torch.LongTensor
audio_bytes_cache_dir_index: int
audio_codes_cache_dir_index: int
audio_bytes_indices: torch.LongTensor
audio_codes_indices: torch.LongTensor
speaker_indices: torch.LongTensor
file_index: int
original_sample_index: int
# TODO(sxjscience): We need to revist the logic about parsing speaker ids.
# Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
"""Preprocess the ChatML sample to get the tokens for the text part.
Args:
sample (ChatMLSample): The ChatML sample to preprocess.
tokenizer: The tokenizer to use for encoding the text.
"""
try:
if not isinstance(sample, ChatMLSample):
# Handle all fields that could be NaN
if "speaker" in sample and pd.isna(sample["speaker"]):
sample["speaker"] = None
if "start_index" in sample and pd.isna(sample["start_index"]):
sample["start_index"] = None
if "content" in sample and pd.isna(sample["content"]):
sample["content"] = ""
# Convert any other potential NaN values in nested structures
def convert_nan_to_none(obj):
import numpy as np
if isinstance(obj, (pd.Series, np.ndarray)):
return obj.tolist()
elif pd.api.types.is_scalar(obj) and pd.isna(obj):
return None
elif isinstance(obj, dict):
return {k: convert_nan_to_none(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
return [convert_nan_to_none(item) for item in obj]
return obj
# Clean the sample data
clean_sample = convert_nan_to_none(sample)
val_keys = []
for field in fields(ChatMLSample):
if field.name in clean_sample:
val_keys.append(field.name)
clean_sample = {k: clean_sample[k] for k in val_keys}
try:
sample = dacite.from_dict(
data_class=ChatMLSample, data=clean_sample, config=dacite.Config(strict=True, check_types=True)
)
except Exception as e:
print(f"Failed to convert to ChatMLSample: {e}")
print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
return None, None, None, None
input_tokens = []
label_tokens = []
audio_contents = []
speaker_id = None
if sample.speaker is not None:
speaker_id = sample.speaker
elif sample.misc is not None:
if "speaker" in sample.misc:
speaker_id = sample.misc["speaker"]
total_m = len(sample.messages)
for turn_id, message in enumerate(sample.messages):
role = message.role
recipient = message.recipient
content = message.content
content_l = []
if isinstance(content, str):
content_l.append(TextContent(text=content))
elif isinstance(content, TextContent):
content_l.append(content)
elif isinstance(content, AudioContent):
content_l.append(content)
elif isinstance(content, list):
for ele in content:
if isinstance(ele, str):
content_l.append(TextContent(text=ele))
else:
content_l.append(ele)
if turn_id == 0:
prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
else:
prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
eot_postfix = "<|eot_id|>"
eom_postfix = "<|eom_id|>"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
input_tokens.extend(prefix_tokens)
label_tokens.extend([-100 for _ in prefix_tokens])
if recipient:
assert role == "assistant", "Recipient is only available for assistant role."
recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
input_tokens.extend(recipient_tokens)
label_tokens.extend(recipient_tokens)
for content in content_l:
if content.type == "text":
text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
input_tokens.extend(text_tokens)
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
label_tokens.extend(text_tokens)
else:
label_tokens.extend([-100 for _ in text_tokens])
elif content.type == "audio":
# Generate the text-part of the audio tokens
audio_contents.append(content)
if role == "user" or role == "system":
# Add the text tokens
text_tokens = tokenizer.encode(
f"<|audio_bos|><|AUDIO|><|audio_eos|>",
add_special_tokens=False,
)
input_tokens.extend(text_tokens)
label_tokens.extend([-100 for _ in text_tokens])
elif role == "assistant":
# Add the text tokens for audio-out part.
text_tokens = tokenizer.encode(
f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
add_special_tokens=False,
)
input_tokens.extend(text_tokens)
if sample.start_index is None or turn_id >= sample.start_index:
label_tokens.extend(text_tokens)
else:
label_tokens.extend([-100 for _ in text_tokens])
next_id = turn_id + 1
if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
input_tokens.extend(postfix_tokens)
else:
postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
input_tokens.extend(postfix_tokens)
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
label_tokens.extend(postfix_tokens)
else:
label_tokens.extend([-100 for _ in postfix_tokens])
return input_tokens, label_tokens, audio_contents, speaker_id
except Exception as e:
print(f"Error in prepare_chatml_sample: {str(e)}")
print(f"Sample data: {json.dumps(sample, indent=2)}")
return None, None, None, None
def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
"""Extract the generation prompt and reference answer from the input tokens.
For example:
Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
<|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
-->
Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
<|start_header_id|>assistant<|end_header_id|>\n\n',
Reference = 'At first they went by quick, too quick to even get.'
Args:
input_tokens: The input tokens.
audio_contents: The audio contents.
tokenizer: The tokenizer to use for decoding the text.
Returns:
prompt_tokens: The tokens for the prompt.
reference_answer: The reference answer.
num_audios_in_reference: The number of audios in the reference answer.
"""
input_text = tokenizer.decode(input_tokens)
generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
postfix = "<|eot_id|>"
assert generation_prefix in input_text
generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
generation_prompt = input_text[:generation_prompt_end_loc]
reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
return tokenizer.encode(generation_prompt, add_special_tokens=False), reference_answer, num_audios_in_reference
def prepare_chatml_dataframe_single_process(df, tokenizer):
"""Prepare the ChatML DataFrame."""
ret = []
for _, row in df.iterrows():
input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
return ret
def prepare_chatml_dataframe(df, tokenizer, num_process=16):
if num_process is None:
return prepare_chatml_dataframe_single_process(df, tokenizer)
else:
num_process = max(min(len(df) // 1000, num_process), 1)
workloads = np.array_split(df, num_process)
with mp.Pool(num_process) as pool:
ret = pool.starmap(
prepare_chatml_dataframe_single_process, [(workload, tokenizer) for workload in workloads]
)
return sum(ret, [])
class DatasetInterface(ABC):
@abstractmethod
def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
"""Retrieve a dataset sample by index."""
raise NotImplementedError
class IterableDatasetInterface(ABC):
@abstractmethod
def __iter__(self) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
"""Retrieve a sample by iterating through the dataset."""
raise NotImplementedError
@dataclass
class DatasetInfo:
dataset_type: str
group_type: Optional[str] = None
mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
================================================
FILE: boson_multimodal/model/__init__.py
================================================
================================================
FILE: boson_multimodal/model/higgs_audio/__init__.py
================================================
from transformers import AutoConfig, AutoModel
from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
from .modeling_higgs_audio import HiggsAudioModel
AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
AutoConfig.register("higgs_audio", HiggsAudioConfig)
AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
================================================
FILE: boson_multimodal/model/higgs_audio/audio_head.py
================================================
"""Projector that maps hidden states from the LLM component to multimodal logits."""
import torch
from torch import nn
from dataclasses import dataclass
from typing import Optional, Tuple
from .common import HiggsAudioPreTrainedModel
from .configuration_higgs_audio import HiggsAudioConfig
@dataclass
class HiggsAudioDecoderLayerOutput:
logits: torch.FloatTensor
audio_logits: torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
past_key_values: Opti
gitextract_40a6xiei/
├── .github/
│ └── workflows/
│ └── test.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── SUPPORT_GUIDELINES.md
├── boson_multimodal/
│ ├── __init__.py
│ ├── audio_processing/
│ │ ├── LICENSE
│ │ ├── descriptaudiocodec/
│ │ │ ├── __init__.py
│ │ │ └── dac/
│ │ │ ├── model/
│ │ │ │ ├── base.py
│ │ │ │ └── dac.py
│ │ │ └── nn/
│ │ │ ├── layers.py
│ │ │ └── quantize.py
│ │ ├── higgs_audio_tokenizer.py
│ │ ├── quantization/
│ │ │ ├── __init__.py
│ │ │ ├── ac.py
│ │ │ ├── core_vq.py
│ │ │ ├── core_vq_lsx_version.py
│ │ │ ├── ddp_utils.py
│ │ │ ├── distrib.py
│ │ │ └── vq.py
│ │ └── semantic_module.py
│ ├── constants.py
│ ├── data_collator/
│ │ ├── __init__.py
│ │ └── higgs_audio_collator.py
│ ├── data_types.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ └── chatml_dataset.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── higgs_audio/
│ │ ├── __init__.py
│ │ ├── audio_head.py
│ │ ├── common.py
│ │ ├── configuration_higgs_audio.py
│ │ ├── cuda_graph_runner.py
│ │ ├── custom_modules.py
│ │ ├── modeling_higgs_audio.py
│ │ └── utils.py
│ └── serve/
│ ├── serve_engine.py
│ └── utils.py
├── examples/
│ ├── README.md
│ ├── generation.py
│ ├── scene_prompts/
│ │ ├── quiet_indoor.txt
│ │ └── reading_blog.txt
│ ├── serve_engine/
│ │ ├── README.md
│ │ ├── input_samples.py
│ │ └── run_hf_example.py
│ ├── transcript/
│ │ ├── multi_speaker/
│ │ │ ├── en_argument.txt
│ │ │ └── en_higgs.txt
│ │ └── single_speaker/
│ │ ├── en_basic.txt
│ │ ├── en_dl.txt
│ │ ├── en_higgs_audio_blog.md
│ │ ├── experimental/
│ │ │ ├── en_bgm.txt
│ │ │ └── en_humming.txt
│ │ └── zh_ai.txt
│ ├── vllm/
│ │ ├── README.md
│ │ └── run_chat_completion.py
│ └── voice_prompts/
│ ├── belinda.txt
│ ├── bigbang_amy.txt
│ ├── bigbang_sheldon.txt
│ ├── broom_salesman.txt
│ ├── chadwick.txt
│ ├── en_man.txt
│ ├── en_woman.txt
│ ├── fiftyshades_anna.txt
│ ├── mabaoguo.txt
│ ├── mabel.txt
│ ├── profile.yaml
│ ├── shrek_donkey.txt
│ ├── shrek_donkey_es.txt
│ ├── shrek_fiona.txt
│ ├── shrek_shrek.txt
│ ├── vex.txt
│ └── zh_man_sichuan.txt
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
└── tech_blogs/
├── ARCHITECTURE_BLOG.md
└── TOKENIZER_BLOG.md
SYMBOL INDEX (374 symbols across 28 files)
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py
class DACFile (line 16) | class DACFile:
method save (line 28) | def save(self, path):
method load (line 47) | def load(cls, path):
class CodecMixin (line 55) | class CodecMixin:
method padding (line 57) | def padding(self):
method padding (line 63) | def padding(self, value):
method get_delay (line 78) | def get_delay(self):
method get_output_length (line 104) | def get_output_length(self, input_length):
method compress (line 122) | def compress(
method decompress (line 230) | def decompress(
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py
function init_weights (line 18) | def init_weights(m):
class ResidualUnit (line 24) | class ResidualUnit(nn.Module):
method __init__ (line 25) | def __init__(self, dim: int = 16, dilation: int = 1):
method forward (line 35) | def forward(self, x):
class EncoderBlock (line 43) | class EncoderBlock(nn.Module):
method __init__ (line 44) | def __init__(self, dim: int = 16, stride: int = 1):
method forward (line 60) | def forward(self, x):
class Encoder (line 64) | class Encoder(nn.Module):
method __init__ (line 65) | def __init__(
method forward (line 90) | def forward(self, x):
class DecoderBlock (line 94) | class DecoderBlock(nn.Module):
method __init__ (line 95) | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: i...
method forward (line 112) | def forward(self, x):
class Decoder (line 116) | class Decoder(nn.Module):
method __init__ (line 117) | def __init__(
method forward (line 148) | def forward(self, x):
class DAC (line 152) | class DAC(BaseModel, CodecMixin):
method __init__ (line 153) | def __init__(
method preprocess (line 203) | def preprocess(self, audio_data, sample_rate):
method encode (line 214) | def encode(
method decode (line 252) | def decode(self, z: torch.Tensor):
method forward (line 271) | def forward(
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py
function WNConv1d (line 9) | def WNConv1d(*args, **kwargs):
function WNConvTranspose1d (line 13) | def WNConvTranspose1d(*args, **kwargs):
function snake (line 19) | def snake(x, alpha):
class Snake1d (line 27) | class Snake1d(nn.Module):
method __init__ (line 28) | def __init__(self, channels):
method forward (line 32) | def forward(self, x):
FILE: boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py
class VectorQuantize (line 13) | class VectorQuantize(nn.Module):
method __init__ (line 25) | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: i...
method forward (line 34) | def forward(self, z):
method embed_code (line 70) | def embed_code(self, embed_id):
method decode_code (line 73) | def decode_code(self, embed_id):
method decode_latents (line 76) | def decode_latents(self, latents):
class ResidualVectorQuantize (line 95) | class ResidualVectorQuantize(nn.Module):
method __init__ (line 101) | def __init__(
method forward (line 122) | def forward(self, z, n_quantizers: int = None):
method from_codes (line 191) | def from_codes(self, codes: torch.Tensor):
method from_latents (line 213) | def from_latents(self, latents: torch.Tensor):
FILE: boson_multimodal/audio_processing/higgs_audio_tokenizer.py
class EncodedResult (line 24) | class EncodedResult:
method __init__ (line 25) | def __init__(self, audio_codes):
class HiggsAudioFeatureExtractor (line 29) | class HiggsAudioFeatureExtractor(nn.Module):
method __init__ (line 30) | def __init__(self, sampling_rate=16000):
method forward (line 34) | def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
class HiggsAudioTokenizer (line 43) | class HiggsAudioTokenizer(nn.Module):
method __init__ (line 44) | def __init__(
method tps (line 138) | def tps(self):
method sampling_rate (line 142) | def sampling_rate(self):
method num_codebooks (line 146) | def num_codebooks(self):
method codebook_size (line 150) | def codebook_size(self):
method get_last_layer (line 153) | def get_last_layer(self):
method calculate_rec_loss (line 156) | def calculate_rec_loss(self, rec, target):
method get_regress_target (line 164) | def get_regress_target(self, x):
method forward (line 209) | def forward(self, x: torch.Tensor, bw: int):
method encode (line 237) | def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, ...
method _xcodec_encode (line 263) | def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = N...
method decode (line 296) | def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
function load_higgs_audio_tokenizer (line 312) | def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
FILE: boson_multimodal/audio_processing/quantization/ac.py
function build_stable_quantized_cdf (line 18) | def build_stable_quantized_cdf(
class ArithmeticCoder (line 56) | class ArithmeticCoder:
method __init__ (line 96) | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
method delta (line 107) | def delta(self) -> int:
method _flush_common_prefix (line 111) | def _flush_common_prefix(self):
method push (line 130) | def push(self, symbol: int, quantized_cdf: torch.Tensor):
method flush (line 160) | def flush(self):
class ArithmeticDecoder (line 169) | class ArithmeticDecoder:
method __init__ (line 185) | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
method delta (line 198) | def delta(self) -> int:
method _flush_common_prefix (line 201) | def _flush_common_prefix(self):
method pull (line 217) | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
function test (line 263) | def test():
FILE: boson_multimodal/audio_processing/quantization/core_vq.py
function default (line 44) | def default(val: tp.Any, d: tp.Any) -> tp.Any:
function ema_inplace (line 48) | def ema_inplace(moving_avg, new, decay: float):
function laplace_smoothing (line 52) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
function uniform_init (line 56) | def uniform_init(*shape: int):
function sample_vectors (line 62) | def sample_vectors(samples, num: int):
function kmeans (line 73) | def kmeans(samples, num_clusters: int, num_iters: int = 10):
class EuclideanCodebook (line 96) | class EuclideanCodebook(nn.Module):
method __init__ (line 112) | def __init__(
method init_embed_ (line 139) | def init_embed_(self, data):
method replace_ (line 151) | def replace_(self, samples, mask):
method expire_codes_ (line 155) | def expire_codes_(self, batch_samples):
method preprocess (line 167) | def preprocess(self, x):
method quantize (line 171) | def quantize(self, x):
method postprocess_emb (line 177) | def postprocess_emb(self, embed_ind, shape):
method dequantize (line 180) | def dequantize(self, embed_ind):
method encode (line 184) | def encode(self, x):
method decode (line 194) | def decode(self, embed_ind):
method forward (line 198) | def forward(self, x):
class VectorQuantization (line 225) | class VectorQuantization(nn.Module):
method __init__ (line 242) | def __init__(
method codebook (line 276) | def codebook(self):
method encode (line 279) | def encode(self, x):
method decode (line 285) | def decode(self, embed_ind):
method forward (line 291) | def forward(self, x):
class ResidualVectorQuantization (line 313) | class ResidualVectorQuantization(nn.Module):
method __init__ (line 318) | def __init__(self, *, num_quantizers, **kwargs):
method forward (line 322) | def forward(self, x, n_q: tp.Optional[int] = None):
method encode (line 342) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor...
method decode (line 354) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
FILE: boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py
function default (line 54) | def default(val: tp.Any, d: tp.Any) -> tp.Any:
function ema_inplace (line 58) | def ema_inplace(moving_avg, new, decay: float):
function laplace_smoothing (line 62) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
function uniform_init (line 66) | def uniform_init(*shape: int):
function sample_vectors (line 72) | def sample_vectors(samples, num: int):
function kmeans (line 83) | def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_us...
class EuclideanCodebook (line 132) | class EuclideanCodebook(nn.Module):
method __init__ (line 148) | def __init__(
method init_embed_ (line 179) | def init_embed_(self, data):
method replace_ (line 200) | def replace_(self, samples, mask):
method expire_codes_ (line 204) | def expire_codes_(self, batch_samples):
method preprocess (line 221) | def preprocess(self, x):
method quantize (line 225) | def quantize(self, x):
method postprocess_emb (line 231) | def postprocess_emb(self, embed_ind, shape):
method dequantize (line 234) | def dequantize(self, embed_ind):
method encode (line 238) | def encode(self, x):
method decode (line 248) | def decode(self, embed_ind):
method forward (line 252) | def forward(self, x):
class VectorQuantization (line 290) | class VectorQuantization(nn.Module):
method __init__ (line 307) | def __init__(
method codebook (line 341) | def codebook(self):
method encode (line 344) | def encode(self, x):
method decode (line 350) | def decode(self, embed_ind):
method forward (line 356) | def forward(self, x):
class ResidualVectorQuantization (line 378) | class ResidualVectorQuantization(nn.Module):
method __init__ (line 383) | def __init__(self, *, num_quantizers, **kwargs):
method forward (line 387) | def forward(self, x, n_q: tp.Optional[int] = None):
method encode (line 407) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor...
method decode (line 419) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
FILE: boson_multimodal/audio_processing/quantization/ddp_utils.py
function set_random_seed (line 17) | def set_random_seed(seed):
function is_logging_process (line 24) | def is_logging_process():
function get_logger (line 28) | def get_logger(cfg, name=None):
class SyncFunction (line 36) | class SyncFunction(torch.autograd.Function):
method forward (line 39) | def forward(ctx, tensor):
method backward (line 50) | def backward(ctx, grad_output):
function get_timestamp (line 59) | def get_timestamp():
function get_commit_hash (line 63) | def get_commit_hash():
class DDP (line 68) | class DDP(DistributedDataParallel):
method forward (line 73) | def forward(self, *inputs, **kwargs): # pragma: no cover
FILE: boson_multimodal/audio_processing/quantization/distrib.py
function rank (line 14) | def rank():
function world_size (line 21) | def world_size():
function is_distributed (line 28) | def is_distributed():
function all_reduce (line 32) | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
function _is_complex_or_float (line 37) | def _is_complex_or_float(tensor):
function _check_number_of_params (line 41) | def _check_number_of_params(params: tp.List[torch.Tensor]):
function broadcast_tensors (line 57) | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
function sync_buffer (line 73) | def sync_buffer(buffers, average=True):
function sync_grad (line 93) | def sync_grad(params):
function average_metrics (line 111) | def average_metrics(metrics: tp.Dict[str, float], count=1.0):
FILE: boson_multimodal/audio_processing/quantization/vq.py
class QuantizedResult (line 21) | class QuantizedResult:
class ResidualVectorQuantizer (line 29) | class ResidualVectorQuantizer(nn.Module):
method __init__ (line 43) | def __init__(
method forward (line 74) | def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Opt...
method get_num_quantizers_for_bandwidth (line 92) | def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth...
method get_bandwidth_per_quantizer (line 100) | def get_bandwidth_per_quantizer(self, sample_rate: int):
method encode (line 104) | def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Opti...
method decode (line 113) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
FILE: boson_multimodal/audio_processing/semantic_module.py
class Conv1d1x1 (line 9) | class Conv1d1x1(nn.Conv1d):
method __init__ (line 12) | def __init__(self, in_channels, out_channels, bias=True):
class Conv1d (line 16) | class Conv1d(nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 46) | def forward(self, x):
class ResidualUnit (line 57) | class ResidualUnit(nn.Module):
method __init__ (line 58) | def __init__(
method forward (line 80) | def forward(self, x):
class ConvTranspose1d (line 86) | class ConvTranspose1d(nn.Module):
method __init__ (line 87) | def __init__(
method forward (line 114) | def forward(self, x):
class EncoderBlock (line 125) | class EncoderBlock(nn.Module):
method __init__ (line 126) | def __init__(
method forward (line 143) | def forward(self, x):
class Encoder (line 150) | class Encoder(nn.Module):
method __init__ (line 151) | def __init__(
method forward (line 186) | def forward(self, x):
class DecoderBlock (line 193) | class DecoderBlock(nn.Module):
method __init__ (line 196) | def __init__(
method forward (line 225) | def forward(self, x):
class Decoder (line 232) | class Decoder(nn.Module):
method __init__ (line 233) | def __init__(
method forward (line 277) | def forward(self, z):
FILE: boson_multimodal/data_collator/higgs_audio_collator.py
function _ceil_to_nearest (line 15) | def _ceil_to_nearest(n, round_to):
function _ceil_to_next_power_of_two (line 19) | def _ceil_to_next_power_of_two(self, x):
class HiggsAudioBatchInput (line 24) | class HiggsAudioBatchInput:
class HiggsAudioSampleCollator (line 47) | class HiggsAudioSampleCollator:
method __init__ (line 68) | def __init__(
method _process_and_duplicate_audio_tokens (line 109) | def _process_and_duplicate_audio_tokens(
method __call__ (line 151) | def __call__(self, batch: List[ChatMLDatasetSample]):
FILE: boson_multimodal/data_types.py
class AudioContent (line 8) | class AudioContent:
class TextContent (line 19) | class TextContent:
class Message (line 25) | class Message:
class ChatMLSample (line 32) | class ChatMLSample:
FILE: boson_multimodal/dataset/chatml_dataset.py
class ChatMLDatasetSample (line 24) | class ChatMLDatasetSample:
method num_audios (line 48) | def num_audios(self):
method get_audio_codes (line 51) | def get_audio_codes(self, idx):
method get_audio_codes_labels (line 60) | def get_audio_codes_labels(self, idx):
method get_wv (line 71) | def get_wv(self, idx):
method cal_num_tokens (line 80) | def cal_num_tokens(
method merge (line 129) | def merge(
class RankedChatMLDatasetSampleTuple (line 277) | class RankedChatMLDatasetSampleTuple:
method max_score_sample (line 281) | def max_score_sample(self) -> ChatMLDatasetSample:
method min_score_sample (line 286) | def min_score_sample(self) -> ChatMLDatasetSample:
class ChatMLDatasetStorageSample (line 293) | class ChatMLDatasetStorageSample:
function prepare_chatml_sample (line 307) | def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
function extract_generation_prompt_from_input_tokens (line 455) | def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
function prepare_chatml_dataframe_single_process (line 493) | def prepare_chatml_dataframe_single_process(df, tokenizer):
function prepare_chatml_dataframe (line 502) | def prepare_chatml_dataframe(df, tokenizer, num_process=16):
class DatasetInterface (line 515) | class DatasetInterface(ABC):
method __getitem__ (line 517) | def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChat...
class IterableDatasetInterface (line 522) | class IterableDatasetInterface(ABC):
method __iter__ (line 524) | def __iter__(self) -> Union["ChatMLDatasetSample", "RankedChatMLDatase...
class DatasetInfo (line 530) | class DatasetInfo:
FILE: boson_multimodal/model/higgs_audio/audio_head.py
class HiggsAudioDecoderLayerOutput (line 14) | class HiggsAudioDecoderLayerOutput:
class HiggsAudioDecoderProjector (line 21) | class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
method __init__ (line 29) | def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] ...
method forward (line 39) | def forward(
FILE: boson_multimodal/model/higgs_audio/common.py
class HiggsAudioPreTrainedModel (line 8) | class HiggsAudioPreTrainedModel(PreTrainedModel):
method _init_weights (line 17) | def _init_weights(self, module):
FILE: boson_multimodal/model/higgs_audio/configuration_higgs_audio.py
class HiggsAudioEncoderConfig (line 5) | class HiggsAudioEncoderConfig(PretrainedConfig):
method __init__ (line 10) | def __init__(
class HiggsAudioConfig (line 47) | class HiggsAudioConfig(PretrainedConfig):
method __init__ (line 118) | def __init__(
FILE: boson_multimodal/model/higgs_audio/cuda_graph_runner.py
class CUDAGraphRunner (line 12) | class CUDAGraphRunner(nn.Module):
method __init__ (line 13) | def __init__(self, model):
method graph (line 23) | def graph(self):
method capture (line 27) | def capture(
method forward (line 106) | def forward(
FILE: boson_multimodal/model/higgs_audio/custom_modules.py
class PartiallyFrozenEmbedding (line 5) | class PartiallyFrozenEmbedding(nn.Module):
method __init__ (line 14) | def __init__(self, original_embedding: nn.Embedding, freeze_until_idx:...
method forward (line 46) | def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
method to_unsplit (line 81) | def to_unsplit(self) -> nn.Embedding:
class PartiallyFrozenLinear (line 96) | class PartiallyFrozenLinear(nn.Module):
method __init__ (line 99) | def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
method forward (line 135) | def forward(self, input_tensor):
method to_unsplit (line 141) | def to_unsplit(self) -> nn.Linear:
FILE: boson_multimodal/model/higgs_audio/modeling_higgs_audio.py
class GenerationMode (line 45) | class GenerationMode(Enum):
function _whisper_encoder_zero_shape_forward (line 53) | def _whisper_encoder_zero_shape_forward(whisper_encoder, *args, **kwargs):
function _prepare_4d_causal_attention_mask_with_cache_position (line 110) | def _prepare_4d_causal_attention_mask_with_cache_position(
class HiggsAudioFeatureProjector (line 163) | class HiggsAudioFeatureProjector(nn.Module):
method __init__ (line 166) | def __init__(self, config: HiggsAudioConfig):
method forward (line 170) | def forward(self, audio_features):
class HiggsAudioEncoder (line 177) | class HiggsAudioEncoder(HiggsAudioPreTrainedModel):
method __init__ (line 191) | def __init__(self, config: HiggsAudioEncoderConfig):
method _freeze_parameters (line 218) | def _freeze_parameters(self):
method get_input_embeddings (line 223) | def get_input_embeddings(self) -> nn.Module:
method set_input_embeddings (line 226) | def set_input_embeddings(self, value: nn.Module):
method forward (line 229) | def forward(
method _get_feat_extract_output_lengths (line 353) | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTe...
class HiggsAudioDualFFNDecoderLayer (line 362) | class HiggsAudioDualFFNDecoderLayer(nn.Module):
method __init__ (line 397) | def __init__(
method forward (line 430) | def forward(
class HiggsAudioModelOutputWithPast (line 732) | class HiggsAudioModelOutputWithPast(ModelOutput):
class HiggsAudioGenerationOutput (line 752) | class HiggsAudioGenerationOutput(ModelOutput):
class HiggsAudioModel (line 794) | class HiggsAudioModel(HiggsAudioPreTrainedModel, GenerationMixin):
method __init__ (line 815) | def __init__(self, config: HiggsAudioConfig):
method set_num_activation_checkpointing_layers (line 907) | def set_num_activation_checkpointing_layers(self, num_layers):
method set_delay_pattern (line 910) | def set_delay_pattern(self):
method set_audio_special_tokens (line 914) | def set_audio_special_tokens(self, tokenizer: AutoTokenizer):
method _embed_audio_ids (line 918) | def _embed_audio_ids(self, audio_ids):
method _apply_audio_tower (line 939) | def _apply_audio_tower(self, audio_features, audio_feature_attention_m...
method _update_causal_mask (line 985) | def _update_causal_mask(
method _prepare_all_static_kv_cache_masks (line 1051) | def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_...
method _forward_core (line 1078) | def _forward_core(
method forward (line 1142) | def forward(
method _update_model_kwargs_for_generation (line 1420) | def _update_model_kwargs_for_generation(
method _copy_kv_cache (line 1454) | def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
method _prepare_kv_cache (line 1467) | def _prepare_kv_cache(
method _sample_audio_tokens (line 1490) | def _sample_audio_tokens(
method _sample_text_tokens (line 1577) | def _sample_text_tokens(
method _sample (line 1624) | def _sample(
method generate (line 1933) | def generate(
method parameter_count_per_component (line 2025) | def parameter_count_per_component(self):
method set_skip_audio_tower (line 2126) | def set_skip_audio_tower(self):
method set_encode_audio_in_tokens (line 2130) | def set_encode_audio_in_tokens(self):
method freeze_audio_tower (line 2133) | def freeze_audio_tower(self):
method freeze_audio_encoder_proj (line 2138) | def freeze_audio_encoder_proj(self):
method freeze_llm (line 2143) | def freeze_llm(self, freeze_embed=True, freeze_embed_until_idx: Option...
method freeze_text_head (line 2173) | def freeze_text_head(self, freeze_text_head_until_idx: Optional[int] =...
method merge_weights_from_checkpoint (line 2186) | def merge_weights_from_checkpoint(cls, checkpoint_dir: str, merged_out...
method capture_model (line 2242) | def capture_model(self, past_key_values: list[Union[Cache, List[torch....
FILE: boson_multimodal/model/higgs_audio/utils.py
function _ceil_to_nearest (line 15) | def _ceil_to_nearest(n, round_to):
function count_parameters (line 19) | def count_parameters(model, trainable_only=True):
function build_delay_pattern_mask (line 26) | def build_delay_pattern_mask(
function revert_delay_pattern (line 91) | def revert_delay_pattern(data):
function merge_input_ids_with_audio_features (line 110) | def merge_input_ids_with_audio_features(
function is_deepspeed_ulysses_enabled (line 436) | def is_deepspeed_ulysses_enabled():
function support_deepspeed_ulysses (line 444) | def support_deepspeed_ulysses(module):
function deepspeed_ulysses_attention (line 479) | def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
function deepspeed_ulysses_rope (line 510) | def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
function _gather_tensors (line 530) | def _gather_tensors(input_, group=None):
function _scatter_tensors (line 548) | def _scatter_tensors(input_, group=None):
class _GatherTensors (line 557) | class _GatherTensors(torch.autograd.Function):
method symbolic (line 561) | def symbolic(graph, input_, group):
method forward (line 565) | def forward(ctx, input_, group):
method backward (line 570) | def backward(ctx, grad_output):
function all_gather_tensors (line 574) | def all_gather_tensors(input_, size=None, dim=0, group=None):
function get_sequence_data_parallel_world_size (line 591) | def get_sequence_data_parallel_world_size():
function get_sequence_data_parallel_rank (line 595) | def get_sequence_data_parallel_rank():
function get_sequence_data_parallel_group (line 599) | def get_sequence_data_parallel_group():
function _gather_tokens (line 609) | def _gather_tokens(input_, dim=0, group=None):
function _drop_tokens (line 632) | def _drop_tokens(input_, dim=0, group=None):
class _DropTokens (line 646) | class _DropTokens(torch.autograd.Function):
method symbolic (line 650) | def symbolic(graph, input_, dim, group, grad_scale):
method forward (line 654) | def forward(ctx, input_, dim, group, grad_scale):
method backward (line 661) | def backward(ctx, grad_output):
class _GatherTokens (line 668) | class _GatherTokens(torch.autograd.Function):
method symbolic (line 672) | def symbolic(graph, input_, dim, group, grad_scale):
method forward (line 676) | def forward(ctx, input_, dim, group, grad_scale):
method backward (line 683) | def backward(ctx, grad_output):
function drop_tokens (line 690) | def drop_tokens(input_, dim=0, group=None, grad_scale=1):
function gather_tokens (line 697) | def gather_tokens(input_, dim=0, group=None, grad_scale=1):
function sequence_chunking_per_rank (line 704) | def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
function disable_deepspeed_ulysses (line 740) | def disable_deepspeed_ulysses():
FILE: boson_multimodal/serve/serve_engine.py
class HiggsAudioStreamerDelta (line 27) | class HiggsAudioStreamerDelta:
class AsyncHiggsAudioStreamer (line 36) | class AsyncHiggsAudioStreamer(BaseStreamer):
method __init__ (line 76) | def __init__(
method put (line 100) | def put(self, value: torch.Tensor):
method end (line 128) | def end(self):
method __aiter__ (line 134) | def __aiter__(self):
method __anext__ (line 137) | async def __anext__(self):
class AsyncStoppingCriteria (line 153) | class AsyncStoppingCriteria(StoppingCriteria):
method __init__ (line 161) | def __init__(self, stop_signal: threading.Event):
method __call__ (line 164) | def __call__(self, input_ids, scores, **kwargs) -> bool:
class HiggsAudioResponse (line 172) | class HiggsAudioResponse:
class HiggsAudioServeEngine (line 181) | class HiggsAudioServeEngine:
method __init__ (line 182) | def __init__(
method _prepare_inputs (line 280) | def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_ge...
method _prepare_kv_caches (line 337) | def _prepare_kv_caches(self):
method generate (line 341) | def generate(
method generate_delta_stream (line 428) | async def generate_delta_stream(
FILE: boson_multimodal/serve/utils.py
function random_uuid (line 15) | def random_uuid() -> str:
function async_generator_wrap (line 19) | async def async_generator_wrap(first_element, gen: AsyncGenerator):
function encode_base64_content_from_file (line 27) | def encode_base64_content_from_file(file_path: str) -> str:
function pcm16_to_target_format (line 35) | def pcm16_to_target_format(
function contains_chinese (line 63) | def contains_chinese(text: str):
function replace_blank (line 68) | def replace_blank(text: str):
function replace_corner_mark (line 79) | def replace_corner_mark(text: str):
function remove_bracket (line 86) | def remove_bracket(text: str):
function split_paragraph (line 98) | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, toke...
function is_only_punctuation (line 153) | def is_only_punctuation(text: str):
function spell_out_number (line 160) | def spell_out_number(text: str, inflect_parser):
function remove_emoji (line 179) | def remove_emoji(text: str):
function remove_repeated_punctuations (line 197) | def remove_repeated_punctuations(text, punctuations):
function full_to_half_width (line 204) | def full_to_half_width(text: str) -> str:
function split_interleaved_delayed_audios (line 212) | def split_interleaved_delayed_audios(
FILE: examples/generation.py
function normalize_chinese_punctuation (line 44) | def normalize_chinese_punctuation(text):
function prepare_chunk_text (line 83) | def prepare_chunk_text(
function _build_system_message_with_audio_prompt (line 160) | def _build_system_message_with_audio_prompt(system_message):
class HiggsAudioModelClient (line 178) | class HiggsAudioModelClient:
method __init__ (line 179) | def __init__(
method _init_static_kv_cache (line 242) | def _init_static_kv_cache(self):
method _prepare_kv_caches (line 263) | def _prepare_kv_caches(self):
method generate (line 268) | def generate(
function prepare_generation_context (line 387) | def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_sys...
function main (line 625) | def main(
FILE: examples/serve_engine/input_samples.py
function encode_base64_content_from_file (line 6) | def encode_base64_content_from_file(file_path: str) -> str:
function get_interleaved_dialogue_input_sample (line 14) | def get_interleaved_dialogue_input_sample():
function get_zero_shot_input_sample (line 38) | def get_zero_shot_input_sample():
function get_voice_clone_input_sample (line 59) | def get_voice_clone_input_sample():
FILE: examples/serve_engine/run_hf_example.py
function main (line 18) | def main(example: str):
FILE: examples/vllm/run_chat_completion.py
function encode_base64_content_from_file (line 26) | def encode_base64_content_from_file(file_path: str) -> str:
function run_smart_voice (line 34) | def run_smart_voice() -> None:
function run_voice_clone (line 63) | def run_voice_clone(stream: bool = False) -> None:
function run_generate_multispeaker (line 131) | def run_generate_multispeaker(stream: bool = False) -> None:
function main (line 184) | def main(args) -> None:
Condensed preview — 79 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (494K chars).
[
{
"path": ".github/workflows/test.yml",
"chars": 436,
"preview": "name: Unit Test\non:\n push:\n branches: [ main ]\n pull_request:\n branches: [ main ]\n\njobs:\n lint:\n name: Lint\n"
},
{
"path": ".gitignore",
"chars": 3596,
"preview": "# Temporary files generated in training\ndpo_samples*\nscoring_results\nresults/\nhf_slurm_logs/\nslurm_results/\nenroot_image"
},
{
"path": ".gitmodules",
"chars": 0,
"preview": ""
},
{
"path": "LICENSE",
"chars": 10141,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 16581,
"preview": "<h1 align=\"center\">Higgs Audio: Redefining Expressiveness in Audio Generation</h1>\n\n<div align=\"center\" style=\"display: "
},
{
"path": "SUPPORT_GUIDELINES.md",
"chars": 1558,
"preview": "# Contribution & Support Guidelines\n\nThank you for your interest in this project! Before opening an issue, please take a"
},
{
"path": "boson_multimodal/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "boson_multimodal/audio_processing/LICENSE",
"chars": 2374,
"preview": "Third-Party License Attribution for Audio Processing Module\n===========================================================\n"
},
{
"path": "boson_multimodal/audio_processing/descriptaudiocodec/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py",
"chars": 9286,
"preview": "import math\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Union\n\nimport numpy as np\nimpo"
},
{
"path": "boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py",
"chars": 11197,
"preview": "import math\nfrom typing import List\nfrom typing import Union\n\nimport numpy as np\nimport torch\nfrom audiotools import Aud"
},
{
"path": "boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py",
"chars": 809,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom "
},
{
"path": "boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py",
"chars": 8906,
"preview": "from typing import Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ein"
},
{
"path": "boson_multimodal/audio_processing/higgs_audio_tokenizer.py",
"chars": 12458,
"preview": "# Based on code from: https://github.com/zhenye234/xcodec\n# Licensed under MIT License\n# Modifications by BosonAI\n\nimpor"
},
{
"path": "boson_multimodal/audio_processing/quantization/__init__.py",
"chars": 271,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "boson_multimodal/audio_processing/quantization/ac.py",
"chars": 12863,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "boson_multimodal/audio_processing/quantization/core_vq.py",
"chars": 13149,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py",
"chars": 15846,
"preview": "# Copyright (c)\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of "
},
{
"path": "boson_multimodal/audio_processing/quantization/ddp_utils.py",
"chars": 9039,
"preview": "import logging\nimport random\nimport subprocess\nfrom datetime import datetime\n\nimport numpy as np\nimport torch\nimport tor"
},
{
"path": "boson_multimodal/audio_processing/quantization/distrib.py",
"chars": 4044,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "boson_multimodal/audio_processing/quantization/vq.py",
"chars": 4694,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "boson_multimodal/audio_processing/semantic_module.py",
"chars": 8305,
"preview": "# Based on code from: https://github.com/zhenye234/xcodec\n# Licensed under MIT License\n# Modifications by BosonAI\n\nimpor"
},
{
"path": "boson_multimodal/constants.py",
"chars": 93,
"preview": "AUDIO_IN_TOKEN = \"<|AUDIO|>\"\nAUDIO_OUT_TOKEN = \"<|AUDIO_OUT|>\"\nEOS_TOKEN = \"<|end_of_text|>\"\n"
},
{
"path": "boson_multimodal/data_collator/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "boson_multimodal/data_collator/higgs_audio_collator.py",
"chars": 24373,
"preview": "import librosa\nimport torch\nimport torch.nn.functional as F\nimport math\nfrom typing import List, Tuple\n\nfrom dataclasses"
},
{
"path": "boson_multimodal/data_types.py",
"chars": 914,
"preview": "\"\"\"Basic data types for multimodal ChatML format.\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Dict, List, O"
},
{
"path": "boson_multimodal/dataset/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "boson_multimodal/dataset/chatml_dataset.py",
"chars": 23529,
"preview": "import dacite\nimport pandas as pd\nimport torch\nimport json\n\nimport numpy as np\nimport multiprocessing as mp\n\nfrom datacl"
},
{
"path": "boson_multimodal/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "boson_multimodal/model/higgs_audio/__init__.py",
"chars": 356,
"preview": "from transformers import AutoConfig, AutoModel\n\nfrom .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncod"
},
{
"path": "boson_multimodal/model/higgs_audio/audio_head.py",
"chars": 5304,
"preview": "\"\"\"Projector that maps hidden states from the LLM component to multimodal logits.\"\"\"\n\nimport torch\nfrom torch import nn\n"
},
{
"path": "boson_multimodal/model/higgs_audio/common.py",
"chars": 1003,
"preview": "from torch import nn\n\nfrom transformers.modeling_utils import PreTrainedModel\n\nfrom .configuration_higgs_audio import Hi"
},
{
"path": "boson_multimodal/model/higgs_audio/configuration_higgs_audio.py",
"chars": 11997,
"preview": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.models.auto import CONFIG_MAPPING\n\n\nclas"
},
{
"path": "boson_multimodal/model/higgs_audio/cuda_graph_runner.py",
"chars": 5106,
"preview": "import torch\nimport torch.nn as nn\nfrom typing import Optional, List, Dict, Tuple, Union\nimport gc\n\nfrom transformers.ca"
},
{
"path": "boson_multimodal/model/higgs_audio/custom_modules.py",
"chars": 6186,
"preview": "import torch\nimport torch.nn as nn\n\n\nclass PartiallyFrozenEmbedding(nn.Module):\n \"\"\"Split an existing `nn.Embedding` "
},
{
"path": "boson_multimodal/model/higgs_audio/modeling_higgs_audio.py",
"chars": 113442,
"preview": "\"\"\"Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio.\"\"\"\n\nimport"
},
{
"path": "boson_multimodal/model/higgs_audio/utils.py",
"chars": 30556,
"preview": "import contextlib\nfrom contextlib import contextmanager\nfrom functools import wraps\nimport torch\nfrom transformers.integ"
},
{
"path": "boson_multimodal/serve/serve_engine.py",
"chars": 19997,
"preview": "import asyncio\nimport base64\nimport torch\nimport numpy as np\nfrom io import BytesIO\nfrom dataclasses import dataclass\nfr"
},
{
"path": "boson_multimodal/serve/utils.py",
"chars": 7417,
"preview": "import uuid\nimport base64\nimport re\nimport regex\nfrom typing import AsyncGenerator, Union\nimport io\nfrom pydub import Au"
},
{
"path": "examples/README.md",
"chars": 5694,
"preview": "# Examples\n\n> [!NOTE] \n> If you do not like the audio you get, you can generate multiple times with different seeds. In"
},
{
"path": "examples/generation.py",
"chars": 29197,
"preview": "\"\"\"Example script for generating audio using HiggsAudio.\"\"\"\n\nimport click\nimport soundfile as sf\nimport langid\nimport ji"
},
{
"path": "examples/scene_prompts/quiet_indoor.txt",
"chars": 37,
"preview": "Audio is recorded from a quiet room.\n"
},
{
"path": "examples/scene_prompts/reading_blog.txt",
"chars": 458,
"preview": "In this audio, the person is reading a blog post aloud. The content is informative and engaging, with the speaker using "
},
{
"path": "examples/serve_engine/README.md",
"chars": 800,
"preview": "# Examples to use HiggsAudioServeEngine\n\nThe `run_hf_example.py` script provides three different examples for using the "
},
{
"path": "examples/serve_engine/input_samples.py",
"chars": 3330,
"preview": "import base64\nimport os\nfrom boson_multimodal.data_types import ChatMLSample, Message, AudioContent\n\n\ndef encode_base64_"
},
{
"path": "examples/serve_engine/run_hf_example.py",
"chars": 1515,
"preview": "\"\"\"Example for using HiggsAudio for generating both the transcript and audio in an interleaved manner.\"\"\"\n\nfrom boson_mu"
},
{
"path": "examples/transcript/multi_speaker/en_argument.txt",
"chars": 371,
"preview": "[SPEAKER0] I can't believe you did that without even asking me first!\n[SPEAKER1] Oh, come on! It wasn't a big deal, and "
},
{
"path": "examples/transcript/multi_speaker/en_higgs.txt",
"chars": 759,
"preview": "[SPEAKER0] You're training HiggsAudio again? Aren't you tired of staring at it all day?\n[SPEAKER1] Ha! This time, I'm tr"
},
{
"path": "examples/transcript/single_speaker/en_basic.txt",
"chars": 117,
"preview": "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.\n"
},
{
"path": "examples/transcript/single_speaker/en_dl.txt",
"chars": 781,
"preview": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\nIt’s your host, Alex, and today, we’re diving into a topic that’s bec"
},
{
"path": "examples/transcript/single_speaker/en_higgs_audio_blog.md",
"chars": 2030,
"preview": "At Boson AI, we work on making communication with AI as easy, natural and fun as talking to a human. Today, we are excit"
},
{
"path": "examples/transcript/single_speaker/experimental/en_bgm.txt",
"chars": 203,
"preview": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it’s due, s"
},
{
"path": "examples/transcript/single_speaker/experimental/en_humming.txt",
"chars": 103,
"preview": "Are you asking if I can hum a tune? Of course I can! [humming start] la la la la la [humming end] See?\n"
},
{
"path": "examples/transcript/single_speaker/zh_ai.txt",
"chars": 173,
"preview": "大家好,欢迎收听本期的跟李沐学AI。今天沐哥在忙着洗数据,所以由我,希格斯主播代替他讲这期视频。\n今天我们要聊的是一个你绝对不能忽视的话题\"多模态学习\"。\n无论你是开发者,数据科学爱好者,还是只是对人工智能感兴趣的人都一定听说过这个词。它已"
},
{
"path": "examples/vllm/README.md",
"chars": 2896,
"preview": "# Serve Higgs Audio with vLLM\n\nWe provided both OpenAI compatible chat completion and audio speech server backed by vLLM"
},
{
"path": "examples/vllm/run_chat_completion.py",
"chars": 8253,
"preview": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"An example showing how to use vLLM to serve multimodal models\nand run online in"
},
{
"path": "examples/voice_prompts/belinda.txt",
"chars": 121,
"preview": "Twas the night before my birthday. Hooray! It's almost here! It may not be a holiday, but it's the best day of the year."
},
{
"path": "examples/voice_prompts/bigbang_amy.txt",
"chars": 244,
"preview": "If that was slang, I'm unfamiliar with it. <SE>[Laughter]</SE> If it was literal, I share your aversion to soiled hosier"
},
{
"path": "examples/voice_prompts/bigbang_sheldon.txt",
"chars": 238,
"preview": "Hello, Amy Farrah Fowler. I'm sorry to inform you that you have been taken in by unsupportable mathematics designed to p"
},
{
"path": "examples/voice_prompts/broom_salesman.txt",
"chars": 265,
"preview": "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic. And the bond between you and you"
},
{
"path": "examples/voice_prompts/chadwick.txt",
"chars": 119,
"preview": "Oh dear, who left all this junk lying around? Whoops, there it goes! Mind your pointed little pink head, starfish man.\n"
},
{
"path": "examples/voice_prompts/en_man.txt",
"chars": 120,
"preview": "Maintaining your ability to learn translates into increased marketability, improved career options and higher salaries.\n"
},
{
"path": "examples/voice_prompts/en_woman.txt",
"chars": 127,
"preview": "The device would work during the day as well, if you took steps to either block direct sunlight or point it away from th"
},
{
"path": "examples/voice_prompts/fiftyshades_anna.txt",
"chars": 75,
"preview": "I'm working at the hardware store till 7. I think I'd like that too. What?\n"
},
{
"path": "examples/voice_prompts/mabaoguo.txt",
"chars": 136,
"preview": "我是浑元形意太极门掌门人马保国,刚才有个朋友问我:马老师发生什么事啦.我说怎么回事,给我发了几张截图,我一看,哦,原来是昨天,有两个年轻人,三十多岁,一个体重九十多公斤,一个体重八十多公斤.他们说,哎,有一个说是:我在健身房练功,颈椎练坏了"
},
{
"path": "examples/voice_prompts/mabel.txt",
"chars": 194,
"preview": "You do talk an awful lot about weather, did you know that? Sometimes I wonder if you're actually content to be a wizard "
},
{
"path": "examples/voice_prompts/profile.yaml",
"chars": 699,
"preview": "profiles:\n male_en: Male, American accent, modern speaking rate, moderate-pitch, friendly tone, and very clear audio.\n "
},
{
"path": "examples/voice_prompts/shrek_donkey.txt",
"chars": 360,
"preview": "And I've got a great idea, I'll stick with you. You're a mean green fighting machine, together we'll scare the spit out "
},
{
"path": "examples/voice_prompts/shrek_donkey_es.txt",
"chars": 156,
"preview": "¡Uy, guau! Eso sí que asusta. Y si el rugido no funciona, tu mal aliento seguro los desmaya. Necesitas unas pastillitas "
},
{
"path": "examples/voice_prompts/shrek_fiona.txt",
"chars": 209,
"preview": "Well, when one lives alone, one has to learn these things in case there's a... There's an arrow in your butt!\nCalm down."
},
{
"path": "examples/voice_prompts/shrek_shrek.txt",
"chars": 181,
"preview": "Well, it's no wonder you don't have any friends. Listen, little donkey, take a look at me. What am I?\nNo! I'm an ogre! Y"
},
{
"path": "examples/voice_prompts/vex.txt",
"chars": 62,
"preview": "Uhh, this is going to take forever. Why is everything so far?\n"
},
{
"path": "examples/voice_prompts/zh_man_sichuan.txt",
"chars": 37,
"preview": "对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
},
{
"path": "pyproject.toml",
"chars": 2310,
"preview": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.ruff]\nline-length = 119\ntarget-v"
},
{
"path": "requirements.txt",
"chars": 240,
"preview": "descript-audio-codec\ntorch\ntransformers>=4.45.1,<4.47.0\nlibrosa\ndacite\nboto3==1.35.36\ns3fs\ntorchvision\ntorchaudio\njson_r"
},
{
"path": "setup.cfg",
"chars": 310,
"preview": "[metadata]\nname = boson_multimodal\nauthor = Boson AI\nversion = 0.1.0\nurl = https://github.com/boson-ai/higgs-audio\ndescr"
},
{
"path": "setup.py",
"chars": 39,
"preview": "from setuptools import setup\n\n\nsetup()\n"
},
{
"path": "tech_blogs/ARCHITECTURE_BLOG.md",
"chars": 1939,
"preview": "# HiggsAudio-V2 Model Architecture\n<img src=\"../figures/higgs_audio_v2_architecture_combined.png\" width=800/>\n\n\nOur mode"
},
{
"path": "tech_blogs/TOKENIZER_BLOG.md",
"chars": 13647,
"preview": "# Higgs Audio Tokenizer\n\nIn this work, we introduce a new discretized audio tokenizer that runs at just **25 frames per "
}
]
About this extraction
This page contains the full source code of the boson-ai/higgs-audio GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 79 files (463.3 KB), approximately 108.8k tokens, and a symbol index with 374 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.