Repository: Zyphra/Zonos
Branch: main
Commit: bc40d98e1e1a
Files: 21
Total size: 104.1 KB
Directory structure:
gitextract_bh66ekbi/
├── .gitignore
├── .python-version
├── CONDITIONING_README.md
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yml
├── gradio_interface.py
├── pyproject.toml
├── sample.py
└── zonos/
├── autoencoder.py
├── backbone/
│ ├── __init__.py
│ ├── _mamba_ssm.py
│ └── _torch.py
├── codebook_pattern.py
├── conditioning.py
├── config.py
├── model.py
├── sampling.py
├── speaker_cloning.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# Misc.
.ipynb_checkpoints/
================================================
FILE: .python-version
================================================
3.12
================================================
FILE: CONDITIONING_README.md
================================================
# Conditioning explanations
Here we will list out all the conditionings the model accepts as well as a short description and some tips for optimal use. For conditionings with a learned unconditional, they can be set to that to allow the model to infer an appropriate setting.
### espeak
- **Type:** `EspeakPhonemeConditioner`
- **Description:**
Responsible for cleaning, phonemicizing, tokenizing, and embedding the text provided to the model. This is the text pre-processing pipeline. If you would like to change how a word is pronounced or enter raw phonemes you can do that here.
Supported by transformer and hybrid models.
---
### speaker
- **Type:** `PassthroughConditioner`
- **Attributes:**
- **cond_dim:** `128`
- **uncond_type:** `learned`
- **projection:** `linear`
- **Description:**
An embedded representation of the speakers voice. We use [these](https://huggingface.co/Zyphra/Zonos-v0.1-speaker-embedding) speaker embedding models. It can capture a surprising amount of detail from the reference clip and supports arbitrary length input. Try to input clean reference clips containing only speech. It can be valid to concatenate multiple clean samples from the same speaker into one long sample and may lead to better cloning. If the speaker clip is very long, it is advisable to cut out long speech-free background music segments if they exist. If the reference clip is yielding noisy outputs with denoising enabled we recommend doing source separation before cloning.
Supported by transformer and hybrid models.
---
### emotion
- **Type:** `FourierConditioner`
- **Attributes:**
- **input_dim:** `8`
- **uncond_type:** `learned`
- **Description:**
Encodes emotion in an 8D vector. Included emotions are Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral in that order. This vector tends to be entangled with various other conditioning inputs. More notably, it's entangled with text based on the text sentiment (eg. Angry texts will be more effectively conditioned to be angry, but if you try to make it sound sad it will be a lot less effective). It's also entangled with pitch standard deviation since larger values there tend to correlate to more emotional utterances. It's also heavily correlated with VQScore and DNSMOS as these conditionings favor neutral speech. It's also possible to do a form of "negative prompting" by doing CFG where the unconditional branch is set to a highly neutral emotion vector instead of the true unconditional value, doing this will exaggerate the emotions as it pushes the model away from being neutral.
Supported by transformer and hybrid models.
---
### fmax
- **Type:** `FourierConditioner`
- **Attributes:**
- **min_val:** `0`
- **max_val:** `24000`
- **uncond_type:** `learned`
- **Description:**
Specifies the max frequency of the audio. For best results select 22050 or 24000 as these correspond to 44.1 and 48KHz audio respectively. They should not be any different in terms of actual max frequency since the model's sampling rate is 44.1KHz but they represent different slices of data which lead to slightly different voicing. Selecting a lower value generally produces lower-quality results both in terms of acoustics and voicing.
For voice cloning it is recommended to use 22050.
Supported by transformer and hybrid models.
---
### pitch_std
- **Type:** `FourierConditioner`
- **Attributes:**
- **min_val:** `0`
- **max_val:** `400`
- **uncond_type:** `learned`
- **Description:**
Specifies the standard deviation of the pitch of the output audio. Wider variations of pitch tend to be more correlated with expressive speech. Good values are from 20-45 for normal speech and 60-150 for expressive speech. Higher than that generally tend to be crazier samples.
Supported by transformer and hybrid models.
---
### speaking_rate
- **Type:** `FourierConditioner`
- **Attributes:**
- **min_val:** `0`
- **max_val:** `40`
- **uncond_type:** `learned`
- **Description:**
Specifies the number of phonemes to be read per second. When entering a long text, it is advisable to adjust the speaking rate such that the number of phonemes is readable within the generation length. For example, if your generation length is 10 seconds, and your input is 300 phonemes, you would want either 30 phonemes per second (which is very very fast) or to generate a longer sample. The model's maximum is 30 seconds. Please note that unrealistic speaking rates can be OOD for the model and create undesirable effects, so at the 30-second limit, it can be better to cut the text short and do multiple generations than to feed the model the entire prompt and have an unrealistically low speaking rate.
Supported by transformer and hybrid models.
---
### language_id
- **Type:** `IntegerConditioner`
- **Attributes:**
- **min_val:** `-1`
- **max_val:** `126`
- **uncond_type:** `learned`
- **Description:**
Indicates which language the output should be in. A mapping for these values can be found in the [conditioning section](https://github.com/Zyphra/Zonos/blob/3807c8e04bd4beaadb9502b3df1ffa4b0350e3f7/zonos/conditioning.py#L308C1-L376C21) of Zonos.
Supported by transformer and hybrid models.
---
### vqscore_8
- **Type:** `FourierConditioner`
- **Attributes:**
- **input_dim:** `8`
- **min_val:** `0.5`
- **max_val:** `0.8`
- **uncond_type:** `learned`
- **Description:**
Encodes the desired [VQScore](https://github.com/JasonSWFu/VQscore) value for the output audio. VQScore is an unsupervised speech quality (cleanliness) estimation method that we found has superior generalization and reduced biases compared to supervised methods like DNSMOS. A good value for our model is 0.78 for high-quality speech. The eight dimensions correspond to consecutive 1/8th chunks of the audio. (eg. for an 8-second output, the first dimension represents the quality of the first second only). For inference, we generally set all 8 dimensions to the same value. This has an unfortunately strong correlation with expressiveness, so for expressive speech, we recommend setting it to unconditional.
Only applicable for the hybrid model.
---
### ctc_loss
- **Type:** `FourierConditioner`
- **Attributes:**
- **min_val:** `-1.0`
- **max_val:** `1000`
- **uncond_type:** `learned`
- **Description:**
Encodes loss values from a [CTC](https://en.wikipedia.org/wiki/Connectionist_temporal_classification) (Connectionist Temporal Classification) setup, this indicates how well the training-time transcription matched with the audio according to a CTC model. For inference always use low values (eg. 0.0 or 1.0)
Only applicable for the hybrid model.
---
### dnsmos_ovrl
- **Type:** `FourierConditioner`
- **Attributes:**
- **min_val:** `1`
- **max_val:** `5`
- **uncond_type:** `learned`
- **Description:**
A [MOS](https://arxiv.org/abs/2110.01763) score for the output audio. This is similar to VQScore and tends to have a stronger entanglement with emotions. It additionally has a strong entanglement with languages. Set to 4.0 for very clean and neutral English speech, else we recommend setting it to unconditional.
Only applicable for the hybrid model.
---
### speaker_noised
- **Type:** `IntegerConditioner`
- **Attributes:**
- **min_val:** `0`
- **max_val:** `1`
- **uncond_type:** `learned`
- **Description:**
Indicates if the speaker embedding is noisy or not. If checked this lets the model clean (denoise) the input speaker embedding. When this is set to True, VQScore and DNSMOS will have a lot more power to clean the speaker embedding, so for very noisy input samples we recommend setting this to True and specifying a high VQScore value. If your speaker cloning outputs sound echo-y or do weird things, setting this to True will help.
Only applicable for the hybrid model.
================================================
FILE: Dockerfile
================================================
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
RUN pip install uv
RUN apt update && \
apt install -y espeak-ng && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY . ./
RUN uv pip install --system -e . && uv pip install --system -e .[compile]
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Zonos-v0.1
---
Zonos-v0.1 is a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality on par with—or even surpassing—top TTS providers.
Our model enables highly natural speech generation from text prompts when given a speaker embedding or audio prefix, and can accurately perform speech cloning when given a reference clip spanning just a few seconds. The conditioning setup also allows for fine control over speaking rate, pitch variation, audio quality, and emotions such as happiness, fear, sadness, and anger. The model outputs speech natively at 44kHz.
##### For more details and speech samples, check out our blog [here](https://www.zyphra.com/post/beta-release-of-zonos-v0-1)
##### We also have a hosted version available at [playground.zyphra.com/audio](https://playground.zyphra.com/audio)
---
Zonos follows a straightforward architecture: text normalization and phonemization via eSpeak, followed by DAC token prediction through a transformer or hybrid backbone. An overview of the architecture can be seen below.
---
## Usage
### Python
```python
import torch
import torchaudio
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
from zonos.utils import DEFAULT_DEVICE as device
# model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
speaker = model.make_speaker_embedding(wav, sampling_rate)
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
conditioning = model.prepare_conditioning(cond_dict)
codes = model.generate(conditioning)
wavs = model.autoencoder.decode(codes).cpu()
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
```
### Gradio interface (recommended)
```bash
uv run gradio_interface.py
# python gradio_interface.py
```
This should produce a `sample.wav` file in your project root directory.
_For repeated sampling we highly recommend using the gradio interface instead, as the minimal example needs to load the model every time it is run._
## Features
- Zero-shot TTS with voice cloning: Input desired text and a 10-30s speaker sample to generate high quality TTS output
- Audio prefix inputs: Add text plus an audio prefix for even richer speaker matching. Audio prefixes can be used to elicit behaviours such as whispering which can otherwise be challenging to replicate when cloning from speaker embeddings
- Multilingual support: Zonos-v0.1 supports English, Japanese, Chinese, French, and German
- Audio quality and emotion control: Zonos offers fine-grained control of many aspects of the generated audio. These include speaking rate, pitch, maximum frequency, audio quality, and various emotions such as happiness, anger, sadness, and fear.
- Fast: our model runs with a real-time factor of ~2x on an RTX 4090 (i.e. generates 2 seconds of audio per 1 second of compute time)
- Gradio WebUI: Zonos comes packaged with an easy to use gradio interface to generate speech
- Simple installation and deployment: Zonos can be installed and deployed simply using the docker file packaged with our repository.
## Installation
#### System requirements
- **Operating System:** Linux (preferably Ubuntu 22.04/24.04), macOS
- **GPU:** 6GB+ VRAM, Hybrid additionally requires a 3000-series or newer Nvidia GPU
Note: Zonos can also run on CPU provided there is enough free RAM. However, this will be a lot slower than running on a dedicated GPU, and likely won't be sufficient for interactive use.
For experimental windows support check out [this fork](https://github.com/sdbds/Zonos-for-windows).
See also [Docker Installation](#docker-installation)
#### System dependencies
Zonos depends on the eSpeak library phonemization. You can install it on Ubuntu with the following command:
```bash
apt install -y espeak-ng # For Ubuntu
# brew install espeak-ng # For MacOS
```
#### Python dependencies
We highly recommend using a recent version of [uv](https://docs.astral.sh/uv/#installation) for installation. If you don't have uv installed, you can install it via pip: `pip install -U uv`.
##### Installing into a new uv virtual environment (recommended)
```bash
uv sync
uv sync --extra compile # optional but needed to run the hybrid
uv pip install -e .
```
##### Installing into the system/actived environment using uv
```bash
uv pip install -e .
uv pip install -e .[compile] # optional but needed to run the hybrid
```
##### Installing into the system/actived environment using pip
```bash
pip install -e .
pip install --no-build-isolation -e .[compile] # optional but needed to run the hybrid
```
##### Confirm that it's working
For convenience we provide a minimal example to check that the installation works:
```bash
uv run sample.py
# python sample.py
```
## Docker installation
```bash
git clone https://github.com/Zyphra/Zonos.git
cd Zonos
# For gradio
docker compose up
# Or for development you can do
docker build -t zonos .
docker run -it --gpus=all --net=host -v /path/to/Zonos:/Zonos -t zonos
cd /Zonos
python sample.py # this will generate a sample.wav in /Zonos
```
================================================
FILE: docker-compose.yml
================================================
version: '3.8'
services:
zonos:
build:
context: .
dockerfile: Dockerfile
container_name: zonos_container
runtime: nvidia
network_mode: "host"
stdin_open: true
tty: true
command: ["python3", "gradio_interface.py"]
environment:
- NVIDIA_VISIBLE_DEVICES=0
- GRADIO_SHARE=False
================================================
FILE: gradio_interface.py
================================================
import torch
import torchaudio
import gradio as gr
from os import getenv
from zonos.model import Zonos, DEFAULT_BACKBONE_CLS as ZonosBackbone
from zonos.conditioning import make_cond_dict, supported_language_codes
from zonos.utils import DEFAULT_DEVICE as device
CURRENT_MODEL_TYPE = None
CURRENT_MODEL = None
SPEAKER_EMBEDDING = None
SPEAKER_AUDIO_PATH = None
def load_model_if_needed(model_choice: str):
global CURRENT_MODEL_TYPE, CURRENT_MODEL
if CURRENT_MODEL_TYPE != model_choice:
if CURRENT_MODEL is not None:
del CURRENT_MODEL
torch.cuda.empty_cache()
print(f"Loading {model_choice} model...")
CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
CURRENT_MODEL.requires_grad_(False).eval()
CURRENT_MODEL_TYPE = model_choice
print(f"{model_choice} model loaded successfully!")
return CURRENT_MODEL
def update_ui(model_choice):
"""
Dynamically show/hide UI elements based on the model's conditioners.
We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
"""
model = load_model_if_needed(model_choice)
cond_names = [c.name for c in model.prefix_conditioner.conditioners]
print("Conditioners in this model:", cond_names)
text_update = gr.update(visible=("espeak" in cond_names))
language_update = gr.update(visible=("espeak" in cond_names))
speaker_audio_update = gr.update(visible=("speaker" in cond_names))
prefix_audio_update = gr.update(visible=True)
emotion1_update = gr.update(visible=("emotion" in cond_names))
emotion2_update = gr.update(visible=("emotion" in cond_names))
emotion3_update = gr.update(visible=("emotion" in cond_names))
emotion4_update = gr.update(visible=("emotion" in cond_names))
emotion5_update = gr.update(visible=("emotion" in cond_names))
emotion6_update = gr.update(visible=("emotion" in cond_names))
emotion7_update = gr.update(visible=("emotion" in cond_names))
emotion8_update = gr.update(visible=("emotion" in cond_names))
vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
fmax_slider_update = gr.update(visible=("fmax" in cond_names))
pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
unconditional_keys_update = gr.update(
choices=[name for name in cond_names if name not in ("espeak", "language_id")]
)
return (
text_update,
language_update,
speaker_audio_update,
prefix_audio_update,
emotion1_update,
emotion2_update,
emotion3_update,
emotion4_update,
emotion5_update,
emotion6_update,
emotion7_update,
emotion8_update,
vq_single_slider_update,
fmax_slider_update,
pitch_std_slider_update,
speaking_rate_slider_update,
dnsmos_slider_update,
speaker_noised_checkbox_update,
unconditional_keys_update,
)
def generate_audio(
model_choice,
text,
language,
speaker_audio,
prefix_audio,
e1,
e2,
e3,
e4,
e5,
e6,
e7,
e8,
vq_single,
fmax,
pitch_std,
speaking_rate,
dnsmos_ovrl,
speaker_noised,
cfg_scale,
top_p,
top_k,
min_p,
linear,
confidence,
quadratic,
seed,
randomize_seed,
unconditional_keys,
progress=gr.Progress(),
):
"""
Generates audio based on the provided UI parameters.
We do NOT use language_id or ctc_loss even if the model has them.
"""
selected_model = load_model_if_needed(model_choice)
speaker_noised_bool = bool(speaker_noised)
fmax = float(fmax)
pitch_std = float(pitch_std)
speaking_rate = float(speaking_rate)
dnsmos_ovrl = float(dnsmos_ovrl)
cfg_scale = float(cfg_scale)
top_p = float(top_p)
top_k = int(top_k)
min_p = float(min_p)
linear = float(linear)
confidence = float(confidence)
quadratic = float(quadratic)
seed = int(seed)
max_new_tokens = 86 * 30
# This is a bit ew, but works for now.
global SPEAKER_AUDIO_PATH, SPEAKER_EMBEDDING
if randomize_seed:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
torch.manual_seed(seed)
if speaker_audio is not None and "speaker" not in unconditional_keys:
if speaker_audio != SPEAKER_AUDIO_PATH:
print("Recomputed speaker embedding")
wav, sr = torchaudio.load(speaker_audio)
SPEAKER_EMBEDDING = selected_model.make_speaker_embedding(wav, sr)
SPEAKER_EMBEDDING = SPEAKER_EMBEDDING.to(device, dtype=torch.bfloat16)
SPEAKER_AUDIO_PATH = speaker_audio
audio_prefix_codes = None
if prefix_audio is not None:
wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
wav_prefix = wav_prefix.mean(0, keepdim=True)
wav_prefix = selected_model.autoencoder.preprocess(wav_prefix, sr_prefix)
wav_prefix = wav_prefix.to(device, dtype=torch.float32)
audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
vq_val = float(vq_single)
vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
cond_dict = make_cond_dict(
text=text,
language=language,
speaker=SPEAKER_EMBEDDING,
emotion=emotion_tensor,
vqscore_8=vq_tensor,
fmax=fmax,
pitch_std=pitch_std,
speaking_rate=speaking_rate,
dnsmos_ovrl=dnsmos_ovrl,
speaker_noised=speaker_noised_bool,
device=device,
unconditional_keys=unconditional_keys,
)
conditioning = selected_model.prepare_conditioning(cond_dict)
estimated_generation_duration = 30 * len(text) / 400
estimated_total_steps = int(estimated_generation_duration * 86)
def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
progress((step, estimated_total_steps))
return True
codes = selected_model.generate(
prefix_conditioning=conditioning,
audio_prefix_codes=audio_prefix_codes,
max_new_tokens=max_new_tokens,
cfg_scale=cfg_scale,
batch_size=1,
sampling_params=dict(top_p=top_p, top_k=top_k, min_p=min_p, linear=linear, conf=confidence, quad=quadratic),
callback=update_progress,
)
wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
sr_out = selected_model.autoencoder.sampling_rate
if wav_out.dim() == 2 and wav_out.size(0) > 1:
wav_out = wav_out[0:1, :]
return (sr_out, wav_out.squeeze().numpy()), seed
def build_interface():
supported_models = []
if "transformer" in ZonosBackbone.supported_architectures:
supported_models.append("Zyphra/Zonos-v0.1-transformer")
if "hybrid" in ZonosBackbone.supported_architectures:
supported_models.append("Zyphra/Zonos-v0.1-hybrid")
else:
print(
"| The current ZonosBackbone does not support the hybrid architecture, meaning only the transformer model will be available in the model selector.\n"
"| This probably means the mamba-ssm library has not been installed."
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=supported_models,
value=supported_models[0],
label="Zonos Model Type",
info="Select the model variant to use.",
)
text = gr.Textbox(
label="Text to Synthesize",
value="Zonos uses eSpeak for text to phoneme conversion!",
lines=4,
max_length=500, # approximately
)
language = gr.Dropdown(
choices=supported_language_codes,
value="en-us",
label="Language Code",
info="Select a language code.",
)
prefix_audio = gr.Audio(
value="assets/silence_100ms.wav",
label="Optional Prefix Audio (continue from this audio)",
type="filepath",
)
with gr.Column():
speaker_audio = gr.Audio(
label="Optional Speaker Audio (for cloning)",
type="filepath",
)
speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
with gr.Row():
with gr.Column():
gr.Markdown("## Conditioning Parameters")
dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)")
vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std")
speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate")
with gr.Column():
gr.Markdown("## Generation Parameters")
cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
seed_number = gr.Number(label="Seed", value=420, precision=0)
randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
with gr.Accordion("Sampling", open=False):
with gr.Row():
with gr.Column():
gr.Markdown("### NovelAi's unified sampler")
linear_slider = gr.Slider(-2.0, 2.0, 0.5, 0.01, label="Linear (set to 0 to disable unified sampling)", info="High values make the output less random.")
#Conf's theoretical range is between -2 * Quad and 0.
confidence_slider = gr.Slider(-2.0, 2.0, 0.40, 0.01, label="Confidence", info="Low values make random outputs more random.")
quadratic_slider = gr.Slider(-2.0, 2.0, 0.00, 0.01, label="Quadratic", info="High values make low probablities much lower.")
with gr.Column():
gr.Markdown("### Legacy sampling")
top_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Top P")
min_k_slider = gr.Slider(0.0, 1024, 0, 1, label="Min K")
min_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Min P")
with gr.Accordion("Advanced Parameters", open=False):
gr.Markdown(
"### Unconditional Toggles\n"
"Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
)
with gr.Row():
unconditional_keys = gr.CheckboxGroup(
[
"speaker",
"emotion",
"vqscore_8",
"fmax",
"pitch_std",
"speaking_rate",
"dnsmos_ovrl",
"speaker_noised",
],
value=["emotion"],
label="Unconditional Keys",
)
gr.Markdown(
"### Emotion Sliders\n"
"Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n"
"Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help."
)
with gr.Row():
emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness")
emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
with gr.Row():
emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other")
emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")
with gr.Column():
generate_button = gr.Button("Generate Audio")
output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
model_choice.change(
fn=update_ui,
inputs=[model_choice],
outputs=[
text,
language,
speaker_audio,
prefix_audio,
emotion1,
emotion2,
emotion3,
emotion4,
emotion5,
emotion6,
emotion7,
emotion8,
vq_single_slider,
fmax_slider,
pitch_std_slider,
speaking_rate_slider,
dnsmos_slider,
speaker_noised_checkbox,
unconditional_keys,
],
)
# On page load, trigger the same UI refresh
demo.load(
fn=update_ui,
inputs=[model_choice],
outputs=[
text,
language,
speaker_audio,
prefix_audio,
emotion1,
emotion2,
emotion3,
emotion4,
emotion5,
emotion6,
emotion7,
emotion8,
vq_single_slider,
fmax_slider,
pitch_std_slider,
speaking_rate_slider,
dnsmos_slider,
speaker_noised_checkbox,
unconditional_keys,
],
)
# Generate audio on button click
generate_button.click(
fn=generate_audio,
inputs=[
model_choice,
text,
language,
speaker_audio,
prefix_audio,
emotion1,
emotion2,
emotion3,
emotion4,
emotion5,
emotion6,
emotion7,
emotion8,
vq_single_slider,
fmax_slider,
pitch_std_slider,
speaking_rate_slider,
dnsmos_slider,
speaker_noised_checkbox,
cfg_scale_slider,
top_p_slider,
min_k_slider,
min_p_slider,
linear_slider,
confidence_slider,
quadratic_slider,
seed_number,
randomize_seed_toggle,
unconditional_keys,
],
outputs=[output_audio, seed_number],
)
return demo
if __name__ == "__main__":
demo = build_interface()
share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
================================================
FILE: pyproject.toml
================================================
[project]
name = "zonos"
version = "0.1.0"
description = "Text-to-speech by Zyphra"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch>=2.5.1",
"setuptools",
"packaging",
"inflect>=7.5.0",
"kanjize>=1.5.0",
"numpy>=2.2.2",
"phonemizer>=3.3.0",
"sudachidict-full>=20241021",
"sudachipy>=0.6.10",
"torchaudio>=2.5.1",
"transformers>=4.48.1",
"soundfile>=0.13.1",
"huggingface-hub>=0.28.1",
"gradio>=5.15.0",
]
# These are technically optional, but mamba-ssm is required to run hybrid models.
[project.optional-dependencies]
compile = [
"flash-attn>=2.7.3",
"mamba-ssm>=2.2.4",
"causal-conv1d>=1.5.0.post8",
]
[tool.setuptools.packages.find]
include = ["zonos"]
[tool.uv]
no-build-isolation-package = ["flash-attn", "mamba-ssm", "causal-conv1d"]
[tool.ruff]
line-length = 120
================================================
FILE: sample.py
================================================
import torch
import torchaudio
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
from zonos.utils import DEFAULT_DEVICE as device
# model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
speaker = model.make_speaker_embedding(wav, sampling_rate)
torch.manual_seed(421)
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
conditioning = model.prepare_conditioning(cond_dict)
codes = model.generate(conditioning)
wavs = model.autoencoder.decode(codes).cpu()
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
================================================
FILE: zonos/autoencoder.py
================================================
import math
import torch
import torchaudio
from transformers.models.dac import DacModel
class DACAutoencoder:
def __init__(self):
super().__init__()
self.dac = DacModel.from_pretrained("descript/dac_44khz")
self.dac.eval().requires_grad_(False)
self.codebook_size = self.dac.config.codebook_size
self.num_codebooks = self.dac.quantizer.n_codebooks
self.sampling_rate = self.dac.config.sampling_rate
def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
wav = torchaudio.functional.resample(wav, sr, 44_100)
right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
return torch.nn.functional.pad(wav, (0, right_pad))
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.dac.encode(wav).audio_codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
with torch.autocast(self.dac.device.type, torch.float16, enabled=self.dac.device.type != "cpu"):
return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1).float()
================================================
FILE: zonos/backbone/__init__.py
================================================
BACKBONES = {}
try:
from ._mamba_ssm import MambaSSMZonosBackbone
BACKBONES["mamba_ssm"] = MambaSSMZonosBackbone
except ImportError:
pass
from ._torch import TorchZonosBackbone
BACKBONES["torch"] = TorchZonosBackbone
================================================
FILE: zonos/backbone/_mamba_ssm.py
================================================
import torch
import torch.nn as nn
from mamba_ssm.models.mixer_seq_simple import create_block
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
from zonos.config import BackboneConfig, InferenceParams
class MambaSSMZonosBackbone(nn.Module):
supported_architectures = ["transformer", "hybrid"]
def __init__(self, config: BackboneConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
create_block(
d_model=config.d_model,
d_intermediate=config.d_intermediate
if (i not in config.attn_layer_idx)
else config.attn_mlp_d_intermediate,
ssm_cfg=config.ssm_cfg,
layer_idx=i,
attn_layer_idx=config.attn_layer_idx,
attn_cfg=config.attn_cfg,
norm_epsilon=config.norm_epsilon,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=True,
rms_norm=config.rms_norm,
)
for i in range(config.n_layer)
]
)
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
for i, layer in enumerate(self.layers)
}
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
residual = None
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, inference_params)
return layer_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
residual,
eps=self.norm_f.eps,
residual_in_fp32=self.config.residual_in_fp32,
is_rms_norm=self.config.rms_norm,
)
================================================
FILE: zonos/backbone/_torch.py
================================================
# Based on gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/095b2229ee3a40e379c11f05b94bd6923db63b4b/model.py
import torch
import torch.nn as nn
from torch.nn import functional as F
from zonos.config import BackboneConfig, InferenceParams
def precompute_freqs_cis(seq_len: int, n_elem: int, base: float = 10000) -> torch.Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
def _update_kv_cache(
k: torch.Tensor, v: torch.Tensor, inference_params: InferenceParams, layer_idx: int
) -> torch.Tensor:
"""k/v: (batch_size, seqlen, nheads, head_dim) or (batch_size, 1, nheads, head_dim)"""
assert layer_idx in inference_params.key_value_memory_dict
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + k.shape[0]
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + k.shape[1]
assert batch_end <= kv_cache.shape[0]
assert sequence_end <= kv_cache.shape[1]
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, 0, ...] = k
kv_cache[batch_start:batch_end, sequence_start:sequence_end, 1, ...] = v
return kv_cache[batch_start:batch_end, :sequence_end, ...]
class TorchZonosBackbone(nn.Module):
supported_architectures = ["transformer"]
freqs_cis: torch.Tensor
def __init__(self, config: BackboneConfig):
assert not config.ssm_cfg, "This backbone implementation only supports the Transformer model."
super().__init__()
self.config = config
self.layers = nn.ModuleList(TransformerBlock(config, i) for i in range(config.n_layer))
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
# TODO: This function should be pure
head_dim = self.config.d_model // self.config.attn_cfg["num_heads"]
self.freqs_cis = precompute_freqs_cis(16384, head_dim)
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
for i, layer in enumerate(self.layers)
}
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams) -> torch.Tensor:
input_pos = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
input_pos = input_pos + inference_params.lengths_per_sample.unsqueeze(-1)
freqs_cis = self.freqs_cis[input_pos].expand(hidden_states.shape[0], -1, -1, -1)
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, inference_params, freqs_cis)
return self.norm_f(hidden_states)
class TransformerBlock(nn.Module):
def __init__(self, config: BackboneConfig, layer_idx: int) -> None:
super().__init__()
self.config = config
self.norm = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
self.mixer = Attention(config, layer_idx)
self.norm2 = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
self.mlp = FeedForward(config)
self.num_heads_kv = config.attn_cfg["num_heads_kv"]
self.head_dim = config.d_model // config.attn_cfg["num_heads"]
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype), None
def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
x = x + self.mixer(self.norm(x), inference_params, freqs_cis)
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
def __init__(self, config: BackboneConfig, layer_idx: int):
super().__init__()
self.num_heads = config.attn_cfg["num_heads"]
self.num_heads_kv = config.attn_cfg["num_heads_kv"]
self.head_dim = config.d_model // self.num_heads
self.layer_idx = layer_idx
total_head_dim = (self.num_heads + 2 * self.num_heads_kv) * self.head_dim
self.in_proj = nn.Linear(config.d_model, total_head_dim, bias=False)
self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False)
def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
batch_size, seqlen, _ = x.shape
q_size = self.num_heads * self.head_dim
kv_size = self.num_heads_kv * self.head_dim
q, k, v = self.in_proj(x).split([q_size, kv_size, kv_size], dim=-1)
q = q.view(batch_size, seqlen, self.num_heads, self.head_dim)
k = k.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
v = v.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
kv = _update_kv_cache(k, v, inference_params, self.layer_idx)
k, v = kv.unbind(dim=-3)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
y = F.scaled_dot_product_attention(q, k, v, is_causal=seqlen > 1, enable_gqa=True)
y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, q_size)
y = self.out_proj(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: BackboneConfig) -> None:
super().__init__()
self.fc1 = nn.Linear(config.d_model, 2 * config.attn_mlp_d_intermediate, bias=False)
self.fc2 = nn.Linear(config.attn_mlp_d_intermediate, config.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y, gate = self.fc1(x).chunk(2, dim=-1)
return self.fc2(y * F.silu(gate))
================================================
FILE: zonos/codebook_pattern.py
================================================
import torch
import torch.nn.functional as F
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
def revert_delay_pattern(codes: torch.Tensor):
_, n_q, seq_len = codes.shape
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
================================================
FILE: zonos/conditioning.py
================================================
from functools import cache
from typing import Any, Literal, Iterable
import torch
import torch.nn as nn
from zonos.config import PrefixConditionerConfig
from zonos.utils import DEFAULT_DEVICE
class Conditioner(nn.Module):
def __init__(
self,
output_dim: int,
name: str,
cond_dim: int | None = None,
projection: Literal["none", "linear", "mlp"] = "none",
uncond_type: Literal["learned", "none"] = "none",
**kwargs,
):
super().__init__()
self.name = name
self.output_dim = output_dim
self.cond_dim = cond_dim = cond_dim or output_dim
if projection == "linear":
self.project = nn.Linear(cond_dim, output_dim)
elif projection == "mlp":
self.project = nn.Sequential(
nn.Linear(cond_dim, output_dim),
nn.SiLU(),
nn.Linear(output_dim, output_dim),
)
else:
self.project = nn.Identity()
self.uncond_vector = None
if uncond_type == "learned":
self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
def apply_cond(self, *inputs: Any) -> torch.Tensor:
raise NotImplementedError()
def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
if inputs is None:
assert self.uncond_vector is not None
return self.uncond_vector.data.view(1, 1, -1)
cond = self.apply_cond(*inputs)
cond = self.project(cond)
return cond
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
import os
import sys
import re
import unicodedata
import inflect
import torch
import torch.nn as nn
from kanjize import number2kanji
from phonemizer.backend import EspeakBackend
from sudachipy import Dictionary, SplitMode
if sys.platform == "darwin":
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "/opt/homebrew/lib/libespeak-ng.dylib"
# --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m: re.Match) -> str:
return m.group(1).replace(",", "")
def _expand_decimal_point(m: re.Match) -> str:
return m.group(1).replace(".", " point ")
def _expand_dollars(m: re.Match) -> str:
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m: re.Match) -> str:
return _inflect.number_to_words(m.group(0))
def _expand_number(m: re.Match) -> str:
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text: str) -> str:
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
# --- Number normalization code end ---
PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
)
symbols = [*_punctuation, *_letters, *_letters_ipa]
_symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
def _get_symbol_id(s: str) -> int:
return _symbol_to_id.get(s, 1)
def get_symbol_ids(text: str) -> list[int]:
return list(map(_get_symbol_id, text))
def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
lengths = list(map(len, phoneme_ids))
longest = max(lengths)
phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
return torch.tensor(phoneme_ids), lengths
def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
text = unicodedata.normalize("NFKC", text)
text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
return final_text
def clean(texts: list[str], languages: list[str]) -> list[str]:
texts_out = []
for text, language in zip(texts, languages):
if "ja" in language:
text = normalize_jp_text(text)
else:
text = normalize_numbers(text)
texts_out.append(text)
return texts_out
@cache
def get_backend(language: str) -> "EspeakBackend":
import logging
from phonemizer.backend import EspeakBackend
logger = logging.getLogger("phonemizer")
backend = EspeakBackend(
language,
preserve_punctuation=True,
with_stress=True,
punctuation_marks=_punctuation,
logger=logger,
)
logger.setLevel(logging.ERROR)
return backend
def phonemize(texts: list[str], languages: list[str]) -> list[str]:
texts = clean(texts, languages)
batch_phonemes = []
for text, language in zip(texts, languages):
backend = get_backend(language)
phonemes = backend.phonemize([text], strip=True)
batch_phonemes.append(phonemes[0])
return batch_phonemes
class EspeakPhonemeConditioner(Conditioner):
def __init__(self, output_dim: int, **kwargs):
super().__init__(output_dim, **kwargs)
self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
"""
Args:
texts: list of texts to convert to phonemes
languages: ISO 639-1 -or otherwise eSpeak compatible- language code
"""
device = self.phoneme_embedder.weight.device
phonemes = phonemize(texts, languages)
phoneme_ids, _ = tokenize_phonemes(phonemes)
phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
return phoneme_embeds
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
class FourierConditioner(Conditioner):
def __init__(
self,
output_dim: int,
input_dim: int = 1,
std: float = 1.0,
min_val: float = 0.0,
max_val: float = 1.0,
**kwargs,
):
assert output_dim % 2 == 0
super().__init__(output_dim, **kwargs)
self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.input_dim
x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
class IntegerConditioner(Conditioner):
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
super().__init__(output_dim, **kwargs)
self.min_val = min_val
self.max_val = max_val
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == 1
return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
class PassthroughConditioner(Conditioner):
def __init__(self, output_dim: int, **kwargs):
super().__init__(output_dim, **kwargs)
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.cond_dim
return x
_cond_cls_map = {
"PassthroughConditioner": PassthroughConditioner,
"EspeakPhonemeConditioner": EspeakPhonemeConditioner,
"FourierConditioner": FourierConditioner,
"IntegerConditioner": IntegerConditioner,
}
def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
class PrefixConditioner(Conditioner):
def __init__(self, config: PrefixConditionerConfig, output_dim: int):
super().__init__(output_dim, "prefix", projection=config.projection)
self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
self.norm = nn.LayerNorm(output_dim)
self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
def forward(self, cond_dict: dict) -> torch.Tensor:
if not set(cond_dict).issuperset(self.required_keys):
raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
conds = []
for conditioner in self.conditioners:
conds.append(conditioner(cond_dict.get(conditioner.name)))
max_bsz = max(map(len, conds))
assert all(c.shape[0] in (max_bsz, 1) for c in conds)
conds = [c.expand(max_bsz, -1, -1) for c in conds]
return self.norm(self.project(torch.cat(conds, dim=-2)))
supported_language_codes = [
'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
'vi-vn-x-central', 'vi-vn-x-south', 'yue'
] # fmt: off
def make_cond_dict(
text: str = "It would be nice to have time for testing, indeed.",
language: str = "en-us",
speaker: torch.Tensor | None = None,
# Emotion vector from 0.0 to 1.0
# Is entangled with pitch_std because more emotion => more pitch variation
# VQScore and DNSMOS because they favor neutral speech
#
# Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral
emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
# Maximum frequency (0 to 24000), should be 22050 or 24000 for 44.1 or 48 kHz audio
# For voice cloning use 22050
fmax: float = 22050.0,
# Standard deviation for pitch (0 to 400), should be
# 20-45 for normal speech,
# 60-150 for expressive speech,
# higher values => crazier samples
pitch_std: float = 20.0,
# Speaking rate in phonemes per minute (0 to 40). 30 is very fast, 10 is slow.
speaking_rate: float = 15.0,
# Target VoiceQualityScore for the generated speech (0.5 to 0.8).
# A list of values must be provided which represent each 1/8th of the audio.
# You should unset for expressive speech.
# According to discord Chat this is only used for the hybrid model
vqscore_8: list[float] = [0.78] * 8,
# CTC target loss
# Only used for the hybrid model
ctc_loss: float = 0.0,
# Only used for the hybrid model
dnsmos_ovrl: float = 4.0,
# Only used for the hybrid model
speaker_noised: bool = False,
unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
device: torch.device | str = DEFAULT_DEVICE,
) -> dict:
"""
A helper to build the 'cond_dict' that the model expects.
By default, it will generate a random speaker embedding
"""
assert language.lower() in supported_language_codes, "Please pick a supported language"
language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
cond_dict = {
"espeak": ([text], [language]),
"speaker": speaker,
"emotion": emotion,
"fmax": fmax,
"pitch_std": pitch_std,
"speaking_rate": speaking_rate,
"language_id": language_code_to_id[language],
"vqscore_8": vqscore_8,
"ctc_loss": ctc_loss,
"dnsmos_ovrl": dnsmos_ovrl,
"speaker_noised": int(speaker_noised),
}
for k in unconditional_keys:
cond_dict.pop(k, None)
for k, v in cond_dict.items():
if isinstance(v, (float, int, list)):
v = torch.tensor(v)
if isinstance(v, torch.Tensor):
cond_dict[k] = v.view(1, 1, -1).to(device)
if k == "emotion":
cond_dict[k] /= cond_dict[k].sum(dim=-1)
return cond_dict
================================================
FILE: zonos/config.py
================================================
from dataclasses import dataclass, field
from typing import Literal
import torch
# https://github.com/state-spaces/mamba/blob//mamba_ssm/utils/generation.py#L18
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: torch.Tensor | None = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
@dataclass
class BackboneConfig:
d_model: int = 1024
d_intermediate: int = 0
attn_mlp_d_intermediate: int = 0
n_layer: int = 16
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = False
residual_in_fp32: bool = False
norm_epsilon: float = 1e-5
@dataclass
class PrefixConditionerConfig:
conditioners: list[dict]
projection: Literal["none", "linear", "mlp"]
@dataclass
class ZonosConfig:
backbone: BackboneConfig
prefix_conditioner: PrefixConditionerConfig
eos_token_id: int = 1024
masked_token_id: int = 1025
pad_vocab_to_multiple_of: int = 8
@classmethod
def from_dict(cls, d: dict) -> "ZonosConfig":
d = d.copy()
backbone_config = BackboneConfig(**d.pop("backbone"))
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
config = cls(backbone_config, prefix_conditioner_config, **d)
return config
================================================
FILE: zonos/model.py
================================================
import json
from typing import Callable
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from zonos.autoencoder import DACAutoencoder
from zonos.backbone import BACKBONES
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
from zonos.conditioning import PrefixConditioner
from zonos.config import InferenceParams, ZonosConfig
from zonos.sampling import sample_from_logits
from zonos.speaker_cloning import SpeakerEmbeddingLDA
from zonos.utils import DEFAULT_DEVICE, find_multiple, pad_weight_
DEFAULT_BACKBONE_CLS = next(iter(BACKBONES.values()))
class Zonos(nn.Module):
def __init__(self, config: ZonosConfig, backbone_cls=DEFAULT_BACKBONE_CLS):
super().__init__()
self.config = config
dim = config.backbone.d_model
self.eos_token_id = config.eos_token_id
self.masked_token_id = config.masked_token_id
self.autoencoder = DACAutoencoder()
self.backbone = backbone_cls(config.backbone)
self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
self.spk_clone_model = None
# TODO: pad to multiple of at least 8
self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
self._cg_graph = None
self._cg_batch_size = None
self._cg_input_ids = None
self._cg_logits = None
self._cg_inference_params = None
self._cg_scale = None
if config.pad_vocab_to_multiple_of:
self.register_load_state_dict_post_hook(self._pad_embeddings_and_heads)
def _pad_embeddings_and_heads(self, *args, **kwargs):
for w in [*self.embeddings, *self.heads]:
pad_weight_(w, self.config.pad_vocab_to_multiple_of)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@classmethod
def from_pretrained(
cls, repo_id: str, revision: str | None = None, device: str = DEFAULT_DEVICE, **kwargs
) -> "Zonos":
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
return cls.from_local(config_path, model_path, device, **kwargs)
@classmethod
def from_local(
cls, config_path: str, model_path: str, device: str = DEFAULT_DEVICE, backbone: str | None = None
) -> "Zonos":
config = ZonosConfig.from_dict(json.load(open(config_path)))
if backbone:
backbone_cls = BACKBONES[backbone]
else:
is_transformer = not bool(config.backbone.ssm_cfg)
backbone_cls = DEFAULT_BACKBONE_CLS
# Preferentially route to pure torch backbone for increased performance and lower latency.
if is_transformer and "torch" in BACKBONES:
backbone_cls = BACKBONES["torch"]
model = cls(config, backbone_cls).to(device, torch.bfloat16)
model.autoencoder.dac.to(device)
sd = model.state_dict()
with safetensors.safe_open(model_path, framework="pt") as f:
for k in f.keys():
sd[k] = f.get_tensor(k)
model.load_state_dict(sd)
return model
def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
"""Generate a speaker embedding from an audio clip."""
if self.spk_clone_model is None:
self.spk_clone_model = SpeakerEmbeddingLDA()
_, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
return spk_embedding.unsqueeze(0).bfloat16()
def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
return torch.stack([head(hidden_states) for head in self.heads], dim=1)
def _compute_logits(
self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
) -> torch.Tensor:
"""
Pass `hidden_states` into `backbone` and `multi_head`, applying
classifier-free guidance if `cfg_scale != 1.0`.
"""
last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
logits = self.apply_heads(last_hidden_states).squeeze(2).float()
if cfg_scale != 1.0:
cond_logits, uncond_logits = logits.chunk(2)
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
logits[..., 1025:].fill_(-torch.inf) # ensures padding is ignored
return logits
def _decode_one_token(
self,
input_ids: torch.Tensor,
inference_params: InferenceParams,
cfg_scale: float,
allow_cudagraphs: bool = True,
) -> torch.Tensor:
"""
Single-step decode. Prepares the hidden states, possibly replicates them
for CFG, and then delegates to `_compute_logits`.
Below we wrap this function with a simple CUDA Graph capturing mechanism,
doing 3 warmup steps if needed and then capturing or replaying the graph.
We only recapture if the batch size changes.
"""
# TODO: support cfg_scale==1
if cfg_scale == 1.0:
hidden_states = self.embed_codes(input_ids)
return self._compute_logits(hidden_states, inference_params, cfg_scale)
bsz = input_ids.size(0)
if not allow_cudagraphs or input_ids.device.type != "cuda":
hidden_states_local = self.embed_codes(input_ids)
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
return self._compute_logits(hidden_states_local, inference_params, cfg_scale)
need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
if need_capture:
self._cg_graph = None
self._cg_batch_size = bsz
self._cg_inference_params = inference_params
self._cg_scale = cfg_scale
for _ in range(3):
hidden_states = self.embed_codes(input_ids)
hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
self._cg_input_ids = input_ids.clone()
self._cg_logits = torch.empty_like(logits)
g = torch.cuda.CUDAGraph()
def capture_region():
hidden_states_local = self.embed_codes(self._cg_input_ids)
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
with torch.cuda.graph(g):
capture_region()
self._cg_graph = g
else:
self._cg_input_ids.copy_(input_ids)
self._cg_graph.replay()
return self._cg_logits
def _prefill(
self,
prefix_hidden_states: torch.Tensor,
input_ids: torch.Tensor,
inference_params: InferenceParams,
cfg_scale: float,
) -> torch.Tensor:
"""
"Prefill" mode: we already have `prefix_hidden_states`, and we want
to append new embeddings, then compute the logits.
"""
# Replicate input_ids if CFG is enabled
if cfg_scale != 1.0:
input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
return self._compute_logits(hidden_states, inference_params, cfg_scale)
def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
max_seqlen = find_multiple(max_seqlen, 8)
key_value_memory_dict = self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32)
return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
if uncond_dict is None:
uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
return torch.cat(
[
self.prefix_conditioner(cond_dict),
self.prefix_conditioner(uncond_dict),
]
)
def can_use_cudagraphs(self) -> bool:
# Only the mamba-ssm backbone supports CUDA Graphs at the moment
return self.device.type == "cuda" and "_mamba_ssm" in str(self.backbone.__class__)
@torch.inference_mode()
def generate(
self,
prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
max_new_tokens: int = 86 * 30,
cfg_scale: float = 2.0,
batch_size: int = 1,
sampling_params: dict = dict(min_p=0.1),
progress_bar: bool = True,
disable_torch_compile: bool = False,
callback: Callable[[torch.Tensor, int, int], bool] | None = None,
):
assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
device = self.device
# Use CUDA Graphs if supported, and torch.compile otherwise.
cg = self.can_use_cudagraphs()
decode_one_token = self._decode_one_token
decode_one_token = torch.compile(decode_one_token, dynamic=True, disable=cg or disable_torch_compile)
unknown_token = -1
audio_seq_len = prefix_audio_len + max_new_tokens
seq_len = prefix_conditioning.shape[1] + audio_seq_len + 9
with torch.device(device):
inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
codes = torch.full((batch_size, 9, audio_seq_len), unknown_token)
if audio_prefix_codes is not None:
codes[..., :prefix_audio_len] = audio_prefix_codes
delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
next_token = sample_from_logits(logits, **sampling_params)
offset = delayed_prefix_audio_codes.shape[2]
frame = delayed_codes[..., offset : offset + 1]
frame.masked_scatter_(frame == unknown_token, next_token)
prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
inference_params.seqlen_offset += prefix_length
inference_params.lengths_per_sample[:] += prefix_length
logit_bias = torch.zeros_like(logits)
logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
max_steps = delayed_codes.shape[2] - offset
remaining_steps = torch.full((batch_size,), max_steps, device=device)
progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
cfg_scale = torch.tensor(cfg_scale)
step = 0
while torch.max(remaining_steps) > 0:
offset += 1
input_ids = delayed_codes[..., offset - 1 : offset]
logits = decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg)
logits += logit_bias
next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
eos_in_cb0 = next_token[:, 0] == self.eos_token_id
remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
stopping |= eos_in_cb0[:, 0]
eos_codebook_idx = 9 - remaining_steps
eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
for i in range(next_token.shape[0]):
if stopping[i]:
idx = eos_codebook_idx[i].item()
next_token[i, :idx] = self.masked_token_id
next_token[i, idx] = self.eos_token_id
frame = delayed_codes[..., offset : offset + 1]
frame.masked_scatter_(frame == unknown_token, next_token)
inference_params.seqlen_offset += 1
inference_params.lengths_per_sample[:] += 1
remaining_steps -= 1
progress.update()
step += 1
if callback is not None and not callback(frame, step, max_steps):
break
out_codes = revert_delay_pattern(delayed_codes)
out_codes.masked_fill_(out_codes >= 1024, 0)
out_codes = out_codes[..., : offset - 9]
self._cg_graph = None # reset cuda graph to avoid cache changes
return out_codes
================================================
FILE: zonos/sampling.py
================================================
import torch
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def apply_unified(probs: torch.Tensor, linear: float, conf: float, quad: float) -> torch.Tensor:
"""Sample next token using unified sampling approach that combines linear scaling, confidence, and quadratic terms.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
linear (float): Linear scaling factor applied to log probabilities.
conf (float): Confidence factor that scales the entropy term.
quad (float): Quadratic penalty factor applied to squared log probabilities.
Returns:
torch.Tensor: Modified probability distribution after applying unified sampling.
"""
logprobs = torch.log(probs.clamp_min(1e-20))
entropy = -torch.sum(probs * logprobs, dim=-1, keepdim=True)
raw = logprobs * (linear + entropy * conf) - logprobs**2 * quad
return raw.softmax(dim=-1)
def apply_top_k(
probs: torch.Tensor,
k: int,
) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
v, _ = torch.topk(probs, min(k, probs.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
probs = torch.where(probs < pivot, 0.0, probs)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs = probs.scatter(-1, probs_idx, probs_sort)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
"""Sample next token using min-p sampling.
Args:
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
Returns:
torch.Tensor: Sampled tokens.
"""
top_probs, _ = probs.max(dim=-1, keepdim=True)
tokens_to_remove = probs < (min_p * top_probs)
probs = probs.masked_fill(tokens_to_remove, 0.0)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def modify_logit_for_repetition_penalty(
logits: torch.Tensor,
generated_tokens: torch.Tensor,
repetition_penalty: float,
repetition_penalty_window: int,
):
"""See https://arxiv.org/abs/1909.05858
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
logits: (batch_size, n_codebooks, vocab_size)
generated_tokens: (batch_size, n_codebooks, seq_len)
"""
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
rp = torch.full_like(logits, repetition_penalty)
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
return torch.where(logits <= 0, logits * factors, logits / factors)
def sample_from_logits(
logits: torch.Tensor,
temperature: float = 1.0,
top_p: float = 0.0,
top_k: int = 0,
min_p: float = 0.0,
linear: float = 0.0,
conf: float = 0.0,
quad: float = 0.0,
generated_tokens: torch.Tensor | None = None,
repetition_penalty: float = 3.0,
repetition_penalty_window: int = 2,
) -> torch.Tensor:
"""Sample next token from logits using either top_k/p/min_p OR using NovelAI's Unified Sampler.
Args:
logits (torch.Tensor): Input logits with token candidates on the last dimension.
temperature (float): Randomness of the sampling. Lower temperature results in more deterministic samples.
To disable sampling entirely, set it to 0. For NovelAI's Unified Sampler, set it to 1.0
top_p (float): Only sample from the most probable tokens whose cumulative probability is less than p.
This is called nucleus sampling. Must be between 0 and 1. Typical values are in the 0.1-0.9 range.
Set to 0 to disable.
top_k (int): Only sample from the top k most probable tokens. Set to 0 to disable.
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
If too high, no token might be sampled leading to silence (?)
linear (float): NovelAI's Unified Sampler -> 0.0 to 1.0, default from gradio 0.5
Set Linear between 0 and 1 according to how unusual you want tokens to be.
Lower numbers will produce more unusual/creative outputs,
but you will have to reroll or edit more.
conf (float): Confidence - Low values make random outputs more random. -> -2.0 * Quad to 2.0, default from gradio 0.4
As a starting point, set Quad = 1/3 - Linear * 4 / 15, and Conf = -Quad / 2.
quad (float): Quadratic - High values make low probablities much lower. -> -2.0 to 2.0, default from gradio 0.0
Returns:
torch.Tensor: Sampled tokens.
"""
if repetition_penalty != 1.0 and generated_tokens is not None:
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
if linear > 0.0:
probs = apply_unified(probs, linear, conf, quad)
if top_p > 0:
probs = apply_top_p(probs, top_p)
if top_k > 0:
probs = apply_top_k(probs, top_k)
if min_p > 0:
probs = apply_min_p(probs, min_p)
next_token = multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token # [batch_size, num_codebooks, 1]
================================================
FILE: zonos/speaker_cloning.py
================================================
import math
from functools import cache
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from huggingface_hub import hf_hub_download
from zonos.utils import DEFAULT_DEVICE
class logFbankCal(nn.Module):
def __init__(
self,
sample_rate: int = 16_000,
n_fft: int = 512,
win_length: float = 0.025,
hop_length: float = 0.01,
n_mels: int = 80,
):
super().__init__()
self.fbankCal = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=int(win_length * sample_rate),
hop_length=int(hop_length * sample_rate),
n_mels=n_mels,
)
def forward(self, x):
out = self.fbankCal(x)
out = torch.log(out + 1e-6)
out = out - out.mean(axis=2).unsqueeze(dim=2)
return out
class ASP(nn.Module):
# Attentive statistics pooling
def __init__(self, in_planes, acoustic_dim):
super(ASP, self).__init__()
outmap_size = int(acoustic_dim / 8)
self.out_dim = in_planes * 8 * outmap_size * 2
self.attention = nn.Sequential(
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)
def forward(self, x):
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
return x
class SimAMBasicBlock(nn.Module):
expansion = 1
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
super(SimAMBasicBlock, self).__init__()
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = NormLayer(planes)
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = NormLayer(planes)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = nn.Sequential(
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
NormLayer(self.expansion * planes),
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.SimAM(out)
out += self.downsample(x)
out = self.relu(out)
return out
def SimAM(self, X, lambda_p=1e-4):
n = X.shape[2] * X.shape[3] - 1
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
v = d.sum(dim=[2, 3], keepdim=True) / n
E_inv = d / (4 * (v + lambda_p)) + 0.5
return X * self.sigmoid(E_inv)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
super(BasicBlock, self).__init__()
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = NormLayer(planes)
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = NormLayer(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = nn.Sequential(
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
NormLayer(self.expansion * planes),
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.downsample(x)
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
super(ResNet, self).__init__()
if feat_dim == "1d":
self.NormLayer = nn.BatchNorm1d
self.ConvLayer = nn.Conv1d
elif feat_dim == "2d":
self.NormLayer = nn.BatchNorm2d
self.ConvLayer = nn.Conv2d
elif feat_dim == "3d":
self.NormLayer = nn.BatchNorm3d
self.ConvLayer = nn.Conv3d
else:
print("error")
self.in_planes = in_planes
self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = self.NormLayer(in_planes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def ResNet293(in_planes: int, **kwargs):
return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
class ResNet293_based(nn.Module):
def __init__(
self,
in_planes: int = 64,
embd_dim: int = 256,
acoustic_dim: int = 80,
featCal=None,
dropout: float = 0,
**kwargs,
):
super(ResNet293_based, self).__init__()
self.featCal = featCal
self.front = ResNet293(in_planes)
block_expansion = SimAMBasicBlock.expansion
self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
self.drop = nn.Dropout(dropout) if dropout else None
def forward(self, x):
x = self.featCal(x)
x = self.front(x.unsqueeze(dim=1))
x = self.pooling(x)
if self.drop:
x = self.drop(x)
x = self.bottleneck(x)
return x
class SEModule(nn.Module):
def __init__(self, channels, bottleneck=128):
super(SEModule, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
nn.ReLU(),
# nn.BatchNorm1d(bottleneck), # Removed
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
nn.Sigmoid(),
)
def forward(self, input):
x = self.se(input)
return input * x
class Bottle2neck(nn.Module):
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
super(Bottle2neck, self).__init__()
width = int(math.floor(planes / scale))
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
self.bn1 = nn.BatchNorm1d(width * scale)
self.nums = scale - 1
convs = []
bns = []
num_pad = math.floor(kernel_size / 2) * dilation
for i in range(self.nums):
convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
bns.append(nn.BatchNorm1d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
self.bn3 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU()
self.width = width
self.se = SEModule(planes)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(sp)
sp = self.bns[i](sp)
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = torch.cat((out, spx[self.nums]), 1)
out = self.conv3(out)
out = self.relu(out)
out = self.bn3(out)
out = self.se(out)
out += residual
return out
class ECAPA_TDNN(nn.Module):
def __init__(self, C, featCal):
super(ECAPA_TDNN, self).__init__()
self.featCal = featCal
self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(C)
self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
# I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
self.attention = nn.Sequential(
nn.Conv1d(4608, 256, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Tanh(), # Added
nn.Conv1d(256, 1536, kernel_size=1),
nn.Softmax(dim=2),
)
self.bn5 = nn.BatchNorm1d(3072)
self.fc6 = nn.Linear(3072, 192)
self.bn6 = nn.BatchNorm1d(192)
def forward(self, x):
x = self.featCal(x)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x1 = self.layer1(x)
x2 = self.layer2(x + x1)
x3 = self.layer3(x + x1 + x2)
x = self.layer4(torch.cat((x1, x2, x3), dim=1))
x = self.relu(x)
t = x.size()[-1]
global_x = torch.cat(
(
x,
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
),
dim=1,
)
w = self.attention(global_x)
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
x = torch.cat((mu, sg), 1)
x = self.bn5(x)
x = self.fc6(x)
x = self.bn6(x)
return x
class SpeakerEmbedding(nn.Module):
def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = DEFAULT_DEVICE):
super().__init__()
self.device = device
with torch.device(device):
self.model = ResNet293_based()
state_dict = torch.load(ckpt_path, weights_only=True, mmap=True, map_location="cpu")
self.model.load_state_dict(state_dict)
self.model.featCal = logFbankCal()
self.requires_grad_(False).eval()
@property
def dtype(self):
return next(self.parameters()).dtype
@cache
def _get_resampler(self, orig_sample_rate: int):
return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
assert wav.ndim < 3
if wav.ndim == 2:
wav = wav.mean(0, keepdim=True)
wav = self._get_resampler(sample_rate)(wav)
return wav
def forward(self, wav: torch.Tensor, sample_rate: int):
wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
return self.model(wav).to(wav.device)
class SpeakerEmbeddingLDA(nn.Module):
def __init__(self, device: str = DEFAULT_DEVICE):
super().__init__()
spk_model_path = hf_hub_download(
repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
filename="ResNet293_SimAM_ASP_base.pt",
)
lda_spk_model_path = hf_hub_download(
repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
filename="ResNet293_SimAM_ASP_base_LDA-128.pt",
)
self.device = device
with torch.device(device):
self.model = SpeakerEmbedding(spk_model_path, device)
lda_sd = torch.load(lda_spk_model_path, weights_only=True)
out_features, in_features = lda_sd["weight"].shape
self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
self.lda.load_state_dict(lda_sd)
self.requires_grad_(False).eval()
def forward(self, wav: torch.Tensor, sample_rate: int):
emb = self.model(wav, sample_rate).to(torch.float32)
return emb, self.lda(emb)
================================================
FILE: zonos/utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def find_multiple(n: int, k: int) -> int:
if k == 0 or n % k == 0:
return n
return n + k - (n % k)
def pad_weight_(w: nn.Embedding | nn.Linear, multiple: int):
"""Pad the weight of an embedding or linear layer to a multiple of `multiple`."""
if isinstance(w, nn.Embedding):
# Pad input dim
if w.weight.shape[1] % multiple == 0:
return
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[1] % multiple))
w.num_embeddings, w.embedding_dim = w.weight.shape
elif isinstance(w, nn.Linear):
# Pad output dim
if w.weight.shape[0] % multiple == 0:
return
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[0] % multiple))
w.out_features, w.in_features = w.weight.shape
else:
raise ValueError(f"Unsupported weight type: {type(w)}")
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
# MPS breaks for whatever reason. Uncomment when it's working.
# if torch.mps.is_available():
# return torch.device("mps")
return torch.device("cpu")
DEFAULT_DEVICE = get_device()