Repository: m-bain/whisperX
Branch: main
Commit: 646f511e6bb4
Files: 28
Total size: 145.8 KB
Directory structure:
gitextract_hiro3isg/
├── .github/
│ ├── FUNDING.yml
│ └── workflows/
│ ├── build-and-release.yml
│ └── python-compatibility.yml
├── .gitignore
├── .python-version
├── CUDNN_TROUBLESHOOTING.md
├── EXAMPLES.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── pyproject.toml
└── whisperx/
├── SubtitlesProcessor.py
├── __init__.py
├── __main__.py
├── alignment.py
├── asr.py
├── assets/
│ └── mel_filters.npz
├── audio.py
├── conjunctions.py
├── diarize.py
├── log_utils.py
├── schema.py
├── transcribe.py
├── utils.py
└── vads/
├── __init__.py
├── pyannote.py
├── silero.py
└── vad.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
custom: https://www.buymeacoffee.com/maxhbain
================================================
FILE: .github/workflows/build-and-release.yml
================================================
name: Build and release
on:
release:
types: [published]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.5.14"
python-version: "3.10"
- name: Check if lockfile is up to date
run: uv lock --check
- name: Build package
run: uv build
- name: Release to Github
uses: softprops/action-gh-release@v2
with:
files: dist/*.whl
- name: Publish package to PyPi
run: uv publish
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .github/workflows/python-compatibility.yml
================================================
name: Python Compatibility Test
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.5.14"
python-version: ${{ matrix.python-version }}
- name: Check if lockfile is up to date
run: uv lock --check
- name: Install the project
run: uv sync --all-extras
- name: Test import
run: |
uv run python -c "import whisperx; print('Successfully imported whisperx')"
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
================================================
FILE: .python-version
================================================
3.10
================================================
FILE: CUDNN_TROUBLESHOOTING.md
================================================
# Troubleshooting cuDNN Loading Errors
This guide helps resolve common cuDNN-related errors when running WhisperX on GPU. These issues typically occur when the system can't locate cuDNN libraries or finds conflicting versions.
## Unable to Load cuDNN Libraries
If you encounter the following error when running WhisperX:
`Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
This means the cuDNN libraries are installed (via whisperx dependencies) but aren't in a location where the system's dynamic linker can find them.
### Solution 1: Add to LD_LIBRARY_PATH (Recommended)
Add this at the start of your Python script or notebook:
```python
import os
# Get current LD_LIBRARY_PATH
original = os.environ.get("LD_LIBRARY_PATH", "")
cudnn_path = "/usr/local/lib/python3.12/dist-packages/nvidia/cudnn/lib/"
os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path
```
**Note:** Adjust the Python version (`python3.12`) to match your environment.
### Solution 2: Symlink to LD_LIBRARY_PATH Directory
If Solution 1 didn't work and you still get the "unable to load" error, symlink the libraries to a directory that's already in your `LD_LIBRARY_PATH`:
1. Check what's in your LD_LIBRARY_PATH: `echo "$LD_LIBRARY_PATH"`
2. Assuming that there is only one path set.
Symlink the downloaded libcudnn files to that path:
`ln -s /usr/local/lib/python3.12/dist-packages/nvidia/cudnn/lib/libcudnn* "$LD_LIBRARY_PATH"/`
**Note:** If `LD_LIBRARY_PATH` contains multiple paths (separated by `:`), pick one directory and use it instead of `"$LD_LIBRARY_PATH"`. For example: `/usr/lib/x86_64-linux-gnu/`
## cuDNN Version Incompatibility
If you encounter this error:
```
RuntimeError: cuDNN version incompatibility: PyTorch was compiled against (9, 10, 2) but found runtime version (9, 2, 1)
```
This means PyTorch is finding a different cuDNN version than the one it was compiled with. **PyTorch comes bundled with its own cuDNN**, but a conflicting cuDNN in `LD_LIBRARY_PATH` is taking precedence.
### Solution: Remove Conflicting cuDNN from Path
Check if there's a conflicting cuDNN path:
```bash
echo $LD_LIBRARY_PATH
```
If you see paths pointing to older cuDNN installations (e.g., system-installed cuDNN or manually downloaded), try one of these:
**Option 1: Clear LD_LIBRARY_PATH temporarily**
```python
import os
# Let PyTorch use its bundled cuDNN
os.environ.pop('LD_LIBRARY_PATH', None)
```
**Option 2: Set LD_LIBRARY_PATH to only the correct version**
```python
import os
# Point only to the cuDNN that matches PyTorch's compiled version
os.environ['LD_LIBRARY_PATH'] = "/usr/local/lib/python3.12/dist-packages/nvidia/cudnn/lib/"
```
**Note:** This error is unlikely on a clean install. If it occurs anyway, [open an issue](https://github.com/m-bain/whisperX/issues). If you've modified system libraries or CUDA/cuDNN, the options above should help resolve most cases.
================================================
FILE: EXAMPLES.md
================================================
# More Examples
## Other Languages
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
### French
whisperx --model large --language fr examples/sample_fr_01.wav
https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
### German
whisperx --model large --language de examples/sample_de_01.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
### Italian
whisperx --model large --language de examples/sample_it_01.wav
https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
### Japanese
whisperx --model large --language ja examples/sample_ja_01.wav
https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
================================================
FILE: LICENSE
================================================
BSD 2-Clause License
Copyright (c) 2024, Max Bain
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: MANIFEST.in
================================================
include whisperx/assets/*
include LICENSE
include requirements.txt
================================================
FILE: README.md
================================================
<h1 align="center">WhisperX</h1>
## Recall.ai - Meeting Transcription API
If you’re looking for a transcription API for meetings, consider checking out [Recall.ai's Meeting Transcription API](https://www.recall.ai/product/meeting-transcription-api?utm_source=github&utm_medium=sponsorship&utm_campaign=mbain-whisperx), an API that works with Zoom, Google Meet, Microsoft Teams, and more. Recall.ai diarizes by pulling the speaker data and separate audio streams from the meeting platforms, which means 100% accurate speaker diarization with actual speaker names.
<p align="center">
<a href="https://github.com/m-bain/whisperX/stargazers">
<img src="https://img.shields.io/github/stars/m-bain/whisperX.svg?colorA=orange&colorB=orange&logo=github"
alt="GitHub stars">
</a>
<a href="https://github.com/m-bain/whisperX/issues">
<img src="https://img.shields.io/github/issues/m-bain/whisperx.svg"
alt="GitHub issues">
</a>
<a href="https://github.com/m-bain/whisperX/blob/master/LICENSE">
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
alt="GitHub license">
</a>
<a href="https://arxiv.org/abs/2303.00747">
<img src="http://img.shields.io/badge/Arxiv-2303.00747-B31B1B.svg"
alt="ArXiv paper">
</a>
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
</a>
</p>
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
<h2 align="left", id="highlights">New🚨</h2>
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
- _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2>
### 0. CUDA Installation
To use WhisperX with GPU acceleration, install the CUDA toolkit 12.8 before WhisperX. Skip this step if using only the CPU.
- For **Linux** users, install the CUDA toolkit 12.8 following this guide:
[CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/).
- For **Windows** users, download and install the CUDA toolkit 12.8:
[CUDA Downloads](https://developer.nvidia.com/cuda-12-8-1-download-archive).
### 1. Simple Installation (Recommended)
The easiest way to install WhisperX is through PyPi:
```bash
pip install whisperx
```
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
```bash
uvx whisperx
```
### 2. Advanced Installation Options
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
#### Option A: Install from GitHub
To install directly from the GitHub repository:
```bash
uvx git+https://github.com/m-bain/whisperX.git
```
#### Option B: Developer Installation
If you want to modify the code or contribute to the project:
```bash
git clone https://github.com/m-bain/whisperX.git
cd whisperX
uv sync --all-extras --dev
```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the [speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) model.
<h2 align="left" id="example">Usage 💬 (command line)</h2>
### English
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx path/to/audio.wav
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
Compare this to original whisper out the box, where many transcriptions are out of sync:
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
To run on CPU instead of GPU (and for running on Mac OS X):
whisperx path/to/audio.wav --compute_type int8 --device cpu
### Other languages
The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
Just pass in the `--language` code, and use the whisper `--model large`.
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
#### E.g. German
whisperx --model large-v2 --language de path/to/audio.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
See more examples in other languages [here](EXAMPLES.md).
## Python usage 🐍
```python
import whisperx
import gc
from whisperx.diarize import DiarizationPipeline
device = "cuda"
audio_file = "audio.mp3"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# delete model if low on GPU resources
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print(result["segments"]) # after alignment
# delete model if low on GPU resources
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = DiarizationPipeline(token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```
## Demos 🚀
[](https://replicate.com/victor-upmeet/whisperx)
[](https://replicate.com/daanelson/whisperx)
[](https://replicate.com/carnifexer/whisperx)
If you don't have access to your own GPUs, use the links above to try out WhisperX.
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
1. reduce batch size, e.g. `--batch_size 4`
2. use a smaller ASR model `--model base`
3. Use lighter compute type `--compute_type int8`
Transcription differences from openai's whisper:
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
<h2 align="left" id="limitations">Limitations ⚠️</h2>
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
- Overlapping speech is not handled particularly well by whisper nor whisperx
- Diarization is far from perfect
- Language specific wav2vec2 model is needed
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
<h2 align="left" id="coming-soon">TODO 🗓</h2>
- [x] Multilingual init
- [x] Automatic align model selection based on language detection
- [x] Python usage
- [x] Incorporating speaker diarization
- [x] Model flush, for low gpu mem resources
- [x] Faster-whisper backend
- [x] Add max-line etc. see (openai's whisper utils.py)
- [x] Sentence-level segments (nltk toolbox)
- [x] Improve alignment logic
- [ ] update examples with diarization and word highlighting
- [ ] Subtitle .ass output <- bring this back (removed in v3)
- [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
- [x] Allow silero-vad as alternative VAD option
- [ ] Improve diarization (word level). _Harder than first thought..._
<h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxhbain@gmail.com for queries.
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from:
- [pyannote-audio](https://github.com/pyannote/pyannote-audio) — Speaker diarization powered by the [speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) model, licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) by [pyannoteAI](https://www.pyannote.ai)
- [silero-vad](https://github.com/snakers4/silero-vad)
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
<h2 align="left" id="cite">Citation</h2>
If you use this in your research, please cite the paper:
```bibtex
@article{bain2022whisperx,
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
journal={INTERSPEECH 2023},
year={2023}
}
```
================================================
FILE: pyproject.toml
================================================
[project]
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.8.2"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.10, <3.14"
license = { text = "BSD-2-Clause" }
dependencies = [
"ctranslate2>=4.5.0",
"faster-whisper>=1.1.1",
"nltk>=3.9.1",
"numpy>=2.1.0",
"omegaconf>=2.3.0",
"pandas>=2.2.3",
"pyannote-audio>=4.0.0",
"huggingface-hub<1.0.0",
"torch~=2.8.0",
"torchaudio~=2.8.0",
"transformers>=4.48.0",
"triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'" # only install triton on x86_64 Linux
]
[project.scripts]
whisperx = "whisperx.__main__:cli"
[build-system]
requires = ["setuptools"]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
where = ["."]
include = ["whisperx*"]
# torchcodec (transitive dep of pyannote-audio >=4) has no wheels for Linux aarch64
[tool.uv]
override-dependencies = [
"torchcodec>=0.6.0; (sys_platform == 'linux' and platform_machine == 'x86_64') or sys_platform == 'darwin' or sys_platform == 'win32'",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
]
triton = [
{ index = "pytorch", marker = "sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
================================================
FILE: whisperx/SubtitlesProcessor.py
================================================
import math
from whisperx.conjunctions import get_conjunctions, get_comma
def normal_round(n):
if n - math.floor(n) < 0.5:
return math.floor(n)
return math.ceil(n)
def format_timestamp(seconds: float, is_vtt: bool = False):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
separator = '.' if is_vtt else ','
hours_marker = f"{hours:02d}:"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
)
class SubtitlesProcessor:
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
self.comma = get_comma(lang)
self.conjunctions = set(get_conjunctions(lang))
self.segments = segments
self.lang = lang
self.max_line_length = max_line_length
self.min_char_length_splitter = min_char_length_splitter
self.is_vtt = is_vtt
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
if self.lang in complex_script_languages:
self.max_line_length = 30
self.min_char_length_splitter = 20
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
k = 0.25
has_prev_end = i > 0 and 'end' in words[i - 1]
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
if has_prev_end:
words[i]['start'] = words[i - 1]['end']
if has_next_start:
words[i]['end'] = words[i + 1]['start']
else:
if next_segment_start_time:
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
else:
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
elif has_next_start:
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
words[i]['end'] = words[i + 1]['start']
else:
if next_segment_start_time:
words[i]['start'] = next_segment_start_time - 1
words[i]['end'] = next_segment_start_time - 0.5
else:
words[i]['start'] = 0
words[i]['end'] = 0
def process_segments(self, advanced_splitting=True):
subtitles = []
for i, segment in enumerate(self.segments):
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
if advanced_splitting:
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
else:
words = segment['words']
for i, word in enumerate(words):
if 'start' not in word or 'end' not in word:
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
subtitles.append({
'start': segment['start'],
'end': segment['end'],
'text': segment['text']
})
return subtitles
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
split_points = []
last_split_point = 0
char_count = 0
words = segment.get('words', segment['text'].split())
add_space = 0 if self.lang in ['zh', 'ja'] else 1
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
char_count_after = total_char_count
for i, word in enumerate(words):
word_text = word['word'] if isinstance(word, dict) else word
word_length = len(word_text) + add_space
char_count += word_length
char_count_after -= word_length
char_count_before = char_count - word_length
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
if char_count >= self.max_line_length:
midpoint = normal_round((last_split_point + i) / 2)
if char_count_before >= self.min_char_length_splitter:
split_points.append(midpoint)
last_split_point = midpoint + 1
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i)
last_split_point = i + 1
char_count = 0
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i - 1)
last_split_point = i
char_count = word_length
return split_points
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
subtitles = []
words = segment.get('words', segment['text'].split())
total_word_count = len(words)
total_time = segment['end'] - segment['start']
elapsed_time = segment['start']
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
start_idx = 0
for split_point in split_points:
fragment_words = words[start_idx:split_point + 1]
current_word_count = len(fragment_words)
if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
end_time = next_start_time_for_word
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration
elapsed_time += current_duration
subtitles.append({
'start': start_time,
'end': end_time,
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})
start_idx = split_point + 1
# Handle the last fragment
if start_idx < len(words):
fragment_words = words[start_idx:]
current_word_count = len(fragment_words)
if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration
if next_start_time and (next_start_time - end_time) <= 0.8:
end_time = next_start_time
subtitles.append({
'start': start_time,
'end': end_time if end_time is not None else segment['end'],
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})
return subtitles
def save(self, filename="subtitles.srt", advanced_splitting=True):
subtitles = self.process_segments(advanced_splitting)
def write_subtitle(file, idx, start_time, end_time, text):
file.write(f"{idx}\n")
file.write(f"{start_time} --> {end_time}\n")
file.write(text + "\n\n")
with open(filename, 'w', encoding='utf-8') as file:
if self.is_vtt:
file.write("WEBVTT\n\n")
if advanced_splitting:
for idx, subtitle in enumerate(subtitles, 1):
start_time = format_timestamp(subtitle['start'], self.is_vtt)
end_time = format_timestamp(subtitle['end'], self.is_vtt)
text = subtitle['text'].strip()
write_subtitle(file, idx, start_time, end_time, text)
return len(subtitles)
================================================
FILE: whisperx/__init__.py
================================================
import importlib
def _lazy_import(name):
module = importlib.import_module(f"whisperx.{name}")
return module
def load_align_model(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.load_align_model(*args, **kwargs)
def align(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.align(*args, **kwargs)
def load_model(*args, **kwargs):
asr = _lazy_import("asr")
return asr.load_model(*args, **kwargs)
def load_audio(*args, **kwargs):
audio = _lazy_import("audio")
return audio.load_audio(*args, **kwargs)
def assign_word_speakers(*args, **kwargs):
diarize = _lazy_import("diarize")
return diarize.assign_word_speakers(*args, **kwargs)
def setup_logging(*args, **kwargs):
"""
Configure logging for WhisperX.
Args:
level: Logging level (debug, info, warning, error, critical). Default: warning
log_file: Optional path to log file. If None, logs only to console.
"""
logging_module = _lazy_import("log_utils")
return logging_module.setup_logging(*args, **kwargs)
def get_logger(*args, **kwargs):
"""
Get a logger instance for the given module.
Args:
name: Logger name (typically __name__ from calling module)
Returns:
Logger instance configured with WhisperX settings
"""
logging_module = _lazy_import("log_utils")
return logging_module.get_logger(*args, **kwargs)
================================================
FILE: whisperx/__main__.py
================================================
import argparse
import importlib.metadata
import platform
import torch
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
optional_int, str2bool)
from whisperx.log_utils import setup_logging
def cli():
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device type to use for PyTorch inference (e.g. cpu, cuda)")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="default", type=str, choices=["default", "float16", "float32", "int8"], help="compute type for computation; 'default' uses float16 on GPU, float32 on CPU")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-community-1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on
args = parser.parse_args().__dict__
log_level = args.get("log_level")
verbose = args.get("verbose")
if log_level is not None:
setup_logging(level=log_level)
elif verbose:
setup_logging(level="info")
else:
setup_logging(level="warning")
from whisperx.transcribe import transcribe_task
transcribe_task(args, parser)
if __name__ == "__main__":
cli()
================================================
FILE: whisperx/alignment.py
================================================
"""
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
import numpy as np
import pandas as pd
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from whisperx.audio import SAMPLE_RATE, load_audio
from whisperx.utils import interpolate_nans, PUNKT_LANGUAGES
from whisperx.schema import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
ProgressCallback,
)
import nltk
from nltk.data import load as nltk_load
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi-vlsp2020',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
"ca": "softcatala/wav2vec2-large-xlsr-catala",
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
"ro": "gigant/romanian-wav2vec2",
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
"gl": "ifrz/wav2vec2-large-xlsr-galician",
"ka": "xsway/wav2vec2-large-xlsr-georgian",
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
"sv": "KBLab/wav2vec2-large-voxrex-swedish",
}
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None, model_cache_only: bool = False):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
logger.error(f"No default alignment model for language: {language_code}. "
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
f"then pass the model name via --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir, local_files_only=model_cache_only)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir, local_files_only=model_cache_only)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def align(
transcript: Iterable[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
progress_callback: ProgressCallback = None,
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount.
if print_progress:
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = text.split(" ")
else:
per_word = text
clean_char, clean_cdx = [], []
for cdx, char in enumerate(text):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(text) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
# Use language-specific Punkt model if available otherwise we fallback to English.
punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english')
try:
sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle')
except LookupError:
nltk.download('punkt_tab', quiet=True)
sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle')
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment_data[sdx] = {
"clean_char": clean_char,
"clean_cdx": clean_cdx,
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
avg_logprob = segment.get("avg_logprob")
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
"words": [],
"chars": None,
}
if avg_logprob is not None:
aligned_seg["avg_logprob"] = avg_logprob
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment_data[sdx]["clean_char"]) == 0:
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
aligned_segments.append(aligned_seg)
continue
if t1 >= MAX_DURATION:
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2]
# Handle the minimum input length for wav2vec2 models
if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad(
waveform_segment, (0, 400 - waveform_segment.shape[-1])
)
else:
lengths = None
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
aligned_segments.append(aligned_seg)
continue
char_segments = merge_repeats(path, text_clean)
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters
char_segments_arr = []
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
char_segments_arr.append(
{
"char": char,
"start": start,
"end": end,
"score": score,
"word-idx": word_idx,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES:
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
end_chars = curr_chars[curr_chars["char"] != ' ']
sentence_end = end_chars["end"].max()
sentence_words = []
for word_idx in curr_chars["word-idx"].unique():
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue
# dont use space character for alignment
word_chars = word_chars[word_chars["char"] != " "]
word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable
word_segment = {"word": word_text}
if not np.isnan(word_start):
word_segment["start"] = word_start
if not np.isnan(word_end):
word_segment["end"] = word_end
if not np.isnan(word_score):
word_segment["score"] = word_score
sentence_words.append(word_segment)
subsegment = {
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"words": sentence_words,
}
if avg_logprob is not None:
subsegment["avg_logprob"] = avg_logprob
aligned_subsegments.append(subsegment)
if return_char_alignments:
curr_chars = curr_chars[["char", "start", "end", "score"]]
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
aligned_subsegments[-1]["chars"] = curr_chars
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
if model_lang in LANGUAGES_WITHOUT_SPACES:
agg_dict["text"] = "".join
if return_char_alignments:
agg_dict["chars"] = "sum"
if avg_logprob is not None:
agg_dict["avg_logprob"] = "first"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
aligned_subsegments = aligned_subsegments.to_dict('records')
if progress_callback is not None:
progress_callback(((sdx + 1) / total_segments) * 100)
aligned_segments += aligned_subsegments
# create word_segments list
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]
return {"segments": aligned_segments, "word_segments": word_segments}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra dimensions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
)
return trellis
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else blank_id].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token
if changed > stayed:
j -= 1
if j == 0:
break
else:
# failed
return None
return path[::-1]
# Merge the labels
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path, transcript):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript[path[i1].token_index],
path[i1].time_index,
path[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
def merge_words(segments, separator="|"):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
================================================
FILE: whisperx/asr.py
================================================
import os
from typing import List, Optional, Union
from dataclasses import replace
import ctranslate2
import faster_whisper
import numpy as np
import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.schema import SingleSegment, TranscriptionResult, ProgressCallback
from whisperx.vads import Vad, Silero, Pyannote
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
hotwords=options.hotwords
)
encoder_output = self.encode(features)
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
no_repeat_ngram_size=options.no_repeat_ngram_size,
repetition_penalty=options.repetition_penalty,
return_scores=True,
)
tokens_batch = [x.sequences_ids[0] for x in result]
avg_logprobs = []
for res in result:
seq_len = len(res.sequences_ids[0])
cum_logprob = res.scores[0] * (seq_len ** options.length_penalty)
avg_logprobs.append(cum_logprob / (seq_len + 1))
def decode_batch(tokens: List[List[int]]) -> List[str]:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)
text = decode_batch(tokens_batch)
return {'text': text, 'avg_logprob': avg_logprobs}
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model: WhisperModel,
vad,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.suppress_numerals = suppress_numerals
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
super(Pipeline, self).__init__()
self.vad_model = vad
self._vad_params = vad_params
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, audio):
audio = audio['inputs']
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return outputs
def postprocess(self, model_outputs):
return model_outputs
def get_iterator(
self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO hack by collating feature_extractor and image_processor
def stack(items):
return {'inputs': torch.stack([x['inputs'] for x in items])}
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def transcribe(
self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
progress_callback: ProgressCallback = None,
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
onset=self._vad_params["vad_onset"],
offset=self._vad_params["vad_offset"],
)
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
logger.info("Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
if progress_callback is not None:
progress_callback(((idx + 1) / total_segments) * 100)
text = out['text']
avg_logprob = out['avg_logprob']
if batch_size in [0, 1, None]:
text = text[0]
avg_logprob = avg_logprob[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3),
"avg_logprob": avg_logprob,
}
)
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
return language
def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="default",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
use_auth_token: Optional[Union[str, bool]] = None,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
Use "default" to automatically select based on device (float16 for GPU, float32 for CPU).
vad_model - The vad model to manually assign.
vad_method - The vad method to use. vad_model has a higher priority if it is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
"""
if compute_type == "default":
compute_type = "float16" if device == "cuda" else "float32"
logger.info(f"Compute type not specified, defaulting to {compute_type} for device {device}")
if whisper_arch.endswith(".en"):
language = "en"
model = model or WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
cpu_threads=threads,
use_auth_token=use_auth_token)
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
tokenizer = None
default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"multilingual": model.model.is_multilingual,
"suppress_numerals": False,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
"hotwords": None,
}
if asr_options is not None:
default_asr_options.update(asr_options)
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500,
"vad_offset": 0.363
}
if vad_options is not None:
default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model
else:
if vad_method == "silero":
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
if device == 'cuda':
device_vad = f'cuda:{device_index}'
else:
device_vad = device
vad_model = Pyannote(torch.device(device_vad), token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")
return FasterWhisperPipeline(
model=model,
vad=vad_model,
options=default_asr_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
================================================
FILE: whisperx/audio.py
================================================
import os
import subprocess
from functools import lru_cache
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from whisperx.utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI to be installed.
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
file,
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(sr),
"-",
]
out = subprocess.run(cmd, capture_output=True, check=True).stdout
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
================================================
FILE: whisperx/conjunctions.py
================================================
# conjunctions.py
from typing import Set
conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'},
'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'},
'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'},
'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'},
'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'}
}
commas_by_language = {
'ja': '、',
'zh': ',',
'fa': '،',
'ur': '،'
}
def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")
================================================
FILE: whisperx/diarize.py
================================================
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union, List, Tuple
import torch
from whisperx.audio import load_audio, SAMPLE_RATE
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult, ProgressCallback
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
class IntervalTree:
"""
Simple interval tree for fast overlap queries using sorted array + binary search.
Uses O(n) space and provides O(log n) query time instead of O(n) linear scan.
This achieves ~228x speedup for speaker assignment in long-form content.
"""
def __init__(self, intervals: List[Tuple[float, float, str]]):
"""
Initialize the interval tree with diarization segments.
Args:
intervals: List of (start, end, speaker) tuples
"""
if not intervals:
self.starts = np.array([])
self.ends = np.array([])
self.speakers: List[str] = []
return
# Sort intervals by start time for binary search
sorted_intervals = sorted(intervals, key=lambda x: x[0])
self.starts = np.array([i[0] for i in sorted_intervals], dtype=np.float64)
self.ends = np.array([i[1] for i in sorted_intervals], dtype=np.float64)
self.speakers = [i[2] for i in sorted_intervals]
def query(self, start: float, end: float) -> List[Tuple[str, float]]:
"""
Find all intervals that overlap with [start, end] and compute intersection.
Args:
start: Query interval start time
end: Query interval end time
Returns:
List of (speaker, intersection_duration) tuples for overlapping segments
"""
if len(self.starts) == 0:
return []
# Binary search to find candidate intervals
# Only intervals with start < end could overlap
right_idx = np.searchsorted(self.starts, end, side='left')
if right_idx == 0:
return []
# Check candidates for actual overlap
candidates = slice(0, right_idx)
overlaps = (self.starts[candidates] < end) & (self.ends[candidates] > start)
results = []
for idx in np.where(overlaps)[0]:
intersection = min(self.ends[idx], end) - max(self.starts[idx], start)
if intersection > 0:
results.append((self.speakers[idx], intersection))
return results
def find_nearest(self, time: float) -> Optional[str]:
"""
Find the speaker of the nearest segment to a given time point.
Args:
time: Time point to find nearest segment for
Returns:
Speaker ID of nearest segment, or None if no segments exist
"""
if len(self.starts) == 0:
return None
# Calculate midpoints of all segments
mids = (self.starts + self.ends) / 2
nearest_idx = np.argmin(np.abs(mids - time))
return self.speakers[nearest_idx]
class DiarizationPipeline:
def __init__(
self,
model_name=None,
token=None,
device: Optional[Union[str, torch.device]] = "cpu",
cache_dir=None,
):
if isinstance(device, str):
device = torch.device(device)
model_config = model_name or "pyannote/speaker-diarization-community-1"
logger.info(f"Loading diarization model: {model_config}")
self.model = Pipeline.from_pretrained(model_config, token=token, cache_dir=cache_dir).to(device)
def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
return_embeddings: bool = False,
progress_callback: ProgressCallback = None,
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
"""
Perform speaker diarization on audio.
Args:
audio: Path to audio file or audio array
num_speakers: Exact number of speakers (if known)
min_speakers: Minimum number of speakers to detect
max_speakers: Maximum number of speakers to detect
return_embeddings: Whether to return speaker embeddings
progress_callback: Optional callable receiving a float (0-100) with progress percentage
Returns:
If return_embeddings is True:
Tuple of (diarization dataframe, speaker embeddings dictionary)
Otherwise:
Just the diarization dataframe
"""
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
hook = None
if progress_callback is not None:
# pyannote's diarization has two progress-trackable steps, each with
# its own completed/total counter that resets between steps. Map each
# step into a sub-range so progress is monotonic and meaningful.
_STEP_RANGES = {
"segmentation": (0.0, 50.0),
"embeddings": (50.0, 99.0),
}
last_pct = [0.0]
def hook(step_name, step_artifact, file=None, total=None, completed=None):
if total is not None and completed is not None and total > 0:
offset, end = _STEP_RANGES.get(step_name, (0.0, 99.0))
pct = offset + min(completed / total, 1.0) * (end - offset)
if pct > last_pct[0]:
last_pct[0] = pct
progress_callback(pct)
output = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
**({"hook": hook} if hook is not None else {}),
)
if progress_callback is not None:
progress_callback(100.0)
diarization = output.speaker_diarization
embeddings = output.speaker_embeddings if return_embeddings else None
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
if return_embeddings and embeddings is not None:
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
return diarize_df, speaker_embeddings
# For backwards compatibility
if return_embeddings:
return diarize_df, None
else:
return diarize_df
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
speaker_embeddings: Optional[dict[str, list[float]]] = None,
fill_nearest: bool = False,
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
"""
Assign speakers to words and segments in the transcript.
Uses an interval tree for O(log n) overlap queries instead of O(n) linear scan,
achieving ~228x speedup for long-form content (3+ hour podcasts).
Args:
diarize_df: Diarization dataframe from DiarizationPipeline
transcript_result: Transcription result to augment with speaker labels
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
fill_nearest: If True, assign speakers even when there's no direct time overlap
Returns:
Updated transcript_result with speaker assignments and optionally embeddings
"""
transcript_segments = transcript_result.get("segments", [])
if not transcript_segments or diarize_df is None or len(diarize_df) == 0:
return transcript_result
# Build interval tree from diarization segments for O(log n) queries
intervals = [
(row['start'], row['end'], row['speaker'])
for _, row in diarize_df.iterrows()
]
tree = IntervalTree(intervals)
for seg in transcript_segments:
seg_start = seg.get('start', 0.0)
seg_end = seg.get('end', 0.0)
# Query overlapping segments using interval tree
overlaps = tree.query(seg_start, seg_end)
if overlaps:
# Sum intersection durations per speaker and pick the dominant one
speaker_intersections: dict[str, float] = {}
for speaker, intersection in overlaps:
speaker_intersections[speaker] = speaker_intersections.get(speaker, 0.0) + intersection
seg['speaker'] = max(speaker_intersections.items(), key=lambda x: x[1])[0]
elif fill_nearest:
# Find nearest segment if no overlap
seg_mid = (seg_start + seg_end) / 2
nearest_speaker = tree.find_nearest(seg_mid)
if nearest_speaker:
seg['speaker'] = nearest_speaker
# Assign speaker to words
if 'words' in seg:
for word in seg['words']:
if 'start' not in word:
continue
word_start = word['start']
word_end = word.get('end', word_start)
word_overlaps = tree.query(word_start, word_end)
if word_overlaps:
speaker_intersections = {}
for speaker, intersection in word_overlaps:
speaker_intersections[speaker] = speaker_intersections.get(speaker, 0.0) + intersection
word['speaker'] = max(speaker_intersections.items(), key=lambda x: x[1])[0]
elif fill_nearest:
word_mid = (word_start + word_end) / 2
nearest_speaker = tree.find_nearest(word_mid)
if nearest_speaker:
word['speaker'] = nearest_speaker
# Add speaker embeddings to the result if provided
if speaker_embeddings is not None:
transcript_result["speaker_embeddings"] = speaker_embeddings
return transcript_result
class Segment:
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start
self.end = end
self.speaker = speaker
================================================
FILE: whisperx/log_utils.py
================================================
import logging
import sys
from typing import Optional
_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
def setup_logging(
level: str = "info",
log_file: Optional[str] = None,
) -> None:
"""
Configure logging for WhisperX.
Args:
level: Logging level (debug, info, warning, error, critical). Default: info
log_file: Optional path to log file. If None, logs only to console.
"""
logger = logging.getLogger("whisperx")
logger.handlers.clear()
try:
log_level = getattr(logging, level.upper())
except AttributeError:
log_level = logging.WARNING
logger.setLevel(log_level)
formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if log_file:
try:
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except (OSError) as e:
logger.warning(f"Failed to create log file '{log_file}': {e}")
logger.warning("Continuing with console logging only")
# Don't propagate to root logger to avoid duplicate messages
logger.propagate = False
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module.
Args:
name: Logger name (typically __name__ from calling module)
Returns:
Logger instance configured with WhisperX settings
"""
whisperx_logger = logging.getLogger("whisperx")
if not whisperx_logger.handlers:
setup_logging()
logger_name = "whisperx" if name == "__main__" else name
return logging.getLogger(logger_name)
================================================
FILE: whisperx/schema.py
================================================
from typing import Callable, TypedDict, Optional, List, Tuple
ProgressCallback = Optional[Callable[[float], None]]
try:
from typing import NotRequired
except ImportError:
from typing_extensions import NotRequired
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
avg_logprob: NotRequired[float]
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
avg_logprob: NotRequired[float]
words: List[SingleWordSegment]
chars: Optional[List[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: List[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: List[SingleAlignedSegment]
word_segments: List[SingleWordSegment]
================================================
FILE: whisperx/transcribe.py
================================================
import argparse
import gc
import os
import warnings
import numpy as np
import torch
from whisperx.alignment import align, load_align_model
from whisperx.asr import load_model
from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
"""Transcription task to be called from CLI.
Args:
args: Dictionary of command-line arguments.
parser: argparse.ArgumentParser object.
"""
# fmt: off
model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir")
model_cache_only: bool = args.pop("model_cache_only")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
verbose: bool = args.pop("verbose")
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
chunk_size: int = args.pop("chunk_size")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
if args["language"] not in LANGUAGES:
if args["language"] in TO_LANGUAGE_CODE:
args["language"] = TO_LANGUAGE_CODE[args["language"]]
else:
raise ValueError(f"Unsupported language: {args['language']}")
if model_name.endswith(".en") and args["language"] != "en":
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
)
args["language"] = "en"
align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
faster_whisper_threads = 4
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
faster_whisper_threads = threads
asr_options = {
"beam_size": args.pop("beam_size"),
"patience": args.pop("patience"),
"length_penalty": args.pop("length_penalty"),
"temperatures": temperature,
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
"log_prob_threshold": args.pop("logprob_threshold"),
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
"hotwords": args.pop("hotwords"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
}
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if no_align:
for option in word_options:
if args[option]:
parser.error(f"--{option} not possible with --no_align")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop
results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
use_auth_token=hf_token,
)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR
logger.info("Performing transcription...")
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
del model
gc.collect()
torch.cuda.empty_cache()
# Part 2: Align Loop
if not no_align:
tmp_results = results
results = []
align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model, model_dir=model_dir, model_cache_only=model_cache_only
)
for result, audio_path in tmp_results:
# >> Align
if len(tmp_results) > 1:
input_audio = audio_path
else:
# lazily load audio from part 1
input_audio = audio
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
logger.info(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device, model_dir=model_dir, model_cache_only=model_cache_only
)
logger.info("Performing alignment...")
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))
# Unload align model
del align_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize:
if hf_token is None:
logger.warning(
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
)
tmp_results = results
logger.info("Performing diarization...")
logger.info(f"Using model: {diarize_model_name}")
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, token=hf_token, device=device, cache_dir=model_dir)
for result, input_audio_path in tmp_results:
diarize_result = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)
if return_speaker_embeddings:
diarize_segments, speaker_embeddings = diarize_result
else:
diarize_segments = diarize_result
speaker_embeddings = None
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
result["language"] = align_language
writer(result, audio_path, writer_args)
================================================
FILE: whisperx/utils.py
================================================
import json
import os
import re
import sys
import zlib
from typing import Callable, Optional, TextIO
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
# Mapping of language codes to NLTK Punkt tokenizer model names
PUNKT_LANGUAGES = {
'cs': 'czech',
'da': 'danish',
'de': 'german',
'el': 'greek',
'en': 'english',
'es': 'spanish',
'et': 'estonian',
'fi': 'finnish',
'fr': 'french',
'it': 'italian',
'nl': 'dutch',
'no': 'norwegian',
'pl': 'polish',
'pt': 'portuguese',
'sl': 'slovene',
'sv': 'swedish',
'tr': 'turkish',
"ml": "malayalam",
"ru": "russian",
}
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options)
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
print(text, file=file, flush=True)
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
if len(result["segments"]) == 0:
return
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times: list[tuple] = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments
if "start" in timing:
long_pause = long_pause and timing["start"] - last > 3.0
else:
long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle, times
subtitle = []
times = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
times.append((segment["start"], segment["end"], segment.get("speaker")))
if "start" in timing:
last = timing["start"]
if len(subtitle) > 0:
yield subtitle, times
if "words" in result["segments"][0]:
for subtitle, times in iterate_subtitles():
speaker = times[0][2]
# Derive cue times from word-level timestamps when available,
# falling back to segment-level times for fully unalignable subtitles.
word_starts = [w["start"] for w in subtitle if "start" in w]
word_ends = [w["end"] for w in subtitle if "end" in w]
if word_starts and word_ends:
subtitle_start = self.format_timestamp(min(word_starts))
subtitle_end = self.format_timestamp(max(word_ends))
else:
subtitle_start = self.format_timestamp(times[0][0])
subtitle_end = self.format_timestamp(times[0][1])
if result["language"] in LANGUAGES_WITHOUT_SPACES:
subtitle_text = "".join([word["word"] for word in subtitle])
else:
subtitle_text = " ".join([word["word"] for word in subtitle])
has_timing = any(["start" in word for word in subtitle])
# add [$SPEAKER_ID]: to each subtitle if speaker is available
prefix = ""
if speaker is not None:
prefix = f"[{speaker}]: "
if highlight_words and has_timing:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
if "start" in this_word:
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, prefix + subtitle_text
yield start, end, prefix + " ".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, prefix + subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if "speaker" in segment:
segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteAudacity(ResultWriter):
"""
Write a transcript to a text file that audacity can import as labels.
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
Yet this is not an audacity project but only a label file!
Please note : Audacity uses seconds in timestamps not ms!
Also there is no header expected.
If speaker is provided it is prepended to the text between double square brackets [[]].
"""
extension: str = "aud"
def write_result(self, result: dict, file: TextIO, options: dict):
ARROW = " "
for segment in result["segments"]:
print(segment["start"], file=file, end=ARROW)
print(segment["end"], file=file, end=ARROW)
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file, ensure_ascii=False)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, str, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
optional_writers = {
"aud": WriteAudacity,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: str, options: dict):
for writer in all_writers:
writer(result, file, options)
return write_all
if output_format in optional_writers:
return optional_writers[output_format](output_dir)
return writers[output_format](output_dir)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()
================================================
FILE: whisperx/vads/__init__.py
================================================
from whisperx.vads.pyannote import Pyannote as Pyannote
from whisperx.vads.silero import Silero as Silero
from whisperx.vads.vad import Vad as Vad
================================================
FILE: whisperx/vads/pyannote.py
================================================
import os
from typing import Callable, Text, Union
from typing import Optional
import numpy as np
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
# Dynamically resolve the path to the model file
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
vad_model = Model.from_pretrained(model_fp, token=token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
vad_pipeline.instantiate(hyperparameters)
return vad_pipeline
class Binarize:
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
to ensure not segments are longer than max_duration.
Parameters
----------
onset : float, optional
Onset threshold. Defaults to 0.5.
offset : float, optional
Offset threshold. Defaults to `onset`.
min_duration_on : float, optional
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off : float, optional
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset : float, optional
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset : float, optional
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration: float
The maximum length of an active segment, divides segment at timestamp with lowest score.
Reference
---------
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores
Parameters
----------
scores : SlidingWindowFeature
Detection scores.
Returns
-------
active : Annotation
Binarized scores.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError(f"This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, token=token, **inference_kwargs)
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
"""Apply voice activity detection
Parameters
----------
file : AudioFile
Processed file.
hook : callable, optional
Hook called after each major step of the pipeline with the following
signature: hook("step_name", step_artefact, file=file)
Returns
-------
speech : Annotation
Speech regions.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
# apply segmentation model (only if needed)
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
class Pyannote(Vad):
def __init__(self, device, token=None, model_fp=None, **kwargs):
logger.info("Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, token=token, model_fp=model_fp)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
================================================
FILE: whisperx/vads/silero.py
================================================
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
logger.info("Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments_list,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
if len(segments_list) == 0:
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
================================================
FILE: whisperx/vads/vad.py
================================================
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
gitextract_hiro3isg/
├── .github/
│ ├── FUNDING.yml
│ └── workflows/
│ ├── build-and-release.yml
│ └── python-compatibility.yml
├── .gitignore
├── .python-version
├── CUDNN_TROUBLESHOOTING.md
├── EXAMPLES.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── pyproject.toml
└── whisperx/
├── SubtitlesProcessor.py
├── __init__.py
├── __main__.py
├── alignment.py
├── asr.py
├── assets/
│ └── mel_filters.npz
├── audio.py
├── conjunctions.py
├── diarize.py
├── log_utils.py
├── schema.py
├── transcribe.py
├── utils.py
└── vads/
├── __init__.py
├── pyannote.py
├── silero.py
└── vad.py
SYMBOL INDEX (118 symbols across 15 files)
FILE: whisperx/SubtitlesProcessor.py
function normal_round (line 4) | def normal_round(n):
function format_timestamp (line 10) | def format_timestamp(seconds: float, is_vtt: bool = False):
class SubtitlesProcessor (line 33) | class SubtitlesProcessor:
method __init__ (line 34) | def __init__(self, segments, lang, max_line_length = 45, min_char_leng...
method estimate_timestamp_for_word (line 47) | def estimate_timestamp_for_word(self, words, i, next_segment_start_tim...
method process_segments (line 76) | def process_segments(self, advanced_splitting=True):
method determine_advanced_split_points (line 99) | def determine_advanced_split_points(self, segment, next_segment_start_...
method generate_subtitles_from_split_points (line 141) | def generate_subtitles_from_split_points(self, segment, split_points, ...
method save (line 205) | def save(self, filename="subtitles.srt", advanced_splitting=True):
FILE: whisperx/__init__.py
function _lazy_import (line 4) | def _lazy_import(name):
function load_align_model (line 9) | def load_align_model(*args, **kwargs):
function align (line 14) | def align(*args, **kwargs):
function load_model (line 19) | def load_model(*args, **kwargs):
function load_audio (line 24) | def load_audio(*args, **kwargs):
function assign_word_speakers (line 29) | def assign_word_speakers(*args, **kwargs):
function setup_logging (line 34) | def setup_logging(*args, **kwargs):
function get_logger (line 46) | def get_logger(*args, **kwargs):
FILE: whisperx/__main__.py
function cli (line 12) | def cli():
FILE: whisperx/alignment.py
function load_align_model (line 79) | def load_align_model(language_code: str, device: str, model_name: Option...
function align (line 116) | def align(
function get_trellis (line 398) | def get_trellis(emission, tokens, blank_id=0):
class Point (line 422) | class Point:
function backtrack (line 428) | def backtrack(trellis, emission, tokens, blank_id=0):
class Segment (line 468) | class Segment:
method __repr__ (line 474) | def __repr__(self):
method length (line 478) | def length(self):
function merge_repeats (line 481) | def merge_repeats(path, transcript):
function merge_words (line 499) | def merge_words(segments, separator="|"):
FILE: whisperx/asr.py
function find_numeral_symbol_tokens (line 22) | def find_numeral_symbol_tokens(tokenizer):
class WhisperModel (line 31) | class WhisperModel(faster_whisper.WhisperModel):
method generate_segment_batched (line 37) | def generate_segment_batched(
method encode (line 95) | def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
class FasterWhisperPipeline (line 106) | class FasterWhisperPipeline(Pipeline):
method __init__ (line 114) | def __init__(
method _sanitize_parameters (line 153) | def _sanitize_parameters(self, **kwargs):
method preprocess (line 159) | def preprocess(self, audio):
method _forward (line 169) | def _forward(self, model_inputs):
method postprocess (line 173) | def postprocess(self, model_outputs):
method get_iterator (line 176) | def get_iterator(
method transcribe (line 197) | def transcribe(
method detect_language (line 300) | def detect_language(self, audio: np.ndarray) -> str:
function load_model (line 315) | def load_model(
FILE: whisperx/audio.py
function load_audio (line 25) | def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
function pad_or_trim (line 68) | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
function mel_filters (line 95) | def mel_filters(device, n_mels: int) -> torch.Tensor:
function log_mel_spectrogram (line 112) | def log_mel_spectrogram(
FILE: whisperx/conjunctions.py
function get_conjunctions (line 42) | def get_conjunctions(lang_code: str) -> Set[str]:
function get_comma (line 46) | def get_comma(lang_code: str) -> str:
FILE: whisperx/diarize.py
class IntervalTree (line 14) | class IntervalTree:
method __init__ (line 22) | def __init__(self, intervals: List[Tuple[float, float, str]]):
method query (line 41) | def query(self, start: float, end: float) -> List[Tuple[str, float]]:
method find_nearest (line 72) | def find_nearest(self, time: float) -> Optional[str]:
class DiarizationPipeline (line 91) | class DiarizationPipeline:
method __init__ (line 92) | def __init__(
method __call__ (line 105) | def __call__(
function assign_word_speakers (line 185) | def assign_word_speakers(
class Segment (line 266) | class Segment:
method __init__ (line 267) | def __init__(self, start:int, end:int, speaker:Optional[str]=None):
FILE: whisperx/log_utils.py
function setup_logging (line 9) | def setup_logging(
function get_logger (line 52) | def get_logger(name: str) -> logging.Logger:
FILE: whisperx/schema.py
class SingleWordSegment (line 11) | class SingleWordSegment(TypedDict):
class SingleCharSegment (line 20) | class SingleCharSegment(TypedDict):
class SingleSegment (line 30) | class SingleSegment(TypedDict):
class SegmentData (line 41) | class SegmentData(TypedDict):
class SingleAlignedSegment (line 52) | class SingleAlignedSegment(TypedDict):
class TranscriptionResult (line 65) | class TranscriptionResult(TypedDict):
class AlignedTranscriptionResult (line 73) | class AlignedTranscriptionResult(TypedDict):
FILE: whisperx/transcribe.py
function transcribe_task (line 20) | def transcribe_task(args: dict, parser: argparse.ArgumentParser):
FILE: whisperx/utils.py
function make_safe (line 156) | def make_safe(string):
function make_safe (line 163) | def make_safe(string):
function exact_div (line 168) | def exact_div(x, y):
function str2bool (line 173) | def str2bool(string):
function optional_int (line 181) | def optional_int(string):
function optional_float (line 185) | def optional_float(string):
function compression_ratio (line 189) | def compression_ratio(text) -> float:
function format_timestamp (line 194) | def format_timestamp(
class ResultWriter (line 215) | class ResultWriter:
method __init__ (line 218) | def __init__(self, output_dir: str):
method __call__ (line 221) | def __call__(self, result: dict, audio_path: str, options: dict):
method write_result (line 231) | def write_result(self, result: dict, file: TextIO, options: dict):
class WriteTXT (line 235) | class WriteTXT(ResultWriter):
method write_result (line 238) | def write_result(self, result: dict, file: TextIO, options: dict):
class SubtitlesWriter (line 248) | class SubtitlesWriter(ResultWriter):
method iterate_result (line 252) | def iterate_result(self, result: dict, options: dict):
method format_timestamp (line 363) | def format_timestamp(self, seconds: float):
class WriteVTT (line 371) | class WriteVTT(SubtitlesWriter):
method write_result (line 376) | def write_result(self, result: dict, file: TextIO, options: dict):
class WriteSRT (line 382) | class WriteSRT(SubtitlesWriter):
method write_result (line 387) | def write_result(self, result: dict, file: TextIO, options: dict):
class WriteTSV (line 394) | class WriteTSV(ResultWriter):
method write_result (line 406) | def write_result(self, result: dict, file: TextIO, options: dict):
class WriteAudacity (line 413) | class WriteAudacity(ResultWriter):
method write_result (line 427) | def write_result(self, result: dict, file: TextIO, options: dict):
class WriteJSON (line 436) | class WriteJSON(ResultWriter):
method write_result (line 439) | def write_result(self, result: dict, file: TextIO, options: dict):
function get_writer (line 443) | def get_writer(
function interpolate_nans (line 470) | def interpolate_nans(x, method='nearest'):
FILE: whisperx/vads/pyannote.py
function load_vad_model (line 21) | def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, token=None...
class Binarize (line 51) | class Binarize:
method __init__ (line 84) | def __init__(
method __call__ (line 108) | def __call__(self, scores: SlidingWindowFeature) -> Annotation:
class VoiceActivitySegmentation (line 188) | class VoiceActivitySegmentation(VoiceActivityDetection):
method __init__ (line 189) | def __init__(
method apply (line 199) | def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> A...
class Pyannote (line 233) | class Pyannote(Vad):
method __init__ (line 235) | def __init__(self, device, token=None, model_fp=None, **kwargs):
method __call__ (line 240) | def __call__(self, audio: AudioFile, **kwargs):
method preprocess_audio (line 244) | def preprocess_audio(audio):
method merge_chunks (line 248) | def merge_chunks(segments,
FILE: whisperx/vads/silero.py
class Silero (line 18) | class Silero(Vad):
method __init__ (line 20) | def __init__(self, **kwargs):
method __call__ (line 33) | def __call__(self, audio: AudioFile, **kwargs):
method preprocess_audio (line 55) | def preprocess_audio(audio):
method merge_chunks (line 59) | def merge_chunks(segments_list,
FILE: whisperx/vads/vad.py
class Vad (line 7) | class Vad:
method __init__ (line 8) | def __init__(self, vad_onset):
method preprocess_audio (line 15) | def preprocess_audio(audio):
method merge_chunks (line 20) | def merge_chunks(segments,
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (157K chars).
[
{
"path": ".github/FUNDING.yml",
"chars": 46,
"preview": "custom: https://www.buymeacoffee.com/maxhbain\n"
},
{
"path": ".github/workflows/build-and-release.yml",
"chars": 692,
"preview": "name: Build and release\n\non:\n release:\n types: [published]\n\njobs:\n build:\n runs-on: ubuntu-latest\n steps:\n "
},
{
"path": ".github/workflows/python-compatibility.yml",
"chars": 771,
"preview": "name: Python Compatibility Test\n\non:\n push:\n branches: [main]\n pull_request:\n branches: [main]\n workflow_dispat"
},
{
"path": ".gitignore",
"chars": 3414,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".python-version",
"chars": 5,
"preview": "3.10\n"
},
{
"path": "CUDNN_TROUBLESHOOTING.md",
"chars": 2957,
"preview": "# Troubleshooting cuDNN Loading Errors\n\nThis guide helps resolve common cuDNN-related errors when running WhisperX on GP"
},
{
"path": "EXAMPLES.md",
"chars": 1259,
"preview": "# More Examples\n\n## Other Languages\n\nFor non-english ASR, it is best to use the `large` whisper model. Alignment models "
},
{
"path": "LICENSE",
"chars": 1297,
"preview": "BSD 2-Clause License\n\nCopyright (c) 2024, Max Bain\n\nRedistribution and use in source and binary forms, with or without\nm"
},
{
"path": "MANIFEST.in",
"chars": 67,
"preview": "include whisperx/assets/*\ninclude LICENSE\ninclude requirements.txt\n"
},
{
"path": "README.md",
"chars": 15455,
"preview": "<h1 align=\"center\">WhisperX</h1>\n\n## Recall.ai - Meeting Transcription API\n\nIf you’re looking for a transcription API fo"
},
{
"path": "pyproject.toml",
"chars": 2032,
"preview": "[project]\nurls = { repository = \"https://github.com/m-bain/whisperx\" }\nauthors = [{ name = \"Max Bain\" }]\nname = \"whisper"
},
{
"path": "whisperx/SubtitlesProcessor.py",
"chars": 9408,
"preview": "import math\r\nfrom whisperx.conjunctions import get_conjunctions, get_comma\r\n\r\ndef normal_round(n):\r\n if n - math.floo"
},
{
"path": "whisperx/__init__.py",
"chars": 1452,
"preview": "import importlib\n\n\ndef _lazy_import(name):\n module = importlib.import_module(f\"whisperx.{name}\")\n return module\n\n\n"
},
{
"path": "whisperx/__main__.py",
"chars": 9227,
"preview": "import argparse\nimport importlib.metadata\nimport platform\n\nimport torch\n\nfrom whisperx.utils import (LANGUAGES, TO_LANGU"
},
{
"path": "whisperx/alignment.py",
"chars": 20087,
"preview": "\"\"\"\nForced Alignment with Whisper\nC. Max Bain\n\"\"\"\nfrom dataclasses import dataclass\nfrom typing import Iterable, Optiona"
},
{
"path": "whisperx/asr.py",
"chars": 17602,
"preview": "import os\nfrom typing import List, Optional, Union\nfrom dataclasses import replace\n\nimport ctranslate2\nimport faster_whi"
},
{
"path": "whisperx/audio.py",
"chars": 4926,
"preview": "import os\nimport subprocess\nfrom functools import lru_cache\nfrom typing import Optional, Union\n\nimport numpy as np\nimpor"
},
{
"path": "whisperx/conjunctions.py",
"chars": 5132,
"preview": "# conjunctions.py\r\n\r\nfrom typing import Set\r\n\r\n\r\nconjunctions_by_language = {\r\n 'en': {'and', 'whether', 'or', 'as', "
},
{
"path": "whisperx/diarize.py",
"chars": 10438,
"preview": "import numpy as np\nimport pandas as pd\nfrom pyannote.audio import Pipeline\nfrom typing import Optional, Union, List, Tup"
},
{
"path": "whisperx/log_utils.py",
"chars": 1926,
"preview": "import logging\nimport sys\nfrom typing import Optional\n\n_LOG_FORMAT = \"%(asctime)s - %(name)s - %(levelname)s - %(message"
},
{
"path": "whisperx/schema.py",
"chars": 1839,
"preview": "from typing import Callable, TypedDict, Optional, List, Tuple\n\nProgressCallback = Optional[Callable[[float], None]]\n\ntry"
},
{
"path": "whisperx/transcribe.py",
"chars": 9006,
"preview": "import argparse\nimport gc\nimport os\nimport warnings\n\nimport numpy as np\nimport torch\n\nfrom whisperx.alignment import ali"
},
{
"path": "whisperx/utils.py",
"chars": 15667,
"preview": "import json\nimport os\nimport re\nimport sys\nimport zlib\nfrom typing import Callable, Optional, TextIO\n\nLANGUAGES = {\n "
},
{
"path": "whisperx/vads/__init__.py",
"chars": 147,
"preview": "from whisperx.vads.pyannote import Pyannote as Pyannote\nfrom whisperx.vads.silero import Silero as Silero\nfrom whisperx."
},
{
"path": "whisperx/vads/pyannote.py",
"chars": 9880,
"preview": "import os\nfrom typing import Callable, Text, Union\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom pya"
},
{
"path": "whisperx/vads/silero.py",
"chars": 3043,
"preview": "from io import IOBase\nfrom pathlib import Path\nfrom typing import Mapping, Text\nfrom typing import Optional\nfrom typing "
},
{
"path": "whisperx/vads/vad.py",
"chars": 1575,
"preview": "from typing import Optional\n\nimport pandas as pd\nfrom pyannote.core import Annotation, Segment\n\n\nclass Vad:\n def __in"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the m-bain/whisperX GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (145.8 KB), approximately 37.2k tokens, and a symbol index with 118 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.