Showing preview only (629K chars total). Download the full file or copy to clipboard to get everything.
Repository: wenet-e2e/wesep
Branch: master
Commit: 99eca54b6030
Files: 128
Total size: 591.2 KB
Directory structure:
gitextract_s23ej_br/
├── .clang-format
├── .flake8
├── .github/
│ └── workflows/
│ └── lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CPPLINT.cfg
├── README.md
├── examples/
│ ├── librimix/
│ │ └── tse/
│ │ ├── README.md
│ │ ├── v1/
│ │ │ ├── README.md
│ │ │ ├── confs/
│ │ │ │ ├── bsrnn.yaml
│ │ │ │ ├── dpcc_init_gan.yaml
│ │ │ │ ├── dpccn.yaml
│ │ │ │ └── tfgridnet.yaml
│ │ │ ├── local/
│ │ │ │ ├── prepare_data.sh
│ │ │ │ ├── prepare_librimix_enroll.py
│ │ │ │ └── prepare_spk2enroll_librispeech.py
│ │ │ ├── path.sh
│ │ │ └── run.sh
│ │ └── v2/
│ │ ├── README.md
│ │ ├── confs/
│ │ │ ├── bsrnn.yaml
│ │ │ ├── bsrnn_feats.yaml
│ │ │ ├── bsrnn_multi_optim.yaml
│ │ │ ├── dpcc_init_gan.yaml
│ │ │ ├── dpccn.yaml
│ │ │ ├── spexplus.yaml
│ │ │ └── tfgridnet.yaml
│ │ ├── local/
│ │ │ ├── prepare_data.sh
│ │ │ ├── prepare_librimix_enroll.py
│ │ │ └── prepare_spk2enroll_librispeech.py
│ │ ├── path.sh
│ │ └── run.sh
│ └── voxceleb1/
│ └── v2/
│ ├── confs/
│ │ └── bsrnn_online.yaml
│ ├── local/
│ │ ├── prepare_data.sh
│ │ ├── prepare_librimix_enroll.py
│ │ ├── prepare_spk2enroll_librispeech.py
│ │ └── prepare_spk2enroll_vox.py
│ ├── path.sh
│ └── run_online.sh
├── requirements.txt
├── runtime/
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── bin/
│ │ ├── CMakeLists.txt
│ │ └── separate_main.cc
│ ├── cmake/
│ │ ├── gflags.cmake
│ │ ├── glog.cmake
│ │ └── libtorch.cmake
│ ├── frontend/
│ │ ├── CMakeLists.txt
│ │ ├── fbank.h
│ │ ├── feature_pipeline.cc
│ │ ├── feature_pipeline.h
│ │ ├── fft.cc
│ │ ├── fft.h
│ │ └── wav.h
│ ├── separate/
│ │ ├── CMakeLists.txt
│ │ ├── separate_engine.cc
│ │ └── separate_engine.h
│ └── utils/
│ ├── CMakeLists.txt
│ ├── blocking_queue.h
│ ├── timer.h
│ ├── utils.cc
│ └── utils.h
├── setup.py
├── tools/
│ ├── extract_embed_depreciated.py
│ ├── make_lmdb.py
│ ├── make_shard_list_premix.py
│ ├── make_shard_online.py
│ ├── parse_options.sh
│ ├── print_train_val_curve.py
│ ├── run.pl
│ ├── score.sh
│ ├── show_enh_score.sh
│ ├── split_scp.pl
│ └── test_dataset.py
└── wesep/
├── __init__.py
├── bin/
│ ├── average_model.py
│ ├── export_jit.py
│ ├── infer.py
│ ├── score.py
│ ├── train.py
│ └── train_gan.py
├── cli/
│ ├── __init__.py
│ ├── extractor.py
│ ├── hub.py
│ └── utils.py
├── dataset/
│ ├── FRAM_RIR.py
│ ├── dataset.py
│ ├── lmdb_data.py
│ ├── processor.py
│ └── vad.py
├── models/
│ ├── __init__.py
│ ├── bsrnn.py
│ ├── bsrnn_feats.py
│ ├── bsrnn_multi_optim.py
│ ├── convtasnet.py
│ ├── dpccn.py
│ ├── sep_model.py
│ └── tfgridnet.py
├── modules/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── norm.py
│ │ └── speaker.py
│ ├── dpccn/
│ │ ├── __init__.py
│ │ └── convs.py
│ ├── metric_gan/
│ │ ├── __init__.py
│ │ └── discriminator.py
│ ├── tasnet/
│ │ ├── __init__.py
│ │ ├── convs.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── separation.py
│ │ ├── separator.py
│ │ └── speaker.py
│ └── tfgridnet/
│ ├── __init__.py
│ └── gridnet_block.py
└── utils/
├── abs_loss.py
├── checkpoint.py
├── datadir_writer.py
├── dnsmos.py
├── executor.py
├── executor_gan.py
├── file_utils.py
├── funcs.py
├── losses.py
├── schedulers.py
├── score.py
├── signal.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .clang-format
================================================
---
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
TabWidth: 8
UseTab: Never
...
================================================
FILE: .flake8
================================================
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 80
max-doc-length = 80
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B905,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
exclude =
================================================
FILE: .github/workflows/lint.yml
================================================
name: Lint
on:
push:
branches:
- main
pull_request:
jobs:
quick-checks:
runs-on: ubuntu-latest
steps:
- name: Fetch Wenet
uses: actions/checkout@v1
- name: Checkout PR tip
run: |
set -eux
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
# We are on a PR, so actions/checkout leaves us on a merge commit.
# Check out the actual tip of the branch.
git checkout ${{ github.event.pull_request.head.sha }}
fi
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
id: get_pr_tip
- name: Ensure no tabs
run: |
(! git grep -I -l $'\t' -- . ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
- name: Ensure no trailing whitespace
run: |
(! git grep -I -n $' $' -- . ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
flake8-py3:
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.9
architecture: x64
- name: Fetch Wenet
uses: actions/checkout@v1
- name: Checkout PR tip
run: |
set -eux
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
# We are on a PR, so actions/checkout leaves us on a merge commit.
# Check out the actual tip of the branch.
git checkout ${{ github.event.pull_request.head.sha }}
fi
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
id: get_pr_tip
- name: Run flake8
run: |
set -eux
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
flake8 --version
flake8
if [ $? != 0 ]; then exit 1; fi
cpplint:
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.x
architecture: x64
- name: Fetch Wenet
uses: actions/checkout@v1
- name: Checkout PR tip
run: |
set -eux
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
# We are on a PR, so actions/checkout leaves us on a merge commit.
# Check out the actual tip of the branch.
git checkout ${{ github.event.pull_request.head.sha }}
fi
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
id: get_pr_tip
- name: Run cpplint
run: |
set -eux
pip install cpplint==1.6.1
cpplint --version
cpplint --recursive .
if [ $? != 0 ]; then exit 1; fi
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.egg-info
# Visual Studio Code files
.vscode
.vs
# PyCharm files
.idea
venv
# Eclipse Project settings
*.*project
.settings
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Editor temporaries
*.swn
*.swo
*.swp
*.swm
*~
# IPython notebook checkpoints
.ipynb_checkpoints
# macOS dir files
.DS_Store
exp
data
raw_wav
tensorboard
**/*build*
wespeaker_models
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.32.0'
hooks:
- id: yapf
- repo: https://github.com/pycqa/flake8
rev: '3.8.2'
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: 'v17.0.6'
hooks:
- id: clang-format
- repo: https://github.com/cpplint/cpplint
rev: '1.6.1'
hooks:
- id: cpplint
================================================
FILE: CPPLINT.cfg
================================================
root=runtime
filter=-build/c++11
================================================
FILE: README.md
================================================
# Wesep
> We aim to build a toolkit focusing on front-end processing in the cocktail party set up, including target speaker extraction and ~~speech separation (Future work)~~
### Install for development & deployment
* Clone this repo
``` sh
https://github.com/wenet-e2e/wesep.git
```
* Create conda env: pytorch version >= 1.12.0 is required !!!
``` sh
conda create -n wesep python=3.9
conda activate wesep
conda install pytorch=1.12.1 torchaudio=0.12.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install -r requirements.txt
pre-commit install # for clean and tidy code
```
## The Target Speaker Extraction Task
> Target speaker extraction (TSE) focuses on isolating the speech of a specific target speaker from overlapped multi-talker speech, which is a typical setup in the cocktail party problem.
WeSep is featured with flexible target speaker modeling, scalable data management, effective on-the-fly data simulation, structured recipes and deployment support.
<img src="resources/tse.png" width="600px">
## Features (To Do List)
- [x] On the fly data simulation
- [x] Dynamic Mixture simulation
- [x] Dynamic Reverb simulation
- [x] Dynamic Noise simulation
- [x] Support time- and frequency- domain models
- Time-domain
- [x] conv-tasnet based models
- [x] Spex+
- Frequency domain
- [x] pBSRNN
- [x] pDPCCN
- [x] tf-gridnet (Extremely slow, need double check)
- [ ] Training Criteria
- [x] SISNR loss
- [x] GAN loss (Need further investigation)
- [ ] Datasets
- [x] Libri2Mix (Illustration for pre-mixed speech)
- [x] VoxCeleb (Illustration for online training)
- [ ] WSJ0-2Mix
- [ ] Speaker Embedding
- [x] Wespeaker Intergration
- [x] Joint Learned Speaker Embedding
- [x] Different fusion methods
- [ ] Pretrained models
- [ ] CLI Usage
- [x] Runtime
## Data Pipe Design
Following Wenet and Wespeaker, WeSep organizes the data processing modules as a pipeline of a set of different processors. The following figure shows such a pipeline with essential processors.
<img src="resources/datapipe.png" width="800px">
## Discussion
For Chinese users, you can scan the QR code on the left to join our group directly. If it has expired, please scan the personal Wechat QR code on the right.
|<img src='resources/Wechat_group.jpg' style=" width: 200px; height: 300px;">|<img src='resources/Wechat.jpg' style=" width: 200px; height: 300px;">|
| ---- | ---- |
## Citations
If you find wespeaker useful, please cite it as
```bibtex
@inproceedings{wang24fa_interspeech,
title = {WeSep: A Scalable and Flexible Toolkit Towards Generalizable Target Speaker Extraction},
author = {Shuai Wang and Ke Zhang and Shaoxiong Lin and Junjie Li and Xuefei Wang and Meng Ge and Jianwei Yu and Yanmin Qian and Haizhou Li},
year = {2024},
booktitle = {Interspeech 2024},
pages = {4273--4277},
doi = {10.21437/Interspeech.2024-1840},
}
```
================================================
FILE: examples/librimix/tse/README.md
================================================
# Libri2Mix Recipe
## Goal of this recipe
This recipe aims to illustrate how to use WeSep to perform the target speaker extraction task on a pre-defined training set such as Libri2Mix, the mixtures have been prepared on the disk. If you want to check the online data processing and scale to larger training set, please check the voxceleb1 recipe.
## Difference of V1 and V2
The difference between v1 and v2 lies in the approach to speaker modeling.
- v1 outlines a process where WeSpeaker is used to extract embeddings beforehand, which are then saved to disk and sampled during training. This setup allows you to use other toolkits or existing speaker embeddings freely.
- v2 demonstrates a more integrated approach with the WeSpeaker toolkit (recommended). You can choose to either fix the speaker encoder or train it jointly with the separation backbone.
================================================
FILE: examples/librimix/tse/v1/README.md
================================================
## Tutorial on LibriMix
If you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback.
NOTE: WE DON'T RECOMMEND THIS VERSION, IT'S JUST FOR ILLUSTRATING HOW TO USE YOUR OWN EXTRACTOR
YOU NEED TO INSTALL WESPEAKER FIRST, CHECK `https://github.com/wenet-e2e/wespeaker` FOR THE INSTRUCTION
### First Experiment
We provide a recipe `examples/librimix/tse/v1/run.sh` on LibriMix data.
The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.
```bash
cd examples/librimix/tse/v1
bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
bash run.sh --stage 4 --stop_stage 4
bash run.sh --stage 5 --stop_stage 5
bash run.sh --stage 6 --stop_stage 6
```
You could also just run the whole script
```bash
bash run.sh --stage 1 --stop_stage 6
```
------
### Stage 1: Prepare Training Data
Prior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`.
As the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`:
+ `fs`: the sample rate of the dataset, valid options are `16k` and `8k`.
+ `min_max`: the mode of mixtures, valiad options are `min` and `max`.
+ `noise_type`: the type of mixture, valiad options are `clean` and `both`.
In our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows:
``` bash
fs=16k
min_max=min
noise_type="clean"
Libri2Mix_dir=/path/to/Libri2Mix
```
After configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location. However, you can customize the `data` variable in `enh.sh` to save these files in any desired location.
```bash
data=data # you can change this to any directory
```
In this stage, `local/prepare_data.sh`accomplishes three tasks:
1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`, each containing the following files:
+ `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id`
```text
s1/103-1240-0003_1235-135887-0017.wav 103
s1/103-1240-0004_4195-186237-0003.wav 103
...
```
+ `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`.
```
103-1240-0003_1235-135887-0017 103 1235
103-1240-0004_4195-186237-0003 103 4195
...
```
+ `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path`
```
s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav
s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav
...
```
+ `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`.
```
103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav
103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav
...
```
2. Prepare the speaker embeddings using wespeaker pretrained models. This step will generate two files in the `dev`, `test`, and `train_100` directories respectively:
+ `embed.ark`: Kaldi ark file that stores the speaker embeddings.
+ `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path`
```
s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569
s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715
...
```
3. Prepare LibriMix target-speaker enroll signal. This step will generate four files in the `dev` and `test` directories respectively:
+ `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`.
```
4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020
4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014
...
```
+ `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.
```
1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav
1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav
...
```
+ `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.
```
1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav
1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav
...
```
+ `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`.
```
"652": [["652-129742-0010", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav"],
...,
["652-129742-0000", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav"]],
...
```
At the end of this stage, the directory structure of `data` should look like this:
```
data/
|__ clean/ # the noise_type you chose
|__ dev/
| |__ embed.ark
| |__ embed.scp
| |__ mixture2enrollment
| |__ single.utt2spk
| |__ single.wav.scp
| |__ spk1.enroll
| |__ spk2.enroll
| |__ spk2enroll.json # empty
| |__ utt2spk
| |__ wav.scp
|
|__ test/ # the same as dev/
|
|__ train_100/
|__ embed.ark
|__ embed.scp
|__ single.utt2spk
|__ single.wav.scp
|__ utt2spk
|__ wav.scp
```
------
### Stage 2: Convert Data Format
This stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet.
This stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively:
+ `shards/`: this directory stores the compressed packets (tar) files.
```bash
ls shards
shards_000000000.tar shards_000000001.tar shards_000000002.tar ...
```
+ `shard.list`: each line records the path to the corresponding tar file.
```
data/clean/dev/shards/shards_000000000.tar
data/clean/dev/shards/shards_000000001.tar
data/clean/dev/shards/shards_000000002.tar
...
```
At the end of this stage, the directory structure of `data` should look like this:
```
data/
|__ clean/ # the noise_type you chose
|__ dev/
| |__ embed.ark, embed.scp, ... # files generated by Stage 1
| |__ shard.list
| |__ shards/
| |__ shards_000000000.tar
| |__ shards_000000001.tar
| |__ shards_000000002.tar
|
|__ test/ # the same as dev/
|
|__ train_100/
|__ embed.ark, embed.scp, ... # files generated by Stage 1
|__ shard.list
|__ shards/
|__ shards_000000000.tar
|__ ...
|__ shards_000000013.tar
```
------
### Stage 3: Neural Networking Training
You can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information:
+ **overall training process related**
```yaml
seed: 42
exp_dir: exp/BSRNN
enable_amp: false
gpus: '0,1'
log_batch_interval: 100
save_epoch_interval: 1
joint_training: false
```
Explanations for some of the parameters mentioned above:
+ `seed`: specify a random seed.
+ `exp_dir`: specify the experiment directory.
+ `enable_amp`: whether enable automatic mixed precision.
+ `gpus`: specify the visible GPUs during training.
+ `log_batch_interval`: specify after how many batch iterations to record in the log.
+ `save_epoch_interval`: specify after how many batch epoches to save a checkpoint.
+ `joint_training`: specify whether the model for extracting speaker embeddings is jointly trained with the TSE model. Defaluts to `false`.
+ **dataset and dataloader realted**
```yaml
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
whole_utt: false
chunk_len: 48000
online_mix: false
data_type: "shard"
train_data: "data/clean/train_100/shard.list"
train_spk_embeds: "data/clean/train_100/embed.scp"
train_utt2spk: "data/clean/train_100/single.utt2spk"
train_spk2utt: "data/clean/train_100/spk2enroll.json"
val_data: "data/clean/dev/shard.list"
val_spk_embeds: "data/clean/dev/embed.scp"
val_utt2spk: "data/clean/dev/single.utt2spk"
val_spk1_enroll: "data/clean/dev/spk1.enroll"
val_spk2_enroll: "data/clean/dev/spk2.enroll"
val_spk2utt: "data/clean/dev/single.wav.scp"
dataloader_args:
batch_size: 16 # A800
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
```
Explanations for some of the parameters mentioned above:
+ `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`.
+ `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set.
+ `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`.
+ `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`.
+ `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`.
+ `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`.
+ `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`.
+ `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.
+ `train_data`: File containing paths to the training set files.
+ `train_spk_embeds`: File containing paths to the speaker embeddings of training set.
+ `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk` generated in Stage 1).
+ `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`.
+ `val_data`: File containing paths to the validation set files.
+ `val_spk_embeds`: Similiar to `train_spk_embeds`.
+ `val_utt2spk`: Similiar to `train_utt2spk`.
+ `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space.
+ `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space.
+ `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp` generated in Stage 1).
+ We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`.
+ Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples.
+ `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario.
+ `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
+ `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.
+ `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
+ `prefetch_factor`: number of batches loaded in advance by each worker.
+ **loss function related**
```yaml
loss: SISDR
loss_args: { }
```
Explanations for some of the parameters mentioned above:
+ `loss`: the loss function used for training.
+ `loss_args`: the required arguments for the loss function.
In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`.
Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter.
+ **neural network structure related**
```yaml
model:
tse_model: BSRNN
model_args:
tse_model:
sr: 16000
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_emb_dim: 256
spk_fuse_type: 'multiply'
use_spk_transform: False
model_init:
tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt
discriminator: null
```
Explanations for some of the parameters mentioned above:
+ `model`: specify the neural network used for training.
+ `model_args`: specify model-specific parameters.
+ `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`.
+ **model optimization related**
```yaml
num_epochs: 150
clip_grad: 5.0
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
```
Explanations for some of the parameters mentioned above:
+ `num_epochs`: total number of training epochs.
+ `clip_grad`: set the threshold for gradient clipping.
+ `optimizer`: set the optimizer.
+ `optimizer_args`: the required arguments for optimizer.
+ `scheduler`: set the scheduler.
+ `scheduler_args`: the required arguments for scheduler.
+ **others**
```yaml
num_avg: 2
```
Explanations for some of the parameters mentioned above:
+ `num_avg`: numbers for averaged model.
To avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
${train_script} --config confs/config.yaml \
--gpus "[0]" \
```
At the end of this stage, an experiment directory will be created in the current directory, containing the following files:
```
${exp_dir}/
|__ train.log
|__ config.yaml
|__ models/
|__ checkpoint_1.pt
|__ ...
|__ checkpoint_150.pt
|__ final_checkpoint.pt -> checkpoint_150.pt
|__ latest_checkpoint.pt -> checkpoint_150.pt
```
------
### Stage 4: Extract Speech Using the Trained Model
After training is complete, you can execute stage 4 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script:
+ `config`: the configuration file used in Stage 3.
+ `fs`: the sample rate of the audio data.
+ `gpus`: the index of the visible GPU.
+ `exp_dir`: the experiment directory.
+ `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.
+ `test_data`: similiar to `train_data`.
+ `test_spk_embeds`: similiar to `train_spk_embeds`.
+ `test_spk1_enroll`: similiar to `dev_spk1_enroll`.
+ `test_spk2_enroll`: similiar to `dev_spk2_enroll`.
+ `test_spk2utt`: similiar to `dev_spk2utt`.
+ `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech.
At the end of this stage, the structure of the experiment directory should look like this:
```
${exp_dir}/
|__ train.log
|__ config.yaml
|__ models/
|__ infer.log
|__ audio/
|__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path`
|__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav
|__ ...
|__ Utt999-61-70968-0003_2830-3980-0008-T61.wav
```
------
### Stage 5: Scoring
In this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment.
At the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this:
```
exp/BSRNN/
|__ ${exp_dir}
| |__ train.log, ... # files and directories generated in Stage 4
| |__ scoring/
|
|__ RESULTS.md
```
------
### Stage 6: Apply Model Average
In this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`:
+ `dst_model`: the path to save the averaged model.
+ `src_path`: source models path for average.
+ `num`: number of source models for the averaged model.
+ `mode`: the mode for model averaging. Validate options are `final` and `best`.
+ `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames.
+ `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models.
+ `epochs`: this parameter only takes effect when `mode` is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models.
================================================
FILE: examples/librimix/tse/v1/confs/bsrnn.yaml
================================================
dataloader_args:
batch_size: 16 # A800: 16
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
model:
tse_model: BSRNN
model_args:
tse_model:
sr: 16000
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_emb_dim: 256
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Multi speaker fuse with seperation modules
joint_training: False
model_init:
tse_model: null
discriminator: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v1/confs/dpcc_init_gan.yaml
================================================
use_metric_loss: true
dataloader_args:
batch_size: 4
drop_last: true
num_workers: 4
pin_memory: false
prefetch_factor: 4
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
enable_amp: false
exp_dir: exp/DPCNN
gpus: '0,1'
log_batch_interval: 100
loss: SISNR
loss_args: { }
gan_loss_weight: 0.05
model:
tse_model: DPCCN
discriminator: CMGAN_Discriminator
model_args:
tse_model:
win: 512
stride: 128
feature_dim: 257
tcn_blocks: 10
tcn_layers: 2
spk_emb_dim: 256
causal: False
spk_fuse_type: 'multiply'
use_spk_transform: False
discriminator: {}
model_init:
tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt
discriminator: null
num_avg: 5
num_epochs: 50
optimizer:
tse_model: Adam
discriminator: Adam
optimizer_args:
tse_model:
lr: 0.0001
weight_decay: 0.0001
discriminator:
lr: 0.001
weight_decay: 0.0001
clip_grad: 3.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
discriminator: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.0001
warm_from_zero: false
warm_up_epoch: 0
discriminator:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v1/confs/dpccn.yaml
================================================
dataloader_args:
batch_size: 4
drop_last: true
num_workers: 4
pin_memory: false
prefetch_factor: 4
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
enable_amp: false
exp_dir: exp/DPCNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
model:
tse_model: DPCCN
model_args:
tse_model:
win: 512
stride: 128
feature_dim: 257
tcn_blocks: 10
tcn_layers: 2
spk_emb_dim: 256
causal: False
spk_fuse_type: 'multiply'
use_spk_transform: False
joint_training: False
model_init:
tse_model: null
discriminator: null
num_avg: 5
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
clip_grad: 3.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v1/confs/tfgridnet.yaml
================================================
dataloader_args:
batch_size: 4
drop_last: true
num_workers: 4
pin_memory: false
prefetch_factor: 4
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 16000
enable_amp: false
exp_dir: exp/TFGridNet
gpus: '0,1'
log_batch_interval: 100
loss: SI_SNR
loss_args: { }
model:
tse_model: TFGridNet
model_args:
tse_model:
n_srcs: 1
n_fft: 128
stride: 64
window: "hann"
n_imics: 1
n_layers: 6
lstm_hidden_units: 192
attn_n_head: 4
attn_approx_qk_dim: 512
emb_dim: 128
emb_ks: 1
emb_hs: 1
activation: "prelu"
eps: 1.0e-5
spk_emb_dim: 256
use_spk_transform: False
spk_fuse_type: "multiply"
joint_training: False
model_init:
tse_model: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v1/local/prepare_data.sh
================================================
#!/bin/bash
# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)
stage=-1
stop_stage=-1
mix_data_path='/Data/Libri2Mix/wav16k/min/'
data=data
noise_type=clean
num_spk=2
. tools/parse_options.sh || exit 1
data=$(realpath ${data})
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare the meta files for the datasets"
for dataset in dev test train-100; do
echo "Preparing files for" $dataset
# Prepare the meta data for the mixed data
dataset_path=$mix_data_path/$dataset/mix_${noise_type}
mkdir -p "${data}"/$noise_type/${dataset}
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' |
awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' |
sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp |
awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk
# Prepare the meta data for single speakers
dataset_path=$mix_data_path/$dataset/s1
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' |
awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk
dataset_path=$mix_data_path/$dataset/s2
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' |
awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Prepare the speaker embeddings using wespeaker pretrained models"
mkdir wespeaker_resnet34
wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.zip
unzip voxceleb_resnet34_LM.zip -d wespeaker_resnet34
mv wespeaker_resnet34/voxceleb_resnet34_LM.yaml wespeaker_resnet34/config.yaml
mv wespeaker_resnet34/voxceleb_resnet34_LM.pt wespeaker_resnet34/avg_model.pt
for dataset in dev test train-100; do
mkdir -p "${data}"/$noise_type/${dataset}
echo "Preparing files for" $dataset
wespeaker --task embedding_kaldi \
--wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \
--output_file "${data}"/$noise_type/${dataset}/embed \
-p wespeaker_resnet34 \
--device cuda:0 # GPU idx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Prepare LibriMix target-speaker enroll signal"
for dset in dev test train-100; do
python local/prepare_spk2enroll_librispeech.py \
"${mix_data_path}/${dset}" \
--is_librimix True \
--outfile "${data}"/$noise_type/${dset}/spk2enroll.json \
--audio_format wav
done
for dset in dev test; do
if [ $num_spk -eq 2 ]; then
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment"
else
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment"
fi
output_file="${data}/${noise_type}/${dset}/mixture2enrollment"
wget -O "$output_file" "$url"
done
for dset in dev test; do
python local/prepare_librimix_enroll.py \
"${data}"/$noise_type/${dset}/wav.scp \
"${data}"/$noise_type/${dset}/spk2enroll.json \
--mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \
--num_spk ${num_spk} \
--train False \
--output_dir "${data}"/${noise_type}/${dset} \
--outfile_prefix "spk"
done
fi
================================================
FILE: examples/librimix/tse/v1/local/prepare_librimix_enroll.py
================================================
import json
import random
from pathlib import Path
from wesep.utils.datadir_writer import DatadirWriter
from wesep.utils.utils import str2bool
def prepare_librimix_enroll(wav_scp,
spk2utts,
output_dir,
num_spk=2,
train=True,
prefix="enroll_spk"):
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
with Path(spk2utts).open("r", encoding="utf-8") as f:
# {spkID: [(uid1, path1), (uid2, path2), ...]}
spk2utt = json.load(f)
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
uttIDs = mixtureID.split("_")
for spk in range(num_spk):
uttID = uttIDs[spk]
spkID = uttID.split("-")[0]
if train:
# For training, we choose the auxiliary signal on the fly.
# Here we use the pattern f"*{uttID} {spkID}".
writer[f"{prefix}{spk + 1}.enroll"][
mixtureID] = f"*{uttID} {spkID}"
else:
enrollID = random.choice(spk2utt[spkID])[1]
while enrollID == uttID and len(spk2utt[spkID]) > 1:
enrollID = random.choice(spk2utt[spkID])[1]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID
def prepare_librimix_enroll_v2(wav_scp,
map_mix2enroll,
output_dir,
num_spk=2,
prefix="spk"):
# noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
mix2enroll = {}
with open(map_mix2enroll) as f:
for line in f:
mix_id, utt_id, enroll_id = line.strip().split()
sid = mix_id.split("_").index(utt_id) + 1
mix2enroll[mix_id, f"s{sid}"] = enroll_id
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
for spk in range(num_spk):
enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id +
".wav")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"wav_scp",
type=str,
help="Path to the wav.scp file",
)
parser.add_argument("spk2utts",
type=str,
help="Path to the json file containing mapping "
"from speaker ID to utterances")
parser.add_argument(
"--num_spk",
type=int,
default=2,
choices=(2, 3),
help="Number of speakers in each mixture sample",
)
parser.add_argument(
"--train",
type=str2bool,
default=True,
help="Whether is the training set or not",
)
parser.add_argument(
"--mix2enroll",
type=str,
default=None,
help="Path to the downloaded map_mixture2enrollment file. "
"If `train` is False, this value is required.",
)
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to the directory for storing output files",
)
parser.add_argument(
"--outfile_prefix",
type=str,
default="spk",
help="Prefix of the output files",
)
args = parser.parse_args()
random.seed(args.seed)
if args.train:
prepare_librimix_enroll(
args.wav_scp,
args.spk2utts,
args.output_dir,
num_spk=args.num_spk,
train=args.train,
prefix=args.outfile_prefix,
)
else:
prepare_librimix_enroll_v2(
args.wav_scp,
args.mix2enroll,
args.output_dir,
num_spk=args.num_spk,
prefix=args.outfile_prefix,
)
================================================
FILE: examples/librimix/tse/v1/local/prepare_spk2enroll_librispeech.py
================================================
import json
from collections import defaultdict
from itertools import chain
from pathlib import Path
from wesep.utils.utils import str2bool
def get_spk2utt(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in Path(path).rglob("*.{}".format(audio_format)):
readerID = audio.parent.parent.stem
uid = audio.stem
assert uid.split("-")[0] == readerID, audio
spk2utt[readerID].append((uid, str(audio)))
return spk2utt
def get_spk2utt_librimix(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in chain(
Path(path).rglob("s1/*.{}".format(audio_format)),
Path(path).rglob("s2/*.{}".format(audio_format)),
Path(path).rglob("s3/*.{}".format(audio_format)),
):
spk_idx = int(audio.parent.stem[1:]) - 1
mix_uid = audio.stem
uid = mix_uid.split("_")[spk_idx]
sid = uid.split("-")[0]
spk2utt[sid].append((uid, str(audio)))
return spk2utt
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_paths",
type=str,
nargs="+",
help="Paths to Librispeech subsets",
)
parser.add_argument(
"--is_librimix",
type=str2bool,
default=False,
help="Whether the provided audio_paths points to LibriMix data",
)
parser.add_argument(
"--outfile",
type=str,
default="spk2utt_tse.json",
help="Path to the output spk2utt json file",
)
parser.add_argument("--audio_format", type=str, default="flac")
args = parser.parse_args()
if args.is_librimix:
# use clean sources from LibriMix as enrollment
spk2utt = get_spk2utt_librimix(args.audio_paths,
audio_format=args.audio_format)
else:
# use Librispeech as enrollment
spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)
outfile = Path(args.outfile)
outfile.parent.mkdir(parents=True, exist_ok=True)
with outfile.open("w", encoding="utf-8") as f:
json.dump(spk2utt, f, indent=4)
================================================
FILE: examples/librimix/tse/v1/path.sh
================================================
export PATH=$PWD:$PATH
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../../:$PYTHONPATH
================================================
FILE: examples/librimix/tse/v1/run.sh
================================================
#!/bin/bash
# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)
. ./path.sh || exit 1
# General configuration
stage=-1
stop_stage=1
# Data preparation related
data=data
fs=16k
min_max=min
noise_type="clean"
data_type="shard" # shard/raw
Libri2Mix_dir=/YourPath/librimix/Libri2Mix
mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}"
# Training related
gpus="[0,1]"
use_gan_loss=false
config=confs/bsrnn.yaml
exp_dir=exp/BSRNN/resnet34-pre_extract-multiply_fuse
if [ -z "${config}" ] && [ -f "${exp_dir}/config.yaml" ]; then
config="${exp_dir}/config.yaml"
fi
# TSE model initialization related
checkpoint=
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then
checkpoint="${exp_dir}/models/latest_checkpoint.pt"
fi
# Inferencing and scoring related
use_pesq=true
use_dnsmos=true
dnsmos_use_gpu=true
# Model average related
num_avg=2
. tools/parse_options.sh || exit 1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare datasets ..."
./local/prepare_data.sh --mix_data_path ${mix_data_path} \
--data ${data} \
--noise_type ${noise_type} \
--stage 2 \
--stop-stage 2
fi
data=${data}/${noise_type}
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Covert train and test data to ${data_type}..."
for dset in train-100 dev test; do
# for dset in train-360; do
python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \
--num_threads 16 \
--prefix shards \
--shuffle \
${data}/$dset/wav.scp ${data}/$dset/utt2spk \
${data}/$dset/shards ${data}/$dset/shard.list
done
fi
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Start training ..."
if ${use_gan_loss}; then
train_script=wesep/bin/train_gan.py
else
train_script=wesep/bin/train.py
fi
export OMP_NUM_THREADS=8
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
${train_script} --config $config \
--exp_dir ${exp_dir} \
--gpus $gpus \
--num_avg ${num_avg} \
--data_type "${data_type}" \
--train_data ${data}/train-100/${data_type}.list \
--train_spk_embeds ${data}/train-100/embed.scp \
--train_utt2spk ${data}/train-100/single.utt2spk \
--train_spk2utt ${data}/train-100/spk2enroll.json \
--val_data ${data}/dev/${data_type}.list \
--val_spk_embeds ${data}/dev/embed.scp \
--val_utt2spk ${data}/dev/single.utt2spk \
--val_spk1_enroll ${data}/dev/spk1.enroll \
--val_spk2_enroll ${data}/dev/spk2.enroll \
--val_spk2utt ${data}/dev/single.wav.scp \
${checkpoint:+--checkpoint $checkpoint}
fi
# shellcheck disable=SC2215
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Start inferencing ..."
python wesep/bin/infer.py --config $config \
--fs ${fs} \
--gpus 0 \
--exp_dir ${exp_dir} \
--data_type "${data_type}" \
--test_data ${data}/test/${data_type}.list \
--test_spk_embeds ${data}/test/embed.scp \
--test_spk1_enroll ${data}/test/spk1.enroll \
--test_spk2_enroll ${data}/test/spk2.enroll \
--test_spk2utt ${data}/test/single.wav.scp \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Start scoring ..."
./tools/score.sh --dset "${data}/test" \
--exp_dir "${exp_dir}" \
--fs ${fs} \
--use_pesq "${use_pesq}" \
--use_dnsmos "${use_dnsmos}" \
--dnsmos_use_gpu "${dnsmos_use_gpu}" \
--n_gpu "${num_gpus}"
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Do model average ..."
avg_model=$exp_dir/models/avg_best_model.pt
python wesep/bin/average_model.py \
--dst_model $avg_model \
--src_path $exp_dir/models \
--num ${num_avg} \
--mode best \
--epochs "138,141"
fi
================================================
FILE: examples/librimix/tse/v2/README.md
================================================
## Tutorial on LibriMix
If you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback.
### First Experiment
We provide a recipe `examples/librimix/tse/v2/run.sh` on LibriMix data.
The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.
```bash
cd examples/librimix/tse/v2
bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
bash run.sh --stage 4 --stop_stage 4
bash run.sh --stage 5 --stop_stage 5
bash run.sh --stage 6 --stop_stage 6
```
You could also just run the whole script
```bash
bash run.sh --stage 1 --stop_stage 6
```
------
### Stage 1: Prepare Training Data
Prior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`.
As the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`:
+ `fs`: the sample rate of the dataset, valid options are `16k` and `8k`.
+ `min_max`: the mode of mixtures, valiad options are `min` and `max`.
+ `noise_type`: the type of mixture, valiad options are `clean` and `both`.
In our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows:
``` bash
fs=16k
min_max=min
noise_type="clean"
Libri2Mix_dir=/path/to/Libri2Mix
```
After configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location.
```bash
data=data # you can change this to any directory
```
In this stage, `local/prepare_data.sh`accomplishes three tasks (Main differences with v1 version):
1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`/`train_360`, each containing the following files:
+ `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id`
```text
s1/103-1240-0003_1235-135887-0017.wav 103
s1/103-1240-0004_4195-186237-0003.wav 103
...
```
+ `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`.
```
103-1240-0003_1235-135887-0017 103 1235
103-1240-0004_4195-186237-0003 103 4195
...
```
+ `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path`
```
s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav
s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav
...
```
+ `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`.
```
103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav
103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav
...
```
2. Prepare LibriMix target-speaker enroll signal. This step will generate one `json` file in the `dev`, `test` and `train_100`/`train_360` directories, and additional three files in the `dev` and `test` directories respectively:
+ `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`.
```
"652": [["652-129742-0010", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav"],
...,
["652-129742-0000", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav"]],
...
```
+ `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`.
```
4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020
4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014
...
```
+ `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.
```
1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav
1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav
...
```
+ `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.
```
1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav
1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav
...
```
At the end of this stage, the directory structure of `data` should look like this:
```
data/
|__ clean/ # the noise_type you chose
|__ dev/
| |__ mixture2enrollment
| |__ single.utt2spk
| |__ single.wav.scp
| |__ spk1.enroll
| |__ spk2.enroll
| |__ spk2enroll.json
| |__ utt2spk
| |__ wav.scp
|
|__ test/ # the same as dev/
|
|__ train_100/
|__ single.utt2spk
|__ single.wav.scp
|__ spk2enroll.json
|__ utt2spk
|__ wav.scp
```
3. Download the speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker for training the TSE model with pretrained speaker encoder. The models will be unzipped into `wespeaker_models/`.
Find more speaker models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md.
4. Prepare the speaker embeddings using wespeaker pretrained models. (Not needed, and comment off in v2 version by default.)
This step will generate two files in the `dev`, `test`, and `train_100` directories respectively:
+ `embed.ark`: Kaldi ark file that stores the speaker embeddings.
+ `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path`
```
s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569
s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715
...
```
------
### Stage 2: Convert Data Format
This stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet.
This stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively:
+ `shards/`: this directory stores the compressed packets (tar) files.
```bash
ls shards
shards_000000000.tar shards_000000001.tar shards_000000002.tar ...
```
+ `shard.list`: each line records the path to the corresponding tar file.
```
data/clean/dev/shards/shards_000000000.tar
data/clean/dev/shards/shards_000000001.tar
data/clean/dev/shards/shards_000000002.tar
...
```
At the end of this stage, the directory structure of `data` should look like this:
```
data/
|__ clean/ # the noise_type you chose
|__ dev/
| |__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1
| |__ shard.list
| |__ shards/
| |__ shards_000000000.tar
| |__ shards_000000001.tar
| |__ shards_000000002.tar
|
|__ test/ # the same as dev/
|
|__ train_100/
|__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1
|__ shard.list
|__ shards/
|__ shards_000000000.tar
|__ ...
|__ shards_000000013.tar
```
------
### Stage 3: Neural Networking Training
You can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information:
+ **overall training process related**
```yaml
seed: 42
exp_dir: exp/BSRNN
enable_amp: false
gpus: '0,1'
log_batch_interval: 100
save_epoch_interval: 1
```
Explanations for some of the parameters mentioned above:
+ `seed`: specify a random seed.
+ `exp_dir`: specify the experiment directory.
+ `enable_amp`: whether enable automatic mixed precision.
+ `gpus`: specify the visible GPUs during training.
+ `log_batch_interval`: specify after how many batch iterations to record in the log.
+ `save_epoch_interval`: specify after how many batch epoches to save a checkpoint.
+ **dataset and dataloader realted**
```yaml
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
whole_utt: false
chunk_len: 48000
online_mix: false
speaker_feat: &speaker_feat true
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
"Usually you don't need to manually write the data part of the configuration into the config file, it will be automatically generated."
data_type: "shard"
train_data: "data/clean/train_100/shard.list"
train_utt2spk: "data/clean/train_100/single.utt2spk"
train_spk2utt: "data/clean/train_100/spk2enroll.json"
val_data: "data/clean/dev/shard.list"
val_utt2spk: "data/clean/dev/single.utt2spk"
val_spk1_enroll: "data/clean/dev/spk1.enroll"
val_spk2_enroll: "data/clean/dev/spk2.enroll"
val_spk2utt: "data/clean/dev/single.wav.scp"
dataloader_args:
batch_size: 12 # A800
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
```
Explanations for some of the parameters mentioned above:
+ `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`.
+ `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set.
+ `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`.
+ `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`.
+ `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`.
+ `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`.
+ `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`.
+ `speaker_feat`: Whether transform the enrollment from waveform to fbank. Recommended setting to `true`. Defaults to `false`.
+ `num_mel_bins`: The parameter of fbank. The feature dimension of the fbank. Defaults to `80`.
+ `frame_shift`: The parameter of fbank. The time of frame shift in `ms`. Defaults to `10`.
+ `frame_length`: The parameter of fbank. The frame length in `ms`. Defaults to `25`.
+ `dither`: The parameter of fbank. Whether add noise to fbank feature. Defaults to `1.0`.
+ `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.
+ `train_data`: File containing paths to the training set files.
+ `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk` generated in Stage 1).
+ `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`.
+ `val_data`: File containing paths to the validation set files.
+ `val_utt2spk`: Similiar to `train_utt2spk`.
+ `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space.
+ `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space.
+ `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp` generated in Stage 1).
+ We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`.
+ Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples.
+ `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario.
+ `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
+ `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.
+ `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
+ `prefetch_factor`: number of batches loaded in advance by each worker.
+ **loss function related**
```yaml
loss: SISDR
loss_args: { }
### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set.
loss: [SISDR, CE]
loss_args:
loss_posi: [[0],[1]]
loss_weight: [[1.0],[1.0]]
```
Explanations for some of the parameters mentioned above:
+ `loss`: the loss function used for training.
+ `loss_args`: the required arguments for the loss function.
+ `loss_posi`: Select which outputs from the TSE model the loss function works on.
+ `loss_weight`: The weight of loss calculated from corresponding loss function.
In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`.
Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter.
+ **neural network structure related**
```yaml
model:
tse_model: BSRNN
model_args:
tse_model:
sr: 16000
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_emb_dim: 256
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False
joint_training: True ### You should always set this para to `True` when using v2 version.
spk_model: ResNet34
spk_model_init: None
spk_args: None
spk_emb_dim: 256
spk_model_freeze: False
spk_feat: *speaker_feat
feat_type: "consistent"
multi_task: False
spksInTrain: 251
model_init:
tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt
discriminator: null
```
Explanations for some of the parameters mentioned above:
+ `model`: specify the neural network used for training.
+ `model_args`: specify model-specific parameters.
+ `spk_fuse_type`: specify the fusion method of the speaker embedding. Support `concat`, `additive`, `multiply` and `FiLM`.
+ `multi_fuse`: whether fuse the speaker embedding multiple times.
+ `joint_training`: specify whether the speaker encoder for extracting speaker embeddings is jointly trained with the TSE model. Always set this to `true`. Do NOT use it to control if training with pretrained speaker encoders. Defaluts to `false`.
+ `spk_model`: specify the speaker model. Supports most speaker models in wespeaker: https://github.com/wenet-e2e/wespeaker/tree/master.
+ `spk_model_init`: the path of the pre-trained speaker model. Find more pretrained models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md. Set `None` for training the speaker model with TSE model from scratch.
+ `spk_args`: specify speaker model-specific parameters.
+ `spk_emb_dim`: the feature dimension of speaker embedding extracted from the speaker encoder.
+ `spk_model_freeze`: whether freeze the weights in speaker encoder. Set `True` when using pretrained speaker encoder.
+ `spk_feat`: Use the defined parameters in `dataset_args` to determine whether to perform feature extraction of enrollment within the model.
+ `feat_type`: specify the type of enrollment's feature, when `spk_feat` is `False`.
+ `multi_task`: whether use such as `CE` loss function for jointly training the speaker encoder. This parameter needs to be coordinated with the `loss`.
+ `spksInTrain`: specify the speaker number in the training dataset. wsj0-2mix: 101, Libri2mix-100: 251, Libri2mix-360:921.
+ `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`.
+ **model optimization related**
```yaml
num_epochs: 150
clip_grad: 5.0
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
```
Explanations for some of the parameters mentioned above:
+ `num_epochs`: total number of training epochs.
+ `clip_grad`: set the threshold for gradient clipping.
+ `optimizer`: set the optimizer.
+ `optimizer_args`: the required arguments for optimizer. Not used in currently version. The learning rate and scheduler are determined by `scheduler_args`.
+ `scheduler`: set the scheduler.
+ `scheduler_args`: the required arguments for scheduler.
+ **others**
```yaml
num_avg: 2
```
Explanations for some of the parameters mentioned above:
+ `num_avg`: numbers for averaged model.
To avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
${train_script} --config confs/config.yaml \
--gpus "[0]" \
```
At the end of this stage, an experiment directory will be created in the current directory, containing the following files:
```
${exp_dir}/
|__ train.log
|__ config.yaml
|__ models/
|__ checkpoint_1.pt
|__ ...
|__ checkpoint_150.pt
|__ final_checkpoint.pt -> checkpoint_150.pt
|__ latest_checkpoint.pt -> checkpoint_150.pt
```
------
### Stage 4: Apply Model Average
In this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`:
+ `dst_model`: the path to save the averaged model.
+ `src_path`: source models path for average.
+ `num`: number of source models for the averaged model.
+ `mode`: the mode for model averaging. Validate options are `final` and `best`.
+ `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames.
+ `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models.
+ `epochs`: this parameter only takes effect when `mode` is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models.
------
### Stage 5: Extract Speech Using the Trained Model
After training is complete, you can execute stage 5 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script:
+ `config`: the configuration file used in Stage 3.
+ `fs`: the sample rate of the audio data.
+ `gpus`: the index of the visible GPU.
+ `exp_dir`: the experiment directory.
+ `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.
+ `test_data`: similiar to `train_data`.
+ `test_spk1_enroll`: similiar to `dev_spk1_enroll`.
+ `test_spk2_enroll`: similiar to `dev_spk2_enroll`.
+ `test_spk2utt`: similiar to `dev_spk2utt`.
+ `save_wav`: control if save the extracted speech in `exp_dir/audio`.
+ `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech.
At the end of this stage, the structure of the experiment directory should look like this:
```
${exp_dir}/
|__ train.log
|__ config.yaml
|__ models/
|__ infer.log
|__ audio/
|__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path`
|__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav
|__ ...
|__ Utt999-61-70968-0003_2830-3980-0008-T61.wav
```
------
### Stage 6: Scoring
In this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment.
At the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this:
```
exp/BSRNN/
|__ ${exp_dir}
| |__ train.log, ... # files and directories generated in Stage 5
| |__ scoring/
|
|__ RESULTS.md
```
================================================
FILE: examples/librimix/tse/v2/confs/bsrnn.yaml
================================================
dataloader_args:
batch_size: 8 #RTX2080 1, V100: 8, A800: 16
drop_last: true
num_workers: 6
pin_memory: true
prefetch_factor: 6
dataset_args:
resample_rate: &sr 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat True
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
# Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589
# only Single-optimization method is supported here.
# if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml
SSA_enroll_prob: 0 # prob to add SSA on enrollment speech
enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0],[1]]
# loss_weight: [[1.0],[1.0]]
model:
tse_model: BSRNN
model_args:
tse_model:
sr: *sr
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
# find_unused_parameters: True
model_init:
tse_model: null
discriminator: null
spk_model: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/bsrnn_feats.yaml
================================================
dataloader_args:
batch_size: 4 #RTX2080 1, V100: 4, A800: 12
drop_last: true
num_workers: 6
pin_memory: true
prefetch_factor: 6
dataset_args:
resample_rate: &sr 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat False
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
# Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589
# only Single-optimization method is supported here.
# if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml
SSA_enroll_prob: 0 # prob to add SSA on enrollment speech
enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0],[1]]
# loss_weight: [[1.0],[1.0]]
model:
tse_model: BSRNN_Feats
model_args:
tse_model:
sr: *sr
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spectral_feat: 'tfmap_emb' # 'tfmap_spec' 'tfmap_emb' False
spk_fuse_type: 'cross_multiply' #'cross_multiply' 'multiply' False
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
#################################################################
###### Ecapa_TDNN
spk_model: ECAPA_TDNN_GLOB_c512
spk_model_init: ./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
spk_args:
embed_dim: &embed_dim 192
feat_dim: 80
pooling_func: ASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: True # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat # if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 # wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
# find_unused_parameters: True
model_init:
tse_model: null
discriminator: null
spk_model: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml
================================================
dataloader_args:
batch_size: 8 #RTX2080 1, V100: 8, A800: 16
drop_last: true
num_workers: 6
pin_memory: true
prefetch_factor: 6
dataset_args:
resample_rate: &sr 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat False
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100
#Please refer to our SLT paper https://www.arxiv.org/abs/2409.09589
# to check our parameter settings.
loss: SISDR
loss_args:
loss_posi: [[0,1]]
loss_weight: [[0.4,0.6]]
#if you wanna use CE loss, multi_task needs to be set True
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0,1],[2,3]]
# loss_weight: [[0.36,0.54],[0.04,0.06]]
model:
tse_model: BSRNN_Multi
model_args:
tse_model:
sr: *sr
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
# find_unused_parameters: True
model_init:
tse_model: null
discriminator: null
spk_model: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/dpcc_init_gan.yaml
================================================
use_metric_loss: true
dataloader_args:
batch_size: 4
drop_last: true
num_workers: 4
pin_memory: false
prefetch_factor: 4
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
enable_amp: false
exp_dir: exp/DPCNN
gpus: '0,1'
log_batch_interval: 100
loss: SISNR
loss_args: { }
gan_loss_weight: 0.05
model:
tse_model: DPCCN
discriminator: CMGAN_Discriminator
model_args:
tse_model:
win: 512
stride: 128
feature_dim: 257
tcn_blocks: 10
tcn_layers: 2
spk_emb_dim: 256
causal: False
spk_fuse_type: 'multiply'
use_spk_transform: False
discriminator: {}
model_init:
tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt
discriminator: null
num_avg: 5
num_epochs: 50
optimizer:
tse_model: Adam
discriminator: Adam
optimizer_args:
tse_model:
lr: 0.0001
weight_decay: 0.0001
discriminator:
lr: 0.001
weight_decay: 0.0001
clip_grad: 3.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
discriminator: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.0001
warm_from_zero: false
warm_up_epoch: 0
discriminator:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/dpccn.yaml
================================================
dataloader_args:
batch_size: 6
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat True
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
enable_amp: false
exp_dir: exp/DPCNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0],[1]]
# loss_weight: [[1.0],[1.0]]
model:
tse_model: DPCCN
model_args:
tse_model:
win: 512
stride: 128
feature_dim: 257
tcn_blocks: 10
tcn_layers: 2
causal: False
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
model_init:
tse_model: null
discriminator: null
num_avg: 5
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
clip_grad: 3.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/spexplus.yaml
================================================
dataloader_args:
batch_size: 8 #A800: 8
drop_last: true
num_workers: 4
pin_memory: true
prefetch_factor: 6
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
enable_amp: false
exp_dir: exp/SpExplus/
gpus: ['0']
log_batch_interval: 100
# joint_training: True
loss: [SISDR, CE] ###SI_SNR, SDR, sisnr, CE
loss_args:
loss_posi: [[0,1,2],[3]]
loss_weight: [[0.8,0.1,0.1],[0.5]]
model:
tse_model: ConvTasNet
model_args:
tse_model:
B: 256
H: 512
L: 20
N: 256
P: 3
R: 4
X: 8
spk_emb_dim: 256
activate: relu
causal: false
norm: gLN
skip_con: False
spk_fuse_type: concatConv # "concat", "additive", "multiply", "FiLM", "None", ("concatConv" only for convtasnet)
use_spk_transform: False
multi_fuse: True # Multi speaker fuse with seperation modules
encoder_type: Multi # Multi, Deep, False
decoder_type: Multi # Multi, Deep, False
joint_training: True
multi_task: True
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
model_init:
tse_model: null
discriminator: null
num_avg: 5
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
save_epoch_interval: 5
clip_grad: 5.0 # False
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: False
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/confs/tfgridnet.yaml
================================================
dataloader_args:
batch_size: 1
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
dataset_args:
resample_rate: &sr 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat True
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
enable_amp: false
exp_dir: exp/TFGridNet
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0],[1]]
# loss_weight: [[1.0],[1.0]]
model:
tse_model: TFGridNet
model_args:
tse_model:
n_srcs: 1
sr: *sr
n_fft: 128
stride: 64
window: "hann"
n_imics: 1
n_layers: 6
lstm_hidden_units: 192
attn_n_head: 4
attn_approx_qk_dim: 512
emb_dim: 128
emb_ks: 1
emb_hs: 1
activation: "prelu"
eps: 1.0e-5
use_spk_transform: False
spk_fuse_type: "multiply"
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
model_init:
tse_model: null
num_avg: 5
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/librimix/tse/v2/local/prepare_data.sh
================================================
#!/bin/bash
# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)
stage=-1
stop_stage=-1
mix_data_path='./Libri2Mix/wav16k/min/'
data=data
noise_type=clean
num_spk=2
. tools/parse_options.sh || exit 1
data=$(realpath ${data})
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare the meta files for the datasets"
for dataset in dev test train-100; do
# for dataset in train-360; do
echo "Preparing files for" $dataset
# Prepare the meta data for the mixed data
dataset_path=$mix_data_path/$dataset/mix_${noise_type}
mkdir -p "${data}"/$noise_type/${dataset}
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' |
awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' |
sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp |
awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk
# Prepare the meta data for single speakers
dataset_path=$mix_data_path/$dataset/s1
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' |
awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk
dataset_path=$mix_data_path/$dataset/s2
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' |
awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Prepare LibriMix target-speaker enroll signal"
for dset in dev test train-100; do
# for dset in train-360; do
python local/prepare_spk2enroll_librispeech.py \
"${mix_data_path}/${dset}" \
--is_librimix True \
--outfile "${data}"/$noise_type/${dset}/spk2enroll.json \
--audio_format wav
done
for dset in dev test; do
if [ $num_spk -eq 2 ]; then
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment"
else
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment"
fi
output_file="${data}/${noise_type}/${dset}/mixture2enrollment"
wget -O "$output_file" "$url"
done
for dset in dev test; do
python local/prepare_librimix_enroll.py \
"${data}"/$noise_type/${dset}/wav.scp \
"${data}"/$noise_type/${dset}/spk2enroll.json \
--mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \
--num_spk ${num_spk} \
--train False \
--output_dir "${data}"/${noise_type}/${dset} \
--outfile_prefix "spk"
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker..."
mkdir wespeaker_models
wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip
unzip voxceleb_resnet34.zip -d wespeaker_models
wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip
unzip voxceleb_ECAPA512.zip -d wespeaker_models
fi
# if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# echo "Prepare the speaker embeddings using wespeaker pretrained models"
# for dataset in dev test train-100; do
# mkdir -p "${data}"/$noise_type/${dataset}
# echo "Preparing files for" $dataset
# wespeaker --task embedding_kaldi \
# --wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \
# --output_file "${data}"/$noise_type/${dataset}/embed \
# -p wespeaker_models/voxceleb_resnet34 \
# -g 0 # GPU idx
# done
# fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
if [ ! -d "${data}/raw_data/musan" ]; then
mkdir -p ${data}/raw_data/musan
#
echo "Downloading musan.tar.gz ..."
echo "This may take a long time. Thus we recommand you to download all archives above in your own way first."
wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data
md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}')
[ $md5 != "0c472d4fc0c5141eca47ad1ffeb2a7df" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1
echo "Decompress all archives ..."
tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data
rm -rf ${data}/raw_data/musan.tar.gz
fi
echo "Prepare wav.scp for musan ..."
mkdir -p ${data}/musan
find ${data}/raw_data/musan -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/musan/wav.scp
# Convert all musan data to LMDB
echo "conver musan data to LMDB ..."
python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb
fi
================================================
FILE: examples/librimix/tse/v2/local/prepare_librimix_enroll.py
================================================
import json
import random
from pathlib import Path
from wesep.utils.datadir_writer import DatadirWriter
from wesep.utils.utils import str2bool
def prepare_librimix_enroll(wav_scp,
spk2utts,
output_dir,
num_spk=2,
train=True,
prefix="enroll_spk"):
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
with Path(spk2utts).open("r", encoding="utf-8") as f:
# {spkID: [(uid1, path1), (uid2, path2), ...]}
spk2utt = json.load(f)
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
uttIDs = mixtureID.split("_")
for spk in range(num_spk):
uttID = uttIDs[spk]
spkID = uttID.split("-")[0]
if train:
# For training, we choose the auxiliary signal on the fly.
# Here we use the pattern f"*{uttID} {spkID}".
writer[f"{prefix}{spk + 1}.enroll"][
mixtureID] = f"*{uttID} {spkID}"
else:
enrollID = random.choice(spk2utt[spkID])[1]
while enrollID == uttID and len(spk2utt[spkID]) > 1:
enrollID = random.choice(spk2utt[spkID])[1]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID
def prepare_librimix_enroll_v2(wav_scp,
map_mix2enroll,
output_dir,
num_spk=2,
prefix="spk"):
# noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
mix2enroll = {}
with open(map_mix2enroll) as f:
for line in f:
mix_id, utt_id, enroll_id = line.strip().split()
sid = mix_id.split("_").index(utt_id) + 1
mix2enroll[mix_id, f"s{sid}"] = enroll_id
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
for spk in range(num_spk):
enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id +
".wav")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"wav_scp",
type=str,
help="Path to the wav.scp file",
)
parser.add_argument(
"spk2utts",
type=str,
help="Path to the json, mapping from speaker ID to utterances",
)
parser.add_argument(
"--num_spk",
type=int,
default=2,
choices=(2, 3),
help="Number of speakers in each mixture sample",
)
parser.add_argument(
"--train",
type=str2bool,
default=True,
help="Whether is the training set or not",
)
parser.add_argument(
"--mix2enroll",
type=str,
default=None,
help="Path to the downloaded map_mixture2enrollment file. "
"If `train` is False, this value is required.",
)
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to the directory for storing output files",
)
parser.add_argument(
"--outfile_prefix",
type=str,
default="spk",
help="Prefix of the output files",
)
args = parser.parse_args()
random.seed(args.seed)
if args.train:
prepare_librimix_enroll(
args.wav_scp,
args.spk2utts,
args.output_dir,
num_spk=args.num_spk,
train=args.train,
prefix=args.outfile_prefix,
)
else:
prepare_librimix_enroll_v2(
args.wav_scp,
args.mix2enroll,
args.output_dir,
num_spk=args.num_spk,
prefix=args.outfile_prefix,
)
================================================
FILE: examples/librimix/tse/v2/local/prepare_spk2enroll_librispeech.py
================================================
import json
from collections import defaultdict
from itertools import chain
from pathlib import Path
from wesep.utils.utils import str2bool
def get_spk2utt(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in Path(path).rglob("*.{}".format(audio_format)):
readerID = audio.parent.parent.stem
uid = audio.stem
assert uid.split("-")[0] == readerID, audio
spk2utt[readerID].append((uid, str(audio)))
return spk2utt
def get_spk2utt_librimix(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in chain(
Path(path).rglob("s1/*.{}".format(audio_format)),
Path(path).rglob("s2/*.{}".format(audio_format)),
Path(path).rglob("s3/*.{}".format(audio_format)),
):
spk_idx = int(audio.parent.stem[1:]) - 1
mix_uid = audio.stem
uid = mix_uid.split("_")[spk_idx]
sid = uid.split("-")[0]
spk2utt[sid].append((uid, str(audio)))
return spk2utt
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_paths",
type=str,
nargs="+",
help="Paths to Librispeech subsets",
)
parser.add_argument(
"--is_librimix",
type=str2bool,
default=False,
help="Whether the provided audio_paths points to LibriMix data",
)
parser.add_argument(
"--outfile",
type=str,
default="spk2utt_tse.json",
help="Path to the output spk2utt json file",
)
parser.add_argument("--audio_format", type=str, default="flac")
args = parser.parse_args()
if args.is_librimix:
# use clean sources from LibriMix as enrollment
spk2utt = get_spk2utt_librimix(args.audio_paths,
audio_format=args.audio_format)
else:
# use Librispeech as enrollment
spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)
outfile = Path(args.outfile)
outfile.parent.mkdir(parents=True, exist_ok=True)
with outfile.open("w", encoding="utf-8") as f:
json.dump(spk2utt, f, indent=4)
================================================
FILE: examples/librimix/tse/v2/path.sh
================================================
export PATH=$PWD:$PATH
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../../:$PYTHONPATH
================================================
FILE: examples/librimix/tse/v2/run.sh
================================================
#!/bin/bash
# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)
. ./path.sh || exit 1
# General configuration
stage=-1
stop_stage=-1
# Data preparation related
data=data
fs=16k
min_max=min
noise_type="clean"
data_type="shard" # shard/raw
Libri2Mix_dir=/YourPATH/librimix/Libri2Mix
mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}"
# Training related
gpus="[0]"
use_gan_loss=false
config=confs/bsrnn.yaml
exp_dir=exp/BSRNN/no_spk_transform-multiply_fuse
if [ -z "${config}" ] && [ -f "${exp_dir}/config.yaml" ]; then
config="${exp_dir}/config.yaml"
fi
# TSE model initialization related
checkpoint=
# Inferencing and scoring related
save_results=true
use_pesq=true
use_dnsmos=true
dnsmos_use_gpu=true
# Model average related
num_avg=10
. tools/parse_options.sh || exit 1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare datasets ..."
./local/prepare_data.sh --mix_data_path ${mix_data_path} \
--data ${data} \
--noise_type ${noise_type} \
--stage 1 \
--stop-stage 3
fi
data=${data}/${noise_type}
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Covert train and test data to ${data_type}..."
for dset in train-100 dev test; do
# for dset in train-360; do
python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \
--num_threads 16 \
--prefix shards \
--shuffle \
${data}/$dset/wav.scp ${data}/$dset/utt2spk \
${data}/$dset/shards ${data}/$dset/shard.list
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Start training ..."
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then
checkpoint="${exp_dir}/models/latest_checkpoint.pt"
fi
if ${use_gan_loss}; then
train_script=wesep/bin/train_gan.py
else
train_script=wesep/bin/train.py
fi
export OMP_NUM_THREADS=8
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
${train_script} --config $config \
--exp_dir ${exp_dir} \
--gpus $gpus \
--num_avg ${num_avg} \
--data_type "${data_type}" \
--train_data ${data}/train-100/${data_type}.list \
--train_utt2spk ${data}/train-100/single.utt2spk \
--train_spk2utt ${data}/train-100/spk2enroll.json \
--val_data ${data}/dev/${data_type}.list \
--val_spk1_enroll ${data}/dev/spk1.enroll \
--val_spk2_enroll ${data}/dev/spk2.enroll \
--val_spk2utt ${data}/dev/single.wav.scp \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Do model average ..."
avg_model=$exp_dir/models/avg_best_model.pt
python wesep/bin/average_model.py \
--dst_model $avg_model \
--src_path $exp_dir/models \
--num ${num_avg} \
--mode best \
--epochs "138,141"
fi
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/avg_best_model.pt" ]; then
checkpoint="${exp_dir}/models/avg_best_model.pt"
fi
# shellcheck disable=SC2215
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Start inferencing ..."
python wesep/bin/infer.py --config $config \
--fs ${fs} \
--gpus 0 \
--exp_dir ${exp_dir} \
--data_type "${data_type}" \
--test_data ${data}/test/${data_type}.list \
--test_spk1_enroll ${data}/test/spk1.enroll \
--test_spk2_enroll ${data}/test/spk2.enroll \
--test_spk2utt ${data}/test/single.wav.scp \
--save_wav ${save_results} \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Start scoring ..."
./tools/score.sh --dset "${data}/test" \
--exp_dir "${exp_dir}" \
--fs ${fs} \
--use_pesq "${use_pesq}" \
--use_dnsmos "${use_dnsmos}" \
--dnsmos_use_gpu "${dnsmos_use_gpu}" \
--n_gpu "${num_gpus}"
fi
================================================
FILE: examples/voxceleb1/v2/confs/bsrnn_online.yaml
================================================
dataloader_args:
batch_size: 8
drop_last: true
num_workers: 6
pin_memory: false
prefetch_factor: 6
dataset_args:
resample_rate: 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
filter_len: true
filter_len_args:
min_num_seconds: 1.0
max_num_seconds: 100.0
chunk_len: 48000
online_mix: true
num_speakers: 2
use_random_snr: true
speaker_feat: &speaker_feat True
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0
reverb_prob: 0
enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100
loss: SISDR
loss_args: { }
model:
tse_model: BSRNN
model_args:
tse_model:
sr: 16000
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921
# find_unused_parameters: True
model_init:
tse_model: null
discriminator: null
num_avg: 2
num_epochs: 150
optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001
weight_decay: 0.0001
clip_grad: 5.0
save_epoch_interval: 1
scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0
seed: 42
================================================
FILE: examples/voxceleb1/v2/local/prepare_data.sh
================================================
#!/bin/bash
# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)
stage=-1
stop_stage=-1
single_data_path='./voxceleb/VoxCeleb1/wav/'
mix_data_path='./Libri2Mix/wav16k/min/'
data=data
noise_type=clean
num_spk=2
. tools/parse_options.sh || exit 1
data=$(realpath ${data})
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare the meta files for the vox1 single speaker datasets"
for dataset in train-vox1; do
echo "Preparing files for" $dataset
# Prepare the meta data for the online mix data
mkdir -p "${data}"/$noise_type/${dataset}
find ${single_data_path} -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' | sort >"${data}"/$noise_type/${dataset}/wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp | awk -F "/" '{print $0,$1}' >"${data}"/$noise_type/${dataset}/utt2spk
python local/prepare_spk2enroll_vox.py \
"${data}/$noise_type/${dataset}/wav.scp" \
--outfile "${data}"/$noise_type/${dataset}/spk2enroll.json
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Prepare the meta files for the val and test datasets"
for dataset in dev test; do
echo "Preparing files for" $dataset
# Prepare the meta data for the mixed data
dataset_path=$mix_data_path/$dataset/mix_${noise_type}
mkdir -p "${data}"/$noise_type/${dataset}
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' |
awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' |
sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp |
awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk
# Prepare the meta data for single speakers
dataset_path=$mix_data_path/$dataset/s1
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' |
awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk
dataset_path=$mix_data_path/$dataset/s2
find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp
awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' |
awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Prepare LibriMix target-speaker enroll signal"
for dset in dev test; do
python local/prepare_spk2enroll_librispeech.py \
"${mix_data_path}/${dset}" \
--is_librimix True \
--outfile "${data}"/$noise_type/${dset}/spk2enroll.json \
--audio_format wav
done
for dset in dev test; do
if [ $num_spk -eq 2 ]; then
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment"
else
url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment"
fi
output_file="${data}/${noise_type}/${dset}/mixture2enrollment"
wget -O "$output_file" "$url"
done
for dset in dev test; do
python local/prepare_librimix_enroll.py \
"${data}"/$noise_type/${dset}/wav.scp \
"${data}"/$noise_type/${dset}/spk2enroll.json \
--mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \
--num_spk ${num_spk} \
--train False \
--output_dir "${data}"/${noise_type}/${dset} \
--outfile_prefix "spk"
done
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker..."
mkdir wespeaker_models
wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip
unzip voxceleb_resnet34.zip -d wespeaker_models
wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip
unzip voxceleb_ECAPA512.zip -d wespeaker_models
fi
# if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# echo "Prepare the speaker embeddings using wespeaker pretrained models"
# for dataset in dev test train-100; do
# mkdir -p "${data}"/$noise_type/${dataset}
# echo "Preparing files for" $dataset
# wespeaker --task embedding_kaldi \
# --wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \
# --output_file "${data}"/$noise_type/${dataset}/embed \
# -p wespeaker_models/voxceleb_resnet34 \
# -g 0 # GPU idx
# done
# fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
if [ ! -d "${data}/raw_data/musan" ]; then
mkdir -p ${data}/raw_data/musan
#
echo "Downloading musan.tar.gz ..."
echo "This may take a long time. Thus we recommand you to download all archives above in your own way first."
wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data
md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}')
[ $md5 != "0c472d4fc0c5141eca47ad1ffeb2a7df" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1
echo "Decompress all archives ..."
tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data
rm -rf ${data}/raw_data/musan.tar.gz
fi
echo "Prepare wav.scp for musan ..."
mkdir -p ${data}/musan
find ${data}/raw_data/musan -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/musan/wav.scp
# Convert all musan data to LMDB
echo "conver musan data to LMDB ..."
python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb
fi
================================================
FILE: examples/voxceleb1/v2/local/prepare_librimix_enroll.py
================================================
import json
import random
from pathlib import Path
from wesep.utils.datadir_writer import DatadirWriter
from wesep.utils.utils import str2bool
def prepare_librimix_enroll(wav_scp,
spk2utts,
output_dir,
num_spk=2,
train=True,
prefix="enroll_spk"):
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
with Path(spk2utts).open("r", encoding="utf-8") as f:
# {spkID: [(uid1, path1), (uid2, path2), ...]}
spk2utt = json.load(f)
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
uttIDs = mixtureID.split("_")
for spk in range(num_spk):
uttID = uttIDs[spk]
spkID = uttID.split("-")[0]
if train:
# For training, we choose the auxiliary signal on the fly.
# Thus, here we use the pattern f"*{uttID} {spkID}" to indicate it. # noqa
writer[f"{prefix}{spk + 1}.enroll"][
mixtureID] = f"*{uttID} {spkID}"
else:
enrollID = random.choice(spk2utt[spkID])[1]
while enrollID == uttID and len(spk2utt[spkID]) > 1:
enrollID = random.choice(spk2utt[spkID])[1]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID
def prepare_librimix_enroll_v2(wav_scp,
map_mix2enroll,
output_dir,
num_spk=2,
prefix="spk"):
# noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py
mixtures = []
with Path(wav_scp).open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
mixtureID = line.strip().split(maxsplit=1)[0]
mixtures.append(mixtureID)
mix2enroll = {}
with open(map_mix2enroll) as f:
for line in f:
mix_id, utt_id, enroll_id = line.strip().split()
sid = mix_id.split("_").index(utt_id) + 1
mix2enroll[mix_id, f"s{sid}"] = enroll_id
with DatadirWriter(Path(output_dir)) as writer:
for mixtureID in mixtures:
# 100-121669-0004_3180-138043-0053
for spk in range(num_spk):
enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"]
writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id +
".wav")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"wav_scp",
type=str,
help="Path to the wav.scp file",
)
parser.add_argument("spk2utts",
type=str,
help="Path to the json file containing mapping "
"from speaker ID to utterances")
parser.add_argument(
"--num_spk",
type=int,
default=2,
choices=(2, 3),
help="Number of speakers in each mixture sample",
)
parser.add_argument(
"--train",
type=str2bool,
default=True,
help="Whether is the training set or not",
)
parser.add_argument(
"--mix2enroll",
type=str,
default=None,
help="Path to the downloaded map_mixture2enrollment file. "
"If `train` is False, this value is required.",
)
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to the directory for storing output files",
)
parser.add_argument(
"--outfile_prefix",
type=str,
default="spk",
help="Prefix of the output files",
)
args = parser.parse_args()
random.seed(args.seed)
if args.train:
prepare_librimix_enroll(
args.wav_scp,
args.spk2utts,
args.output_dir,
num_spk=args.num_spk,
train=args.train,
prefix=args.outfile_prefix,
)
else:
prepare_librimix_enroll_v2(
args.wav_scp,
args.mix2enroll,
args.output_dir,
num_spk=args.num_spk,
prefix=args.outfile_prefix,
)
================================================
FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_librispeech.py
================================================
import json
from collections import defaultdict
from itertools import chain
from pathlib import Path
from wesep.utils.utils import str2bool
def get_spk2utt_vox1(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
path = Path(path)
for subdir in path.iterdir():
if subdir.is_dir():
readerID = subdir.name
for audio in subdir.rglob("*.{}".format(audio_format)):
spk2utt[readerID].append(str(audio))
return spk2utt
def get_spk2utt(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in Path(path).rglob("*.{}".format(audio_format)):
readerID = audio.parent.parent.stem
uid = audio.stem
assert uid.split("-")[0] == readerID, audio
spk2utt[readerID].append((uid, str(audio)))
return spk2utt
def get_spk2utt_librimix(paths, audio_format="flac"):
spk2utt = defaultdict(list)
for path in paths:
for audio in chain(
Path(path).rglob("s1/*.{}".format(audio_format)),
Path(path).rglob("s2/*.{}".format(audio_format)),
Path(path).rglob("s3/*.{}".format(audio_format)),
):
spk_idx = int(audio.parent.stem[1:]) - 1
mix_uid = audio.stem
uid = mix_uid.split("_")[spk_idx]
sid = uid.split("-")[0]
spk2utt[sid].append((uid, str(audio)))
return spk2utt
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_paths",
type=str,
nargs="+",
help="Paths to Librispeech subsets",
)
parser.add_argument(
"--is_librimix",
type=str2bool,
default=False,
help="Whether the provided audio_paths points to LibriMix data",
)
parser.add_argument(
"--is_vox1",
type=str2bool,
default=False,
help="Whether the provided audio_paths points to vox1 data",
)
parser.add_argument(
"--outfile",
type=str,
default="spk2utt_tse.json",
help="Path to the output spk2utt json file",
)
parser.add_argument("--audio_format", type=str, default="flac")
args = parser.parse_args()
if args.is_librimix:
# use clean sources from LibriMix as enrollment
spk2utt = get_spk2utt_librimix(args.audio_paths,
audio_format=args.audio_format)
elif args.is_vox1:
spk2utt = get_spk2utt_vox1(args.audio_paths,
audio_format=args.audio_format)
else:
# use Librispeech as enrollment
spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)
outfile = Path(args.outfile)
outfile.parent.mkdir(parents=True, exist_ok=True)
with outfile.open("w", encoding="utf-8") as f:
json.dump(spk2utt, f, indent=4)
================================================
FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_vox.py
================================================
import json
from collections import defaultdict
from pathlib import Path
def get_spk2utt_from_wavscp(wav_scp_path):
spk2utt = defaultdict(list)
with open(wav_scp_path, "r") as readin:
for line in readin:
speaker_id = line.split("/")[0]
uid, audio_path = line.strip().split()
spk2utt[speaker_id].append((uid, str(audio_path)))
return spk2utt
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"wav_scp_path",
type=str,
help="Paths to Librispeech subsets",
)
parser.add_argument(
"--outfile",
type=str,
default="spk2utt_tse.json",
help="Path to the output spk2utt json file",
)
args = parser.parse_args()
spk2utt = get_spk2utt_from_wavscp(args.wav_scp_path)
outfile = Path(args.outfile)
outfile.parent.mkdir(parents=True, exist_ok=True)
with outfile.open("w", encoding="utf-8") as f:
json.dump(spk2utt, f)
================================================
FILE: examples/voxceleb1/v2/path.sh
================================================
export PATH=$PWD:$PATH
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:$PYTHONPATH
================================================
FILE: examples/voxceleb1/v2/run_online.sh
================================================
#!/bin/bash
# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)
. ./path.sh || exit 1
stage=-1
stop_stage=-1
data=data
fs=16k
min_max=min
noise_type="clean"
data_type="shard" # shard/raw
Vox1_dir=/YourPATH/voxceleb/VoxCeleb1/wav
Libri2Mix_dir=/YourPATH/librimix/Libri2Mix #For validate and test the TSE model.
mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}"
gpus="[0,1]"
num_avg=10
checkpoint=
config=confs/bsrnn_online.yaml
exp_dir=exp/BSRNN_Online/no_spk_transform_multiply
save_results=true
. tools/parse_options.sh || exit 1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Prepare datasets ..."
./local/prepare_data.sh --single_data_path ${Vox1_dir} \
--mix_data_path ${mix_data_path} \
--data ${data} \
--noise_type ${noise_type} \
--stage 1 \
--stop-stage 4
fi
data=${data}/${noise_type}
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Covert train and test data to ${data_type}..."
for dset in train-vox1; do
python tools/make_shard_online.py --num_utts_per_shard 1000 \
--num_threads 16 \
--prefix shards \
--shuffle \
${data}/$dset/wav.scp ${data}/$dset/utt2spk \
${data}/$dset/shards_online ${data}/$dset/shard_online.list
done
for dset in dev test; do
python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \
--num_threads 16 \
--prefix shards \
--shuffle \
${data}/$dset/wav.scp ${data}/$dset/utt2spk \
${data}/$dset/shards ${data}/$dset/shard.list
done
fi
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then
checkpoint="${exp_dir}/models/latest_checkpoint.pt"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# rm -r $exp_dir
echo "Start training ..."
export OMP_NUM_THREADS=8
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then
checkpoint="${exp_dir}/models/latest_checkpoint.pt"
fi
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wesep/bin/train.py --config $config \
--exp_dir ${exp_dir} \
--gpus $gpus \
--num_avg ${num_avg} \
--data_type "${data_type}" \
--train_data ${data}/train-vox1/${data_type}_online.list \
--train_utt2spk ${data}/train-vox1/utt2spk \
--train_spk2utt ${data}/train-vox1/spk2enroll.json \
--val_data ${data}/dev/${data_type}.list \
--val_spk2utt ${data}/dev/single.wav.scp \
--val_spk1_enroll ${data}/dev/spk1.enroll \
--val_spk2_enroll ${data}/dev/spk2.enroll \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Do model average ..."
avg_model=$exp_dir/models/avg_best_model.pt
python wesep/bin/average_model.py \
--dst_model $avg_model \
--src_path $exp_dir/models \
--num ${num_avg} \
--mode best \
--epochs "138,141"
fi
if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/avg_best_model.pt" ]; then
checkpoint="${exp_dir}/models/avg_best_model.pt"
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
python wesep/bin/infer.py --config $config \
--gpus 0 \
--exp_dir ${exp_dir} \
--data_type "${data_type}" \
--test_data ${data}/test/${data_type}.list \
--test_spk1_enroll ${data}/test/spk1.enroll \
--test_spk2_enroll ${data}/test/spk2.enroll \
--test_spk2utt ${data}/test/single.wav.scp \
--save_wav ${save_results} \
${checkpoint:+--checkpoint $checkpoint}
fi
#./run.sh --stage 4 --stop-stage 4 --config exp/BSRNN/train_clean_460/multiply_no_spk_transform/config.yaml --exp_dir exp/BSRNN/train_clean_460/multiply_no_spk_transform/ --checkpoint exp/BSRNN/train_clean_460/multiply_no_spk_transform/models/avg_best_model.pt
================================================
FILE: requirements.txt
================================================
fast_bss_eval==0.1.4
fire==0.4.0
joblib==1.1.0
kaldiio==2.18.0
librosa==0.10.1
lmdb==1.3.0
matplotlib==3.5.1
mir_eval==0.7
silero-vad==5.1.2
numpy==1.22.4
pesq==0.0.4
pystoi==0.3.3
PyYAML==6.0
Requests==2.31.0
scipy==1.7.3
soundfile==0.12.1
tableprint==0.9.1
thop==0.1.1.post2209072238
torchnet==0.0.4
tqdm==4.64.0
flake8==3.8.2
flake8-bugbear
flake8-comprehensions
flake8-executable
flake8-pyi==20.5.0
auraloss
torchmetrics==1.2.0
h5py
pre-commit==3.5.0
================================================
FILE: runtime/.gitignore
================================================
fc_base
build*
================================================
FILE: runtime/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.14)
project(wesep VERSION 0.1)
option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF)
set(CMAKE_VERBOSE_MAKEFILE OFF)
include(FetchContent)
set(FETCHCONTENT_QUIET OFF)
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_base})
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC")
include(libtorch)
include(glog)
include(gflags)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
# build all libraries
add_subdirectory(utils)
add_subdirectory(frontend)
add_subdirectory(separate)
add_subdirectory(bin)
================================================
FILE: runtime/README.md
================================================
# Libtorch backend on wesep
* Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above.
``` sh
mkdir build && cd build
cmake ..
cmake --build .
```
* Testing.
1. the RTF(real time factor) is shown in the console, and outputs will be written to the wav file.
``` sh
export GLOG_logtostderr=1
export GLOG_v=2
./build/bin/separate_main \
--wav_scp $wav_scp \
--model /path/to/model.zip \
--output_dir /output/dir/
```
================================================
FILE: runtime/bin/CMakeLists.txt
================================================
add_executable(separate_main separate_main.cc)
target_link_libraries(separate_main PUBLIC frontend separate)
================================================
FILE: runtime/bin/separate_main.cc
================================================
// Copyright (c) 2024 wesep team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "frontend/wav.h"
#include "separate/separate_engine.h"
#include "utils/timer.h"
#include "utils/utils.h"
DEFINE_string(wav_path, "", "the path of mixing audio.");
DEFINE_string(spk1_emb, "", "the emb of spk1.");
DEFINE_string(spk2_emb, "", "the emb of spk2.");
DEFINE_string(wav_scp, "", "input wav scp.");
DEFINE_string(model, "", "the path of wesep model.");
DEFINE_string(output_dir, "", "output path.");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(feat_dim, 80, "fbank feature dimension.");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
std::vector<std::vector<std::string>> waves;
if (!FLAGS_wav_path.empty() && !FLAGS_spk1_emb.empty() &&
!FLAGS_spk2_emb.empty()) {
waves.push_back(std::vector<std::string>(
{"test", FLAGS_wav_path, FLAGS_spk1_emb, FLAGS_spk2_emb}));
} else {
std::ifstream wav_scp(FLAGS_wav_scp);
std::string line;
while (getline(wav_scp, line)) {
std::vector<std::string> strs;
wesep::SplitString(line, &strs);
CHECK_EQ(strs.size(), 4);
waves.push_back(
std::vector<std::string>({strs[0], strs[1], strs[2], strs[3]}));
}
if (waves.empty()) {
LOG(FATAL) << "Please provide non-empty wav scp.";
}
}
if (FLAGS_output_dir.empty()) {
LOG(FATAL) << "Invalid output path.";
}
int g_total_waves_dur = 0;
int g_total_process_time = 0;
auto model = std::make_shared<wesep::SeparateEngine>(
FLAGS_model, FLAGS_feat_dim, FLAGS_sample_rate);
for (auto wav : waves) {
// mix wav
wenet::WavReader wav_reader(wav[1]);
CHECK_EQ(wav_reader.sample_rate(), 16000);
int16_t* mix_wav_data = const_cast<int16_t*>(wav_reader.data());
int wave_dur =
static_cast<int>(static_cast<float>(wav_reader.num_sample()) /
wav_reader.sample_rate() * 1000);
// spk1
wenet::WavReader spk1_reader(wav[2]);
CHECK_EQ(spk1_reader.sample_rate(), 16000);
int16_t* spk1_data = const_cast<int16_t*>(spk1_reader.data());
// spk2
wenet::WavReader spk2_reader(wav[3]);
CHECK_EQ(spk2_reader.sample_rate(), 16000);
int16_t* spk2_data = const_cast<int16_t*>(spk2_reader.data());
// forward
std::vector<std::vector<float>> outputs;
int process_time = 0;
wenet::Timer timer;
model->ForwardFunc(
std::vector<int16_t>(mix_wav_data,
mix_wav_data + wav_reader.num_sample()),
spk1_data, spk2_data,
std::min(spk1_reader.num_sample(), spk2_reader.num_sample()), &outputs);
process_time = timer.Elapsed();
LOG(INFO) << "process: " << wav[0]
<< " RTF: " << static_cast<float>(process_time) / wave_dur;
// 保存音频
wenet::WriteWavFile(outputs[0].data(), outputs[0].size(), 16000,
FLAGS_output_dir + "/" + wav[0] + "-spk1.wav");
wenet::WriteWavFile(outputs[1].data(), outputs[1].size(), 16000,
FLAGS_output_dir + "/" + wav[0] + "-spk2.wav");
g_total_process_time += process_time;
g_total_waves_dur += wave_dur;
}
LOG(INFO) << "Total: process " << g_total_waves_dur << "ms audio taken "
<< g_total_process_time << "ms.";
LOG(INFO) << "RTF: " << std::setprecision(4)
<< static_cast<float>(g_total_process_time) / g_total_waves_dur;
return 0;
}
================================================
FILE: runtime/cmake/gflags.cmake
================================================
FetchContent_Declare(gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
include_directories(${gflags_BINARY_DIR}/include)
================================================
FILE: runtime/cmake/glog.cmake
================================================
FetchContent_Declare(glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR})
================================================
FILE: runtime/cmake/libtorch.cmake
================================================
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=d52f63577a07adb0bfd6d77c90f7da21896e94f71eb7dcd55ed7835ccb3b2b59")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=bee1b7be308792aa60fc95a4f5274d9658cb7248002d0e333d49eb81ec88430c")
endif()
else()
message(FATAL_ERROR "Unsported System '${CMAKE_SYSTEM_NAME}' (expected 'Linux')")
endif()
FetchContent_Declare(libtorch
URL ${LIBTORCH_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(libtorch)
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
include_directories(${TORCH_INCLUDE_DIRS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG")
================================================
FILE: runtime/frontend/CMakeLists.txt
================================================
add_library(frontend STATIC
feature_pipeline.cc
fft.cc
)
target_link_libraries(frontend PUBLIC utils)
================================================
FILE: runtime/frontend/fbank.h
================================================
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FBANK_H_
#define FRONTEND_FBANK_H_
#include <cstring>
#include <limits>
#include <random>
#include <utility>
#include <vector>
#include "frontend/fft.h"
#include "glog/logging.h"
namespace wenet {
// This code is based on kaldi Fbank implentation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
frame_shift_(frame_shift),
use_log_(true),
remove_dc_offset_(true),
generator_(0),
distribution_(0, 1.0),
dither_(0.0) {
fft_points_ = UpperPowerOfTwo(frame_length_);
// generate bit reversal table and trigonometric function table
const int fft_points_4 = fft_points_ / 4;
bitrev_.resize(fft_points_);
sintbl_.resize(fft_points_ + fft_points_4);
make_sintbl(fft_points_, sintbl_.data());
make_bitrev(fft_points_, bitrev_.data());
int num_fft_bins = fft_points_ / 2;
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
int low_freq = 20, high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
bins_.resize(num_bins_);
center_freqs_.resize(num_bins_);
for (int bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
center_freqs_[bin] = InverseMelScale(center_mel);
std::vector<float> this_bin(num_fft_bins);
int first_index = -1, last_index = -1;
for (int i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
}
}
CHECK(first_index != -1 && last_index >= first_index);
bins_[bin].first = first_index;
int size = last_index + 1 - first_index;
bins_[bin].second.resize(size);
for (int i = 0; i < size; ++i) {
bins_[bin].second[i] = this_bin[first_index + i];
}
}
// NOTE(cdliang): add hamming window
hamming_window_.resize(frame_length_);
double a = M_2PI / (frame_length - 1);
for (int i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl);
}
}
void set_use_log(bool use_log) { use_log_ = use_log; }
void set_remove_dc_offset(bool remove_dc_offset) {
remove_dc_offset_ = remove_dc_offset;
}
void set_dither(float dither) { dither_ = dither; }
int num_bins() const { return num_bins_; }
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static int UpperPowerOfTwo(int n) {
return static_cast<int>(pow(2, ceil(log(n) / log(2))));
}
// preemphasis
void PreEmphasis(float coeff, std::vector<float>* data) const {
if (coeff == 0.0) return;
for (int i = data->size() - 1; i > 0; i--)
(*data)[i] -= coeff * (*data)[i - 1];
(*data)[0] -= coeff * (*data)[0];
}
// add hamming window
void Hamming(std::vector<float>* data) const {
CHECK_GE(data->size(), hamming_window_.size());
for (size_t i = 0; i < hamming_window_.size(); ++i) {
(*data)[i] *= hamming_window_[i];
}
}
// Compute fbank feat, return num frames
int Compute(const std::vector<float>& wave,
std::vector<std::vector<float>>* feat) {
int num_samples = wave.size();
if (num_samples < frame_length_) return 0;
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
feat->resize(num_frames);
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
std::vector<float> power(fft_points_ / 2);
for (int i = 0; i < num_frames; ++i) {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);
// optional add noise
if (dither_ != 0.0) {
for (size_t j = 0; j < data.size(); ++j)
data[j] += dither_ * distribution_(generator_);
}
// optinal remove dc offset
if (remove_dc_offset_) {
float mean = 0.0;
for (size_t j = 0; j < data.size(); ++j) mean += data[j];
mean /= data.size();
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
}
PreEmphasis(0.97, &data);
// Povey(&data);
Hamming(&data);
// copy data to fft_real
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
memset(fft_real.data() + frame_length_, 0,
sizeof(float) * (fft_points_ - frame_length_));
memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
fft_points_);
// power
for (int j = 0; j < fft_points_ / 2; ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
}
(*feat)[i].resize(num_bins_);
// cepstral coefficients, triangle filter array
for (int j = 0; j < num_bins_; ++j) {
float mel_energy = 0.0;
int s = bins_[j].first;
for (size_t k = 0; k < bins_[j].second.size(); ++k) {
mel_energy += bins_[j].second[k] * power[s + k];
}
// optional use log
if (use_log_) {
if (mel_energy < std::numeric_limits<float>::epsilon())
mel_energy = std::numeric_limits<float>::epsilon();
mel_energy = logf(mel_energy);
}
(*feat)[i][j] = mel_energy;
// printf("%f ", mel_energy);
}
// printf("\n");
}
return num_frames;
}
private:
int num_bins_;
int sample_rate_;
int frame_length_, frame_shift_;
int fft_points_;
bool use_log_;
bool remove_dc_offset_;
std::vector<float> center_freqs_;
std::vector<std::pair<int, std::vector<float>>> bins_;
std::vector<float> hamming_window_;
std::default_random_engine generator_;
std::normal_distribution<float> distribution_;
float dither_;
// bit reversal table
std::vector<int> bitrev_;
// trigonometric function table
std::vector<float> sintbl_;
};
} // namespace wenet
#endif // FRONTEND_FBANK_H_
================================================
FILE: runtime/frontend/feature_pipeline.cc
================================================
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_pipeline.h"
#include <algorithm>
#include <utility>
namespace wenet {
FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
: config_(config),
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift),
num_frames_(0),
input_finished_(false) {}
void FeaturePipeline::AcceptWaveform(const std::vector<float>& wav) {
std::vector<std::vector<float>> feats;
std::vector<float> waves;
waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());
waves.insert(waves.end(), wav.begin(), wav.end());
int num_frames = fbank_.Compute(waves, &feats);
for (size_t i = 0; i < feats.size(); ++i) {
feature_queue_.Push(std::move(feats[i]));
}
num_frames_ += num_frames;
int left_samples = waves.size() - config_.frame_shift * num_frames;
remained_wav_.resize(left_samples);
std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),
remained_wav_.begin());
// We are still adding wave, notify input is not finished
finish_condition_.notify_one();
}
void FeaturePipeline::AcceptWaveform(const std::vector<int16_t>& wav) {
std::vector<float> float_wav(wav.size());
for (size_t i = 0; i < wav.size(); i++) {
float_wav[i] = static_cast<float>(wav[i]);
}
this->AcceptWaveform(float_wav);
}
void FeaturePipeline::set_input_finished() {
CHECK(!input_finished_);
{
std::lock_guard<std::mutex> lock(mutex_);
input_finished_ = true;
}
finish_condition_.notify_one();
}
bool FeaturePipeline::ReadOne(std::vector<float>* feat) {
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
std::unique_lock<std::mutex> lock(mutex_);
while (!input_finished_) {
// This will release the lock and wait for notify_one()
// from AcceptWaveform() or set_input_finished()
finish_condition_.wait(lock);
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
}
}
CHECK(input_finished_);
// Double check queue.empty, see issue#893 for detailed discussions.
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
return false;
}
}
}
bool FeaturePipeline::Read(int num_frames,
std::vector<std::vector<float>>* feats) {
feats->clear();
std::vector<float> feat;
while (feats->size() < num_frames) {
if (ReadOne(&feat)) {
feats->push_back(std::move(feat));
} else {
return false;
}
}
return true;
}
void FeaturePipeline::Reset() {
input_finished_ = false;
num_frames_ = 0;
remained_wav_.clear();
feature_queue_.Clear();
}
} // namespace wenet
================================================
FILE: runtime/frontend/feature_pipeline.h
================================================
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FEATURE_PIPELINE_H_
#define FRONTEND_FEATURE_PIPELINE_H_
#include <mutex>
#include <queue>
#include <string>
#include <vector>
#include "frontend/fbank.h"
#include "glog/logging.h"
#include "utils/blocking_queue.h"
namespace wenet {
struct FeaturePipelineConfig {
int num_bins;
int sample_rate;
int frame_length;
int frame_shift;
FeaturePipelineConfig(int num_bins, int sample_rate)
: num_bins(num_bins), // 80 dim fbank
sample_rate(sample_rate) { // 16k sample rate
frame_length = sample_rate / 1000 * 25; // frame length 25ms
frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
}
void Info() const {
LOG(INFO) << "feature pipeline config"
<< " num_bins " << num_bins << " frame_length " << frame_length
<< " frame_shift " << frame_shift;
}
};
// Typically, FeaturePipeline is used in two threads: one thread A calls
// AcceptWaveform() to add raw wav data and set_input_finished() to notice
// the end of input wav, another thread B (decoder thread) calls Read() to
// consume features.So a BlockingQueue is used to make this class thread safe.
// The Read() is designed as a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
class FeaturePipeline {
public:
explicit FeaturePipeline(const FeaturePipelineConfig& config);
// The feature extraction is done in AcceptWaveform().
void AcceptWaveform(const std::vector<float>& wav);
void AcceptWaveform(const std::vector<int16_t>& wav);
// Current extracted frames number.
int num_frames() const { return num_frames_; }
int feature_dim() const { return feature_dim_; }
const FeaturePipelineConfig& config() const { return config_; }
// The caller should call this method when speech input is end.
// Never call AcceptWaveform() after calling set_input_finished() !
void set_input_finished();
bool input_finished() const { return input_finished_; }
// Return False if input is finished and no feature could be read.
// Return True if a feature is read.
// This function is a blocking method. It will block the thread when
// there is no feature in feature_queue_ and the input is not finished.
bool ReadOne(std::vector<float>* feat);
// Read #num_frames frame features.
// Return False if less then #num_frames features are read and the
// input is finished.
// Return True if #num_frames features are read.
// This function is a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
bool Read(int num_frames, std::vector<std::vector<float>>* feats);
void Reset();
bool IsLastFrame(int frame) const {
return input_finished_ && (frame == num_frames_ - 1);
}
int NumQueuedFrames() const { return feature_queue_.Size(); }
private:
const FeaturePipelineConfig& config_;
int feature_dim_;
Fbank fbank_;
BlockingQueue<std::vector<float>> feature_queue_;
int num_frames_;
bool input_finished_;
// The feature extraction is done in AcceptWaveform().
// This wavefrom sample points are consumed by frame size.
// The residual wavefrom sample points after framing are
// kept to be used in next AcceptWaveform() calling.
std::vector<float> remained_wav_;
// Used to block the Read when there is no feature in feature_queue_
// and the input is not finished.
mutable std::mutex mutex_;
std::condition_variable finish_condition_;
};
} // namespace wenet
#endif // FRONTEND_FEATURE_PIPELINE_H_
================================================
FILE: runtime/frontend/fft.cc
================================================
// Copyright (c) 2016 HR
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "frontend/fft.h"
namespace wenet {
void make_sintbl(int n, float* sintbl) {
int i, n2, n4, n8;
float c, s, dc, ds, t;
n2 = n / 2;
n4 = n / 4;
n8 = n / 8;
t = sin(M_PI / n);
dc = 2 * t * t;
ds = sqrt(dc * (2 - dc));
t = 2 * dc;
c = sintbl[n4] = 1;
s = sintbl[0] = 0;
for (i = 1; i < n8; ++i) {
c -= dc;
dc += t * c;
s += ds;
ds -= t * s;
sintbl[i] = s;
sintbl[n4 - i] = c;
}
if (n8 != 0) sintbl[n8] = sqrt(0.5);
for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i];
for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i];
}
void make_bitrev(int n, int* bitrev) {
int i, j, k, n2;
n2 = n / 2;
i = j = 0;
for (;;) {
bitrev[i] = j;
if (++i >= n) break;
k = n2;
while (k <= j) {
j -= k;
k /= 2;
}
j += k;
}
}
// bitrev: bit reversal table
// sintbl: trigonometric function table
// x:real part
// y:image part
// n: fft length
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) {
int i, j, k, ik, h, d, k2, n4, inverse;
float t, s, c, dx, dy;
/* preparation */
if (n < 0) {
n = -n;
inverse = 1; /* inverse transform */
} else {
inverse = 0;
}
n4 = n / 4;
if (n == 0) {
return 0;
}
/* bit reversal */
for (i = 0; i < n; ++i) {
j = bitrev[i];
if (i < j) {
t = x[i];
x[i] = x[j];
x[j] = t;
t = y[i];
y[i] = y[j];
y[j] = t;
}
}
/* transformation */
for (k = 1; k < n; k = k2) {
h = 0;
k2 = k + k;
d = n / k2;
for (j = 0; j < k; ++j) {
c = sintbl[h + n4];
if (inverse)
s = -sintbl[h];
else
s = sintbl[h];
for (i = j; i < n; i += k2) {
ik = i + k;
dx = s * y[ik] + c * x[ik];
dy = c * y[ik] - s * x[ik];
x[ik] = x[i] - dx;
x[i] += dx;
y[ik] = y[i] - dy;
y[i] += dy;
}
h += d;
}
}
if (inverse) {
/* divide by n in case of the inverse transformation */
for (i = 0; i < n; ++i) {
x[i] /= n;
y[i] /= n;
}
}
return 0; /* finished successfully */
}
} // namespace wenet
================================================
FILE: runtime/frontend/fft.h
================================================
// Copyright (c) 2016 HR
#ifndef FRONTEND_FFT_H_
#define FRONTEND_FFT_H_
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace wenet {
// Fast Fourier Transform
void make_sintbl(int n, float* sintbl);
void make_bitrev(int n, int* bitrev);
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n);
} // namespace wenet
#endif // FRONTEND_FFT_H_
================================================
FILE: runtime/frontend/wav.h
================================================
// Copyright (c) 2016 Personal (Binbin Zhang)
// Created on 2016-08-15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <memory>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
DEFINE_int32(pcm_sample_rate, 16000, "pcm data sample rate");
namespace wenet {
class AudioReader {
public:
AudioReader() {}
explicit AudioReader(const std::string& filename) {}
virtual ~AudioReader() {}
virtual int num_channel() const = 0;
virtual int sample_rate() const = 0;
virtual int bits_per_sample() const = 0;
virtual int num_sample() const = 0;
virtual const int16_t* data() const = 0;
};
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader : public AudioReader {
public:
WavReader() {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb");
if (NULL == fp) {
LOG(WARNING) << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
fprintf(stderr,
"WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any subchunks between "fmt" and "data". Usually there will
// be a single "fact" subchunk, but on Windows there can also be a
// "list" subchunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next subchunk
fread(header.data, 8, sizeof(char), fp);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_.resize(num_data);
int num_read = fread(&data_[0], 1, header.data_size, fp);
if (num_read < header.data_size) {
// If the header size is wrong, adjust
header.data_size = num_read;
num_data = header.data_size / (bits_per_sample_ / 8);
data_.resize(num_data);
}
num_sample_ = num_data / num_channel_;
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_sample() const { return num_sample_; }
const int16_t* data() const { return data_.data(); }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_sample_; // sample points per channel
std::vector<int16_t> data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_sample, int num_channel, int sample_rate,
int bits_per_sample)
: data_(data),
num_sample_(num_sample),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_sample_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_sample_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_sample_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
class PcmReader : public AudioReader {
public:
PcmReader() {}
explicit PcmReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb");
if (NULL == fp) {
LOG(WARNING) << "Error in read " << filename;
return false;
}
num_channel_ = 1;
sample_rate_ = FLAGS_pcm_sample_rate;
bits_per_sample_ = 16;
fseek(fp, 0, SEEK_END);
int data_size = ftell(fp);
fseek(fp, 0, SEEK_SET);
num_sample_ = data_size / sizeof(int16_t);
data_.resize(num_sample_);
fread(&data_[0], data_size, 1, fp);
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_sample() const { return num_sample_; }
const int16_t* data() const { return data_.data(); }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_sample_; // sample points per channel
std::vector<int16_t> data_;
};
std::shared_ptr<AudioReader> ReadAudioFile(const std::string& filename) {
size_t pos = filename.rfind('.');
std::string suffix = filename.substr(pos);
if (suffix == ".wav" || suffix == ".WAV") {
return std::make_shared<WavReader>(filename);
} else {
return std::make_shared<PcmReader>(filename);
}
}
void WriteWavFile(const float* data, int data_size, int sample_rate,
const std::string& wav_path) {
std::vector<float> tmp_wav(data, data + data_size);
for (int i = 0; i < tmp_wav.size(); i++) {
tmp_wav[i] *= (1 << 15);
}
WavWriter wav_write(tmp_wav.data(), tmp_wav.size(), 1, sample_rate, 16);
wav_write.Write(wav_path);
}
} // namespace wenet
#endif // FRONTEND_WAV_H_
================================================
FILE: runtime/separate/CMakeLists.txt
================================================
add_library(separate STATIC separate_engine.cc)
target_link_libraries(separate PUBLIC frontend ${TORCH_LIBRARIES})
================================================
FILE: runtime/separate/separate_engine.cc
================================================
// Copyright (c) 2024 wesep team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "separate/separate_engine.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace wesep {
void SeparateEngine::InitEngineThreads(int num_threads) {
// for multi-thread performance
at::set_num_threads(num_threads);
VLOG(1) << "Num intra-op threads: " << at::get_num_threads();
}
SeparateEngine::SeparateEngine(const std::string& model_path,
const int feat_dim, const int sample_rate) {
sample_rate_ = sample_rate;
feat_dim_ = feat_dim;
feature_config_ =
std::make_shared<wenet::FeaturePipelineConfig>(feat_dim, sample_rate);
feature_pipeline_ =
std::make_shared<wenet::FeaturePipeline>(*feature_config_);
feature_pipeline_->Reset();
InitEngineThreads(1);
torch::jit::script::Module model = torch::jit::load(model_path);
model_ = std::make_shared<torch::jit::script::Module>(std::move(model));
model_->eval();
}
void SeparateEngine::ExtractFeature(const int16_t* data, int data_size,
std::vector<std::vector<float>>* feat) {
feature_pipeline_->AcceptWaveform(
std::vector<int16_t>(data, data + data_size));
feature_pipeline_->set_input_finished();
feature_pipeline_->Read(feature_pipeline_->num_frames(), feat);
feature_pipeline_->Reset();
this->ApplyMean(feat);
}
void SeparateEngine::ApplyMean(std::vector<std::vector<float>>* feat) {
std::vector<float> mean(feat_dim_, 0);
for (auto& i : *feat) {
std::transform(i.begin(), i.end(), mean.begin(), mean.begin(),
std::plus<>{});
}
std::transform(mean.begin(), mean.end(), mean.begin(),
[&](const float d) { return d / feat->size(); });
for (auto& i : *feat) {
std::transform(i.begin(), i.end(), mean.begin(), i.begin(), std::minus<>{});
}
}
void SeparateEngine::ForwardFunc(const std::vector<int16_t>& mix_wav,
const int16_t* spk1_emb,
const int16_t* spk2_emb, int data_size,
std::vector<std::vector<float>>* output) {
// pre-process
std::vector<float> input_wav(mix_wav.size());
for (int i = 0; i < mix_wav.size(); i++) {
input_wav[i] = static_cast<float>(mix_wav[i]) / (1 << 15);
}
std::vector<std::vector<float>> spk1_emb_feat;
this->ExtractFeature(spk1_emb, data_size, &spk1_emb_feat);
std::vector<std::vector<float>> spk2_emb_feat;
this->ExtractFeature(spk2_emb, data_size, &spk2_emb_feat);
// torch mix_wav
torch::Tensor torch_wav = torch::zeros({2, mix_wav.size()}, torch::kFloat32);
for (size_t i = 0; i < 2; i++) {
torch::Tensor row =
torch::from_blob(input_wav.data(), {input_wav.size()}, torch::kFloat32)
.clone();
torch_wav[i] = std::move(row);
}
// torch spk_emb_feat
torch::Tensor torch_spk_emb_feat =
torch::zeros({2, spk1_emb_feat.size(), feat_dim_}, torch::kFloat32);
for (size_t i = 0; i < spk1_emb_feat.size(); i++) {
torch::Tensor row1 =
torch::from_blob(spk1_emb_feat[i].data(), {feat_dim_}, torch::kFloat32);
torch_spk_emb_feat[0][i] = std::move(row1);
torch::Tensor row2 =
torch::from_blob(spk2_emb_feat[i].data(), {feat_dim_}, torch::kFloat32);
torch_spk_emb_feat[1][i] = std::move(row2);
}
// forward
torch::NoGradGuard no_grad;
auto outputs =
model_->forward({torch_wav, torch_spk_emb_feat}).toTuple()->elements();
torch::Tensor wav_out = outputs[0].toTensor();
auto accessor = wav_out.accessor<float, 2>();
output->resize(2, std::vector<float>(wav_out.size(1), 0.0));
for (int i = 0; i < wav_out.size(1); i++) {
(*output)[0][i] = accessor[0][i];
(*output)[1][i] = accessor[1][i];
}
}
} // namespace wesep
================================================
FILE: runtime/separate/separate_engine.h
================================================
// Copyright (c) 2024 wesep team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SEPARATE_SEPARATE_ENGINE_H_
#define SEPARATE_SEPARATE_ENGINE_H_
#include <memory>
#include <string>
#include <vector>
#include "torch/script.h"
#include "torch/torch.h"
#include "frontend/feature_pipeline.h"
namespace wesep {
class SeparateEngine {
public:
explicit SeparateEngine(const std::string& model_path, const int feat_dim,
const int sample_rate);
void InitEngineThreads(int num_threads = 1);
void ForwardFunc(const std::vector<int16_t>& mix_wav, const int16_t* spk1_emb,
const int16_t* spk2_emb, int data_size,
std::vector<std::vector<float>>* output);
void ExtractFeature(const int16_t* data, int data_size,
std::vector<std::vector<float>>* feat);
void ApplyMean(std::vector<std::vector<float>>* feat);
private:
std::shared_ptr<torch::jit::script::Module> model_ = nullptr;
std::shared_ptr<wenet::FeaturePipelineConfig> feature_config_ = nullptr;
std::shared_ptr<wenet::FeaturePipeline> feature_pipeline_ = nullptr;
int sample_rate_ = 16000;
int feat_dim_ = 80;
};
} // namespace wesep
#endif // SEPARATE_SEPARATE_ENGINE_H_
================================================
FILE: runtime/utils/CMakeLists.txt
================================================
add_library(utils STATIC
utils.cc
)
target_link_libraries(utils PUBLIC glog gflags frontend)
================================================
FILE: runtime/utils/blocking_queue.h
================================================
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_BLOCKING_QUEUE_H_
#define UTILS_BLOCKING_QUEUE_H_
#include <condition_variable>
#include <limits>
#include <mutex>
#include <queue>
#include <utility>
namespace wenet {
#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \
Type(const Type&) = delete; \
Type& operator=(const Type&) = delete;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity = std::numeric_limits<int>::max())
: capacity_(capacity) {}
void Push(const T& value) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.size() >= capacity_) {
not_full_condition_.wait(lock);
}
queue_.push(value);
}
not_empty_condition_.notify_one();
}
void Push(T&& value) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.size() >= capacity_) {
not_full_condition_.wait(lock);
}
queue_.push(std::move(value));
}
not_empty_condition_.notify_one();
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.empty()) {
not_empty_condition_.wait(lock);
}
T t(std::move(queue_.front()));
queue_.pop();
not_full_condition_.notify_one();
return t;
}
bool Empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.empty();
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
void Clear() {
while (!Empty()) {
Pop();
}
}
private:
size_t capacity_;
mutable std::mutex mutex_;
std::condition_variable not_full_condition_;
std::condition_variable not_empty_condition_;
std::queue<T> queue_;
public:
WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue);
};
} // namespace wenet
#endif // UTILS_BLOCKING_QUEUE_H_
================================================
FILE: runtime/utils/timer.h
================================================
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_TIMER_H_
#define UTILS_TIMER_H_
#include <chrono>
namespace wenet {
class Timer {
public:
Timer() : time_start_(std::chrono::steady_clock::now()) {}
void Reset() { time_start_ = std::chrono::steady_clock::now(); }
// return int in milliseconds
int Elapsed() const {
auto time_now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(time_now -
time_start_)
.count();
}
private:
std::chrono::time_point<std::chrono::steady_clock> time_start_;
};
} // namespace wenet
#endif // UTILS_TIMER_H_
================================================
FILE: runtime/utils/utils.cc
================================================
// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <fstream>
#include <limits>
#include <numeric>
#include <sstream>
#include <vector>
#include "glog/logging.h"
#include "utils/utils.h"
namespace wesep {
std::string Ltrim(const std::string& str) {
size_t start = str.find_first_not_of(WHITESPACE);
return (start == std::string::npos) ? "" : str.substr(start);
}
std::string Rtrim(const std::string& str) {
size_t end = str.find_last_not_of(WHITESPACE);
return (end == std::string::npos) ? "" : str.substr(0, end + 1);
}
std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); }
void SplitString(const std::string& str, std::vector<std::string>* strs) {
SplitStringToVector(Trim(str), " \t", true, strs);
}
void SplitStringToVector(const std::string& full, const char* delim,
bool omit_empty_strings,
std::vector<std::string>* out) {
size_t start = 0, found = 0, end = full.size();
out->clear();
while (found != std::string::npos) {
found = full.find_first_of(delim, start);
// start != end condition is for when the delimiter is at the end
if (!omit_empty_strings || (found != start && start != end))
out->push_back(full.substr(start, found - start));
start = found + 1;
}
}
#ifdef _MSC_VER
std::wstring ToWString(const std::string& str) {
unsigned len = str.size() * 2;
setlocale(LC_CTYPE, "");
wchar_t* p = new wchar_t[len];
mbstowcs(p, str.c_str(), len);
std::wstring wstr(p);
delete[] p;
return wstr;
}
#endif
} // namespace wesep
================================================
FILE: runtime/utils/utils.h
================================================
// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_UTILS_H_
#define UTILS_UTILS_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace wesep {
const char WHITESPACE[] = " \n\r\t\f\v";
// Split the string with space or tab.
void SplitString(const std::string& str, std::vector<std::string>* strs);
void SplitStringToVector(const std::string& full, const char* delim,
bool omit_empty_strings,
std::vector<std::string>* out);
#ifdef _MSC_VER
std::wstring ToWString(const std::string& str);
#endif
} // namespace wesep
#endif // UTILS_UTILS_H_
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
requirements = [
"tqdm",
"kaldiio",
"torch>=1.12.0",
"torchaudio>=0.12.0",
"silero-vad",
]
setup(
name="wesep",
install_requires=requirements,
packages=find_packages(),
entry_points={
"console_scripts": [
"wesep = wesep.cli.extractor:main",
],
},
)
================================================
FILE: tools/extract_embed_depreciated.py
================================================
# Copyright (c) 2022, Shuai Wang (wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import kaldiio
import onnxruntime as ort
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="infer example using onnx")
parser.add_argument("--onnx_path", required=True, help="onnx path")
parser.add_argument("--wav_scp", required=True, help="wav path")
parser.add_argument("--out_path",
required=True,
help="output path of the embeddings")
args = parser.parse_args()
return args
def compute_fbank(wav_path,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0):
"""Extract fbank, simlilar to the one in wespeaker.dataset.processor,
While integrating the wave reading and CMN.
"""
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
sample_frequency=sample_rate,
window_type="hamming",
use_energy=False,
)
# CMN, without CVN
mat = mat - torch.mean(mat, dim=0)
return mat
def main():
args = get_args()
so = ort.SessionOptions()
so.inter_op_num_threads = 1
so.intra_op_num_threads = 1
session = ort.InferenceSession(args.onnx_path, sess_options=so)
embed_ark = os.path.join(args.out_path, "embed.ark")
embed_scp = os.path.join(args.out_path, "embed.scp")
with kaldiio.WriteHelper("ark,scp:" + embed_ark + "," +
embed_scp) as writer:
with open(args.wav_scp, "r") as read_scp:
for line in tqdm(read_scp):
tokens = line.strip().split(" ")
name, wav_path = tokens[0], tokens[1]
feats = compute_fbank(wav_path)
feats = feats.unsqueeze(0).numpy() # add batch dimension
embed = session.run(output_names=["embs"],
input_feed={"feats": feats})
writer(name, embed[0])
if __name__ == "__main__":
main()
================================================
FILE: tools/make_lmdb.py
================================================
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import pickle
import lmdb
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument("in_scp_file", help="input scp file")
parser.add_argument("out_lmdb", help="output lmdb")
args = parser.parse_args()
return args
def main():
args = get_args()
db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB
# txn is for Transaciton
txn = db.begin(write=True)
keys = []
with open(args.in_scp_file, "r", encoding="utf8") as fin:
lines = fin.readlines()
for i, line in enumerate(tqdm(lines)):
arr = line.strip().split()
assert len(arr) == 2
key, wav = arr[0], arr[1]
keys.append(key)
with open(wav, "rb") as fin:
data = fin.read()
txn.put(key.encode(), data)
# Write flush to disk
if i % 100 == 0:
txn.commit()
txn = db.begin(write=True)
txn.commit()
with db.begin(write=True) as txn:
txn.put(b"__keys__", pickle.dumps(keys))
db.sync()
db.close()
if __name__ == "__main__":
main()
================================================
FILE: tools/make_shard_list_premix.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang
# 2023 SRIBD Shuai Wang )
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import logging
import multiprocessing
import os
import random
import tarfile
import time
import sys
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def write_tar_file(data_list, tar_file, index=0, total=1):
logging.info('Processing {} {}/{}'.format(tar_file, index, total))
read_time = 0.0
write_time = 0.0
with tarfile.open(tar_file, "w") as tar:
for item in data_list:
assert len(
item) == 3, 'item should have 3 elements: Key, Speaker, Wav'
key, spks, wavs = item
spk_idx = 1
for spk in spks:
assert isinstance(spk, str)
spk_file = key + '.spk' + str(spk_idx)
spk = spk.encode('utf8')
spk_data = io.BytesIO(spk)
spk_info = tarfile.TarInfo(spk_file)
spk_info.size = len(spk)
tar.addfile(spk_info, spk_data)
spk_idx = spk_idx + 1
spk_idx = 0
for wav in wavs:
suffix = wav.split('.')[-1]
assert suffix in AUDIO_FORMAT_SETS
ts = time.time()
try:
with open(wav, 'rb') as fin:
data = fin.read()
except FileNotFoundError as e:
print(e)
sys.exit()
read_time += (time.time() - ts)
ts = time.time()
if spk_idx > 0:
wav_file = key + '_spk' + str(spk_idx) + '.' + suffix
else:
wav_file = key + '.' + suffix
wav_data = io.BytesIO(data)
wav_info = tarfile.TarInfo(wav_file)
wav_info.size = len(data)
tar.addfile(wav_info, wav_data)
write_time += (time.time() - ts)
spk_idx = spk_idx + 1
logging.info('read {} write {}'.format(read_time, write_time))
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_utts_per_shard',
type=int,
default=1000,
help='num utts per shard')
parser.add_argument('--num_threads',
type=int,
default=1,
help='num threads for make shards')
parser.add_argument('--prefix',
default='shards',
help='prefix of shards tar file')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--shuffle',
action='store_true',
help='whether to shuffle data')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('utt2spk_file', help='utt2spk file')
parser.add_argument('shards_dir', help='output shards dir')
parser.add_argument('shards_list', help='output shards list file')
args = parser.parse_args()
return args
def main():
args = get_args()
random.seed(args.seed)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0] # key = os.path.splitext(arr[0])[0]
wav_table[key] = [arr[i + 1] for i in range(len(arr) - 1)]
data = []
with open(args.utt2spk_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0] # key = os.path.splitext(arr[0])[0]
spks = [arr[i + 1] for i in range(le
gitextract_s23ej_br/
├── .clang-format
├── .flake8
├── .github/
│ └── workflows/
│ └── lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CPPLINT.cfg
├── README.md
├── examples/
│ ├── librimix/
│ │ └── tse/
│ │ ├── README.md
│ │ ├── v1/
│ │ │ ├── README.md
│ │ │ ├── confs/
│ │ │ │ ├── bsrnn.yaml
│ │ │ │ ├── dpcc_init_gan.yaml
│ │ │ │ ├── dpccn.yaml
│ │ │ │ └── tfgridnet.yaml
│ │ │ ├── local/
│ │ │ │ ├── prepare_data.sh
│ │ │ │ ├── prepare_librimix_enroll.py
│ │ │ │ └── prepare_spk2enroll_librispeech.py
│ │ │ ├── path.sh
│ │ │ └── run.sh
│ │ └── v2/
│ │ ├── README.md
│ │ ├── confs/
│ │ │ ├── bsrnn.yaml
│ │ │ ├── bsrnn_feats.yaml
│ │ │ ├── bsrnn_multi_optim.yaml
│ │ │ ├── dpcc_init_gan.yaml
│ │ │ ├── dpccn.yaml
│ │ │ ├── spexplus.yaml
│ │ │ └── tfgridnet.yaml
│ │ ├── local/
│ │ │ ├── prepare_data.sh
│ │ │ ├── prepare_librimix_enroll.py
│ │ │ └── prepare_spk2enroll_librispeech.py
│ │ ├── path.sh
│ │ └── run.sh
│ └── voxceleb1/
│ └── v2/
│ ├── confs/
│ │ └── bsrnn_online.yaml
│ ├── local/
│ │ ├── prepare_data.sh
│ │ ├── prepare_librimix_enroll.py
│ │ ├── prepare_spk2enroll_librispeech.py
│ │ └── prepare_spk2enroll_vox.py
│ ├── path.sh
│ └── run_online.sh
├── requirements.txt
├── runtime/
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── bin/
│ │ ├── CMakeLists.txt
│ │ └── separate_main.cc
│ ├── cmake/
│ │ ├── gflags.cmake
│ │ ├── glog.cmake
│ │ └── libtorch.cmake
│ ├── frontend/
│ │ ├── CMakeLists.txt
│ │ ├── fbank.h
│ │ ├── feature_pipeline.cc
│ │ ├── feature_pipeline.h
│ │ ├── fft.cc
│ │ ├── fft.h
│ │ └── wav.h
│ ├── separate/
│ │ ├── CMakeLists.txt
│ │ ├── separate_engine.cc
│ │ └── separate_engine.h
│ └── utils/
│ ├── CMakeLists.txt
│ ├── blocking_queue.h
│ ├── timer.h
│ ├── utils.cc
│ └── utils.h
├── setup.py
├── tools/
│ ├── extract_embed_depreciated.py
│ ├── make_lmdb.py
│ ├── make_shard_list_premix.py
│ ├── make_shard_online.py
│ ├── parse_options.sh
│ ├── print_train_val_curve.py
│ ├── run.pl
│ ├── score.sh
│ ├── show_enh_score.sh
│ ├── split_scp.pl
│ └── test_dataset.py
└── wesep/
├── __init__.py
├── bin/
│ ├── average_model.py
│ ├── export_jit.py
│ ├── infer.py
│ ├── score.py
│ ├── train.py
│ └── train_gan.py
├── cli/
│ ├── __init__.py
│ ├── extractor.py
│ ├── hub.py
│ └── utils.py
├── dataset/
│ ├── FRAM_RIR.py
│ ├── dataset.py
│ ├── lmdb_data.py
│ ├── processor.py
│ └── vad.py
├── models/
│ ├── __init__.py
│ ├── bsrnn.py
│ ├── bsrnn_feats.py
│ ├── bsrnn_multi_optim.py
│ ├── convtasnet.py
│ ├── dpccn.py
│ ├── sep_model.py
│ └── tfgridnet.py
├── modules/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── norm.py
│ │ └── speaker.py
│ ├── dpccn/
│ │ ├── __init__.py
│ │ └── convs.py
│ ├── metric_gan/
│ │ ├── __init__.py
│ │ └── discriminator.py
│ ├── tasnet/
│ │ ├── __init__.py
│ │ ├── convs.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── separation.py
│ │ ├── separator.py
│ │ └── speaker.py
│ └── tfgridnet/
│ ├── __init__.py
│ └── gridnet_block.py
└── utils/
├── abs_loss.py
├── checkpoint.py
├── datadir_writer.py
├── dnsmos.py
├── executor.py
├── executor_gan.py
├── file_utils.py
├── funcs.py
├── losses.py
├── schedulers.py
├── score.py
├── signal.py
└── utils.py
SYMBOL INDEX (411 symbols across 71 files)
FILE: examples/librimix/tse/v1/local/prepare_librimix_enroll.py
function prepare_librimix_enroll (line 9) | def prepare_librimix_enroll(wav_scp,
function prepare_librimix_enroll_v2 (line 46) | def prepare_librimix_enroll_v2(wav_scp,
FILE: examples/librimix/tse/v1/local/prepare_spk2enroll_librispeech.py
function get_spk2utt (line 9) | def get_spk2utt(paths, audio_format="flac"):
function get_spk2utt_librimix (line 21) | def get_spk2utt_librimix(paths, audio_format="flac"):
FILE: examples/librimix/tse/v2/local/prepare_librimix_enroll.py
function prepare_librimix_enroll (line 9) | def prepare_librimix_enroll(wav_scp,
function prepare_librimix_enroll_v2 (line 46) | def prepare_librimix_enroll_v2(wav_scp,
FILE: examples/librimix/tse/v2/local/prepare_spk2enroll_librispeech.py
function get_spk2utt (line 9) | def get_spk2utt(paths, audio_format="flac"):
function get_spk2utt_librimix (line 21) | def get_spk2utt_librimix(paths, audio_format="flac"):
FILE: examples/voxceleb1/v2/local/prepare_librimix_enroll.py
function prepare_librimix_enroll (line 9) | def prepare_librimix_enroll(wav_scp,
function prepare_librimix_enroll_v2 (line 46) | def prepare_librimix_enroll_v2(wav_scp,
FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_librispeech.py
function get_spk2utt_vox1 (line 9) | def get_spk2utt_vox1(paths, audio_format="flac"):
function get_spk2utt (line 22) | def get_spk2utt(paths, audio_format="flac"):
function get_spk2utt_librimix (line 34) | def get_spk2utt_librimix(paths, audio_format="flac"):
FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_vox.py
function get_spk2utt_from_wavscp (line 6) | def get_spk2utt_from_wavscp(wav_scp_path):
FILE: runtime/bin/separate_main.cc
function main (line 36) | int main(int argc, char* argv[]) {
FILE: runtime/frontend/fbank.h
function namespace (line 27) | namespace wenet {
FILE: runtime/frontend/feature_pipeline.cc
type wenet (line 20) | namespace wenet {
FILE: runtime/frontend/feature_pipeline.h
function namespace (line 27) | namespace wenet {
function class (line 56) | class FeaturePipeline {
FILE: runtime/frontend/fft.cc
type wenet (line 9) | namespace wenet {
function make_sintbl (line 11) | void make_sintbl(int n, float* sintbl) {
function make_bitrev (line 37) | void make_bitrev(int n, int* bitrev) {
function fft (line 59) | int fft(const int* bitrev, const float* sintbl, float* x, float* y, in...
FILE: runtime/frontend/fft.h
function namespace (line 13) | namespace wenet {
FILE: runtime/frontend/wav.h
function namespace (line 34) | namespace wenet {
FILE: runtime/separate/separate_engine.cc
type wesep (line 29) | namespace wesep {
FILE: runtime/separate/separate_engine.h
function namespace (line 27) | namespace wesep {
FILE: runtime/utils/blocking_queue.h
function namespace (line 24) | namespace wenet {
FILE: runtime/utils/timer.h
function namespace (line 20) | namespace wenet {
FILE: runtime/utils/utils.cc
type wesep (line 26) | namespace wesep {
function Ltrim (line 28) | std::string Ltrim(const std::string& str) {
function Rtrim (line 33) | std::string Rtrim(const std::string& str) {
function Trim (line 38) | std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); }
function SplitString (line 40) | void SplitString(const std::string& str, std::vector<std::string>* str...
function SplitStringToVector (line 44) | void SplitStringToVector(const std::string& full, const char* delim,
function ToWString (line 59) | std::wstring ToWString(const std::string& str) {
FILE: runtime/utils/utils.h
function namespace (line 22) | namespace wesep {
FILE: tools/extract_embed_depreciated.py
function get_args (line 26) | def get_args():
function compute_fbank (line 37) | def compute_fbank(wav_path,
function main (line 62) | def main():
FILE: tools/make_lmdb.py
function get_args (line 23) | def get_args():
function main (line 31) | def main():
FILE: tools/make_shard_list_premix.py
function write_tar_file (line 29) | def write_tar_file(data_list, tar_file, index=0, total=1):
function get_args (line 76) | def get_args():
function main (line 101) | def main():
FILE: tools/make_shard_online.py
function write_tar_file (line 28) | def write_tar_file(data_list, tar_file, index=0, total=1):
function get_args (line 65) | def get_args():
function main (line 90) | def main():
FILE: tools/test_dataset.py
function test_premixed_dataset (line 8) | def test_premixed_dataset():
function test_online_dataset (line 31) | def test_online_dataset():
FILE: wesep/bin/average_model.py
function get_args (line 25) | def get_args():
function main (line 64) | def main():
FILE: wesep/bin/export_jit.py
function get_args (line 13) | def get_args():
function main (line 22) | def main():
FILE: wesep/bin/infer.py
function infer (line 27) | def infer(config="confs/conf.yaml", **kwargs):
FILE: wesep/bin/score.py
function get_readers (line 19) | def get_readers(scps: List[str], dtype: str):
function read_audio (line 25) | def read_audio(reader, key, audio_format="sound"):
function scoring (line 32) | def scoring(
function get_parser (line 218) | def get_parser():
function main (line 316) | def main(cmd=None):
FILE: wesep/bin/train.py
function train (line 51) | def train(config="conf/config.yaml", **kwargs):
FILE: wesep/bin/train_gan.py
function train (line 51) | def train(config="conf/config.yaml", **kwargs):
FILE: wesep/cli/extractor.py
class Extractor (line 18) | class Extractor:
method __init__ (line 20) | def __init__(self, model_dir: str):
method set_wavform_norm (line 47) | def set_wavform_norm(self, wavform_norm: bool):
method set_resample_rate (line 50) | def set_resample_rate(self, resample_rate: int):
method set_vad (line 53) | def set_vad(self, apply_vad: bool):
method set_device (line 56) | def set_device(self, device: str):
method set_output_norm (line 60) | def set_output_norm(self, output_norm: bool):
method compute_fbank (line 63) | def compute_fbank(
method extract_speech (line 83) | def extract_speech(self, audio_path: str, audio_path_2: str):
method extract_speech_from_pcm (line 95) | def extract_speech_from_pcm(self,
function load_model (line 162) | def load_model(language: str) -> Extractor:
function load_model_local (line 167) | def load_model_local(model_dir: str) -> Extractor:
function main (line 171) | def main():
FILE: wesep/cli/hub.py
function download (line 27) | def download(url: str, dest: str, only_child=True):
class Hub (line 87) | class Hub(object):
method __init__ (line 99) | def __init__(self) -> None:
method get_model (line 103) | def get_model(lang: str) -> str:
FILE: wesep/cli/utils.py
function get_args (line 4) | def get_args():
FILE: wesep/dataset/FRAM_RIR.py
function calc_cos (line 26) | def calc_cos(orientation_rad):
function freq_invariant_decay_func (line 43) | def freq_invariant_decay_func(cos_theta, pattern="cardioid"):
function freq_invariant_src_decay_func (line 73) | def freq_invariant_src_decay_func(mic_pos,
function freq_invariant_mic_decay_func (line 99) | def freq_invariant_mic_decay_func(mic_pos,
function FRAM_RIR (line 126) | def FRAM_RIR(
function sample_mic_arch (line 358) | def sample_mic_arch(n_mic, mic_spacing=None, bounding_box=None):
function sample_src_pos (line 390) | def sample_src_pos(
function sample_mic_array_pos (line 414) | def sample_mic_array_pos(mic_arch, room_dim, min_dis_wall=None):
function sample_a_config (line 482) | def sample_a_config(simu_config):
function single_channel (line 514) | def single_channel(simu_config):
function multi_channel_array (line 526) | def multi_channel_array(simu_config):
function multi_channel_adhoc (line 539) | def multi_channel_adhoc(simu_config):
function multi_channel_src_orientation (line 554) | def multi_channel_src_orientation():
function multi_channel_mic_orientation (line 592) | def multi_channel_mic_orientation():
FILE: wesep/dataset/dataset.py
class Processor (line 26) | class Processor(IterableDataset):
method __init__ (line 28) | def __init__(self, source, f, *args, **kw):
method set_epoch (line 35) | def set_epoch(self, epoch):
method __iter__ (line 38) | def __iter__(self):
method apply (line 46) | def apply(self, f):
class DistributedSampler (line 51) | class DistributedSampler:
method __init__ (line 53) | def __init__(self, shuffle=True, partition=True):
method update (line 59) | def update(self):
method set_epoch (line 81) | def set_epoch(self, epoch):
method sample (line 84) | def sample(self, data):
class DataList (line 106) | class DataList(IterableDataset):
method __init__ (line 108) | def __init__(self,
method set_epoch (line 117) | def set_epoch(self, epoch):
method __iter__ (line 120) | def __iter__(self):
function tse_collate_fn_2spk (line 139) | def tse_collate_fn_2spk(batch, mode="min"):
function tse_collate_fn (line 206) | def tse_collate_fn(batch, mode="min"):
function Dataset (line 267) | def Dataset(
FILE: wesep/dataset/lmdb_data.py
class LmdbData (line 21) | class LmdbData:
method __init__ (line 23) | def __init__(self, lmdb_file):
method random_one (line 34) | def random_one(self):
method __del__ (line 43) | def __del__(self):
FILE: wesep/dataset/processor.py
function url_opener (line 32) | def url_opener(data):
function tar_file_and_group (line 63) | def tar_file_and_group(data):
function tar_file_and_group_single_spk (line 128) | def tar_file_and_group_single_spk(data):
function parse_raw_single_spk (line 180) | def parse_raw_single_spk(data):
function mix_speakers (line 210) | def mix_speakers(data, num_speaker=2, shuffle_size=1000):
function snr_mixer (line 277) | def snr_mixer(data, use_random_snr: bool = False):
function shuffle (line 323) | def shuffle(data, shuffle_size=2500):
function spk_to_id (line 347) | def spk_to_id(data, spk2id):
function resample (line 367) | def resample(data, resample_rate=16000):
function sample_spk_embedding (line 391) | def sample_spk_embedding(data, spk_embeds):
function sample_fix_spk_embedding (line 407) | def sample_fix_spk_embedding(data, spk2embed_dict, spk1_embed, spk2_embed):
function sample_enrollment (line 428) | def sample_enrollment(data, spk_embeds, dict_spk):
function sample_fix_spk_enrollment (line 450) | def sample_fix_spk_enrollment(data,
function compute_fbank (line 480) | def compute_fbank(data,
function apply_cmvn (line 515) | def apply_cmvn(data, norm_mean=True, norm_var=False):
function get_random_chunk (line 538) | def get_random_chunk(data_list, chunk_len):
function filter_len (line 581) | def filter_len(
function random_chunk (line 612) | def random_chunk(data, chunk_len):
function fix_chunk (line 631) | def fix_chunk(data, chunk_len):
function add_noise (line 650) | def add_noise(
function add_reverb (line 746) | def add_reverb(data, reverb_prob=0):
function add_noise_on_enroll (line 785) | def add_noise_on_enroll(
function add_reverb_on_enroll (line 892) | def add_reverb_on_enroll(data, reverb_enroll_prob=0):
function spec_aug (line 928) | def spec_aug(data, num_t_mask=1, num_f_mask=1, max_t=10, max_f=8, prob=0):
FILE: wesep/dataset/vad.py
class VoiceActivityDetection (line 5) | class VoiceActivityDetection:
method __init__ (line 7) | def __init__(self, wave):
method segmentation (line 10) | def segmentation(self, overlap, slice_len):
method calc_energy (line 31) | def calc_energy(self, audio):
method select (line 42) | def select(self):
FILE: wesep/models/__init__.py
function get_model (line 10) | def get_model(model_name: str):
FILE: wesep/models/bsrnn.py
class ResRNN (line 16) | class ResRNN(nn.Module):
method __init__ (line 18) | def __init__(self, input_size, hidden_size, bidirectional=True):
method forward (line 38) | def forward(self, input):
class BSNet (line 55) | class BSNet(nn.Module):
method __init__ (line 57) | def __init__(self, in_channel, nband=7, bidirectional=True):
method forward (line 69) | def forward(self, input, dummy: Optional[torch.Tensor] = None):
class FuseSeparation (line 86) | class FuseSeparation(nn.Module):
method __init__ (line 88) | def __init__(
method forward (line 125) | def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
class BSRNN (line 151) | class BSRNN(nn.Module):
method __init__ (line 154) | def __init__(
method pad_input (line 284) | def pad_input(self, input, window, stride):
method forward (line 300) | def forward(self, input, embeddings):
FILE: wesep/models/bsrnn_feats.py
class ResRNN (line 18) | class ResRNN(nn.Module):
method __init__ (line 20) | def __init__(self, input_size, hidden_size, bidirectional=True):
method forward (line 40) | def forward(self, input):
class BSNet (line 57) | class BSNet(nn.Module):
method __init__ (line 59) | def __init__(self, in_channel, nband=7, bidirectional=True):
method forward (line 71) | def forward(self, input, dummy: Optional[torch.Tensor] = None):
class CrossAtt (line 87) | class CrossAtt(nn.Module):
method __init__ (line 88) | def __init__(self, embed_dim, num_heads, *args, **kwargs):
method forward (line 93) | def forward(self, query, key, value):
class FuseSeparation (line 110) | class FuseSeparation(nn.Module):
method __init__ (line 112) | def __init__(
method forward (line 160) | def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
class BSRNN_Feats (line 201) | class BSRNN_Feats(nn.Module):
method __init__ (line 204) | def __init__(
method pad_input (line 340) | def pad_input(self, input, window, stride):
method forward (line 356) | def forward(self, input, embeddings):
FILE: wesep/models/bsrnn_multi_optim.py
class ResRNN (line 15) | class ResRNN(nn.Module):
method __init__ (line 17) | def __init__(self, input_size, hidden_size, bidirectional=True):
method forward (line 38) | def forward(self, input):
class BSNet (line 55) | class BSNet(nn.Module):
method __init__ (line 57) | def __init__(self, in_channel, nband=7, bidirectional=True):
method forward (line 69) | def forward(self, input, dummy: Optional[torch.Tensor] = None):
class FuseSeparation (line 91) | class FuseSeparation(nn.Module):
method __init__ (line 93) | def __init__(
method forward (line 132) | def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
class BSRNN_Multi (line 156) | class BSRNN_Multi(nn.Module):
method __init__ (line 159) | def __init__(
method pad_input (line 290) | def pad_input(self, input, window, stride):
method forward (line 306) | def forward(self, input, embeddings):
FILE: wesep/models/convtasnet.py
class ConvTasNet (line 14) | class ConvTasNet(nn.Module):
method __init__ (line 16) | def __init__(
method forward (line 162) | def forward(self, x, embeddings):
function check_parameters (line 222) | def check_parameters(net):
function test_convtasnet (line 230) | def test_convtasnet():
FILE: wesep/models/dpccn.py
class DPCCN (line 16) | class DPCCN(nn.Module):
method __init__ (line 18) | def __init__(
method _build_encoder (line 131) | def _build_encoder(self, **enc_kargs):
method _build_decoder (line 151) | def _build_decoder(self, **dec_kargs):
method _build_tcn_blocks (line 174) | def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs):
method _build_tcn_layers (line 184) | def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs):
method _build_avg_pool (line 195) | def _build_avg_pool(self, pool_size):
method forward (line 206) | def forward(self, input, aux):
FILE: wesep/models/sep_model.py
function get_model (line 7) | def get_model(model_name: str):
FILE: wesep/models/tfgridnet.py
class TFGridNet (line 29) | class TFGridNet(nn.Module):
method __init__ (line 76) | def __init__(
method forward (line 197) | def forward(
method num_spk (line 305) | def num_spk(self):
method pad2 (line 309) | def pad2(input_tensor, target_len):
FILE: wesep/modules/common/norm.py
class GlobalChannelLayerNorm (line 7) | class GlobalChannelLayerNorm(nn.Module):
method __init__ (line 18) | def __init__(self, dim, eps=1e-05, elementwise_affine=True):
method forward (line 31) | def forward(self, x):
class ChannelWiseLayerNorm (line 51) | class ChannelWiseLayerNorm(nn.LayerNorm):
method __init__ (line 56) | def __init__(self, *args, **kwargs):
method forward (line 59) | def forward(self, x):
function select_norm (line 69) | def select_norm(norm, dim):
class FiLM (line 84) | class FiLM(nn.Module):
method __init__ (line 89) | def __init__(self,
method init_weights (line 111) | def init_weights(self):
method forward (line 118) | def forward(self, embed, x):
class ConditionalLayerNorm (line 142) | class ConditionalLayerNorm(nn.Module):
method __init__ (line 147) | def __init__(self,
method reset_parameters (line 172) | def reset_parameters(self):
method forward (line 176) | def forward(self, input, embed):
method extra_repr (line 189) | def extra_repr(self):
FILE: wesep/modules/common/speaker.py
class PreEmphasis (line 10) | class PreEmphasis(torch.nn.Module):
method __init__ (line 12) | def __init__(self, coef: float = 0.97):
method forward (line 20) | def forward(self, input: torch.tensor) -> torch.tensor:
class SpeakerTransform (line 26) | class SpeakerTransform(nn.Module):
method __init__ (line 28) | def __init__(self, embed_dim=256, num_layers=3, hid_dim=128):
method forward (line 45) | def forward(self, x):
class LinearLayer (line 52) | class LinearLayer(nn.Module):
method __init__ (line 54) | def __init__(self, in_features, out_features, bias=True):
method forward (line 59) | def forward(self, x, dummy: Optional[torch.Tensor] = None):
class SpeakerFuseLayer (line 63) | class SpeakerFuseLayer(nn.Module):
method __init__ (line 65) | def __init__(self, embed_dim=256, feat_dim=512, fuse_type="concat"):
method forward (line 81) | def forward(self, x, embed):
function test_speaker_fuse (line 128) | def test_speaker_fuse():
FILE: wesep/modules/dpccn/convs.py
class Conv1D (line 7) | class Conv1D(nn.Conv1d):
method __init__ (line 12) | def __init__(self, *args, **kwargs):
method forward (line 15) | def forward(self, x, squeeze=False):
class Conv2dBlock (line 28) | class Conv2dBlock(nn.Module):
method __init__ (line 30) | def __init__(
method forward (line 44) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class ConvTrans2dBlock (line 50) | class ConvTrans2dBlock(nn.Module):
method __init__ (line 52) | def __init__(
method forward (line 67) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class DenseBlock (line 73) | class DenseBlock(nn.Module):
method __init__ (line 75) | def __init__(self, in_dims, out_dims, mode="enc", **kargs):
method forward (line 97) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class TCNBlock (line 106) | class TCNBlock(nn.Module):
method __init__ (line 112) | def __init__(
method forward (line 144) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: wesep/modules/metric_gan/discriminator.py
class LearnableSigmoid (line 6) | class LearnableSigmoid(nn.Module):
method __init__ (line 8) | def __init__(self, in_features, beta=1):
method forward (line 14) | def forward(self, x):
class CMGAN_Discriminator (line 19) | class CMGAN_Discriminator(nn.Module):
method __init__ (line 21) | def __init__(
method forward (line 106) | def forward(self, ref_wav, est_wav):
function test_CMGAN_Discriminator (line 145) | def test_CMGAN_Discriminator():
FILE: wesep/modules/tasnet/convs.py
class Conv1D (line 9) | class Conv1D(nn.Conv1d):
method __init__ (line 11) | def __init__(self, *args, **kwargs):
method forward (line 14) | def forward(self, x, squeeze=False):
class ConvTrans1D (line 25) | class ConvTrans1D(nn.ConvTranspose1d):
method __init__ (line 27) | def __init__(self, *args, **kwargs):
method forward (line 30) | def forward(self, x, squeeze=False):
class Conv1DBlock (line 43) | class Conv1DBlock(nn.Module):
method __init__ (line 48) | def __init__(
method forward (line 84) | def forward(self, x):
class Conv1DBlock4Fuse (line 107) | class Conv1DBlock4Fuse(nn.Module):
method __init__ (line 113) | def __init__(
method forward (line 148) | def forward(self, x, aux):
FILE: wesep/modules/tasnet/decoder.py
class DeepDecoder (line 7) | class DeepDecoder(nn.Module):
method __init__ (line 9) | def __init__(self, N, kernel_size=16, stride=16 // 2):
method forward (line 47) | def forward(self, x):
class MultiDecoder (line 60) | class MultiDecoder(nn.Module):
method __init__ (line 62) | def __init__(self, in_channels, middle_channels, out_channels, kernel_...
method forward (line 92) | def forward(self, x, w1, w2, w3, actLayer):
FILE: wesep/modules/tasnet/encoder.py
class DeepEncoder (line 9) | class DeepEncoder(nn.Module):
method __init__ (line 11) | def __init__(self, in_channels, out_channels, kernel_size, stride):
method forward (line 53) | def forward(self, x):
class MultiEncoder (line 63) | class MultiEncoder(nn.Module):
method __init__ (line 65) | def __init__(self, in_channels, middle_channels, out_channels, kernel_...
method forward (line 95) | def forward(self, x):
FILE: wesep/modules/tasnet/separation.py
class Separation (line 8) | class Separation(nn.Module):
method __init__ (line 10) | def __init__(
method forward (line 41) | def forward(self, x):
class FuseSeparation (line 60) | class FuseSeparation(nn.Module):
method __init__ (line 62) | def __init__(
method forward (line 166) | def forward(self, x, spk_embedding):
FILE: wesep/modules/tasnet/separator.py
class Separation (line 6) | class Separation(nn.Module):
method __init__ (line 18) | def __init__(self, R, X, B, H, P, norm="gln", causal=False, skip_con=T...
method forward (line 27) | def forward(self, x):
FILE: wesep/modules/tasnet/speaker.py
class ResBlock (line 7) | class ResBlock(nn.Module):
method __init__ (line 15) | def __init__(self, in_dims, out_dims):
method forward (line 34) | def forward(self, x):
class ResNet4SpExplus (line 48) | class ResNet4SpExplus(nn.Module):
method __init__ (line 50) | def __init__(self, in_channel=256, C_embedding=256):
method forward (line 61) | def forward(self, x):
FILE: wesep/modules/tfgridnet/gridnet_block.py
class GridNetBlock (line 26) | class GridNetBlock(nn.Module):
method __getitem__ (line 28) | def __getitem__(self, key):
method __init__ (line 31) | def __init__(
method forward (line 118) | def forward(self, x):
class LayerNormalization4DCF (line 230) | class LayerNormalization4DCF(nn.Module):
method __init__ (line 232) | def __init__(self, input_dimension, eps=1e-5):
method forward (line 242) | def forward(self, x):
class AllHeadPReLULayerNormalization4DCF (line 256) | class AllHeadPReLULayerNormalization4DCF(nn.Module):
method __init__ (line 258) | def __init__(self, input_dimension, eps=1e-5):
method forward (line 273) | def forward(self, x):
FILE: wesep/utils/abs_loss.py
class AbsEnhLoss (line 8) | class AbsEnhLoss(torch.nn.Module, ABC):
method name (line 13) | def name(self) -> str:
method only_for_test (line 19) | def only_for_test(self) -> bool:
method forward (line 23) | def forward(
FILE: wesep/utils/checkpoint.py
function load_pretrained_model (line 8) | def load_pretrained_model(model: torch.nn.Module,
function load_checkpoint (line 30) | def load_checkpoint(
function save_checkpoint (line 81) | def save_checkpoint(
FILE: wesep/utils/datadir_writer.py
class DatadirWriter (line 8) | class DatadirWriter:
method __init__ (line 21) | def __init__(self, p: Union[Path, str]):
method __enter__ (line 28) | def __enter__(self):
method __getitem__ (line 31) | def __getitem__(self, key: str) -> "DatadirWriter":
method __setitem__ (line 43) | def __setitem__(self, key: str, value: str):
method __exit__ (line 56) | def __exit__(self, exc_type, exc_val, exc_tb):
method close (line 59) | def close(self):
FILE: wesep/utils/dnsmos.py
function poly1d (line 18) | def poly1d(coefficients, use_numpy=False):
class DNSMOS_web (line 29) | class DNSMOS_web:
method __init__ (line 32) | def __init__(self, auth_key):
method __call__ (line 35) | def __call__(self, aud, input_fs, fname="", method="p808"):
class DNSMOS_local (line 60) | class DNSMOS_local:
method __init__ (line 63) | def __init__(
method audio_melspec (line 121) | def audio_melspec(
method get_polyfit_val (line 157) | def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
method __call__ (line 177) | def __call__(self, aud, input_fs, is_personalized_MOS=False):
FILE: wesep/utils/executor.py
class Executor (line 27) | class Executor:
method __init__ (line 29) | def __init__(self):
method train (line 32) | def train(
method cv (line 154) | def cv(
FILE: wesep/utils/executor_gan.py
class ExecutorGAN (line 28) | class ExecutorGAN:
method __init__ (line 30) | def __init__(self):
method train (line 33) | def train(
method cv (line 191) | def cv(
method mse_loss (line 274) | def mse_loss(self, output, target):
method _calculate_discriminator_loss (line 277) | def _calculate_discriminator_loss(
FILE: wesep/utils/file_utils.py
function read_lists (line 11) | def read_lists(list_file):
function read_vec_scp_file (line 20) | def read_vec_scp_file(scp_file):
function norm_embeddings (line 35) | def norm_embeddings(embeddings, kaldi_style=True):
function read_label_file (line 50) | def read_label_file(label_file):
function load_speaker_embeddings (line 64) | def load_speaker_embeddings(scp_file, utt2spk_file):
function read_2columns_text (line 86) | def read_2columns_text(path: Union[Path, str]) -> Dict[str, str]:
function read_multi_columns_text (line 116) | def read_multi_columns_text(
function soundfile_read (line 164) | def soundfile_read(
class SoundScpReader (line 233) | class SoundScpReader(collections.abc.Mapping):
method __init__ (line 270) | def __init__(
method __getitem__ (line 289) | def __getitem__(self, key) -> Tuple[int, np.ndarray]:
method get_path (line 301) | def get_path(self, key):
method __contains__ (line 304) | def __contains__(self, item):
method __len__ (line 307) | def __len__(self):
method __iter__ (line 310) | def __iter__(self):
method keys (line 313) | def keys(self):
FILE: wesep/utils/funcs.py
function overlap_and_add (line 10) | def overlap_and_add(signal, frame_step):
function remove_pad (line 59) | def remove_pad(inputs, inputs_lengths):
function clip_gradients (line 79) | def clip_gradients(model, clip):
function compute_fbank (line 91) | def compute_fbank(
function apply_cmvn (line 119) | def apply_cmvn(data, norm_mean=True, norm_var=False):
FILE: wesep/utils/losses.py
function parse_loss (line 34) | def parse_loss(loss):
FILE: wesep/utils/schedulers.py
class MarginScheduler (line 20) | class MarginScheduler:
method __init__ (line 22) | def __init__(
method init_margin (line 54) | def init_margin(self):
method get_increase_margin (line 58) | def get_increase_margin(self):
method step (line 73) | def step(self, current_iter=None):
method get_margin (line 90) | def get_margin(self):
class BaseClass (line 99) | class BaseClass:
method __init__ (line 104) | def __init__(
method get_multi_process_coeff (line 130) | def get_multi_process_coeff(self):
method get_current_lr (line 142) | def get_current_lr(self):
method get_lr (line 148) | def get_lr(self):
method set_lr (line 151) | def set_lr(self):
method step (line 156) | def step(self, current_iter=None):
method step_return_lr (line 163) | def step_return_lr(self, current_iter=None):
method state_dict (line 172) | def state_dict(self):
method load_state_dict (line 183) | def load_state_dict(self, state_dict):
class ExponentialDecrease (line 193) | class ExponentialDecrease(BaseClass):
method __init__ (line 195) | def __init__(
method get_current_lr (line 217) | def get_current_lr(self):
class TriAngular2 (line 225) | class TriAngular2(BaseClass):
method __init__ (line 230) | def __init__(
method get_current_lr (line 260) | def get_current_lr(self):
function show_lr_curve (line 280) | def show_lr_curve(scheduler):
FILE: wesep/utils/score.py
function cal_SISNR (line 7) | def cal_SISNR(est, ref, eps=1e-8):
function cal_SISNRi (line 24) | def cal_SISNRi(est, ref, mix, eps=1e-8):
function cal_PESQ (line 39) | def cal_PESQ(est, ref):
function cal_PESQ_norm (line 46) | def cal_PESQ_norm(est, ref):
function cal_PESQi (line 58) | def cal_PESQi(est, ref, mix):
function cal_STOI (line 73) | def cal_STOI(est, ref):
function cal_STOIi (line 79) | def cal_STOIi(est, ref, mix):
function batch_evaluation (line 94) | def batch_evaluation(metric, est, ref, lengths=None, parallel=False, n_j...
FILE: wesep/utils/signal.py
function init_kernels (line 8) | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
class ConvSTFT (line 38) | class ConvSTFT(nn.Module):
method __init__ (line 40) | def __init__(
method forward (line 62) | def forward(self, inputs):
class ConviSTFT (line 80) | class ConviSTFT(nn.Module):
method __init__ (line 82) | def __init__(
method forward (line 110) | def forward(self, inputs, phase=None):
FILE: wesep/utils/utils.py
function str2bool (line 31) | def str2bool(value: str) -> bool:
function get_logger (line 35) | def get_logger(outdir, fname):
function setup_logger (line 50) | def setup_logger(rank, exp_dir, device_ids, MAX_NUM_LOG_FILES: int = 100):
function parse_config_or_kwargs (line 73) | def parse_config_or_kwargs(config_file, **kwargs):
function validate_path (line 93) | def validate_path(dir_name):
function set_seed (line 103) | def set_seed(seed=42):
function generate_enahnced_scp (line 115) | def generate_enahnced_scp(directory: str, extension: str = "wav"):
function get_commandline_args (line 139) | def get_commandline_args():
class ArgumentParser (line 176) | class ArgumentParser(argparse.ArgumentParser):
method __init__ (line 189) | def __init__(self, *args, **kwargs):
method parse_known_args (line 193) | def parse_known_args(self, args=None, namespace=None):
function get_layer (line 221) | def get_layer(l_name, library=torch.nn):
Condensed preview — 128 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (641K chars).
[
{
"path": ".clang-format",
"chars": 2722,
"preview": "---\nLanguage: Cpp\n# BasedOnStyle: Google\nAccessModifierOffset: -1\nAlignAfterOpenBracket: Align\nAlignConsecutiveA"
},
{
"path": ".flake8",
"chars": 642,
"preview": "[flake8]\nselect = B,C,E,F,P,T4,W,B9\nmax-line-length = 80\nmax-doc-length = 80\n# C408 ignored because we like the dict key"
},
{
"path": ".github/workflows/lint.yml",
"chars": 3047,
"preview": "name: Lint\n\non:\n push:\n branches:\n - main\n pull_request:\n\njobs:\n quick-checks:\n runs-on: ubuntu-latest\n s"
},
{
"path": ".gitignore",
"chars": 452,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n*.egg-info\n\n# Visual Studio Code files\n.vscode"
},
{
"path": ".pre-commit-config.yaml",
"chars": 511,
"preview": "repos:\n - repo: https://github.com/pre-commit/pre-commit-hooks\n rev: v4.5.0\n hooks:\n - id: trailing-whitespace"
},
{
"path": "CPPLINT.cfg",
"chars": 33,
"preview": "root=runtime\nfilter=-build/c++11\n"
},
{
"path": "README.md",
"chars": 3043,
"preview": "# Wesep\r\n\r\n> We aim to build a toolkit focusing on front-end processing in the cocktail party set up, including target s"
},
{
"path": "examples/librimix/tse/README.md",
"chars": 863,
"preview": "# Libri2Mix Recipe\n\n\n## Goal of this recipe\nThis recipe aims to illustrate how to use WeSep to perform the target speake"
},
{
"path": "examples/librimix/tse/v1/README.md",
"chars": 21248,
"preview": "## Tutorial on LibriMix\n\nIf you meet any problems when going through this tutorial, please feel free to ask in github is"
},
{
"path": "examples/librimix/tse/v1/confs/bsrnn.yaml",
"chars": 1068,
"preview": "dataloader_args:\n batch_size: 16 # A800: 16\n drop_last: true\n num_workers: 6\n pin_memory: false\n prefetch_factor: "
},
{
"path": "examples/librimix/tse/v1/confs/dpcc_init_gan.yaml",
"chars": 1362,
"preview": "use_metric_loss: true\n\ndataloader_args:\n batch_size: 4\n drop_last: true\n num_workers: 4\n pin_memory: false\n prefetc"
},
{
"path": "examples/librimix/tse/v1/confs/dpccn.yaml",
"chars": 1004,
"preview": "dataloader_args:\n batch_size: 4\n drop_last: true\n num_workers: 4\n pin_memory: false\n prefetch_factor: 4\n\ndataset_ar"
},
{
"path": "examples/librimix/tse/v1/confs/tfgridnet.yaml",
"chars": 1140,
"preview": "dataloader_args:\n batch_size: 4\n drop_last: true\n num_workers: 4\n pin_memory: false\n prefetch_factor: 4\n\ndataset_ar"
},
{
"path": "examples/librimix/tse/v1/local/prepare_data.sh",
"chars": 3836,
"preview": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nmix_data_path='/Data/Libri2M"
},
{
"path": "examples/librimix/tse/v1/local/prepare_librimix_enroll.py",
"chars": 4755,
"preview": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.util"
},
{
"path": "examples/librimix/tse/v1/local/prepare_spk2enroll_librispeech.py",
"chars": 2283,
"preview": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.u"
},
{
"path": "examples/librimix/tse/v1/path.sh",
"chars": 180,
"preview": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHON"
},
{
"path": "examples/librimix/tse/v1/run.sh",
"chars": 3799,
"preview": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\n# General configuration\nstage=-"
},
{
"path": "examples/librimix/tse/v2/README.md",
"chars": 24260,
"preview": "## Tutorial on LibriMix\n\nIf you meet any problems when going through this tutorial, please feel free to ask in github is"
},
{
"path": "examples/librimix/tse/v2/confs/bsrnn.yaml",
"chars": 3643,
"preview": "dataloader_args:\n batch_size: 8 #RTX2080 1, V100: 8, A800: 16\n drop_last: true\n num_workers: 6\n pin_memory: true\n p"
},
{
"path": "examples/librimix/tse/v2/confs/bsrnn_feats.yaml",
"chars": 3181,
"preview": "dataloader_args:\n batch_size: 4 #RTX2080 1, V100: 4, A800: 12\n drop_last: true\n num_workers: 6\n pin_memory: true\n p"
},
{
"path": "examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml",
"chars": 3571,
"preview": "dataloader_args:\n batch_size: 8 #RTX2080 1, V100: 8, A800: 16\n drop_last: true\n num_workers: 6\n pin_memory: true\n p"
},
{
"path": "examples/librimix/tse/v2/confs/dpcc_init_gan.yaml",
"chars": 1669,
"preview": "use_metric_loss: true\n\ndataloader_args:\n batch_size: 4\n drop_last: true\n num_workers: 4\n pin_memory: false\n prefetc"
},
{
"path": "examples/librimix/tse/v2/confs/dpccn.yaml",
"chars": 3284,
"preview": "dataloader_args:\n batch_size: 6\n drop_last: true\n num_workers: 6\n pin_memory: false\n prefetch_factor: 6\n\ndataset_ar"
},
{
"path": "examples/librimix/tse/v2/confs/spexplus.yaml",
"chars": 1953,
"preview": "dataloader_args:\n batch_size: 8 #A800: 8\n drop_last: true\n num_workers: 4\n pin_memory: true\n prefetch_factor:"
},
{
"path": "examples/librimix/tse/v2/confs/tfgridnet.yaml",
"chars": 3435,
"preview": "dataloader_args:\n batch_size: 1\n drop_last: true\n num_workers: 6\n pin_memory: false\n prefetch_factor: 6\n\ndataset_ar"
},
{
"path": "examples/librimix/tse/v2/local/prepare_data.sh",
"chars": 5106,
"preview": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nmix_data_path='./Libri2Mix/w"
},
{
"path": "examples/librimix/tse/v2/local/prepare_librimix_enroll.py",
"chars": 4696,
"preview": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.util"
},
{
"path": "examples/librimix/tse/v2/local/prepare_spk2enroll_librispeech.py",
"chars": 2283,
"preview": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.u"
},
{
"path": "examples/librimix/tse/v2/path.sh",
"chars": 180,
"preview": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHON"
},
{
"path": "examples/librimix/tse/v2/run.sh",
"chars": 3797,
"preview": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\n# General configuration\nstage=-"
},
{
"path": "examples/voxceleb1/v2/confs/bsrnn_online.yaml",
"chars": 2812,
"preview": "dataloader_args:\n batch_size: 8\n drop_last: true\n num_workers: 6\n pin_memory: false\n prefetch_factor: 6\n\ndataset_ar"
},
{
"path": "examples/voxceleb1/v2/local/prepare_data.sh",
"chars": 5823,
"preview": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nsingle_data_path='./voxceleb"
},
{
"path": "examples/voxceleb1/v2/local/prepare_librimix_enroll.py",
"chars": 4784,
"preview": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.util"
},
{
"path": "examples/voxceleb1/v2/local/prepare_spk2enroll_librispeech.py",
"chars": 2984,
"preview": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.u"
},
{
"path": "examples/voxceleb1/v2/local/prepare_spk2enroll_vox.py",
"chars": 1024,
"preview": "import json\nfrom collections import defaultdict\nfrom pathlib import Path\n\n\ndef get_spk2utt_from_wavscp(wav_scp_path):\n "
},
{
"path": "examples/voxceleb1/v2/path.sh",
"chars": 177,
"preview": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHON"
},
{
"path": "examples/voxceleb1/v2/run_online.sh",
"chars": 3801,
"preview": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\nstage=-1\nstop_stage=-1\n\ndata=da"
},
{
"path": "requirements.txt",
"chars": 455,
"preview": "fast_bss_eval==0.1.4\nfire==0.4.0\njoblib==1.1.0\nkaldiio==2.18.0\nlibrosa==0.10.1\nlmdb==1.3.0\nmatplotlib==3.5.1\nmir_eval==0"
},
{
"path": "runtime/.gitignore",
"chars": 15,
"preview": "fc_base\nbuild*\n"
},
{
"path": "runtime/CMakeLists.txt",
"chars": 692,
"preview": "\ncmake_minimum_required(VERSION 3.14)\nproject(wesep VERSION 0.1)\n\noption(CXX11_ABI \"whether to use CXX11_ABI libtorch\" O"
},
{
"path": "runtime/README.md",
"chars": 442,
"preview": "# Libtorch backend on wesep\n\n\n* Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above.\n\n``` sh\nmkdir b"
},
{
"path": "runtime/bin/CMakeLists.txt",
"chars": 109,
"preview": "add_executable(separate_main separate_main.cc)\ntarget_link_libraries(separate_main PUBLIC frontend separate)\n"
},
{
"path": "runtime/bin/separate_main.cc",
"chars": 4133,
"preview": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"Licens"
},
{
"path": "runtime/cmake/gflags.cmake",
"chars": 261,
"preview": "FetchContent_Declare(gflags\n URL https://github.com/gflags/gflags/archive/v2.2.2.zip\n URL_HASH SHA256=19713a36c9f"
},
{
"path": "runtime/cmake/glog.cmake",
"chars": 268,
"preview": "FetchContent_Declare(glog\n URL https://github.com/google/glog/archive/v0.4.0.zip\n URL_HASH SHA256=9e1b54eb2782f53"
},
{
"path": "runtime/cmake/libtorch.cmake",
"chars": 907,
"preview": "if(${CMAKE_SYSTEM_NAME} STREQUAL \"Linux\")\n if(CXX11_ABI)\n set(LIBTORCH_URL \"https://download.pytorch.org/libtorch/cp"
},
{
"path": "runtime/frontend/CMakeLists.txt",
"chars": 106,
"preview": "add_library(frontend STATIC\n feature_pipeline.cc\n fft.cc\n)\ntarget_link_libraries(frontend PUBLIC utils)\n"
},
{
"path": "runtime/frontend/fbank.h",
"chars": 7570,
"preview": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// y"
},
{
"path": "runtime/frontend/feature_pipeline.cc",
"chars": 3432,
"preview": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// y"
},
{
"path": "runtime/frontend/feature_pipeline.h",
"chars": 4151,
"preview": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// y"
},
{
"path": "runtime/frontend/fft.cc",
"chars": 2237,
"preview": "// Copyright (c) 2016 HR\n\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"frontend/fft.h\"\n\nnamespace"
},
{
"path": "runtime/frontend/fft.h",
"chars": 467,
"preview": "// Copyright (c) 2016 HR\n\n#ifndef FRONTEND_FFT_H_\n#define FRONTEND_FFT_H_\n\n#ifndef M_PI\n#define M_PI 3.14159265358979323"
},
{
"path": "runtime/frontend/wav.h",
"chars": 7851,
"preview": "// Copyright (c) 2016 Personal (Binbin Zhang)\n// Created on 2016-08-15\n//\n// Licensed under the Apache License, Version "
},
{
"path": "runtime/separate/CMakeLists.txt",
"chars": 115,
"preview": "add_library(separate STATIC separate_engine.cc)\ntarget_link_libraries(separate PUBLIC frontend ${TORCH_LIBRARIES})\n"
},
{
"path": "runtime/separate/separate_engine.cc",
"chars": 4499,
"preview": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"Licens"
},
{
"path": "runtime/separate/separate_engine.h",
"chars": 1776,
"preview": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"Licens"
},
{
"path": "runtime/utils/CMakeLists.txt",
"chars": 95,
"preview": "add_library(utils STATIC\n utils.cc\n)\ntarget_link_libraries(utils PUBLIC glog gflags frontend)\n"
},
{
"path": "runtime/utils/blocking_queue.h",
"chars": 2407,
"preview": "// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n//"
},
{
"path": "runtime/utils/timer.h",
"chars": 1256,
"preview": "// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n//"
},
{
"path": "runtime/utils/utils.cc",
"chars": 2187,
"preview": "// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n//\n// Licensed under the Apache License, Version"
},
{
"path": "runtime/utils/utils.h",
"chars": 1211,
"preview": "// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n//\n// Licensed under the Apache License, Version"
},
{
"path": "setup.py",
"chars": 363,
"preview": "from setuptools import setup, find_packages\n\nrequirements = [\n \"tqdm\",\n \"kaldiio\",\n \"torch>=1.12.0\",\n \"torch"
},
{
"path": "tools/extract_embed_depreciated.py",
"chars": 2953,
"preview": "# Copyright (c) 2022, Shuai Wang (wsstriving@gmail.com)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
},
{
"path": "tools/make_lmdb.py",
"chars": 1798,
"preview": "# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n#"
},
{
"path": "tools/make_shard_list_premix.py",
"chars": 5530,
"preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang\r\n# 2023 SRIBD Shuai Wang )\r\n#\r\n# L"
},
{
"path": "tools/make_shard_online.py",
"chars": 4773,
"preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang\n# 2023 Shuai Wang )\n#\n# Lic"
},
{
"path": "tools/parse_options.sh",
"chars": 3631,
"preview": "#!/bin/bash\n\n# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);\n# Arnab Ghoshal, Karel V"
},
{
"path": "tools/print_train_val_curve.py",
"chars": 2071,
"preview": "import re\r\n\r\nimport matplotlib.pyplot as plt\r\n\r\n# Initialize lists to store epochs, train losses and validation losses\r\n"
},
{
"path": "tools/run.pl",
"chars": 12799,
"preview": "#!/usr/bin/env perl\nuse warnings; #sed replacement for -w perl parameter\n# In general, doing\n# run.pl some.log a b c is"
},
{
"path": "tools/score.sh",
"chars": 4366,
"preview": "#!/bin/bash\n\nmin() {\n local a b\n a=$1\n for b in \"$@\"; do\n if [ \"${b}\" -le \"${a}\" ]; then\n a=\""
},
{
"path": "tools/show_enh_score.sh",
"chars": 2060,
"preview": "#!/usr/bin/env bash\nmindepth=0\nmaxdepth=1\n\n. tools/parse_options.sh\n\nif [ $# -gt 1 ]; then\n echo \"Usage: $0 --mindept"
},
{
"path": "tools/split_scp.pl",
"chars": 9535,
"preview": "#!/usr/bin/env perl\n\n# Copyright 2010-2011 Microsoft Corporation\n\n# See ../../COPYING for clarification regarding multip"
},
{
"path": "tools/test_dataset.py",
"chars": 1912,
"preview": "from torch.utils.data import DataLoader\r\n\r\nfrom wesep.dataset.dataset import Dataset\r\nfrom wesep.dataset.dataset import "
},
{
"path": "wesep/__init__.py",
"chars": 107,
"preview": "from wesep.cli.extractor import load_model # noqa\nfrom wesep.cli.extractor import load_model_local # noqa"
},
{
"path": "wesep/bin/average_model.py",
"chars": 3317,
"preview": "# Copyright (c) 2020 Mobvoi Inc (Di Wu)\r\n# 2021 Hongji Wang (jijijiang77@gmail.com)\r\n# 2022 "
},
{
"path": "wesep/bin/export_jit.py",
"chars": 1414,
"preview": "from __future__ import print_function\n\nimport argparse\nimport os\n\nimport torch\nimport yaml\n\nfrom wesep.models import get"
},
{
"path": "wesep/bin/infer.py",
"chars": 7013,
"preview": "from __future__ import print_function\n\nimport os\nimport time\n\nimport fire\nimport soundfile\nimport torch\nfrom torch.utils"
},
{
"path": "wesep/bin/score.py",
"chars": 12283,
"preview": "# ported from\n# https://github.com/espnet/espnet/blob/master/espnet2/bin/enh_scoring.py\nimport argparse\nimport logging\ni"
},
{
"path": "wesep/bin/train.py",
"chars": 14456,
"preview": "# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licens"
},
{
"path": "wesep/bin/train_gan.py",
"chars": 16784,
"preview": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\n# 2022 Chengdong Liang (liangchengdong@mail.nwpu."
},
{
"path": "wesep/cli/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wesep/cli/extractor.py",
"chars": 7098,
"preview": "import os\nimport sys\n\nfrom silero_vad import load_silero_vad, get_speech_timestamps\nimport torch\nimport torchaudio\nimpor"
},
{
"path": "wesep/cli/hub.py",
"chars": 4151,
"preview": "# Copyright (c) 2022 Mddct(hamddct@gmail.com)\n# 2023 Binbin Zhang(binbzha@qq.com)\n# 2024 "
},
{
"path": "wesep/cli/utils.py",
"chars": 1628,
"preview": "import argparse\n\n\ndef get_args():\n parser = argparse.ArgumentParser(description=\"\")\n parser.add_argument(\n "
},
{
"path": "wesep/dataset/FRAM_RIR.py",
"chars": 23154,
"preview": "# Author: Rongzhi Gu, Yi Luo\n# Copyright: Tencent AI Lab\n# Licensed under the Apache License, Version 2.0 (the \"License\""
},
{
"path": "wesep/dataset/dataset.py",
"chars": 14385,
"preview": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n# 2023 Shuai Wang (wsstriving@gmail.com)\n# Licens"
},
{
"path": "wesep/dataset/lmdb_data.py",
"chars": 1636,
"preview": "# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
},
{
"path": "wesep/dataset/processor.py",
"chars": 35652,
"preview": "import io\nimport json\nimport logging\nimport random\nimport tarfile\nfrom subprocess import PIPE, Popen\nfrom urllib.parse i"
},
{
"path": "wesep/dataset/vad.py",
"chars": 3427,
"preview": "import numpy as np\nimport soundfile as sf\n\n\nclass VoiceActivityDetection:\n\n def __init__(self, wave):\n self.wa"
},
{
"path": "wesep/models/__init__.py",
"chars": 1079,
"preview": "import wesep.models.bsrnn as bsrnn\nimport wesep.models.convtasnet as convtasnet\nimport wesep.models.dpccn as dpccn\nimpor"
},
{
"path": "wesep/models/bsrnn.py",
"chars": 15568,
"preview": "from __future__ import print_function\r\n\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.n"
},
{
"path": "wesep/models/bsrnn_feats.py",
"chars": 24483,
"preview": "from __future__ import print_function\r\n\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.n"
},
{
"path": "wesep/models/bsrnn_multi_optim.py",
"chars": 18592,
"preview": "from __future__ import print_function\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn "
},
{
"path": "wesep/models/convtasnet.py",
"chars": 8456,
"preview": "import torch\r\nimport torch.nn as nn\r\n\r\nfrom wesep.modules.common import select_norm\r\nfrom wesep.modules.common.speaker i"
},
{
"path": "wesep/models/dpccn.py",
"chars": 10658,
"preview": "import torch\r\nimport torch.nn as nn\r\nimport torchaudio\r\n\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n"
},
{
"path": "wesep/models/sep_model.py",
"chars": 699,
"preview": "import wesep.models.bsrnn as bsrnn\nimport wesep.models.convtasnet as convtasnet\nimport wesep.models.dpccn as dpccn\nimpor"
},
{
"path": "wesep/models/tfgridnet.py",
"chars": 11696,
"preview": "# The implementation is based on:\n# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separ"
},
{
"path": "wesep/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wesep/modules/common/__init__.py",
"chars": 245,
"preview": "from wesep.modules.common.norm import ChannelWiseLayerNorm # noqa\nfrom wesep.modules.common.norm import FiLM # noqa\nfr"
},
{
"path": "wesep/modules/common/norm.py",
"chars": 6481,
"preview": "import numbers\n\nimport torch\nimport torch.nn as nn\n\n\nclass GlobalChannelLayerNorm(nn.Module):\n \"\"\"\n Calculate Glob"
},
{
"path": "wesep/modules/common/speaker.py",
"chars": 4969,
"preview": "from typing import Optional\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom wesep.module"
},
{
"path": "wesep/modules/dpccn/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wesep/modules/dpccn/convs.py",
"chars": 4837,
"preview": "from typing import Tuple\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\n\r\nclass Conv1D(nn.Conv1d):\r\n \"\"\"\r\n 1D conv in C"
},
{
"path": "wesep/modules/metric_gan/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wesep/modules/metric_gan/discriminator.py",
"chars": 5149,
"preview": "import torch\nimport torch.nn as nn\n\n\n# utility functions/classes used in the implementation of discriminators.\nclass Lea"
},
{
"path": "wesep/modules/tasnet/__init__.py",
"chars": 396,
"preview": "from wesep.modules.tasnet.decoder import DeepDecoder # noqa\r\nfrom wesep.modules.tasnet.decoder import MultiDecoder # n"
},
{
"path": "wesep/modules/tasnet/convs.py",
"chars": 4684,
"preview": "import torch\nimport torch.nn as nn\n\nfrom wesep.modules.common import select_norm\n\n# from wesep.modules.common.spkadapt i"
},
{
"path": "wesep/modules/tasnet/decoder.py",
"chars": 3972,
"preview": "import torch\r\nimport torch.nn as nn\r\n\r\nfrom wesep.modules.tasnet.convs import Conv1D, ConvTrans1D\r\n\r\n\r\nclass DeepDecoder"
},
{
"path": "wesep/modules/tasnet/encoder.py",
"chars": 3709,
"preview": "import torch as th\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom wesep.modules.common import select_no"
},
{
"path": "wesep/modules/tasnet/separation.py",
"chars": 6477,
"preview": "import torch.nn as nn\r\n\r\nfrom wesep.modules.common import select_norm\r\nfrom wesep.modules.common.speaker import SpeakerF"
},
{
"path": "wesep/modules/tasnet/separator.py",
"chars": 1437,
"preview": "import torch.nn as nn\r\n\r\nfrom wesep.modules.tasnet.convs import Conv1DBlock\r\n\r\n\r\nclass Separation(nn.Module):\r\n \"\"\"\r\n"
},
{
"path": "wesep/modules/tasnet/speaker.py",
"chars": 1984,
"preview": "import torch.nn as nn\n\nfrom wesep.modules.common.norm import ChannelWiseLayerNorm\nfrom wesep.modules.tasnet.convs import"
},
{
"path": "wesep/modules/tfgridnet/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wesep/modules/tfgridnet/gridnet_block.py",
"chars": 10452,
"preview": "# The implementation is based on:\n# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separ"
},
{
"path": "wesep/utils/abs_loss.py",
"chars": 705,
"preview": "from abc import ABC, abstractmethod\n\nimport torch\n\nEPS = torch.finfo(torch.get_default_dtype()).eps\n\n\nclass AbsEnhLoss(t"
},
{
"path": "wesep/utils/checkpoint.py",
"chars": 3547,
"preview": "from typing import List, Optional\r\n\r\nimport torch\r\n\r\nfrom wesep.utils.schedulers import BaseClass\r\n\r\n\r\ndef load_pretrain"
},
{
"path": "wesep/utils/datadir_writer.py",
"chars": 2245,
"preview": "import warnings\nfrom pathlib import Path\nfrom typing import Union\n\n\n# ported from\n# https://github.com/espnet/espnet/blo"
},
{
"path": "wesep/utils/dnsmos.py",
"chars": 10612,
"preview": "import json\nimport math\n\nimport librosa\nimport numpy as np\nimport requests\nimport torch\nimport torchaudio\n\nSAMPLING_RATE"
},
{
"path": "wesep/utils/executor.py",
"chars": 7825,
"preview": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\r\n# 2022 Chengdong Liang (liangchengdong@mail.nwpu"
},
{
"path": "wesep/utils/executor_gan.py",
"chars": 13270,
"preview": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\n# 2022 Chengdong Liang (liangchengdong@mail.nwpu."
},
{
"path": "wesep/utils/file_utils.py",
"chars": 9203,
"preview": "import collections\r\nimport math\r\nfrom pathlib import Path\r\nfrom typing import Dict, List, Optional, Tuple, Union\r\n\r\nimpo"
},
{
"path": "wesep/utils/funcs.py",
"chars": 4935,
"preview": "# Created on 2018/12\r\n# Author: Kaituo XU\r\n\r\nimport math\r\n\r\nimport torch\r\nimport torchaudio.compliance.kaldi as kaldi\r\n\r"
},
{
"path": "wesep/utils/losses.py",
"chars": 1100,
"preview": "import auraloss\nimport torch.nn as nn\nimport torchmetrics.audio as audio_metrics\nfrom torchmetrics.functional.audio impo"
},
{
"path": "wesep/utils/schedulers.py",
"chars": 9640,
"preview": "# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)\n# 2021 Zhengyang Chen (chenzhengyang117@gmail.com)\n"
},
{
"path": "wesep/utils/score.py",
"chars": 4009,
"preview": "import numpy as np\nfrom joblib import Parallel, delayed\nfrom pesq import pesq\nfrom pystoi.stoi import stoi\n\n\ndef cal_SIS"
},
{
"path": "wesep/utils/signal.py",
"chars": 4096,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scipy.signal import get_windo"
},
{
"path": "wesep/utils/utils.py",
"chars": 8749,
"preview": "# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licens"
}
]
About this extraction
This page contains the full source code of the wenet-e2e/wesep GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 128 files (591.2 KB), approximately 157.2k tokens, and a symbol index with 411 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.