Repository: wenet-e2e/wesep Branch: master Commit: 99eca54b6030 Files: 128 Total size: 591.2 KB Directory structure: gitextract_s23ej_br/ ├── .clang-format ├── .flake8 ├── .github/ │ └── workflows/ │ └── lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CPPLINT.cfg ├── README.md ├── examples/ │ ├── librimix/ │ │ └── tse/ │ │ ├── README.md │ │ ├── v1/ │ │ │ ├── README.md │ │ │ ├── confs/ │ │ │ │ ├── bsrnn.yaml │ │ │ │ ├── dpcc_init_gan.yaml │ │ │ │ ├── dpccn.yaml │ │ │ │ └── tfgridnet.yaml │ │ │ ├── local/ │ │ │ │ ├── prepare_data.sh │ │ │ │ ├── prepare_librimix_enroll.py │ │ │ │ └── prepare_spk2enroll_librispeech.py │ │ │ ├── path.sh │ │ │ └── run.sh │ │ └── v2/ │ │ ├── README.md │ │ ├── confs/ │ │ │ ├── bsrnn.yaml │ │ │ ├── bsrnn_feats.yaml │ │ │ ├── bsrnn_multi_optim.yaml │ │ │ ├── dpcc_init_gan.yaml │ │ │ ├── dpccn.yaml │ │ │ ├── spexplus.yaml │ │ │ └── tfgridnet.yaml │ │ ├── local/ │ │ │ ├── prepare_data.sh │ │ │ ├── prepare_librimix_enroll.py │ │ │ └── prepare_spk2enroll_librispeech.py │ │ ├── path.sh │ │ └── run.sh │ └── voxceleb1/ │ └── v2/ │ ├── confs/ │ │ └── bsrnn_online.yaml │ ├── local/ │ │ ├── prepare_data.sh │ │ ├── prepare_librimix_enroll.py │ │ ├── prepare_spk2enroll_librispeech.py │ │ └── prepare_spk2enroll_vox.py │ ├── path.sh │ └── run_online.sh ├── requirements.txt ├── runtime/ │ ├── .gitignore │ ├── CMakeLists.txt │ ├── README.md │ ├── bin/ │ │ ├── CMakeLists.txt │ │ └── separate_main.cc │ ├── cmake/ │ │ ├── gflags.cmake │ │ ├── glog.cmake │ │ └── libtorch.cmake │ ├── frontend/ │ │ ├── CMakeLists.txt │ │ ├── fbank.h │ │ ├── feature_pipeline.cc │ │ ├── feature_pipeline.h │ │ ├── fft.cc │ │ ├── fft.h │ │ └── wav.h │ ├── separate/ │ │ ├── CMakeLists.txt │ │ ├── separate_engine.cc │ │ └── separate_engine.h │ └── utils/ │ ├── CMakeLists.txt │ ├── blocking_queue.h │ ├── timer.h │ ├── utils.cc │ └── utils.h ├── setup.py ├── tools/ │ ├── extract_embed_depreciated.py │ ├── make_lmdb.py │ ├── make_shard_list_premix.py │ ├── make_shard_online.py │ ├── parse_options.sh │ ├── print_train_val_curve.py │ ├── run.pl │ ├── score.sh │ ├── show_enh_score.sh │ ├── split_scp.pl │ └── test_dataset.py └── wesep/ ├── __init__.py ├── bin/ │ ├── average_model.py │ ├── export_jit.py │ ├── infer.py │ ├── score.py │ ├── train.py │ └── train_gan.py ├── cli/ │ ├── __init__.py │ ├── extractor.py │ ├── hub.py │ └── utils.py ├── dataset/ │ ├── FRAM_RIR.py │ ├── dataset.py │ ├── lmdb_data.py │ ├── processor.py │ └── vad.py ├── models/ │ ├── __init__.py │ ├── bsrnn.py │ ├── bsrnn_feats.py │ ├── bsrnn_multi_optim.py │ ├── convtasnet.py │ ├── dpccn.py │ ├── sep_model.py │ └── tfgridnet.py ├── modules/ │ ├── __init__.py │ ├── common/ │ │ ├── __init__.py │ │ ├── norm.py │ │ └── speaker.py │ ├── dpccn/ │ │ ├── __init__.py │ │ └── convs.py │ ├── metric_gan/ │ │ ├── __init__.py │ │ └── discriminator.py │ ├── tasnet/ │ │ ├── __init__.py │ │ ├── convs.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── separation.py │ │ ├── separator.py │ │ └── speaker.py │ └── tfgridnet/ │ ├── __init__.py │ └── gridnet_block.py └── utils/ ├── abs_loss.py ├── checkpoint.py ├── datadir_writer.py ├── dnsmos.py ├── executor.py ├── executor_gan.py ├── file_utils.py ├── funcs.py ├── losses.py ├── schedulers.py ├── score.py ├── signal.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ --- Language: Cpp # BasedOnStyle: Google AccessModifierOffset: -1 AlignAfterOpenBracket: Align AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false AlignEscapedNewlinesLeft: true AlignOperands: true AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: All AllowShortIfStatementsOnASingleLine: true AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: true AlwaysBreakTemplateDeclarations: true BinPackArguments: true BinPackParameters: true BraceWrapping: AfterClass: false AfterControlStatement: false AfterEnum: false AfterFunction: false AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false BeforeCatch: false BeforeElse: false IndentBraces: false BreakBeforeBinaryOperators: None BreakBeforeBraces: Attach BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakAfterJavaFieldAnnotations: false BreakStringLiterals: true ColumnLimit: 80 CommentPragmas: '^ IWYU pragma:' ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DisableFormat: false ExperimentalAutoDetectBinPacking: false ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] IncludeCategories: - Regex: '^<.*\.h>' Priority: 1 - Regex: '^<.*' Priority: 2 - Regex: '.*' Priority: 3 IncludeIsMainRegex: '([-_](test|unittest))?$' IndentCaseLabels: true IndentWidth: 2 IndentWrappedFunctionNames: false JavaScriptQuotes: Leave JavaScriptWrapImports: true KeepEmptyLinesAtTheStartOfBlocks: false MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None ObjCBlockIndentWidth: 2 ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: false PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 PointerAlignment: Left ReflowComments: true SortIncludes: true SpaceAfterCStyleCast: false SpaceBeforeAssignmentOperators: true SpaceBeforeParens: ControlStatements SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 2 SpacesInAngles: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false Standard: Auto TabWidth: 8 UseTab: Never ... ================================================ FILE: .flake8 ================================================ [flake8] select = B,C,E,F,P,T4,W,B9 max-line-length = 80 max-doc-length = 80 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, # shebang has extra meaning in fbcode lints, so I think it's not worth trying # to line this up with executable bit EXE001, # these ignores are from flake8-bugbear; please fix! B007,B008,B905, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 exclude = ================================================ FILE: .github/workflows/lint.yml ================================================ name: Lint on: push: branches: - main pull_request: jobs: quick-checks: runs-on: ubuntu-latest steps: - name: Fetch Wenet uses: actions/checkout@v1 - name: Checkout PR tip run: | set -eux if [[ "${{ github.event_name }}" == "pull_request" ]]; then # We are on a PR, so actions/checkout leaves us on a merge commit. # Check out the actual tip of the branch. git checkout ${{ github.event.pull_request.head.sha }} fi echo ::set-output name=commit_sha::$(git rev-parse HEAD) id: get_pr_tip - name: Ensure no tabs run: | (! git grep -I -l $'\t' -- . ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false)) - name: Ensure no trailing whitespace run: | (! git grep -I -n $' $' -- . ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false)) flake8-py3: runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v1 with: python-version: 3.9 architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 - name: Checkout PR tip run: | set -eux if [[ "${{ github.event_name }}" == "pull_request" ]]; then # We are on a PR, so actions/checkout leaves us on a merge commit. # Check out the actual tip of the branch. git checkout ${{ github.event.pull_request.head.sha }} fi echo ::set-output name=commit_sha::$(git rev-parse HEAD) id: get_pr_tip - name: Run flake8 run: | set -eux pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 flake8 --version flake8 if [ $? != 0 ]; then exit 1; fi cpplint: runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v1 with: python-version: 3.x architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 - name: Checkout PR tip run: | set -eux if [[ "${{ github.event_name }}" == "pull_request" ]]; then # We are on a PR, so actions/checkout leaves us on a merge commit. # Check out the actual tip of the branch. git checkout ${{ github.event.pull_request.head.sha }} fi echo ::set-output name=commit_sha::$(git rev-parse HEAD) id: get_pr_tip - name: Run cpplint run: | set -eux pip install cpplint==1.6.1 cpplint --version cpplint --recursive . if [ $? != 0 ]; then exit 1; fi ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class *.egg-info # Visual Studio Code files .vscode .vs # PyCharm files .idea venv # Eclipse Project settings *.*project .settings # Sublime Text settings *.sublime-workspace *.sublime-project # Editor temporaries *.swn *.swo *.swp *.swm *~ # IPython notebook checkpoints .ipynb_checkpoints # macOS dir files .DS_Store exp data raw_wav tensorboard **/*build* wespeaker_models ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-yapf rev: 'v0.32.0' hooks: - id: yapf - repo: https://github.com/pycqa/flake8 rev: '3.8.2' hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-clang-format rev: 'v17.0.6' hooks: - id: clang-format - repo: https://github.com/cpplint/cpplint rev: '1.6.1' hooks: - id: cpplint ================================================ FILE: CPPLINT.cfg ================================================ root=runtime filter=-build/c++11 ================================================ FILE: README.md ================================================ # Wesep > We aim to build a toolkit focusing on front-end processing in the cocktail party set up, including target speaker extraction and ~~speech separation (Future work)~~ ### Install for development & deployment * Clone this repo ``` sh https://github.com/wenet-e2e/wesep.git ``` * Create conda env: pytorch version >= 1.12.0 is required !!! ``` sh conda create -n wesep python=3.9 conda activate wesep conda install pytorch=1.12.1 torchaudio=0.12.1 cudatoolkit=11.3 -c pytorch -c conda-forge pip install -r requirements.txt pre-commit install # for clean and tidy code ``` ## The Target Speaker Extraction Task > Target speaker extraction (TSE) focuses on isolating the speech of a specific target speaker from overlapped multi-talker speech, which is a typical setup in the cocktail party problem. WeSep is featured with flexible target speaker modeling, scalable data management, effective on-the-fly data simulation, structured recipes and deployment support. ## Features (To Do List) - [x] On the fly data simulation - [x] Dynamic Mixture simulation - [x] Dynamic Reverb simulation - [x] Dynamic Noise simulation - [x] Support time- and frequency- domain models - Time-domain - [x] conv-tasnet based models - [x] Spex+ - Frequency domain - [x] pBSRNN - [x] pDPCCN - [x] tf-gridnet (Extremely slow, need double check) - [ ] Training Criteria - [x] SISNR loss - [x] GAN loss (Need further investigation) - [ ] Datasets - [x] Libri2Mix (Illustration for pre-mixed speech) - [x] VoxCeleb (Illustration for online training) - [ ] WSJ0-2Mix - [ ] Speaker Embedding - [x] Wespeaker Intergration - [x] Joint Learned Speaker Embedding - [x] Different fusion methods - [ ] Pretrained models - [ ] CLI Usage - [x] Runtime ## Data Pipe Design Following Wenet and Wespeaker, WeSep organizes the data processing modules as a pipeline of a set of different processors. The following figure shows such a pipeline with essential processors. ## Discussion For Chinese users, you can scan the QR code on the left to join our group directly. If it has expired, please scan the personal Wechat QR code on the right. ||| | ---- | ---- | ## Citations If you find wespeaker useful, please cite it as ```bibtex @inproceedings{wang24fa_interspeech, title = {WeSep: A Scalable and Flexible Toolkit Towards Generalizable Target Speaker Extraction}, author = {Shuai Wang and Ke Zhang and Shaoxiong Lin and Junjie Li and Xuefei Wang and Meng Ge and Jianwei Yu and Yanmin Qian and Haizhou Li}, year = {2024}, booktitle = {Interspeech 2024}, pages = {4273--4277}, doi = {10.21437/Interspeech.2024-1840}, } ``` ================================================ FILE: examples/librimix/tse/README.md ================================================ # Libri2Mix Recipe ## Goal of this recipe This recipe aims to illustrate how to use WeSep to perform the target speaker extraction task on a pre-defined training set such as Libri2Mix, the mixtures have been prepared on the disk. If you want to check the online data processing and scale to larger training set, please check the voxceleb1 recipe. ## Difference of V1 and V2 The difference between v1 and v2 lies in the approach to speaker modeling. - v1 outlines a process where WeSpeaker is used to extract embeddings beforehand, which are then saved to disk and sampled during training. This setup allows you to use other toolkits or existing speaker embeddings freely. - v2 demonstrates a more integrated approach with the WeSpeaker toolkit (recommended). You can choose to either fix the speaker encoder or train it jointly with the separation backbone. ================================================ FILE: examples/librimix/tse/v1/README.md ================================================ ## Tutorial on LibriMix If you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback. NOTE: WE DON'T RECOMMEND THIS VERSION, IT'S JUST FOR ILLUSTRATING HOW TO USE YOUR OWN EXTRACTOR YOU NEED TO INSTALL WESPEAKER FIRST, CHECK `https://github.com/wenet-e2e/wespeaker` FOR THE INSTRUCTION ### First Experiment We provide a recipe `examples/librimix/tse/v1/run.sh` on LibriMix data. The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process. ```bash cd examples/librimix/tse/v1 bash run.sh --stage 1 --stop_stage 1 bash run.sh --stage 2 --stop_stage 2 bash run.sh --stage 3 --stop_stage 3 bash run.sh --stage 4 --stop_stage 4 bash run.sh --stage 5 --stop_stage 5 bash run.sh --stage 6 --stop_stage 6 ``` You could also just run the whole script ```bash bash run.sh --stage 1 --stop_stage 6 ``` ------ ### Stage 1: Prepare Training Data Prior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`. As the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`: + `fs`: the sample rate of the dataset, valid options are `16k` and `8k`. + `min_max`: the mode of mixtures, valiad options are `min` and `max`. + `noise_type`: the type of mixture, valiad options are `clean` and `both`. In our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows: ``` bash fs=16k min_max=min noise_type="clean" Libri2Mix_dir=/path/to/Libri2Mix ``` After configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location. However, you can customize the `data` variable in `enh.sh` to save these files in any desired location. ```bash data=data # you can change this to any directory ``` In this stage, `local/prepare_data.sh`accomplishes three tasks: 1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`, each containing the following files: + `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id` ```text s1/103-1240-0003_1235-135887-0017.wav 103 s1/103-1240-0004_4195-186237-0003.wav 103 ... ``` + `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`. ``` 103-1240-0003_1235-135887-0017 103 1235 103-1240-0004_4195-186237-0003 103 4195 ... ``` + `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path` ``` s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav ... ``` + `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`. ``` 103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav 103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav ... ``` 2. Prepare the speaker embeddings using wespeaker pretrained models. This step will generate two files in the `dev`, `test`, and `train_100` directories respectively: + `embed.ark`: Kaldi ark file that stores the speaker embeddings. + `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path` ``` s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569 s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715 ... ``` 3. Prepare LibriMix target-speaker enroll signal. This step will generate four files in the `dev` and `test` directories respectively: + `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`. ``` 4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020 4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014 ... ``` + `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`. ``` 1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav 1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav ... ``` + `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`. ``` 1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav 1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav ... ``` + `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`. ``` "652": [["652-129742-0010", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav"], ..., ["652-129742-0000", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav"]], ... ``` At the end of this stage, the directory structure of `data` should look like this: ``` data/ |__ clean/ # the noise_type you chose    |__ dev/ | |__ embed.ark | |__ embed.scp | |__ mixture2enrollment | |__ single.utt2spk | |__ single.wav.scp | |__ spk1.enroll | |__ spk2.enroll | |__ spk2enroll.json # empty | |__ utt2spk | |__ wav.scp |    |__ test/ # the same as dev/ | |__ train_100/ |__ embed.ark |__ embed.scp |__ single.utt2spk |__ single.wav.scp |__ utt2spk |__ wav.scp ``` ------ ### Stage 2: Convert Data Format This stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet. This stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively: + `shards/`: this directory stores the compressed packets (tar) files. ```bash ls shards shards_000000000.tar shards_000000001.tar shards_000000002.tar ... ``` + `shard.list`: each line records the path to the corresponding tar file. ``` data/clean/dev/shards/shards_000000000.tar data/clean/dev/shards/shards_000000001.tar data/clean/dev/shards/shards_000000002.tar ... ``` At the end of this stage, the directory structure of `data` should look like this: ``` data/ |__ clean/ # the noise_type you chose    |__ dev/ | |__ embed.ark, embed.scp, ... # files generated by Stage 1 | |__ shard.list | |__ shards/ | |__ shards_000000000.tar | |__ shards_000000001.tar | |__ shards_000000002.tar |    |__ test/ # the same as dev/ | |__ train_100/ |__ embed.ark, embed.scp, ... # files generated by Stage 1 |__ shard.list |__ shards/ |__ shards_000000000.tar |__ ... |__ shards_000000013.tar ``` ------ ### Stage 3: Neural Networking Training You can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information: + **overall training process related** ```yaml seed: 42 exp_dir: exp/BSRNN enable_amp: false gpus: '0,1' log_batch_interval: 100 save_epoch_interval: 1 joint_training: false ``` Explanations for some of the parameters mentioned above: + `seed`: specify a random seed. + `exp_dir`: specify the experiment directory. + `enable_amp`: whether enable automatic mixed precision. + `gpus`: specify the visible GPUs during training. + `log_batch_interval`: specify after how many batch iterations to record in the log. + `save_epoch_interval`: specify after how many batch epoches to save a checkpoint. + `joint_training`: specify whether the model for extracting speaker embeddings is jointly trained with the TSE model. Defaluts to `false`. + **dataset and dataloader realted** ```yaml dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 whole_utt: false chunk_len: 48000 online_mix: false data_type: "shard" train_data: "data/clean/train_100/shard.list" train_spk_embeds: "data/clean/train_100/embed.scp" train_utt2spk: "data/clean/train_100/single.utt2spk" train_spk2utt: "data/clean/train_100/spk2enroll.json" val_data: "data/clean/dev/shard.list" val_spk_embeds: "data/clean/dev/embed.scp" val_utt2spk: "data/clean/dev/single.utt2spk" val_spk1_enroll: "data/clean/dev/spk1.enroll" val_spk2_enroll: "data/clean/dev/spk2.enroll" val_spk2utt: "data/clean/dev/single.wav.scp" dataloader_args: batch_size: 16 # A800 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 ``` Explanations for some of the parameters mentioned above: + `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`. + `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set. + `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`. + `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`. + `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`. + `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`. + `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`. + `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`. + `train_data`: File containing paths to the training set files. + `train_spk_embeds`: File containing paths to the speaker embeddings of training set. + `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk` generated in Stage 1). + `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`. + `val_data`: File containing paths to the validation set files. + `val_spk_embeds`: Similiar to `train_spk_embeds`. + `val_utt2spk`: Similiar to `train_utt2spk`. + `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space. + `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space. + `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp` generated in Stage 1). + We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`. + Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples. + `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario. + `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process. + `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them. + `prefetch_factor`: number of batches loaded in advance by each worker. + **loss function related** ```yaml loss: SISDR loss_args: { } ``` Explanations for some of the parameters mentioned above: + `loss`: the loss function used for training. + `loss_args`: the required arguments for the loss function. In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`. Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter. + **neural network structure related** ```yaml model: tse_model: BSRNN model_args: tse_model: sr: 16000 win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_emb_dim: 256 spk_fuse_type: 'multiply' use_spk_transform: False model_init: tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt discriminator: null ``` Explanations for some of the parameters mentioned above: + `model`: specify the neural network used for training. + `model_args`: specify model-specific parameters. + `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`. + **model optimization related** ```yaml num_epochs: 150 clip_grad: 5.0 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 ``` Explanations for some of the parameters mentioned above: + `num_epochs`: total number of training epochs. + `clip_grad`: set the threshold for gradient clipping. + `optimizer`: set the optimizer. + `optimizer_args`: the required arguments for optimizer. + `scheduler`: set the scheduler. + `scheduler_args`: the required arguments for scheduler. + **others** ```yaml num_avg: 2 ``` Explanations for some of the parameters mentioned above: + `num_avg`: numbers for averaged model. To avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file: ```bash torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ ${train_script} --config confs/config.yaml \ --gpus "[0]" \ ``` At the end of this stage, an experiment directory will be created in the current directory, containing the following files: ``` ${exp_dir}/ |__ train.log |__ config.yaml |__ models/ |__ checkpoint_1.pt |__ ... |__ checkpoint_150.pt |__ final_checkpoint.pt -> checkpoint_150.pt |__ latest_checkpoint.pt -> checkpoint_150.pt ``` ------ ### Stage 4: Extract Speech Using the Trained Model After training is complete, you can execute stage 4 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script: + `config`: the configuration file used in Stage 3. + `fs`: the sample rate of the audio data. + `gpus`: the index of the visible GPU. + `exp_dir`: the experiment directory. + `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`. + `test_data`: similiar to `train_data`. + `test_spk_embeds`: similiar to `train_spk_embeds`. + `test_spk1_enroll`: similiar to `dev_spk1_enroll`. + `test_spk2_enroll`: similiar to `dev_spk2_enroll`. + `test_spk2utt`: similiar to `dev_spk2utt`. + `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech. At the end of this stage, the structure of the experiment directory should look like this: ``` ${exp_dir}/ |__ train.log |__ config.yaml |__ models/ |__ infer.log |__ audio/ |__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path` |__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav |__ ... |__ Utt999-61-70968-0003_2830-3980-0008-T61.wav ``` ------ ### Stage 5: Scoring In this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment. At the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this: ``` exp/BSRNN/ |__ ${exp_dir} | |__ train.log, ... # files and directories generated in Stage 4 | |__ scoring/ | |__ RESULTS.md ``` ------ ### Stage 6: Apply Model Average In this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`: + `dst_model`: the path to save the averaged model. + `src_path`: source models path for average. + `num`: number of source models for the averaged model. + `mode`: the mode for model averaging. Validate options are `final` and `best`. + `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames. + `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models. + `epochs`: this parameter only takes effect when `mode` is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models. ================================================ FILE: examples/librimix/tse/v1/confs/bsrnn.yaml ================================================ dataloader_args: batch_size: 16 # A800: 16 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 enable_amp: false exp_dir: exp/BSRNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } model: tse_model: BSRNN model_args: tse_model: sr: 16000 win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_emb_dim: 256 spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False # Multi speaker fuse with seperation modules joint_training: False model_init: tse_model: null discriminator: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v1/confs/dpcc_init_gan.yaml ================================================ use_metric_loss: true dataloader_args: batch_size: 4 drop_last: true num_workers: 4 pin_memory: false prefetch_factor: 4 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 enable_amp: false exp_dir: exp/DPCNN gpus: '0,1' log_batch_interval: 100 loss: SISNR loss_args: { } gan_loss_weight: 0.05 model: tse_model: DPCCN discriminator: CMGAN_Discriminator model_args: tse_model: win: 512 stride: 128 feature_dim: 257 tcn_blocks: 10 tcn_layers: 2 spk_emb_dim: 256 causal: False spk_fuse_type: 'multiply' use_spk_transform: False discriminator: {} model_init: tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt discriminator: null num_avg: 5 num_epochs: 50 optimizer: tse_model: Adam discriminator: Adam optimizer_args: tse_model: lr: 0.0001 weight_decay: 0.0001 discriminator: lr: 0.001 weight_decay: 0.0001 clip_grad: 3.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease discriminator: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.0001 warm_from_zero: false warm_up_epoch: 0 discriminator: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v1/confs/dpccn.yaml ================================================ dataloader_args: batch_size: 4 drop_last: true num_workers: 4 pin_memory: false prefetch_factor: 4 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 enable_amp: false exp_dir: exp/DPCNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } model: tse_model: DPCCN model_args: tse_model: win: 512 stride: 128 feature_dim: 257 tcn_blocks: 10 tcn_layers: 2 spk_emb_dim: 256 causal: False spk_fuse_type: 'multiply' use_spk_transform: False joint_training: False model_init: tse_model: null discriminator: null num_avg: 5 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 clip_grad: 3.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v1/confs/tfgridnet.yaml ================================================ dataloader_args: batch_size: 4 drop_last: true num_workers: 4 pin_memory: false prefetch_factor: 4 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 16000 enable_amp: false exp_dir: exp/TFGridNet gpus: '0,1' log_batch_interval: 100 loss: SI_SNR loss_args: { } model: tse_model: TFGridNet model_args: tse_model: n_srcs: 1 n_fft: 128 stride: 64 window: "hann" n_imics: 1 n_layers: 6 lstm_hidden_units: 192 attn_n_head: 4 attn_approx_qk_dim: 512 emb_dim: 128 emb_ks: 1 emb_hs: 1 activation: "prelu" eps: 1.0e-5 spk_emb_dim: 256 use_spk_transform: False spk_fuse_type: "multiply" joint_training: False model_init: tse_model: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v1/local/prepare_data.sh ================================================ #!/bin/bash # Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com) stage=-1 stop_stage=-1 mix_data_path='/Data/Libri2Mix/wav16k/min/' data=data noise_type=clean num_spk=2 . tools/parse_options.sh || exit 1 data=$(realpath ${data}) if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare the meta files for the datasets" for dataset in dev test train-100; do echo "Preparing files for" $dataset # Prepare the meta data for the mixed data dataset_path=$mix_data_path/$dataset/mix_${noise_type} mkdir -p "${data}"/$noise_type/${dataset} find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' | awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' | sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp | awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk # Prepare the meta data for single speakers dataset_path=$mix_data_path/$dataset/s1 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' | awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk dataset_path=$mix_data_path/$dataset/s2 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' | awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk done fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Prepare the speaker embeddings using wespeaker pretrained models" mkdir wespeaker_resnet34 wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.zip unzip voxceleb_resnet34_LM.zip -d wespeaker_resnet34 mv wespeaker_resnet34/voxceleb_resnet34_LM.yaml wespeaker_resnet34/config.yaml mv wespeaker_resnet34/voxceleb_resnet34_LM.pt wespeaker_resnet34/avg_model.pt for dataset in dev test train-100; do mkdir -p "${data}"/$noise_type/${dataset} echo "Preparing files for" $dataset wespeaker --task embedding_kaldi \ --wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \ --output_file "${data}"/$noise_type/${dataset}/embed \ -p wespeaker_resnet34 \ --device cuda:0 # GPU idx done fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "stage 3: Prepare LibriMix target-speaker enroll signal" for dset in dev test train-100; do python local/prepare_spk2enroll_librispeech.py \ "${mix_data_path}/${dset}" \ --is_librimix True \ --outfile "${data}"/$noise_type/${dset}/spk2enroll.json \ --audio_format wav done for dset in dev test; do if [ $num_spk -eq 2 ]; then url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment" else url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment" fi output_file="${data}/${noise_type}/${dset}/mixture2enrollment" wget -O "$output_file" "$url" done for dset in dev test; do python local/prepare_librimix_enroll.py \ "${data}"/$noise_type/${dset}/wav.scp \ "${data}"/$noise_type/${dset}/spk2enroll.json \ --mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \ --num_spk ${num_spk} \ --train False \ --output_dir "${data}"/${noise_type}/${dset} \ --outfile_prefix "spk" done fi ================================================ FILE: examples/librimix/tse/v1/local/prepare_librimix_enroll.py ================================================ import json import random from pathlib import Path from wesep.utils.datadir_writer import DatadirWriter from wesep.utils.utils import str2bool def prepare_librimix_enroll(wav_scp, spk2utts, output_dir, num_spk=2, train=True, prefix="enroll_spk"): mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) with Path(spk2utts).open("r", encoding="utf-8") as f: # {spkID: [(uid1, path1), (uid2, path2), ...]} spk2utt = json.load(f) with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 uttIDs = mixtureID.split("_") for spk in range(num_spk): uttID = uttIDs[spk] spkID = uttID.split("-")[0] if train: # For training, we choose the auxiliary signal on the fly. # Here we use the pattern f"*{uttID} {spkID}". writer[f"{prefix}{spk + 1}.enroll"][ mixtureID] = f"*{uttID} {spkID}" else: enrollID = random.choice(spk2utt[spkID])[1] while enrollID == uttID and len(spk2utt[spkID]) > 1: enrollID = random.choice(spk2utt[spkID])[1] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID def prepare_librimix_enroll_v2(wav_scp, map_mix2enroll, output_dir, num_spk=2, prefix="spk"): # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) mix2enroll = {} with open(map_mix2enroll) as f: for line in f: mix_id, utt_id, enroll_id = line.strip().split() sid = mix_id.split("_").index(utt_id) + 1 mix2enroll[mix_id, f"s{sid}"] = enroll_id with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 for spk in range(num_spk): enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id + ".wav") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "wav_scp", type=str, help="Path to the wav.scp file", ) parser.add_argument("spk2utts", type=str, help="Path to the json file containing mapping " "from speaker ID to utterances") parser.add_argument( "--num_spk", type=int, default=2, choices=(2, 3), help="Number of speakers in each mixture sample", ) parser.add_argument( "--train", type=str2bool, default=True, help="Whether is the training set or not", ) parser.add_argument( "--mix2enroll", type=str, default=None, help="Path to the downloaded map_mixture2enrollment file. " "If `train` is False, this value is required.", ) parser.add_argument("--seed", type=int, default=1, help="Random seed") parser.add_argument( "--output_dir", type=str, required=True, help="Path to the directory for storing output files", ) parser.add_argument( "--outfile_prefix", type=str, default="spk", help="Prefix of the output files", ) args = parser.parse_args() random.seed(args.seed) if args.train: prepare_librimix_enroll( args.wav_scp, args.spk2utts, args.output_dir, num_spk=args.num_spk, train=args.train, prefix=args.outfile_prefix, ) else: prepare_librimix_enroll_v2( args.wav_scp, args.mix2enroll, args.output_dir, num_spk=args.num_spk, prefix=args.outfile_prefix, ) ================================================ FILE: examples/librimix/tse/v1/local/prepare_spk2enroll_librispeech.py ================================================ import json from collections import defaultdict from itertools import chain from pathlib import Path from wesep.utils.utils import str2bool def get_spk2utt(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in Path(path).rglob("*.{}".format(audio_format)): readerID = audio.parent.parent.stem uid = audio.stem assert uid.split("-")[0] == readerID, audio spk2utt[readerID].append((uid, str(audio))) return spk2utt def get_spk2utt_librimix(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in chain( Path(path).rglob("s1/*.{}".format(audio_format)), Path(path).rglob("s2/*.{}".format(audio_format)), Path(path).rglob("s3/*.{}".format(audio_format)), ): spk_idx = int(audio.parent.stem[1:]) - 1 mix_uid = audio.stem uid = mix_uid.split("_")[spk_idx] sid = uid.split("-")[0] spk2utt[sid].append((uid, str(audio))) return spk2utt if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "audio_paths", type=str, nargs="+", help="Paths to Librispeech subsets", ) parser.add_argument( "--is_librimix", type=str2bool, default=False, help="Whether the provided audio_paths points to LibriMix data", ) parser.add_argument( "--outfile", type=str, default="spk2utt_tse.json", help="Path to the output spk2utt json file", ) parser.add_argument("--audio_format", type=str, default="flac") args = parser.parse_args() if args.is_librimix: # use clean sources from LibriMix as enrollment spk2utt = get_spk2utt_librimix(args.audio_paths, audio_format=args.audio_format) else: # use Librispeech as enrollment spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format) outfile = Path(args.outfile) outfile.parent.mkdir(parents=True, exist_ok=True) with outfile.open("w", encoding="utf-8") as f: json.dump(spk2utt, f, indent=4) ================================================ FILE: examples/librimix/tse/v1/path.sh ================================================ export PATH=$PWD:$PATH # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=../../../../:$PYTHONPATH ================================================ FILE: examples/librimix/tse/v1/run.sh ================================================ #!/bin/bash # Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn) . ./path.sh || exit 1 # General configuration stage=-1 stop_stage=1 # Data preparation related data=data fs=16k min_max=min noise_type="clean" data_type="shard" # shard/raw Libri2Mix_dir=/YourPath/librimix/Libri2Mix mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}" # Training related gpus="[0,1]" use_gan_loss=false config=confs/bsrnn.yaml exp_dir=exp/BSRNN/resnet34-pre_extract-multiply_fuse if [ -z "${config}" ] && [ -f "${exp_dir}/config.yaml" ]; then config="${exp_dir}/config.yaml" fi # TSE model initialization related checkpoint= if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then checkpoint="${exp_dir}/models/latest_checkpoint.pt" fi # Inferencing and scoring related use_pesq=true use_dnsmos=true dnsmos_use_gpu=true # Model average related num_avg=2 . tools/parse_options.sh || exit 1 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare datasets ..." ./local/prepare_data.sh --mix_data_path ${mix_data_path} \ --data ${data} \ --noise_type ${noise_type} \ --stage 2 \ --stop-stage 2 fi data=${data}/${noise_type} if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Covert train and test data to ${data_type}..." for dset in train-100 dev test; do # for dset in train-360; do python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \ --num_threads 16 \ --prefix shards \ --shuffle \ ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ ${data}/$dset/shards ${data}/$dset/shard.list done fi num_gpus=$(echo $gpus | awk -F ',' '{print NF}') if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "Start training ..." if ${use_gan_loss}; then train_script=wesep/bin/train_gan.py else train_script=wesep/bin/train.py fi export OMP_NUM_THREADS=8 torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ ${train_script} --config $config \ --exp_dir ${exp_dir} \ --gpus $gpus \ --num_avg ${num_avg} \ --data_type "${data_type}" \ --train_data ${data}/train-100/${data_type}.list \ --train_spk_embeds ${data}/train-100/embed.scp \ --train_utt2spk ${data}/train-100/single.utt2spk \ --train_spk2utt ${data}/train-100/spk2enroll.json \ --val_data ${data}/dev/${data_type}.list \ --val_spk_embeds ${data}/dev/embed.scp \ --val_utt2spk ${data}/dev/single.utt2spk \ --val_spk1_enroll ${data}/dev/spk1.enroll \ --val_spk2_enroll ${data}/dev/spk2.enroll \ --val_spk2utt ${data}/dev/single.wav.scp \ ${checkpoint:+--checkpoint $checkpoint} fi # shellcheck disable=SC2215 if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "Start inferencing ..." python wesep/bin/infer.py --config $config \ --fs ${fs} \ --gpus 0 \ --exp_dir ${exp_dir} \ --data_type "${data_type}" \ --test_data ${data}/test/${data_type}.list \ --test_spk_embeds ${data}/test/embed.scp \ --test_spk1_enroll ${data}/test/spk1.enroll \ --test_spk2_enroll ${data}/test/spk2.enroll \ --test_spk2utt ${data}/test/single.wav.scp \ ${checkpoint:+--checkpoint $checkpoint} fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then echo "Start scoring ..." ./tools/score.sh --dset "${data}/test" \ --exp_dir "${exp_dir}" \ --fs ${fs} \ --use_pesq "${use_pesq}" \ --use_dnsmos "${use_dnsmos}" \ --dnsmos_use_gpu "${dnsmos_use_gpu}" \ --n_gpu "${num_gpus}" fi if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then echo "Do model average ..." avg_model=$exp_dir/models/avg_best_model.pt python wesep/bin/average_model.py \ --dst_model $avg_model \ --src_path $exp_dir/models \ --num ${num_avg} \ --mode best \ --epochs "138,141" fi ================================================ FILE: examples/librimix/tse/v2/README.md ================================================ ## Tutorial on LibriMix If you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback. ### First Experiment We provide a recipe `examples/librimix/tse/v2/run.sh` on LibriMix data. The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process. ```bash cd examples/librimix/tse/v2 bash run.sh --stage 1 --stop_stage 1 bash run.sh --stage 2 --stop_stage 2 bash run.sh --stage 3 --stop_stage 3 bash run.sh --stage 4 --stop_stage 4 bash run.sh --stage 5 --stop_stage 5 bash run.sh --stage 6 --stop_stage 6 ``` You could also just run the whole script ```bash bash run.sh --stage 1 --stop_stage 6 ``` ------ ### Stage 1: Prepare Training Data Prior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`. As the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`: + `fs`: the sample rate of the dataset, valid options are `16k` and `8k`. + `min_max`: the mode of mixtures, valiad options are `min` and `max`. + `noise_type`: the type of mixture, valiad options are `clean` and `both`. In our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows: ``` bash fs=16k min_max=min noise_type="clean" Libri2Mix_dir=/path/to/Libri2Mix ``` After configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location. ```bash data=data # you can change this to any directory ``` In this stage, `local/prepare_data.sh`accomplishes three tasks (Main differences with v1 version): 1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`/`train_360`, each containing the following files: + `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id` ```text s1/103-1240-0003_1235-135887-0017.wav 103 s1/103-1240-0004_4195-186237-0003.wav 103 ... ``` + `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`. ``` 103-1240-0003_1235-135887-0017 103 1235 103-1240-0004_4195-186237-0003 103 4195 ... ``` + `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path` ``` s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav ... ``` + `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`. ``` 103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav 103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav ... ``` 2. Prepare LibriMix target-speaker enroll signal. This step will generate one `json` file in the `dev`, `test` and `train_100`/`train_360` directories, and additional three files in the `dev` and `test` directories respectively: + `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`. ``` "652": [["652-129742-0010", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav"], ..., ["652-129742-0000", "/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav"]], ... ``` + `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`. ``` 4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020 4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014 ... ``` + `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`. ``` 1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav 1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav ... ``` + `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`. ``` 1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav 1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav ... ``` At the end of this stage, the directory structure of `data` should look like this: ``` data/ |__ clean/ # the noise_type you chose    |__ dev/ | |__ mixture2enrollment | |__ single.utt2spk | |__ single.wav.scp | |__ spk1.enroll | |__ spk2.enroll | |__ spk2enroll.json | |__ utt2spk | |__ wav.scp |    |__ test/ # the same as dev/ | |__ train_100/ |__ single.utt2spk |__ single.wav.scp |__ spk2enroll.json |__ utt2spk |__ wav.scp ``` 3. Download the speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker for training the TSE model with pretrained speaker encoder. The models will be unzipped into `wespeaker_models/`. Find more speaker models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md. 4. Prepare the speaker embeddings using wespeaker pretrained models. (Not needed, and comment off in v2 version by default.) This step will generate two files in the `dev`, `test`, and `train_100` directories respectively: + `embed.ark`: Kaldi ark file that stores the speaker embeddings. + `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path` ``` s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569 s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715 ... ``` ------ ### Stage 2: Convert Data Format This stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet. This stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively: + `shards/`: this directory stores the compressed packets (tar) files. ```bash ls shards shards_000000000.tar shards_000000001.tar shards_000000002.tar ... ``` + `shard.list`: each line records the path to the corresponding tar file. ``` data/clean/dev/shards/shards_000000000.tar data/clean/dev/shards/shards_000000001.tar data/clean/dev/shards/shards_000000002.tar ... ``` At the end of this stage, the directory structure of `data` should look like this: ``` data/ |__ clean/ # the noise_type you chose    |__ dev/ | |__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1 | |__ shard.list | |__ shards/ | |__ shards_000000000.tar | |__ shards_000000001.tar | |__ shards_000000002.tar |    |__ test/ # the same as dev/ | |__ train_100/ |__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1 |__ shard.list |__ shards/ |__ shards_000000000.tar |__ ... |__ shards_000000013.tar ``` ------ ### Stage 3: Neural Networking Training You can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information: + **overall training process related** ```yaml seed: 42 exp_dir: exp/BSRNN enable_amp: false gpus: '0,1' log_batch_interval: 100 save_epoch_interval: 1 ``` Explanations for some of the parameters mentioned above: + `seed`: specify a random seed. + `exp_dir`: specify the experiment directory. + `enable_amp`: whether enable automatic mixed precision. + `gpus`: specify the visible GPUs during training. + `log_batch_interval`: specify after how many batch iterations to record in the log. + `save_epoch_interval`: specify after how many batch epoches to save a checkpoint. + **dataset and dataloader realted** ```yaml dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 whole_utt: false chunk_len: 48000 online_mix: false speaker_feat: &speaker_feat true fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 "Usually you don't need to manually write the data part of the configuration into the config file, it will be automatically generated." data_type: "shard" train_data: "data/clean/train_100/shard.list" train_utt2spk: "data/clean/train_100/single.utt2spk" train_spk2utt: "data/clean/train_100/spk2enroll.json" val_data: "data/clean/dev/shard.list" val_utt2spk: "data/clean/dev/single.utt2spk" val_spk1_enroll: "data/clean/dev/spk1.enroll" val_spk2_enroll: "data/clean/dev/spk2.enroll" val_spk2utt: "data/clean/dev/single.wav.scp" dataloader_args: batch_size: 12 # A800 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 ``` Explanations for some of the parameters mentioned above: + `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`. + `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set. + `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`. + `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`. + `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`. + `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`. + `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`. + `speaker_feat`: Whether transform the enrollment from waveform to fbank. Recommended setting to `true`. Defaults to `false`. + `num_mel_bins`: The parameter of fbank. The feature dimension of the fbank. Defaults to `80`. + `frame_shift`: The parameter of fbank. The time of frame shift in `ms`. Defaults to `10`. + `frame_length`: The parameter of fbank. The frame length in `ms`. Defaults to `25`. + `dither`: The parameter of fbank. Whether add noise to fbank feature. Defaults to `1.0`. + `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`. + `train_data`: File containing paths to the training set files. + `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk` generated in Stage 1). + `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`. + `val_data`: File containing paths to the validation set files. + `val_utt2spk`: Similiar to `train_utt2spk`. + `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space. + `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space. + `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp` generated in Stage 1). + We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`. + Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples. + `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario. + `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process. + `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them. + `prefetch_factor`: number of batches loaded in advance by each worker. + **loss function related** ```yaml loss: SISDR loss_args: { } ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set. loss: [SISDR, CE] loss_args: loss_posi: [[0],[1]] loss_weight: [[1.0],[1.0]] ``` Explanations for some of the parameters mentioned above: + `loss`: the loss function used for training. + `loss_args`: the required arguments for the loss function. + `loss_posi`: Select which outputs from the TSE model the loss function works on. + `loss_weight`: The weight of loss calculated from corresponding loss function. In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`. Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter. + **neural network structure related** ```yaml model: tse_model: BSRNN model_args: tse_model: sr: 16000 win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_emb_dim: 256 spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False joint_training: True ### You should always set this para to `True` when using v2 version. spk_model: ResNet34 spk_model_init: None spk_args: None spk_emb_dim: 256 spk_model_freeze: False spk_feat: *speaker_feat feat_type: "consistent" multi_task: False spksInTrain: 251 model_init: tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt discriminator: null ``` Explanations for some of the parameters mentioned above: + `model`: specify the neural network used for training. + `model_args`: specify model-specific parameters. + `spk_fuse_type`: specify the fusion method of the speaker embedding. Support `concat`, `additive`, `multiply` and `FiLM`. + `multi_fuse`: whether fuse the speaker embedding multiple times. + `joint_training`: specify whether the speaker encoder for extracting speaker embeddings is jointly trained with the TSE model. Always set this to `true`. Do NOT use it to control if training with pretrained speaker encoders. Defaluts to `false`. + `spk_model`: specify the speaker model. Supports most speaker models in wespeaker: https://github.com/wenet-e2e/wespeaker/tree/master. + `spk_model_init`: the path of the pre-trained speaker model. Find more pretrained models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md. Set `None` for training the speaker model with TSE model from scratch. + `spk_args`: specify speaker model-specific parameters. + `spk_emb_dim`: the feature dimension of speaker embedding extracted from the speaker encoder. + `spk_model_freeze`: whether freeze the weights in speaker encoder. Set `True` when using pretrained speaker encoder. + `spk_feat`: Use the defined parameters in `dataset_args` to determine whether to perform feature extraction of enrollment within the model. + `feat_type`: specify the type of enrollment's feature, when `spk_feat` is `False`. + `multi_task`: whether use such as `CE` loss function for jointly training the speaker encoder. This parameter needs to be coordinated with the `loss`. + `spksInTrain`: specify the speaker number in the training dataset. wsj0-2mix: 101, Libri2mix-100: 251, Libri2mix-360:921. + `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`. + **model optimization related** ```yaml num_epochs: 150 clip_grad: 5.0 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 ``` Explanations for some of the parameters mentioned above: + `num_epochs`: total number of training epochs. + `clip_grad`: set the threshold for gradient clipping. + `optimizer`: set the optimizer. + `optimizer_args`: the required arguments for optimizer. Not used in currently version. The learning rate and scheduler are determined by `scheduler_args`. + `scheduler`: set the scheduler. + `scheduler_args`: the required arguments for scheduler. + **others** ```yaml num_avg: 2 ``` Explanations for some of the parameters mentioned above: + `num_avg`: numbers for averaged model. To avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file: ```bash torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ ${train_script} --config confs/config.yaml \ --gpus "[0]" \ ``` At the end of this stage, an experiment directory will be created in the current directory, containing the following files: ``` ${exp_dir}/ |__ train.log |__ config.yaml |__ models/ |__ checkpoint_1.pt |__ ... |__ checkpoint_150.pt |__ final_checkpoint.pt -> checkpoint_150.pt |__ latest_checkpoint.pt -> checkpoint_150.pt ``` ------ ### Stage 4: Apply Model Average In this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`: + `dst_model`: the path to save the averaged model. + `src_path`: source models path for average. + `num`: number of source models for the averaged model. + `mode`: the mode for model averaging. Validate options are `final` and `best`. + `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames. + `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models. + `epochs`: this parameter only takes effect when `mode` is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models. ------ ### Stage 5: Extract Speech Using the Trained Model After training is complete, you can execute stage 5 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script: + `config`: the configuration file used in Stage 3. + `fs`: the sample rate of the audio data. + `gpus`: the index of the visible GPU. + `exp_dir`: the experiment directory. + `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`. + `test_data`: similiar to `train_data`. + `test_spk1_enroll`: similiar to `dev_spk1_enroll`. + `test_spk2_enroll`: similiar to `dev_spk2_enroll`. + `test_spk2utt`: similiar to `dev_spk2utt`. + `save_wav`: control if save the extracted speech in `exp_dir/audio`. + `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech. At the end of this stage, the structure of the experiment directory should look like this: ``` ${exp_dir}/ |__ train.log |__ config.yaml |__ models/ |__ infer.log |__ audio/ |__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path` |__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav |__ ... |__ Utt999-61-70968-0003_2830-3980-0008-T61.wav ``` ------ ### Stage 6: Scoring In this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment. At the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this: ``` exp/BSRNN/ |__ ${exp_dir} | |__ train.log, ... # files and directories generated in Stage 5 | |__ scoring/ | |__ RESULTS.md ``` ================================================ FILE: examples/librimix/tse/v2/confs/bsrnn.yaml ================================================ dataloader_args: batch_size: 8 #RTX2080 1, V100: 8, A800: 16 drop_last: true num_workers: 6 pin_memory: true prefetch_factor: 6 dataset_args: resample_rate: &sr 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 speaker_feat: &speaker_feat True fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589 # only Single-optimization method is supported here. # if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml SSA_enroll_prob: 0 # prob to add SSA on enrollment speech enable_amp: false exp_dir: exp/BSRNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } # loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set # loss_args: # loss_posi: [[0],[1]] # loss_weight: [[1.0],[1.0]] model: tse_model: BSRNN model_args: tse_model: sr: *sr win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt spk_args: feat_dim: 80 embed_dim: &embed_dim 256 pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False ####### Ecapa_TDNN # spk_model: ECAPA_TDNN_GLOB_c512 # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt # spk_args: # embed_dim: &embed_dim 192 # feat_dim: 80 # pooling_func: ASTP ####### CAMPPlus # spk_model: CAMPPlus # spk_model_init: False # spk_args: # feat_dim: 80 # embed_dim: &embed_dim 192 # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 # find_unused_parameters: True model_init: tse_model: null discriminator: null spk_model: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/bsrnn_feats.yaml ================================================ dataloader_args: batch_size: 4 #RTX2080 1, V100: 4, A800: 12 drop_last: true num_workers: 6 pin_memory: true prefetch_factor: 6 dataset_args: resample_rate: &sr 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 speaker_feat: &speaker_feat False fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589 # only Single-optimization method is supported here. # if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml SSA_enroll_prob: 0 # prob to add SSA on enrollment speech enable_amp: false exp_dir: exp/BSRNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } # loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set # loss_args: # loss_posi: [[0],[1]] # loss_weight: [[1.0],[1.0]] model: tse_model: BSRNN_Feats model_args: tse_model: sr: *sr win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spectral_feat: 'tfmap_emb' # 'tfmap_spec' 'tfmap_emb' False spk_fuse_type: 'cross_multiply' #'cross_multiply' 'multiply' False use_spk_transform: False multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders ################################################################# ###### Ecapa_TDNN spk_model: ECAPA_TDNN_GLOB_c512 spk_model_init: ./wespeaker_models/voxceleb_ECAPA512/avg_model.pt spk_args: embed_dim: &embed_dim 192 feat_dim: 80 pooling_func: ASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: True # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat # if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 # wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 # find_unused_parameters: True model_init: tse_model: null discriminator: null spk_model: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml ================================================ dataloader_args: batch_size: 8 #RTX2080 1, V100: 8, A800: 16 drop_last: true num_workers: 6 pin_memory: true prefetch_factor: 6 dataset_args: resample_rate: &sr 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 speaker_feat: &speaker_feat False fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech enable_amp: false exp_dir: exp/BSRNN gpus: '0,1' log_batch_interval: 100 #Please refer to our SLT paper https://www.arxiv.org/abs/2409.09589 # to check our parameter settings. loss: SISDR loss_args: loss_posi: [[0,1]] loss_weight: [[0.4,0.6]] #if you wanna use CE loss, multi_task needs to be set True # loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set # loss_args: # loss_posi: [[0,1],[2,3]] # loss_weight: [[0.36,0.54],[0.04,0.06]] model: tse_model: BSRNN_Multi model_args: tse_model: sr: *sr win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt spk_args: feat_dim: 80 embed_dim: &embed_dim 256 pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False ####### Ecapa_TDNN # spk_model: ECAPA_TDNN_GLOB_c512 # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt # spk_args: # embed_dim: &embed_dim 192 # feat_dim: 80 # pooling_func: ASTP ####### CAMPPlus # spk_model: CAMPPlus # spk_model_init: False # spk_args: # feat_dim: 80 # embed_dim: &embed_dim 192 # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 # find_unused_parameters: True model_init: tse_model: null discriminator: null spk_model: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/dpcc_init_gan.yaml ================================================ use_metric_loss: true dataloader_args: batch_size: 4 drop_last: true num_workers: 4 pin_memory: false prefetch_factor: 4 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech enable_amp: false exp_dir: exp/DPCNN gpus: '0,1' log_batch_interval: 100 loss: SISNR loss_args: { } gan_loss_weight: 0.05 model: tse_model: DPCCN discriminator: CMGAN_Discriminator model_args: tse_model: win: 512 stride: 128 feature_dim: 257 tcn_blocks: 10 tcn_layers: 2 spk_emb_dim: 256 causal: False spk_fuse_type: 'multiply' use_spk_transform: False discriminator: {} model_init: tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt discriminator: null num_avg: 5 num_epochs: 50 optimizer: tse_model: Adam discriminator: Adam optimizer_args: tse_model: lr: 0.0001 weight_decay: 0.0001 discriminator: lr: 0.001 weight_decay: 0.0001 clip_grad: 3.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease discriminator: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.0001 warm_from_zero: false warm_up_epoch: 0 discriminator: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/dpccn.yaml ================================================ dataloader_args: batch_size: 6 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 speaker_feat: &speaker_feat True fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech enable_amp: false exp_dir: exp/DPCNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } # loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set # loss_args: # loss_posi: [[0],[1]] # loss_weight: [[1.0],[1.0]] model: tse_model: DPCCN model_args: tse_model: win: 512 stride: 128 feature_dim: 257 tcn_blocks: 10 tcn_layers: 2 causal: False spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt spk_args: feat_dim: 80 embed_dim: &embed_dim 256 pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False ####### Ecapa_TDNN # spk_model: ECAPA_TDNN_GLOB_c512 # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt # spk_args: # embed_dim: &embed_dim 192 # feat_dim: 80 # pooling_func: ASTP ####### CAMPPlus # spk_model: CAMPPlus # spk_model_init: False # spk_args: # feat_dim: 80 # embed_dim: &embed_dim 192 # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 model_init: tse_model: null discriminator: null num_avg: 5 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 clip_grad: 3.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/spexplus.yaml ================================================ dataloader_args: batch_size: 8 #A800: 8 drop_last: true num_workers: 4 pin_memory: true prefetch_factor: 6 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech enable_amp: false exp_dir: exp/SpExplus/ gpus: ['0'] log_batch_interval: 100 # joint_training: True loss: [SISDR, CE] ###SI_SNR, SDR, sisnr, CE loss_args: loss_posi: [[0,1,2],[3]] loss_weight: [[0.8,0.1,0.1],[0.5]] model: tse_model: ConvTasNet model_args: tse_model: B: 256 H: 512 L: 20 N: 256 P: 3 R: 4 X: 8 spk_emb_dim: 256 activate: relu causal: false norm: gLN skip_con: False spk_fuse_type: concatConv # "concat", "additive", "multiply", "FiLM", "None", ("concatConv" only for convtasnet) use_spk_transform: False multi_fuse: True # Multi speaker fuse with seperation modules encoder_type: Multi # Multi, Deep, False decoder_type: Multi # Multi, Deep, False joint_training: True multi_task: True spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 model_init: tse_model: null discriminator: null num_avg: 5 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 save_epoch_interval: 5 clip_grad: 5.0 # False scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: False warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/confs/tfgridnet.yaml ================================================ dataloader_args: batch_size: 1 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 dataset_args: resample_rate: &sr 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 chunk_len: 48000 speaker_feat: &speaker_feat True fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 # prob to add noise aug per sample specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech enable_amp: false exp_dir: exp/TFGridNet gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } # loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set # loss_args: # loss_posi: [[0],[1]] # loss_weight: [[1.0],[1.0]] model: tse_model: TFGridNet model_args: tse_model: n_srcs: 1 sr: *sr n_fft: 128 stride: 64 window: "hann" n_imics: 1 n_layers: 6 lstm_hidden_units: 192 attn_n_head: 4 attn_approx_qk_dim: 512 emb_dim: 128 emb_ks: 1 emb_hs: 1 activation: "prelu" eps: 1.0e-5 use_spk_transform: False spk_fuse_type: "multiply" multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt spk_args: feat_dim: 80 embed_dim: &embed_dim 256 pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False ####### Ecapa_TDNN # spk_model: ECAPA_TDNN_GLOB_c512 # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt # spk_args: # embed_dim: &embed_dim 192 # feat_dim: 80 # pooling_func: ASTP ####### CAMPPlus # spk_model: CAMPPlus # spk_model_init: False # spk_args: # feat_dim: 80 # embed_dim: &embed_dim 192 # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 model_init: tse_model: null num_avg: 5 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/librimix/tse/v2/local/prepare_data.sh ================================================ #!/bin/bash # Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com) stage=-1 stop_stage=-1 mix_data_path='./Libri2Mix/wav16k/min/' data=data noise_type=clean num_spk=2 . tools/parse_options.sh || exit 1 data=$(realpath ${data}) if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare the meta files for the datasets" for dataset in dev test train-100; do # for dataset in train-360; do echo "Preparing files for" $dataset # Prepare the meta data for the mixed data dataset_path=$mix_data_path/$dataset/mix_${noise_type} mkdir -p "${data}"/$noise_type/${dataset} find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' | awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' | sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp | awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk # Prepare the meta data for single speakers dataset_path=$mix_data_path/$dataset/s1 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' | awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk dataset_path=$mix_data_path/$dataset/s2 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' | awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk done fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "stage 2: Prepare LibriMix target-speaker enroll signal" for dset in dev test train-100; do # for dset in train-360; do python local/prepare_spk2enroll_librispeech.py \ "${mix_data_path}/${dset}" \ --is_librimix True \ --outfile "${data}"/$noise_type/${dset}/spk2enroll.json \ --audio_format wav done for dset in dev test; do if [ $num_spk -eq 2 ]; then url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment" else url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment" fi output_file="${data}/${noise_type}/${dset}/mixture2enrollment" wget -O "$output_file" "$url" done for dset in dev test; do python local/prepare_librimix_enroll.py \ "${data}"/$noise_type/${dset}/wav.scp \ "${data}"/$noise_type/${dset}/spk2enroll.json \ --mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \ --num_spk ${num_spk} \ --train False \ --output_dir "${data}"/${noise_type}/${dset} \ --outfile_prefix "spk" done fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker..." mkdir wespeaker_models wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip unzip voxceleb_resnet34.zip -d wespeaker_models wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip unzip voxceleb_ECAPA512.zip -d wespeaker_models fi # if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # echo "Prepare the speaker embeddings using wespeaker pretrained models" # for dataset in dev test train-100; do # mkdir -p "${data}"/$noise_type/${dataset} # echo "Preparing files for" $dataset # wespeaker --task embedding_kaldi \ # --wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \ # --output_file "${data}"/$noise_type/${dataset}/embed \ # -p wespeaker_models/voxceleb_resnet34 \ # -g 0 # GPU idx # done # fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ! -d "${data}/raw_data/musan" ]; then mkdir -p ${data}/raw_data/musan # echo "Downloading musan.tar.gz ..." echo "This may take a long time. Thus we recommand you to download all archives above in your own way first." wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}') [ $md5 != "0c472d4fc0c5141eca47ad1ffeb2a7df" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1 echo "Decompress all archives ..." tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data rm -rf ${data}/raw_data/musan.tar.gz fi echo "Prepare wav.scp for musan ..." mkdir -p ${data}/musan find ${data}/raw_data/musan -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/musan/wav.scp # Convert all musan data to LMDB echo "conver musan data to LMDB ..." python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb fi ================================================ FILE: examples/librimix/tse/v2/local/prepare_librimix_enroll.py ================================================ import json import random from pathlib import Path from wesep.utils.datadir_writer import DatadirWriter from wesep.utils.utils import str2bool def prepare_librimix_enroll(wav_scp, spk2utts, output_dir, num_spk=2, train=True, prefix="enroll_spk"): mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) with Path(spk2utts).open("r", encoding="utf-8") as f: # {spkID: [(uid1, path1), (uid2, path2), ...]} spk2utt = json.load(f) with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 uttIDs = mixtureID.split("_") for spk in range(num_spk): uttID = uttIDs[spk] spkID = uttID.split("-")[0] if train: # For training, we choose the auxiliary signal on the fly. # Here we use the pattern f"*{uttID} {spkID}". writer[f"{prefix}{spk + 1}.enroll"][ mixtureID] = f"*{uttID} {spkID}" else: enrollID = random.choice(spk2utt[spkID])[1] while enrollID == uttID and len(spk2utt[spkID]) > 1: enrollID = random.choice(spk2utt[spkID])[1] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID def prepare_librimix_enroll_v2(wav_scp, map_mix2enroll, output_dir, num_spk=2, prefix="spk"): # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) mix2enroll = {} with open(map_mix2enroll) as f: for line in f: mix_id, utt_id, enroll_id = line.strip().split() sid = mix_id.split("_").index(utt_id) + 1 mix2enroll[mix_id, f"s{sid}"] = enroll_id with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 for spk in range(num_spk): enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id + ".wav") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "wav_scp", type=str, help="Path to the wav.scp file", ) parser.add_argument( "spk2utts", type=str, help="Path to the json, mapping from speaker ID to utterances", ) parser.add_argument( "--num_spk", type=int, default=2, choices=(2, 3), help="Number of speakers in each mixture sample", ) parser.add_argument( "--train", type=str2bool, default=True, help="Whether is the training set or not", ) parser.add_argument( "--mix2enroll", type=str, default=None, help="Path to the downloaded map_mixture2enrollment file. " "If `train` is False, this value is required.", ) parser.add_argument("--seed", type=int, default=1, help="Random seed") parser.add_argument( "--output_dir", type=str, required=True, help="Path to the directory for storing output files", ) parser.add_argument( "--outfile_prefix", type=str, default="spk", help="Prefix of the output files", ) args = parser.parse_args() random.seed(args.seed) if args.train: prepare_librimix_enroll( args.wav_scp, args.spk2utts, args.output_dir, num_spk=args.num_spk, train=args.train, prefix=args.outfile_prefix, ) else: prepare_librimix_enroll_v2( args.wav_scp, args.mix2enroll, args.output_dir, num_spk=args.num_spk, prefix=args.outfile_prefix, ) ================================================ FILE: examples/librimix/tse/v2/local/prepare_spk2enroll_librispeech.py ================================================ import json from collections import defaultdict from itertools import chain from pathlib import Path from wesep.utils.utils import str2bool def get_spk2utt(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in Path(path).rglob("*.{}".format(audio_format)): readerID = audio.parent.parent.stem uid = audio.stem assert uid.split("-")[0] == readerID, audio spk2utt[readerID].append((uid, str(audio))) return spk2utt def get_spk2utt_librimix(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in chain( Path(path).rglob("s1/*.{}".format(audio_format)), Path(path).rglob("s2/*.{}".format(audio_format)), Path(path).rglob("s3/*.{}".format(audio_format)), ): spk_idx = int(audio.parent.stem[1:]) - 1 mix_uid = audio.stem uid = mix_uid.split("_")[spk_idx] sid = uid.split("-")[0] spk2utt[sid].append((uid, str(audio))) return spk2utt if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "audio_paths", type=str, nargs="+", help="Paths to Librispeech subsets", ) parser.add_argument( "--is_librimix", type=str2bool, default=False, help="Whether the provided audio_paths points to LibriMix data", ) parser.add_argument( "--outfile", type=str, default="spk2utt_tse.json", help="Path to the output spk2utt json file", ) parser.add_argument("--audio_format", type=str, default="flac") args = parser.parse_args() if args.is_librimix: # use clean sources from LibriMix as enrollment spk2utt = get_spk2utt_librimix(args.audio_paths, audio_format=args.audio_format) else: # use Librispeech as enrollment spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format) outfile = Path(args.outfile) outfile.parent.mkdir(parents=True, exist_ok=True) with outfile.open("w", encoding="utf-8") as f: json.dump(spk2utt, f, indent=4) ================================================ FILE: examples/librimix/tse/v2/path.sh ================================================ export PATH=$PWD:$PATH # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=../../../../:$PYTHONPATH ================================================ FILE: examples/librimix/tse/v2/run.sh ================================================ #!/bin/bash # Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn) . ./path.sh || exit 1 # General configuration stage=-1 stop_stage=-1 # Data preparation related data=data fs=16k min_max=min noise_type="clean" data_type="shard" # shard/raw Libri2Mix_dir=/YourPATH/librimix/Libri2Mix mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}" # Training related gpus="[0]" use_gan_loss=false config=confs/bsrnn.yaml exp_dir=exp/BSRNN/no_spk_transform-multiply_fuse if [ -z "${config}" ] && [ -f "${exp_dir}/config.yaml" ]; then config="${exp_dir}/config.yaml" fi # TSE model initialization related checkpoint= # Inferencing and scoring related save_results=true use_pesq=true use_dnsmos=true dnsmos_use_gpu=true # Model average related num_avg=10 . tools/parse_options.sh || exit 1 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare datasets ..." ./local/prepare_data.sh --mix_data_path ${mix_data_path} \ --data ${data} \ --noise_type ${noise_type} \ --stage 1 \ --stop-stage 3 fi data=${data}/${noise_type} if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Covert train and test data to ${data_type}..." for dset in train-100 dev test; do # for dset in train-360; do python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \ --num_threads 16 \ --prefix shards \ --shuffle \ ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ ${data}/$dset/shards ${data}/$dset/shard.list done fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "Start training ..." num_gpus=$(echo $gpus | awk -F ',' '{print NF}') if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then checkpoint="${exp_dir}/models/latest_checkpoint.pt" fi if ${use_gan_loss}; then train_script=wesep/bin/train_gan.py else train_script=wesep/bin/train.py fi export OMP_NUM_THREADS=8 torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ ${train_script} --config $config \ --exp_dir ${exp_dir} \ --gpus $gpus \ --num_avg ${num_avg} \ --data_type "${data_type}" \ --train_data ${data}/train-100/${data_type}.list \ --train_utt2spk ${data}/train-100/single.utt2spk \ --train_spk2utt ${data}/train-100/spk2enroll.json \ --val_data ${data}/dev/${data_type}.list \ --val_spk1_enroll ${data}/dev/spk1.enroll \ --val_spk2_enroll ${data}/dev/spk2.enroll \ --val_spk2utt ${data}/dev/single.wav.scp \ ${checkpoint:+--checkpoint $checkpoint} fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "Do model average ..." avg_model=$exp_dir/models/avg_best_model.pt python wesep/bin/average_model.py \ --dst_model $avg_model \ --src_path $exp_dir/models \ --num ${num_avg} \ --mode best \ --epochs "138,141" fi if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/avg_best_model.pt" ]; then checkpoint="${exp_dir}/models/avg_best_model.pt" fi # shellcheck disable=SC2215 if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then echo "Start inferencing ..." python wesep/bin/infer.py --config $config \ --fs ${fs} \ --gpus 0 \ --exp_dir ${exp_dir} \ --data_type "${data_type}" \ --test_data ${data}/test/${data_type}.list \ --test_spk1_enroll ${data}/test/spk1.enroll \ --test_spk2_enroll ${data}/test/spk2.enroll \ --test_spk2utt ${data}/test/single.wav.scp \ --save_wav ${save_results} \ ${checkpoint:+--checkpoint $checkpoint} fi if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then echo "Start scoring ..." ./tools/score.sh --dset "${data}/test" \ --exp_dir "${exp_dir}" \ --fs ${fs} \ --use_pesq "${use_pesq}" \ --use_dnsmos "${use_dnsmos}" \ --dnsmos_use_gpu "${dnsmos_use_gpu}" \ --n_gpu "${num_gpus}" fi ================================================ FILE: examples/voxceleb1/v2/confs/bsrnn_online.yaml ================================================ dataloader_args: batch_size: 8 drop_last: true num_workers: 6 pin_memory: false prefetch_factor: 6 dataset_args: resample_rate: 16000 sample_num_per_epoch: 0 shuffle: true shuffle_args: shuffle_size: 2500 filter_len: true filter_len_args: min_num_seconds: 1.0 max_num_seconds: 100.0 chunk_len: 48000 online_mix: true num_speakers: 2 use_random_snr: true speaker_feat: &speaker_feat True fbank_args: num_mel_bins: 80 frame_shift: 10 frame_length: 25 dither: 1.0 noise_lmdb_file: './data/musan/lmdb' noise_prob: 0 reverb_prob: 0 enable_amp: false exp_dir: exp/BSRNN gpus: '0,1' log_batch_interval: 100 loss: SISDR loss_args: { } model: tse_model: BSRNN model_args: tse_model: sr: 16000 win: 512 stride: 128 feature_dim: 128 num_repeat: 6 spk_fuse_type: 'multiply' use_spk_transform: False multi_fuse: False # Fuse the speaker embedding multiple times. joint_training: True ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt spk_args: feat_dim: 80 embed_dim: &embed_dim 256 pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False ####### Ecapa_TDNN # spk_model: ECAPA_TDNN_GLOB_c512 # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt # spk_args: # embed_dim: &embed_dim 192 # feat_dim: 80 # pooling_func: ASTP ####### CAMPPlus # spk_model: CAMPPlus # spk_model_init: False # spk_args: # feat_dim: 80 # embed_dim: &embed_dim 192 # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP ################################################################# spk_emb_dim: *embed_dim spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used feat_type: "consistent" multi_task: False spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 # find_unused_parameters: True model_init: tse_model: null discriminator: null num_avg: 2 num_epochs: 150 optimizer: tse_model: Adam optimizer_args: tse_model: lr: 0.001 weight_decay: 0.0001 clip_grad: 5.0 save_epoch_interval: 1 scheduler: tse_model: ExponentialDecrease scheduler_args: tse_model: final_lr: 2.5e-05 initial_lr: 0.001 warm_from_zero: false warm_up_epoch: 0 seed: 42 ================================================ FILE: examples/voxceleb1/v2/local/prepare_data.sh ================================================ #!/bin/bash # Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com) stage=-1 stop_stage=-1 single_data_path='./voxceleb/VoxCeleb1/wav/' mix_data_path='./Libri2Mix/wav16k/min/' data=data noise_type=clean num_spk=2 . tools/parse_options.sh || exit 1 data=$(realpath ${data}) if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare the meta files for the vox1 single speaker datasets" for dataset in train-vox1; do echo "Preparing files for" $dataset # Prepare the meta data for the online mix data mkdir -p "${data}"/$noise_type/${dataset} find ${single_data_path} -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' | sort >"${data}"/$noise_type/${dataset}/wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp | awk -F "/" '{print $0,$1}' >"${data}"/$noise_type/${dataset}/utt2spk python local/prepare_spk2enroll_vox.py \ "${data}/$noise_type/${dataset}/wav.scp" \ --outfile "${data}"/$noise_type/${dataset}/spk2enroll.json done fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Prepare the meta files for the val and test datasets" for dataset in dev test; do echo "Preparing files for" $dataset # Prepare the meta data for the mixed data dataset_path=$mix_data_path/$dataset/mix_${noise_type} mkdir -p "${data}"/$noise_type/${dataset} find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print $NF}' | awk -v path="${dataset_path}" '{print $1 , path "/" $1 , path "/../s1/" $1 , path "/../s2/" $1}' | sed 's#.wav##' | sort -k1,1 >"${data}"/$noise_type/${dataset}/wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/wav.scp | awk -F[_-] '{print $0, $1,$4}' >"${data}"/$noise_type/${dataset}/utt2spk # Prepare the meta data for single speakers dataset_path=$mix_data_path/$dataset/s1 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s1/" $NF, $0}' | sort -k1,1 >"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's1' | awk -F[-_/] '{print $0, $2}' >"${data}"/$noise_type/${dataset}/single.utt2spk dataset_path=$mix_data_path/$dataset/s2 find ${dataset_path}/ -type f -name "*.wav" | awk -F/ '{print "s2/" $NF, $0}' | sort -k1,1 >>"${data}"/$noise_type/${dataset}/single.wav.scp awk '{print $1}' "${data}"/$noise_type/${dataset}/single.wav.scp | grep 's2' | awk -F[-_/] '{print $0, $5}' >>"${data}"/$noise_type/${dataset}/single.utt2spk done fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "stage 3: Prepare LibriMix target-speaker enroll signal" for dset in dev test; do python local/prepare_spk2enroll_librispeech.py \ "${mix_data_path}/${dset}" \ --is_librimix True \ --outfile "${data}"/$noise_type/${dset}/spk2enroll.json \ --audio_format wav done for dset in dev test; do if [ $num_spk -eq 2 ]; then url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment" else url="https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment" fi output_file="${data}/${noise_type}/${dset}/mixture2enrollment" wget -O "$output_file" "$url" done for dset in dev test; do python local/prepare_librimix_enroll.py \ "${data}"/$noise_type/${dset}/wav.scp \ "${data}"/$noise_type/${dset}/spk2enroll.json \ --mix2enroll "${data}/${noise_type}/${dset}/mixture2enrollment" \ --num_spk ${num_spk} \ --train False \ --output_dir "${data}"/${noise_type}/${dset} \ --outfile_prefix "spk" done fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker..." mkdir wespeaker_models wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip unzip voxceleb_resnet34.zip -d wespeaker_models wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip unzip voxceleb_ECAPA512.zip -d wespeaker_models fi # if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # echo "Prepare the speaker embeddings using wespeaker pretrained models" # for dataset in dev test train-100; do # mkdir -p "${data}"/$noise_type/${dataset} # echo "Preparing files for" $dataset # wespeaker --task embedding_kaldi \ # --wav_scp "${data}"/$noise_type/${dataset}/single.wav.scp \ # --output_file "${data}"/$noise_type/${dataset}/embed \ # -p wespeaker_models/voxceleb_resnet34 \ # -g 0 # GPU idx # done # fi if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ! -d "${data}/raw_data/musan" ]; then mkdir -p ${data}/raw_data/musan # echo "Downloading musan.tar.gz ..." echo "This may take a long time. Thus we recommand you to download all archives above in your own way first." wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}') [ $md5 != "0c472d4fc0c5141eca47ad1ffeb2a7df" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1 echo "Decompress all archives ..." tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data rm -rf ${data}/raw_data/musan.tar.gz fi echo "Prepare wav.scp for musan ..." mkdir -p ${data}/musan find ${data}/raw_data/musan -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/musan/wav.scp # Convert all musan data to LMDB echo "conver musan data to LMDB ..." python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb fi ================================================ FILE: examples/voxceleb1/v2/local/prepare_librimix_enroll.py ================================================ import json import random from pathlib import Path from wesep.utils.datadir_writer import DatadirWriter from wesep.utils.utils import str2bool def prepare_librimix_enroll(wav_scp, spk2utts, output_dir, num_spk=2, train=True, prefix="enroll_spk"): mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) with Path(spk2utts).open("r", encoding="utf-8") as f: # {spkID: [(uid1, path1), (uid2, path2), ...]} spk2utt = json.load(f) with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 uttIDs = mixtureID.split("_") for spk in range(num_spk): uttID = uttIDs[spk] spkID = uttID.split("-")[0] if train: # For training, we choose the auxiliary signal on the fly. # Thus, here we use the pattern f"*{uttID} {spkID}" to indicate it. # noqa writer[f"{prefix}{spk + 1}.enroll"][ mixtureID] = f"*{uttID} {spkID}" else: enrollID = random.choice(spk2utt[spkID])[1] while enrollID == uttID and len(spk2utt[spkID]) > 1: enrollID = random.choice(spk2utt[spkID])[1] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = enrollID def prepare_librimix_enroll_v2(wav_scp, map_mix2enroll, output_dir, num_spk=2, prefix="spk"): # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py mixtures = [] with Path(wav_scp).open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue mixtureID = line.strip().split(maxsplit=1)[0] mixtures.append(mixtureID) mix2enroll = {} with open(map_mix2enroll) as f: for line in f: mix_id, utt_id, enroll_id = line.strip().split() sid = mix_id.split("_").index(utt_id) + 1 mix2enroll[mix_id, f"s{sid}"] = enroll_id with DatadirWriter(Path(output_dir)) as writer: for mixtureID in mixtures: # 100-121669-0004_3180-138043-0053 for spk in range(num_spk): enroll_id = mix2enroll[mixtureID, f"s{spk + 1}"] writer[f"{prefix}{spk + 1}.enroll"][mixtureID] = (enroll_id + ".wav") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "wav_scp", type=str, help="Path to the wav.scp file", ) parser.add_argument("spk2utts", type=str, help="Path to the json file containing mapping " "from speaker ID to utterances") parser.add_argument( "--num_spk", type=int, default=2, choices=(2, 3), help="Number of speakers in each mixture sample", ) parser.add_argument( "--train", type=str2bool, default=True, help="Whether is the training set or not", ) parser.add_argument( "--mix2enroll", type=str, default=None, help="Path to the downloaded map_mixture2enrollment file. " "If `train` is False, this value is required.", ) parser.add_argument("--seed", type=int, default=1, help="Random seed") parser.add_argument( "--output_dir", type=str, required=True, help="Path to the directory for storing output files", ) parser.add_argument( "--outfile_prefix", type=str, default="spk", help="Prefix of the output files", ) args = parser.parse_args() random.seed(args.seed) if args.train: prepare_librimix_enroll( args.wav_scp, args.spk2utts, args.output_dir, num_spk=args.num_spk, train=args.train, prefix=args.outfile_prefix, ) else: prepare_librimix_enroll_v2( args.wav_scp, args.mix2enroll, args.output_dir, num_spk=args.num_spk, prefix=args.outfile_prefix, ) ================================================ FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_librispeech.py ================================================ import json from collections import defaultdict from itertools import chain from pathlib import Path from wesep.utils.utils import str2bool def get_spk2utt_vox1(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: path = Path(path) for subdir in path.iterdir(): if subdir.is_dir(): readerID = subdir.name for audio in subdir.rglob("*.{}".format(audio_format)): spk2utt[readerID].append(str(audio)) return spk2utt def get_spk2utt(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in Path(path).rglob("*.{}".format(audio_format)): readerID = audio.parent.parent.stem uid = audio.stem assert uid.split("-")[0] == readerID, audio spk2utt[readerID].append((uid, str(audio))) return spk2utt def get_spk2utt_librimix(paths, audio_format="flac"): spk2utt = defaultdict(list) for path in paths: for audio in chain( Path(path).rglob("s1/*.{}".format(audio_format)), Path(path).rglob("s2/*.{}".format(audio_format)), Path(path).rglob("s3/*.{}".format(audio_format)), ): spk_idx = int(audio.parent.stem[1:]) - 1 mix_uid = audio.stem uid = mix_uid.split("_")[spk_idx] sid = uid.split("-")[0] spk2utt[sid].append((uid, str(audio))) return spk2utt if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "audio_paths", type=str, nargs="+", help="Paths to Librispeech subsets", ) parser.add_argument( "--is_librimix", type=str2bool, default=False, help="Whether the provided audio_paths points to LibriMix data", ) parser.add_argument( "--is_vox1", type=str2bool, default=False, help="Whether the provided audio_paths points to vox1 data", ) parser.add_argument( "--outfile", type=str, default="spk2utt_tse.json", help="Path to the output spk2utt json file", ) parser.add_argument("--audio_format", type=str, default="flac") args = parser.parse_args() if args.is_librimix: # use clean sources from LibriMix as enrollment spk2utt = get_spk2utt_librimix(args.audio_paths, audio_format=args.audio_format) elif args.is_vox1: spk2utt = get_spk2utt_vox1(args.audio_paths, audio_format=args.audio_format) else: # use Librispeech as enrollment spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format) outfile = Path(args.outfile) outfile.parent.mkdir(parents=True, exist_ok=True) with outfile.open("w", encoding="utf-8") as f: json.dump(spk2utt, f, indent=4) ================================================ FILE: examples/voxceleb1/v2/local/prepare_spk2enroll_vox.py ================================================ import json from collections import defaultdict from pathlib import Path def get_spk2utt_from_wavscp(wav_scp_path): spk2utt = defaultdict(list) with open(wav_scp_path, "r") as readin: for line in readin: speaker_id = line.split("/")[0] uid, audio_path = line.strip().split() spk2utt[speaker_id].append((uid, str(audio_path))) return spk2utt if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "wav_scp_path", type=str, help="Paths to Librispeech subsets", ) parser.add_argument( "--outfile", type=str, default="spk2utt_tse.json", help="Path to the output spk2utt json file", ) args = parser.parse_args() spk2utt = get_spk2utt_from_wavscp(args.wav_scp_path) outfile = Path(args.outfile) outfile.parent.mkdir(parents=True, exist_ok=True) with outfile.open("w", encoding="utf-8") as f: json.dump(spk2utt, f) ================================================ FILE: examples/voxceleb1/v2/path.sh ================================================ export PATH=$PWD:$PATH # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=../../../:$PYTHONPATH ================================================ FILE: examples/voxceleb1/v2/run_online.sh ================================================ #!/bin/bash # Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn) . ./path.sh || exit 1 stage=-1 stop_stage=-1 data=data fs=16k min_max=min noise_type="clean" data_type="shard" # shard/raw Vox1_dir=/YourPATH/voxceleb/VoxCeleb1/wav Libri2Mix_dir=/YourPATH/librimix/Libri2Mix #For validate and test the TSE model. mix_data_path="${Libri2Mix_dir}/wav${fs}/${min_max}" gpus="[0,1]" num_avg=10 checkpoint= config=confs/bsrnn_online.yaml exp_dir=exp/BSRNN_Online/no_spk_transform_multiply save_results=true . tools/parse_options.sh || exit 1 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Prepare datasets ..." ./local/prepare_data.sh --single_data_path ${Vox1_dir} \ --mix_data_path ${mix_data_path} \ --data ${data} \ --noise_type ${noise_type} \ --stage 1 \ --stop-stage 4 fi data=${data}/${noise_type} if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Covert train and test data to ${data_type}..." for dset in train-vox1; do python tools/make_shard_online.py --num_utts_per_shard 1000 \ --num_threads 16 \ --prefix shards \ --shuffle \ ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ ${data}/$dset/shards_online ${data}/$dset/shard_online.list done for dset in dev test; do python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \ --num_threads 16 \ --prefix shards \ --shuffle \ ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ ${data}/$dset/shards ${data}/$dset/shard.list done fi if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then checkpoint="${exp_dir}/models/latest_checkpoint.pt" fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # rm -r $exp_dir echo "Start training ..." export OMP_NUM_THREADS=8 num_gpus=$(echo $gpus | awk -F ',' '{print NF}') if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/latest_checkpoint.pt" ]; then checkpoint="${exp_dir}/models/latest_checkpoint.pt" fi torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ wesep/bin/train.py --config $config \ --exp_dir ${exp_dir} \ --gpus $gpus \ --num_avg ${num_avg} \ --data_type "${data_type}" \ --train_data ${data}/train-vox1/${data_type}_online.list \ --train_utt2spk ${data}/train-vox1/utt2spk \ --train_spk2utt ${data}/train-vox1/spk2enroll.json \ --val_data ${data}/dev/${data_type}.list \ --val_spk2utt ${data}/dev/single.wav.scp \ --val_spk1_enroll ${data}/dev/spk1.enroll \ --val_spk2_enroll ${data}/dev/spk2.enroll \ ${checkpoint:+--checkpoint $checkpoint} fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "Do model average ..." avg_model=$exp_dir/models/avg_best_model.pt python wesep/bin/average_model.py \ --dst_model $avg_model \ --src_path $exp_dir/models \ --num ${num_avg} \ --mode best \ --epochs "138,141" fi if [ -z "${checkpoint}" ] && [ -f "${exp_dir}/models/avg_best_model.pt" ]; then checkpoint="${exp_dir}/models/avg_best_model.pt" fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then python wesep/bin/infer.py --config $config \ --gpus 0 \ --exp_dir ${exp_dir} \ --data_type "${data_type}" \ --test_data ${data}/test/${data_type}.list \ --test_spk1_enroll ${data}/test/spk1.enroll \ --test_spk2_enroll ${data}/test/spk2.enroll \ --test_spk2utt ${data}/test/single.wav.scp \ --save_wav ${save_results} \ ${checkpoint:+--checkpoint $checkpoint} fi #./run.sh --stage 4 --stop-stage 4 --config exp/BSRNN/train_clean_460/multiply_no_spk_transform/config.yaml --exp_dir exp/BSRNN/train_clean_460/multiply_no_spk_transform/ --checkpoint exp/BSRNN/train_clean_460/multiply_no_spk_transform/models/avg_best_model.pt ================================================ FILE: requirements.txt ================================================ fast_bss_eval==0.1.4 fire==0.4.0 joblib==1.1.0 kaldiio==2.18.0 librosa==0.10.1 lmdb==1.3.0 matplotlib==3.5.1 mir_eval==0.7 silero-vad==5.1.2 numpy==1.22.4 pesq==0.0.4 pystoi==0.3.3 PyYAML==6.0 Requests==2.31.0 scipy==1.7.3 soundfile==0.12.1 tableprint==0.9.1 thop==0.1.1.post2209072238 torchnet==0.0.4 tqdm==4.64.0 flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 auraloss torchmetrics==1.2.0 h5py pre-commit==3.5.0 ================================================ FILE: runtime/.gitignore ================================================ fc_base build* ================================================ FILE: runtime/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.14) project(wesep VERSION 0.1) option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF) set(CMAKE_VERBOSE_MAKEFILE OFF) include(FetchContent) set(FETCHCONTENT_QUIET OFF) get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(FETCHCONTENT_BASE_DIR ${fc_base}) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC") include(libtorch) include(glog) include(gflags) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) # build all libraries add_subdirectory(utils) add_subdirectory(frontend) add_subdirectory(separate) add_subdirectory(bin) ================================================ FILE: runtime/README.md ================================================ # Libtorch backend on wesep * Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above. ``` sh mkdir build && cd build cmake .. cmake --build . ``` * Testing. 1. the RTF(real time factor) is shown in the console, and outputs will be written to the wav file. ``` sh export GLOG_logtostderr=1 export GLOG_v=2 ./build/bin/separate_main \ --wav_scp $wav_scp \ --model /path/to/model.zip \ --output_dir /output/dir/ ``` ================================================ FILE: runtime/bin/CMakeLists.txt ================================================ add_executable(separate_main separate_main.cc) target_link_libraries(separate_main PUBLIC frontend separate) ================================================ FILE: runtime/bin/separate_main.cc ================================================ // Copyright (c) 2024 wesep team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" #include "frontend/wav.h" #include "separate/separate_engine.h" #include "utils/timer.h" #include "utils/utils.h" DEFINE_string(wav_path, "", "the path of mixing audio."); DEFINE_string(spk1_emb, "", "the emb of spk1."); DEFINE_string(spk2_emb, "", "the emb of spk2."); DEFINE_string(wav_scp, "", "input wav scp."); DEFINE_string(model, "", "the path of wesep model."); DEFINE_string(output_dir, "", "output path."); DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(feat_dim, 80, "fbank feature dimension."); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); std::vector> waves; if (!FLAGS_wav_path.empty() && !FLAGS_spk1_emb.empty() && !FLAGS_spk2_emb.empty()) { waves.push_back(std::vector( {"test", FLAGS_wav_path, FLAGS_spk1_emb, FLAGS_spk2_emb})); } else { std::ifstream wav_scp(FLAGS_wav_scp); std::string line; while (getline(wav_scp, line)) { std::vector strs; wesep::SplitString(line, &strs); CHECK_EQ(strs.size(), 4); waves.push_back( std::vector({strs[0], strs[1], strs[2], strs[3]})); } if (waves.empty()) { LOG(FATAL) << "Please provide non-empty wav scp."; } } if (FLAGS_output_dir.empty()) { LOG(FATAL) << "Invalid output path."; } int g_total_waves_dur = 0; int g_total_process_time = 0; auto model = std::make_shared( FLAGS_model, FLAGS_feat_dim, FLAGS_sample_rate); for (auto wav : waves) { // mix wav wenet::WavReader wav_reader(wav[1]); CHECK_EQ(wav_reader.sample_rate(), 16000); int16_t* mix_wav_data = const_cast(wav_reader.data()); int wave_dur = static_cast(static_cast(wav_reader.num_sample()) / wav_reader.sample_rate() * 1000); // spk1 wenet::WavReader spk1_reader(wav[2]); CHECK_EQ(spk1_reader.sample_rate(), 16000); int16_t* spk1_data = const_cast(spk1_reader.data()); // spk2 wenet::WavReader spk2_reader(wav[3]); CHECK_EQ(spk2_reader.sample_rate(), 16000); int16_t* spk2_data = const_cast(spk2_reader.data()); // forward std::vector> outputs; int process_time = 0; wenet::Timer timer; model->ForwardFunc( std::vector(mix_wav_data, mix_wav_data + wav_reader.num_sample()), spk1_data, spk2_data, std::min(spk1_reader.num_sample(), spk2_reader.num_sample()), &outputs); process_time = timer.Elapsed(); LOG(INFO) << "process: " << wav[0] << " RTF: " << static_cast(process_time) / wave_dur; // 保存音频 wenet::WriteWavFile(outputs[0].data(), outputs[0].size(), 16000, FLAGS_output_dir + "/" + wav[0] + "-spk1.wav"); wenet::WriteWavFile(outputs[1].data(), outputs[1].size(), 16000, FLAGS_output_dir + "/" + wav[0] + "-spk2.wav"); g_total_process_time += process_time; g_total_waves_dur += wave_dur; } LOG(INFO) << "Total: process " << g_total_waves_dur << "ms audio taken " << g_total_process_time << "ms."; LOG(INFO) << "RTF: " << std::setprecision(4) << static_cast(g_total_process_time) / g_total_waves_dur; return 0; } ================================================ FILE: runtime/cmake/gflags.cmake ================================================ FetchContent_Declare(gflags URL https://github.com/gflags/gflags/archive/v2.2.2.zip URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 ) FetchContent_MakeAvailable(gflags) include_directories(${gflags_BINARY_DIR}/include) ================================================ FILE: runtime/cmake/glog.cmake ================================================ FetchContent_Declare(glog URL https://github.com/google/glog/archive/v0.4.0.zip URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc ) FetchContent_MakeAvailable(glog) include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) ================================================ FILE: runtime/cmake/libtorch.cmake ================================================ if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") if(CXX11_ABI) set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip") set(URL_HASH "SHA256=d52f63577a07adb0bfd6d77c90f7da21896e94f71eb7dcd55ed7835ccb3b2b59") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip") set(URL_HASH "SHA256=bee1b7be308792aa60fc95a4f5274d9658cb7248002d0e333d49eb81ec88430c") endif() else() message(FATAL_ERROR "Unsported System '${CMAKE_SYSTEM_NAME}' (expected 'Linux')") endif() FetchContent_Declare(libtorch URL ${LIBTORCH_URL} URL_HASH ${URL_HASH} ) FetchContent_MakeAvailable(libtorch) find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH) include_directories(${TORCH_INCLUDE_DIRS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG") ================================================ FILE: runtime/frontend/CMakeLists.txt ================================================ add_library(frontend STATIC feature_pipeline.cc fft.cc ) target_link_libraries(frontend PUBLIC utils) ================================================ FILE: runtime/frontend/fbank.h ================================================ // Copyright (c) 2017 Personal (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef FRONTEND_FBANK_H_ #define FRONTEND_FBANK_H_ #include #include #include #include #include #include "frontend/fft.h" #include "glog/logging.h" namespace wenet { // This code is based on kaldi Fbank implentation, please see // https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc class Fbank { public: Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift) : num_bins_(num_bins), sample_rate_(sample_rate), frame_length_(frame_length), frame_shift_(frame_shift), use_log_(true), remove_dc_offset_(true), generator_(0), distribution_(0, 1.0), dither_(0.0) { fft_points_ = UpperPowerOfTwo(frame_length_); // generate bit reversal table and trigonometric function table const int fft_points_4 = fft_points_ / 4; bitrev_.resize(fft_points_); sintbl_.resize(fft_points_ + fft_points_4); make_sintbl(fft_points_, sintbl_.data()); make_bitrev(fft_points_, bitrev_.data()); int num_fft_bins = fft_points_ / 2; float fft_bin_width = static_cast(sample_rate_) / fft_points_; int low_freq = 20, high_freq = sample_rate_ / 2; float mel_low_freq = MelScale(low_freq); float mel_high_freq = MelScale(high_freq); float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); bins_.resize(num_bins_); center_freqs_.resize(num_bins_); for (int bin = 0; bin < num_bins; ++bin) { float left_mel = mel_low_freq + bin * mel_freq_delta, center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; center_freqs_[bin] = InverseMelScale(center_mel); std::vector this_bin(num_fft_bins); int first_index = -1, last_index = -1; for (int i = 0; i < num_fft_bins; ++i) { float freq = (fft_bin_width * i); // Center frequency of this fft // bin. float mel = MelScale(freq); if (mel > left_mel && mel < right_mel) { float weight; if (mel <= center_mel) weight = (mel - left_mel) / (center_mel - left_mel); else weight = (right_mel - mel) / (right_mel - center_mel); this_bin[i] = weight; if (first_index == -1) first_index = i; last_index = i; } } CHECK(first_index != -1 && last_index >= first_index); bins_[bin].first = first_index; int size = last_index + 1 - first_index; bins_[bin].second.resize(size); for (int i = 0; i < size; ++i) { bins_[bin].second[i] = this_bin[first_index + i]; } } // NOTE(cdliang): add hamming window hamming_window_.resize(frame_length_); double a = M_2PI / (frame_length - 1); for (int i = 0; i < frame_length; i++) { double i_fl = static_cast(i); hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl); } } void set_use_log(bool use_log) { use_log_ = use_log; } void set_remove_dc_offset(bool remove_dc_offset) { remove_dc_offset_ = remove_dc_offset; } void set_dither(float dither) { dither_ = dither; } int num_bins() const { return num_bins_; } static inline float InverseMelScale(float mel_freq) { return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); } static inline float MelScale(float freq) { return 1127.0f * logf(1.0f + freq / 700.0f); } static int UpperPowerOfTwo(int n) { return static_cast(pow(2, ceil(log(n) / log(2)))); } // preemphasis void PreEmphasis(float coeff, std::vector* data) const { if (coeff == 0.0) return; for (int i = data->size() - 1; i > 0; i--) (*data)[i] -= coeff * (*data)[i - 1]; (*data)[0] -= coeff * (*data)[0]; } // add hamming window void Hamming(std::vector* data) const { CHECK_GE(data->size(), hamming_window_.size()); for (size_t i = 0; i < hamming_window_.size(); ++i) { (*data)[i] *= hamming_window_[i]; } } // Compute fbank feat, return num frames int Compute(const std::vector& wave, std::vector>* feat) { int num_samples = wave.size(); if (num_samples < frame_length_) return 0; int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_); feat->resize(num_frames); std::vector fft_real(fft_points_, 0), fft_img(fft_points_, 0); std::vector power(fft_points_ / 2); for (int i = 0; i < num_frames; ++i) { std::vector data(wave.data() + i * frame_shift_, wave.data() + i * frame_shift_ + frame_length_); // optional add noise if (dither_ != 0.0) { for (size_t j = 0; j < data.size(); ++j) data[j] += dither_ * distribution_(generator_); } // optinal remove dc offset if (remove_dc_offset_) { float mean = 0.0; for (size_t j = 0; j < data.size(); ++j) mean += data[j]; mean /= data.size(); for (size_t j = 0; j < data.size(); ++j) data[j] -= mean; } PreEmphasis(0.97, &data); // Povey(&data); Hamming(&data); // copy data to fft_real memset(fft_img.data(), 0, sizeof(float) * fft_points_); memset(fft_real.data() + frame_length_, 0, sizeof(float) * (fft_points_ - frame_length_)); memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_); fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(), fft_points_); // power for (int j = 0; j < fft_points_ / 2; ++j) { power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; } (*feat)[i].resize(num_bins_); // cepstral coefficients, triangle filter array for (int j = 0; j < num_bins_; ++j) { float mel_energy = 0.0; int s = bins_[j].first; for (size_t k = 0; k < bins_[j].second.size(); ++k) { mel_energy += bins_[j].second[k] * power[s + k]; } // optional use log if (use_log_) { if (mel_energy < std::numeric_limits::epsilon()) mel_energy = std::numeric_limits::epsilon(); mel_energy = logf(mel_energy); } (*feat)[i][j] = mel_energy; // printf("%f ", mel_energy); } // printf("\n"); } return num_frames; } private: int num_bins_; int sample_rate_; int frame_length_, frame_shift_; int fft_points_; bool use_log_; bool remove_dc_offset_; std::vector center_freqs_; std::vector>> bins_; std::vector hamming_window_; std::default_random_engine generator_; std::normal_distribution distribution_; float dither_; // bit reversal table std::vector bitrev_; // trigonometric function table std::vector sintbl_; }; } // namespace wenet #endif // FRONTEND_FBANK_H_ ================================================ FILE: runtime/frontend/feature_pipeline.cc ================================================ // Copyright (c) 2017 Personal (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "frontend/feature_pipeline.h" #include #include namespace wenet { FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config) : config_(config), feature_dim_(config.num_bins), fbank_(config.num_bins, config.sample_rate, config.frame_length, config.frame_shift), num_frames_(0), input_finished_(false) {} void FeaturePipeline::AcceptWaveform(const std::vector& wav) { std::vector> feats; std::vector waves; waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end()); waves.insert(waves.end(), wav.begin(), wav.end()); int num_frames = fbank_.Compute(waves, &feats); for (size_t i = 0; i < feats.size(); ++i) { feature_queue_.Push(std::move(feats[i])); } num_frames_ += num_frames; int left_samples = waves.size() - config_.frame_shift * num_frames; remained_wav_.resize(left_samples); std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(), remained_wav_.begin()); // We are still adding wave, notify input is not finished finish_condition_.notify_one(); } void FeaturePipeline::AcceptWaveform(const std::vector& wav) { std::vector float_wav(wav.size()); for (size_t i = 0; i < wav.size(); i++) { float_wav[i] = static_cast(wav[i]); } this->AcceptWaveform(float_wav); } void FeaturePipeline::set_input_finished() { CHECK(!input_finished_); { std::lock_guard lock(mutex_); input_finished_ = true; } finish_condition_.notify_one(); } bool FeaturePipeline::ReadOne(std::vector* feat) { if (!feature_queue_.Empty()) { *feat = std::move(feature_queue_.Pop()); return true; } else { std::unique_lock lock(mutex_); while (!input_finished_) { // This will release the lock and wait for notify_one() // from AcceptWaveform() or set_input_finished() finish_condition_.wait(lock); if (!feature_queue_.Empty()) { *feat = std::move(feature_queue_.Pop()); return true; } } CHECK(input_finished_); // Double check queue.empty, see issue#893 for detailed discussions. if (!feature_queue_.Empty()) { *feat = std::move(feature_queue_.Pop()); return true; } else { return false; } } } bool FeaturePipeline::Read(int num_frames, std::vector>* feats) { feats->clear(); std::vector feat; while (feats->size() < num_frames) { if (ReadOne(&feat)) { feats->push_back(std::move(feat)); } else { return false; } } return true; } void FeaturePipeline::Reset() { input_finished_ = false; num_frames_ = 0; remained_wav_.clear(); feature_queue_.Clear(); } } // namespace wenet ================================================ FILE: runtime/frontend/feature_pipeline.h ================================================ // Copyright (c) 2017 Personal (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef FRONTEND_FEATURE_PIPELINE_H_ #define FRONTEND_FEATURE_PIPELINE_H_ #include #include #include #include #include "frontend/fbank.h" #include "glog/logging.h" #include "utils/blocking_queue.h" namespace wenet { struct FeaturePipelineConfig { int num_bins; int sample_rate; int frame_length; int frame_shift; FeaturePipelineConfig(int num_bins, int sample_rate) : num_bins(num_bins), // 80 dim fbank sample_rate(sample_rate) { // 16k sample rate frame_length = sample_rate / 1000 * 25; // frame length 25ms frame_shift = sample_rate / 1000 * 10; // frame shift 10ms } void Info() const { LOG(INFO) << "feature pipeline config" << " num_bins " << num_bins << " frame_length " << frame_length << " frame_shift " << frame_shift; } }; // Typically, FeaturePipeline is used in two threads: one thread A calls // AcceptWaveform() to add raw wav data and set_input_finished() to notice // the end of input wav, another thread B (decoder thread) calls Read() to // consume features.So a BlockingQueue is used to make this class thread safe. // The Read() is designed as a blocking method when there is no feature // in feature_queue_ and the input is not finished. class FeaturePipeline { public: explicit FeaturePipeline(const FeaturePipelineConfig& config); // The feature extraction is done in AcceptWaveform(). void AcceptWaveform(const std::vector& wav); void AcceptWaveform(const std::vector& wav); // Current extracted frames number. int num_frames() const { return num_frames_; } int feature_dim() const { return feature_dim_; } const FeaturePipelineConfig& config() const { return config_; } // The caller should call this method when speech input is end. // Never call AcceptWaveform() after calling set_input_finished() ! void set_input_finished(); bool input_finished() const { return input_finished_; } // Return False if input is finished and no feature could be read. // Return True if a feature is read. // This function is a blocking method. It will block the thread when // there is no feature in feature_queue_ and the input is not finished. bool ReadOne(std::vector* feat); // Read #num_frames frame features. // Return False if less then #num_frames features are read and the // input is finished. // Return True if #num_frames features are read. // This function is a blocking method when there is no feature // in feature_queue_ and the input is not finished. bool Read(int num_frames, std::vector>* feats); void Reset(); bool IsLastFrame(int frame) const { return input_finished_ && (frame == num_frames_ - 1); } int NumQueuedFrames() const { return feature_queue_.Size(); } private: const FeaturePipelineConfig& config_; int feature_dim_; Fbank fbank_; BlockingQueue> feature_queue_; int num_frames_; bool input_finished_; // The feature extraction is done in AcceptWaveform(). // This wavefrom sample points are consumed by frame size. // The residual wavefrom sample points after framing are // kept to be used in next AcceptWaveform() calling. std::vector remained_wav_; // Used to block the Read when there is no feature in feature_queue_ // and the input is not finished. mutable std::mutex mutex_; std::condition_variable finish_condition_; }; } // namespace wenet #endif // FRONTEND_FEATURE_PIPELINE_H_ ================================================ FILE: runtime/frontend/fft.cc ================================================ // Copyright (c) 2016 HR #include #include #include #include "frontend/fft.h" namespace wenet { void make_sintbl(int n, float* sintbl) { int i, n2, n4, n8; float c, s, dc, ds, t; n2 = n / 2; n4 = n / 4; n8 = n / 8; t = sin(M_PI / n); dc = 2 * t * t; ds = sqrt(dc * (2 - dc)); t = 2 * dc; c = sintbl[n4] = 1; s = sintbl[0] = 0; for (i = 1; i < n8; ++i) { c -= dc; dc += t * c; s += ds; ds -= t * s; sintbl[i] = s; sintbl[n4 - i] = c; } if (n8 != 0) sintbl[n8] = sqrt(0.5); for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i]; for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i]; } void make_bitrev(int n, int* bitrev) { int i, j, k, n2; n2 = n / 2; i = j = 0; for (;;) { bitrev[i] = j; if (++i >= n) break; k = n2; while (k <= j) { j -= k; k /= 2; } j += k; } } // bitrev: bit reversal table // sintbl: trigonometric function table // x:real part // y:image part // n: fft length int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) { int i, j, k, ik, h, d, k2, n4, inverse; float t, s, c, dx, dy; /* preparation */ if (n < 0) { n = -n; inverse = 1; /* inverse transform */ } else { inverse = 0; } n4 = n / 4; if (n == 0) { return 0; } /* bit reversal */ for (i = 0; i < n; ++i) { j = bitrev[i]; if (i < j) { t = x[i]; x[i] = x[j]; x[j] = t; t = y[i]; y[i] = y[j]; y[j] = t; } } /* transformation */ for (k = 1; k < n; k = k2) { h = 0; k2 = k + k; d = n / k2; for (j = 0; j < k; ++j) { c = sintbl[h + n4]; if (inverse) s = -sintbl[h]; else s = sintbl[h]; for (i = j; i < n; i += k2) { ik = i + k; dx = s * y[ik] + c * x[ik]; dy = c * y[ik] - s * x[ik]; x[ik] = x[i] - dx; x[i] += dx; y[ik] = y[i] - dy; y[i] += dy; } h += d; } } if (inverse) { /* divide by n in case of the inverse transformation */ for (i = 0; i < n; ++i) { x[i] /= n; y[i] /= n; } } return 0; /* finished successfully */ } } // namespace wenet ================================================ FILE: runtime/frontend/fft.h ================================================ // Copyright (c) 2016 HR #ifndef FRONTEND_FFT_H_ #define FRONTEND_FFT_H_ #ifndef M_PI #define M_PI 3.1415926535897932384626433832795 #endif #ifndef M_2PI #define M_2PI 6.283185307179586476925286766559005 #endif namespace wenet { // Fast Fourier Transform void make_sintbl(int n, float* sintbl); void make_bitrev(int n, int* bitrev); int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n); } // namespace wenet #endif // FRONTEND_FFT_H_ ================================================ FILE: runtime/frontend/wav.h ================================================ // Copyright (c) 2016 Personal (Binbin Zhang) // Created on 2016-08-15 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef FRONTEND_WAV_H_ #define FRONTEND_WAV_H_ #include #include #include #include #include #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" DEFINE_int32(pcm_sample_rate, 16000, "pcm data sample rate"); namespace wenet { class AudioReader { public: AudioReader() {} explicit AudioReader(const std::string& filename) {} virtual ~AudioReader() {} virtual int num_channel() const = 0; virtual int sample_rate() const = 0; virtual int bits_per_sample() const = 0; virtual int num_sample() const = 0; virtual const int16_t* data() const = 0; }; struct WavHeader { char riff[4]; // "riff" unsigned int size; char wav[4]; // "WAVE" char fmt[4]; // "fmt " unsigned int fmt_size; uint16_t format; uint16_t channels; unsigned int sample_rate; unsigned int bytes_per_second; uint16_t block_size; uint16_t bit; char data[4]; // "data" unsigned int data_size; }; class WavReader : public AudioReader { public: WavReader() {} explicit WavReader(const std::string& filename) { Open(filename); } bool Open(const std::string& filename) { FILE* fp = fopen(filename.c_str(), "rb"); if (NULL == fp) { LOG(WARNING) << "Error in read " << filename; return false; } WavHeader header; fread(&header, 1, sizeof(header), fp); if (header.fmt_size < 16) { fprintf(stderr, "WaveData: expect PCM format data " "to have fmt chunk of at least size 16.\n"); return false; } else if (header.fmt_size > 16) { int offset = 44 - 8 + header.fmt_size - 16; fseek(fp, offset, SEEK_SET); fread(header.data, 8, sizeof(char), fp); } // check "riff" "WAVE" "fmt " "data" // Skip any subchunks between "fmt" and "data". Usually there will // be a single "fact" subchunk, but on Windows there can also be a // "list" subchunk. while (0 != strncmp(header.data, "data", 4)) { // We will just ignore the data in these chunks. fseek(fp, header.data_size, SEEK_CUR); // read next subchunk fread(header.data, 8, sizeof(char), fp); } num_channel_ = header.channels; sample_rate_ = header.sample_rate; bits_per_sample_ = header.bit; int num_data = header.data_size / (bits_per_sample_ / 8); data_.resize(num_data); int num_read = fread(&data_[0], 1, header.data_size, fp); if (num_read < header.data_size) { // If the header size is wrong, adjust header.data_size = num_read; num_data = header.data_size / (bits_per_sample_ / 8); data_.resize(num_data); } num_sample_ = num_data / num_channel_; fclose(fp); return true; } int num_channel() const { return num_channel_; } int sample_rate() const { return sample_rate_; } int bits_per_sample() const { return bits_per_sample_; } int num_sample() const { return num_sample_; } const int16_t* data() const { return data_.data(); } private: int num_channel_; int sample_rate_; int bits_per_sample_; int num_sample_; // sample points per channel std::vector data_; }; class WavWriter { public: WavWriter(const float* data, int num_sample, int num_channel, int sample_rate, int bits_per_sample) : data_(data), num_sample_(num_sample), num_channel_(num_channel), sample_rate_(sample_rate), bits_per_sample_(bits_per_sample) {} void Write(const std::string& filename) { FILE* fp = fopen(filename.c_str(), "w"); // init char 'riff' 'WAVE' 'fmt ' 'data' WavHeader header; char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; memcpy(&header, wav_header, sizeof(header)); header.channels = num_channel_; header.bit = bits_per_sample_; header.sample_rate = sample_rate_; header.data_size = num_sample_ * num_channel_ * (bits_per_sample_ / 8); header.size = sizeof(header) - 8 + header.data_size; header.bytes_per_second = sample_rate_ * num_channel_ * (bits_per_sample_ / 8); header.block_size = num_channel_ * (bits_per_sample_ / 8); fwrite(&header, 1, sizeof(header), fp); for (int i = 0; i < num_sample_; ++i) { for (int j = 0; j < num_channel_; ++j) { switch (bits_per_sample_) { case 8: { char sample = static_cast(data_[i * num_channel_ + j]); fwrite(&sample, 1, sizeof(sample), fp); break; } case 16: { int16_t sample = static_cast(data_[i * num_channel_ + j]); fwrite(&sample, 1, sizeof(sample), fp); break; } case 32: { int sample = static_cast(data_[i * num_channel_ + j]); fwrite(&sample, 1, sizeof(sample), fp); break; } } } } fclose(fp); } private: const float* data_; int num_sample_; // total float points in data_ int num_channel_; int sample_rate_; int bits_per_sample_; }; class PcmReader : public AudioReader { public: PcmReader() {} explicit PcmReader(const std::string& filename) { Open(filename); } bool Open(const std::string& filename) { FILE* fp = fopen(filename.c_str(), "rb"); if (NULL == fp) { LOG(WARNING) << "Error in read " << filename; return false; } num_channel_ = 1; sample_rate_ = FLAGS_pcm_sample_rate; bits_per_sample_ = 16; fseek(fp, 0, SEEK_END); int data_size = ftell(fp); fseek(fp, 0, SEEK_SET); num_sample_ = data_size / sizeof(int16_t); data_.resize(num_sample_); fread(&data_[0], data_size, 1, fp); fclose(fp); return true; } int num_channel() const { return num_channel_; } int sample_rate() const { return sample_rate_; } int bits_per_sample() const { return bits_per_sample_; } int num_sample() const { return num_sample_; } const int16_t* data() const { return data_.data(); } private: int num_channel_; int sample_rate_; int bits_per_sample_; int num_sample_; // sample points per channel std::vector data_; }; std::shared_ptr ReadAudioFile(const std::string& filename) { size_t pos = filename.rfind('.'); std::string suffix = filename.substr(pos); if (suffix == ".wav" || suffix == ".WAV") { return std::make_shared(filename); } else { return std::make_shared(filename); } } void WriteWavFile(const float* data, int data_size, int sample_rate, const std::string& wav_path) { std::vector tmp_wav(data, data + data_size); for (int i = 0; i < tmp_wav.size(); i++) { tmp_wav[i] *= (1 << 15); } WavWriter wav_write(tmp_wav.data(), tmp_wav.size(), 1, sample_rate, 16); wav_write.Write(wav_path); } } // namespace wenet #endif // FRONTEND_WAV_H_ ================================================ FILE: runtime/separate/CMakeLists.txt ================================================ add_library(separate STATIC separate_engine.cc) target_link_libraries(separate PUBLIC frontend ${TORCH_LIBRARIES}) ================================================ FILE: runtime/separate/separate_engine.cc ================================================ // Copyright (c) 2024 wesep team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "separate/separate_engine.h" #include #include #include #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" #include "torch/script.h" #include "torch/torch.h" namespace wesep { void SeparateEngine::InitEngineThreads(int num_threads) { // for multi-thread performance at::set_num_threads(num_threads); VLOG(1) << "Num intra-op threads: " << at::get_num_threads(); } SeparateEngine::SeparateEngine(const std::string& model_path, const int feat_dim, const int sample_rate) { sample_rate_ = sample_rate; feat_dim_ = feat_dim; feature_config_ = std::make_shared(feat_dim, sample_rate); feature_pipeline_ = std::make_shared(*feature_config_); feature_pipeline_->Reset(); InitEngineThreads(1); torch::jit::script::Module model = torch::jit::load(model_path); model_ = std::make_shared(std::move(model)); model_->eval(); } void SeparateEngine::ExtractFeature(const int16_t* data, int data_size, std::vector>* feat) { feature_pipeline_->AcceptWaveform( std::vector(data, data + data_size)); feature_pipeline_->set_input_finished(); feature_pipeline_->Read(feature_pipeline_->num_frames(), feat); feature_pipeline_->Reset(); this->ApplyMean(feat); } void SeparateEngine::ApplyMean(std::vector>* feat) { std::vector mean(feat_dim_, 0); for (auto& i : *feat) { std::transform(i.begin(), i.end(), mean.begin(), mean.begin(), std::plus<>{}); } std::transform(mean.begin(), mean.end(), mean.begin(), [&](const float d) { return d / feat->size(); }); for (auto& i : *feat) { std::transform(i.begin(), i.end(), mean.begin(), i.begin(), std::minus<>{}); } } void SeparateEngine::ForwardFunc(const std::vector& mix_wav, const int16_t* spk1_emb, const int16_t* spk2_emb, int data_size, std::vector>* output) { // pre-process std::vector input_wav(mix_wav.size()); for (int i = 0; i < mix_wav.size(); i++) { input_wav[i] = static_cast(mix_wav[i]) / (1 << 15); } std::vector> spk1_emb_feat; this->ExtractFeature(spk1_emb, data_size, &spk1_emb_feat); std::vector> spk2_emb_feat; this->ExtractFeature(spk2_emb, data_size, &spk2_emb_feat); // torch mix_wav torch::Tensor torch_wav = torch::zeros({2, mix_wav.size()}, torch::kFloat32); for (size_t i = 0; i < 2; i++) { torch::Tensor row = torch::from_blob(input_wav.data(), {input_wav.size()}, torch::kFloat32) .clone(); torch_wav[i] = std::move(row); } // torch spk_emb_feat torch::Tensor torch_spk_emb_feat = torch::zeros({2, spk1_emb_feat.size(), feat_dim_}, torch::kFloat32); for (size_t i = 0; i < spk1_emb_feat.size(); i++) { torch::Tensor row1 = torch::from_blob(spk1_emb_feat[i].data(), {feat_dim_}, torch::kFloat32); torch_spk_emb_feat[0][i] = std::move(row1); torch::Tensor row2 = torch::from_blob(spk2_emb_feat[i].data(), {feat_dim_}, torch::kFloat32); torch_spk_emb_feat[1][i] = std::move(row2); } // forward torch::NoGradGuard no_grad; auto outputs = model_->forward({torch_wav, torch_spk_emb_feat}).toTuple()->elements(); torch::Tensor wav_out = outputs[0].toTensor(); auto accessor = wav_out.accessor(); output->resize(2, std::vector(wav_out.size(1), 0.0)); for (int i = 0; i < wav_out.size(1); i++) { (*output)[0][i] = accessor[0][i]; (*output)[1][i] = accessor[1][i]; } } } // namespace wesep ================================================ FILE: runtime/separate/separate_engine.h ================================================ // Copyright (c) 2024 wesep team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef SEPARATE_SEPARATE_ENGINE_H_ #define SEPARATE_SEPARATE_ENGINE_H_ #include #include #include #include "torch/script.h" #include "torch/torch.h" #include "frontend/feature_pipeline.h" namespace wesep { class SeparateEngine { public: explicit SeparateEngine(const std::string& model_path, const int feat_dim, const int sample_rate); void InitEngineThreads(int num_threads = 1); void ForwardFunc(const std::vector& mix_wav, const int16_t* spk1_emb, const int16_t* spk2_emb, int data_size, std::vector>* output); void ExtractFeature(const int16_t* data, int data_size, std::vector>* feat); void ApplyMean(std::vector>* feat); private: std::shared_ptr model_ = nullptr; std::shared_ptr feature_config_ = nullptr; std::shared_ptr feature_pipeline_ = nullptr; int sample_rate_ = 16000; int feat_dim_ = 80; }; } // namespace wesep #endif // SEPARATE_SEPARATE_ENGINE_H_ ================================================ FILE: runtime/utils/CMakeLists.txt ================================================ add_library(utils STATIC utils.cc ) target_link_libraries(utils PUBLIC glog gflags frontend) ================================================ FILE: runtime/utils/blocking_queue.h ================================================ // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef UTILS_BLOCKING_QUEUE_H_ #define UTILS_BLOCKING_QUEUE_H_ #include #include #include #include #include namespace wenet { #define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \ Type(const Type&) = delete; \ Type& operator=(const Type&) = delete; template class BlockingQueue { public: explicit BlockingQueue(size_t capacity = std::numeric_limits::max()) : capacity_(capacity) {} void Push(const T& value) { { std::unique_lock lock(mutex_); while (queue_.size() >= capacity_) { not_full_condition_.wait(lock); } queue_.push(value); } not_empty_condition_.notify_one(); } void Push(T&& value) { { std::unique_lock lock(mutex_); while (queue_.size() >= capacity_) { not_full_condition_.wait(lock); } queue_.push(std::move(value)); } not_empty_condition_.notify_one(); } T Pop() { std::unique_lock lock(mutex_); while (queue_.empty()) { not_empty_condition_.wait(lock); } T t(std::move(queue_.front())); queue_.pop(); not_full_condition_.notify_one(); return t; } bool Empty() const { std::lock_guard lock(mutex_); return queue_.empty(); } size_t Size() const { std::lock_guard lock(mutex_); return queue_.size(); } void Clear() { while (!Empty()) { Pop(); } } private: size_t capacity_; mutable std::mutex mutex_; std::condition_variable not_full_condition_; std::condition_variable not_empty_condition_; std::queue queue_; public: WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue); }; } // namespace wenet #endif // UTILS_BLOCKING_QUEUE_H_ ================================================ FILE: runtime/utils/timer.h ================================================ // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef UTILS_TIMER_H_ #define UTILS_TIMER_H_ #include namespace wenet { class Timer { public: Timer() : time_start_(std::chrono::steady_clock::now()) {} void Reset() { time_start_ = std::chrono::steady_clock::now(); } // return int in milliseconds int Elapsed() const { auto time_now = std::chrono::steady_clock::now(); return std::chrono::duration_cast(time_now - time_start_) .count(); } private: std::chrono::time_point time_start_; }; } // namespace wenet #endif // UTILS_TIMER_H_ ================================================ FILE: runtime/utils/utils.cc ================================================ // Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include "glog/logging.h" #include "utils/utils.h" namespace wesep { std::string Ltrim(const std::string& str) { size_t start = str.find_first_not_of(WHITESPACE); return (start == std::string::npos) ? "" : str.substr(start); } std::string Rtrim(const std::string& str) { size_t end = str.find_last_not_of(WHITESPACE); return (end == std::string::npos) ? "" : str.substr(0, end + 1); } std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); } void SplitString(const std::string& str, std::vector* strs) { SplitStringToVector(Trim(str), " \t", true, strs); } void SplitStringToVector(const std::string& full, const char* delim, bool omit_empty_strings, std::vector* out) { size_t start = 0, found = 0, end = full.size(); out->clear(); while (found != std::string::npos) { found = full.find_first_of(delim, start); // start != end condition is for when the delimiter is at the end if (!omit_empty_strings || (found != start && start != end)) out->push_back(full.substr(start, found - start)); start = found + 1; } } #ifdef _MSC_VER std::wstring ToWString(const std::string& str) { unsigned len = str.size() * 2; setlocale(LC_CTYPE, ""); wchar_t* p = new wchar_t[len]; mbstowcs(p, str.c_str(), len); std::wstring wstr(p); delete[] p; return wstr; } #endif } // namespace wesep ================================================ FILE: runtime/utils/utils.h ================================================ // Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef UTILS_UTILS_H_ #define UTILS_UTILS_H_ #include #include #include namespace wesep { const char WHITESPACE[] = " \n\r\t\f\v"; // Split the string with space or tab. void SplitString(const std::string& str, std::vector* strs); void SplitStringToVector(const std::string& full, const char* delim, bool omit_empty_strings, std::vector* out); #ifdef _MSC_VER std::wstring ToWString(const std::string& str); #endif } // namespace wesep #endif // UTILS_UTILS_H_ ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages requirements = [ "tqdm", "kaldiio", "torch>=1.12.0", "torchaudio>=0.12.0", "silero-vad", ] setup( name="wesep", install_requires=requirements, packages=find_packages(), entry_points={ "console_scripts": [ "wesep = wesep.cli.extractor:main", ], }, ) ================================================ FILE: tools/extract_embed_depreciated.py ================================================ # Copyright (c) 2022, Shuai Wang (wsstriving@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import kaldiio import onnxruntime as ort import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser(description="infer example using onnx") parser.add_argument("--onnx_path", required=True, help="onnx path") parser.add_argument("--wav_scp", required=True, help="wav path") parser.add_argument("--out_path", required=True, help="output path of the embeddings") args = parser.parse_args() return args def compute_fbank(wav_path, num_mel_bins=80, frame_length=25, frame_shift=10, dither=0.0): """Extract fbank, simlilar to the one in wespeaker.dataset.processor, While integrating the wave reading and CMN. """ waveform, sample_rate = torchaudio.load(wav_path) waveform = waveform * (1 << 15) mat = kaldi.fbank( waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, sample_frequency=sample_rate, window_type="hamming", use_energy=False, ) # CMN, without CVN mat = mat - torch.mean(mat, dim=0) return mat def main(): args = get_args() so = ort.SessionOptions() so.inter_op_num_threads = 1 so.intra_op_num_threads = 1 session = ort.InferenceSession(args.onnx_path, sess_options=so) embed_ark = os.path.join(args.out_path, "embed.ark") embed_scp = os.path.join(args.out_path, "embed.scp") with kaldiio.WriteHelper("ark,scp:" + embed_ark + "," + embed_scp) as writer: with open(args.wav_scp, "r") as read_scp: for line in tqdm(read_scp): tokens = line.strip().split(" ") name, wav_path = tokens[0], tokens[1] feats = compute_fbank(wav_path) feats = feats.unsqueeze(0).numpy() # add batch dimension embed = session.run(output_names=["embs"], input_feed={"feats": feats}) writer(name, embed[0]) if __name__ == "__main__": main() ================================================ FILE: tools/make_lmdb.py ================================================ # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import math import pickle import lmdb from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser(description="") parser.add_argument("in_scp_file", help="input scp file") parser.add_argument("out_lmdb", help="output lmdb") args = parser.parse_args() return args def main(): args = get_args() db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB # txn is for Transaciton txn = db.begin(write=True) keys = [] with open(args.in_scp_file, "r", encoding="utf8") as fin: lines = fin.readlines() for i, line in enumerate(tqdm(lines)): arr = line.strip().split() assert len(arr) == 2 key, wav = arr[0], arr[1] keys.append(key) with open(wav, "rb") as fin: data = fin.read() txn.put(key.encode(), data) # Write flush to disk if i % 100 == 0: txn.commit() txn = db.begin(write=True) txn.commit() with db.begin(write=True) as txn: txn.put(b"__keys__", pickle.dumps(keys)) db.sync() db.close() if __name__ == "__main__": main() ================================================ FILE: tools/make_shard_list_premix.py ================================================ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang # 2023 SRIBD Shuai Wang ) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import io import logging import multiprocessing import os import random import tarfile import time import sys AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} def write_tar_file(data_list, tar_file, index=0, total=1): logging.info('Processing {} {}/{}'.format(tar_file, index, total)) read_time = 0.0 write_time = 0.0 with tarfile.open(tar_file, "w") as tar: for item in data_list: assert len( item) == 3, 'item should have 3 elements: Key, Speaker, Wav' key, spks, wavs = item spk_idx = 1 for spk in spks: assert isinstance(spk, str) spk_file = key + '.spk' + str(spk_idx) spk = spk.encode('utf8') spk_data = io.BytesIO(spk) spk_info = tarfile.TarInfo(spk_file) spk_info.size = len(spk) tar.addfile(spk_info, spk_data) spk_idx = spk_idx + 1 spk_idx = 0 for wav in wavs: suffix = wav.split('.')[-1] assert suffix in AUDIO_FORMAT_SETS ts = time.time() try: with open(wav, 'rb') as fin: data = fin.read() except FileNotFoundError as e: print(e) sys.exit() read_time += (time.time() - ts) ts = time.time() if spk_idx > 0: wav_file = key + '_spk' + str(spk_idx) + '.' + suffix else: wav_file = key + '.' + suffix wav_data = io.BytesIO(data) wav_info = tarfile.TarInfo(wav_file) wav_info.size = len(data) tar.addfile(wav_info, wav_data) write_time += (time.time() - ts) spk_idx = spk_idx + 1 logging.info('read {} write {}'.format(read_time, write_time)) def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument('--num_utts_per_shard', type=int, default=1000, help='num utts per shard') parser.add_argument('--num_threads', type=int, default=1, help='num threads for make shards') parser.add_argument('--prefix', default='shards', help='prefix of shards tar file') parser.add_argument('--seed', type=int, default=42, help='random seed') parser.add_argument('--shuffle', action='store_true', help='whether to shuffle data') parser.add_argument('wav_file', help='wav file') parser.add_argument('utt2spk_file', help='utt2spk file') parser.add_argument('shards_dir', help='output shards dir') parser.add_argument('shards_list', help='output shards list file') args = parser.parse_args() return args def main(): args = get_args() random.seed(args.seed) logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') wav_table = {} with open(args.wav_file, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() key = arr[0] # key = os.path.splitext(arr[0])[0] wav_table[key] = [arr[i + 1] for i in range(len(arr) - 1)] data = [] with open(args.utt2spk_file, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() key = arr[0] # key = os.path.splitext(arr[0])[0] spks = [arr[i + 1] for i in range(len(arr) - 1)] assert key in wav_table wavs = wav_table[key] data.append((key, spks, wavs)) if args.shuffle: random.shuffle(data) num = args.num_utts_per_shard chunks = [data[i:i + num] for i in range(0, len(data), num)] os.makedirs(args.shards_dir, exist_ok=True) # Using thread pool to speedup pool = multiprocessing.Pool(processes=args.num_threads) shards_list = [] num_chunks = len(chunks) for i, chunk in enumerate(chunks): tar_file = os.path.join(args.shards_dir, '{}_{:09d}.tar'.format(args.prefix, i)) shards_list.append(tar_file) pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks)) pool.close() pool.join() with open(args.shards_list, 'w', encoding='utf8') as fout: for name in shards_list: fout.write(name + '\n') if __name__ == '__main__': main() ================================================ FILE: tools/make_shard_online.py ================================================ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang # 2023 Shuai Wang ) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import io import logging import multiprocessing import os import random import tarfile import time AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} def write_tar_file(data_list, tar_file, index=0, total=1): logging.info('Processing {} {}/{}'.format(tar_file, index, total)) read_time = 0.0 write_time = 0.0 with tarfile.open(tar_file, "w") as tar: for item in data_list: assert len( item) == 3, 'item should have 3 elements: Key, Speaker, Wav' key, spk, wav = item suffix = wav.split('.')[-1] assert suffix in AUDIO_FORMAT_SETS ts = time.time() with open(wav, 'rb') as fin: data = fin.read() read_time += (time.time() - ts) assert isinstance(spk, str) ts = time.time() spk_file = key + '.spk' spk = spk.encode('utf8') spk_data = io.BytesIO(spk) spk_info = tarfile.TarInfo(spk_file) spk_info.size = len(spk) tar.addfile(spk_info, spk_data) wav_file = key + '.' + suffix wav_data = io.BytesIO(data) wav_info = tarfile.TarInfo(wav_file) wav_info.size = len(data) tar.addfile(wav_info, wav_data) write_time += (time.time() - ts) logging.info('read {} write {}'.format(read_time, write_time)) def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument('--num_utts_per_shard', type=int, default=1000, help='num utts per shard') parser.add_argument('--num_threads', type=int, default=1, help='num threads for make shards') parser.add_argument('--prefix', default='shards', help='prefix of shards tar file') parser.add_argument('--seed', type=int, default=42, help='random seed') parser.add_argument('--shuffle', action='store_true', help='whether to shuffle data') parser.add_argument('wav_file', help='wav file') parser.add_argument('utt2spk_file', help='utt2spk file') parser.add_argument('shards_dir', help='output shards dir') parser.add_argument('shards_list', help='output shards list file') args = parser.parse_args() return args def main(): args = get_args() random.seed(args.seed) logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') wav_table = {} with open(args.wav_file, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() key = arr[0] # key = os.path.splitext(arr[0])[0] wav_table[key] = ' '.join(arr[1:]) data = [] with open(args.utt2spk_file, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split(maxsplit=1) key = arr[0] # key = os.path.splitext(arr[0])[0] spk = arr[1] assert key in wav_table wav = wav_table[key] data.append((key, spk, wav)) if args.shuffle: random.shuffle(data) num = args.num_utts_per_shard chunks = [data[i:i + num] for i in range(0, len(data), num)] os.makedirs(args.shards_dir, exist_ok=True) # Using thread pool to speedup pool = multiprocessing.Pool(processes=args.num_threads) shards_list = [] num_chunks = len(chunks) for i, chunk in enumerate(chunks): tar_file = os.path.join(args.shards_dir, '{}_{:09d}.tar'.format(args.prefix, i)) shards_list.append(tar_file) pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks)) pool.close() pool.join() with open(args.shards_list, 'w', encoding='utf8') as fout: for name in shards_list: fout.write(name + '\n') if __name__ == '__main__': main() ================================================ FILE: tools/parse_options.sh ================================================ #!/bin/bash # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); # Arnab Ghoshal, Karel Vesely # 2022 Hongji Wang (jijijiang77@gmail.com) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, # MERCHANTABLITY OR NON-INFRINGEMENT. # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # Parse command-line options. # To be sourced by another script (as in ". parse_options.sh"). # Option format is: --option-name arg # and shell variable "option_name" gets set to value "arg." # The exception is --help, which takes no arguments, but prints the # $help_message variable (if defined). ### ### The --conf file options have lower priority to command line ### options, so we need to import them first... ### # Now import all the confs specified by command-line, in left-to-right order for ((argpos = 1; argpos < $#; argpos++)); do if [ "${!argpos}" == "--conf" ]; then argpos_plus1=$((argpos + 1)) conf=${!argpos_plus1} [ ! -r $conf ] && echo "$0: missing conf '$conf'" && exit 1 . $conf # source the conf file. fi done ### ### No we process the command line options ### while true; do [ -z "${1:-}" ] && break # break if there are no arguments case "$1" in # If the enclosing script is called with --help option, print the help # message and exit. Scripts should put help messages in $help_message --help | -h) if [ -z "$help_message" ]; then echo "No help found." 1>&2 else printf "$help_message\n" 1>&2; fi exit 0 ;; --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" exit 1 ;; # If the first command-line argument begins with "--" (e.g. --foo-bar), # then work out the variable name as $name, which will equal "foo_bar". --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g) # Next we test whether the variable in question is undefned-- if so it's # an invalid option and we die. Note: $0 evaluates to the name of the # enclosing script. # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar # is undefined. We then have to wrap this test inside "eval" because # foo_bar is itself inside a variable ($name). eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1 oldval="$(eval echo \$$name)" # Work out whether we seem to be expecting a Boolean argument. if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then was_bool=true else was_bool=false fi # Set the variable to the right value-- the escaped quotes make it work if # the option had spaces, like --cmd "queue.pl -sync y" eval $name=\"$2\" # Check that Boolean-valued arguments are really Boolean. if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 exit 1 fi shift 2 ;; *) break ;; esac done # Check for an empty argument to the --cmd option, which can easily occur as a # result of scripting errors. [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1 true # so this script returns exit code 0. ================================================ FILE: tools/print_train_val_curve.py ================================================ import re import matplotlib.pyplot as plt # Initialize lists to store epochs, train losses and validation losses epochs = [] train_loss = [] val_loss = [] # Open the log file prev_epoch = 0 with open("train.log", "r") as f: for line in f: # Find lines with epoch info if "info" in line: # Extract epoch number epoch = int(re.search(r"Epoch (\d+)", line).group(1)) if epoch != prev_epoch: print(prev_epoch, epoch) # Extract loss values # pattern = r'loss (.*?)\n' pattern = r"[-+]?\d*\.\d+" loss = float(re.search(pattern, line).group()) if "Train" in line: epochs.append(epoch) train_loss.append(loss) elif "Val" in line: val_loss.append(loss) prev_epoch = epoch # Create the plot plt.figure(figsize=(10, 5)) # Plot training and validation loss plt.plot(epochs, train_loss, label="Training Loss", color="blue") plt.plot(epochs, val_loss, label="Validation Loss", color="red") # Add horizontal lines at the minimum values plt.axhline(min(train_loss), color="blue", linestyle="--", label="Min Training Loss") plt.axhline(min(val_loss), color="red", linestyle="--", label="Min Validation Loss") # Annotate the minimum values on the y-axis plt.text( 0, min(train_loss), "{:.2f}".format(min(train_loss)), va="center", ha="left", backgroundcolor="w", ) plt.text( 0, min(val_loss), "{:.2f}".format(min(val_loss)), va="center", ha="left", backgroundcolor="w", ) # Add legend, title, and x, y labels plt.legend(loc="upper right") plt.title("Training and Validation Loss Over Epochs") plt.ylabel("Loss Value") plt.xlabel("Epochs") # Save the plot as a .png file plt.savefig("train_val_loss.png") # Show the plot # plt.show() ================================================ FILE: tools/run.pl ================================================ #!/usr/bin/env perl use warnings; #sed replacement for -w perl parameter # In general, doing # run.pl some.log a b c is like running the command a b c in # the bash shell, and putting the standard error and output into some.log. # To run parallel jobs (backgrounded on the host machine), you can do (e.g.) # run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB # and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier]. # If any of the jobs fails, this script will fail. # A typical example is: # run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz # and run.pl will run something like: # ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log # # Basically it takes the command-line arguments, quotes them # as necessary to preserve spaces, and evaluates them with bash. # In addition it puts the command line at the top of the log, and # the start and end times of the command at the beginning and end. # The reason why this is useful is so that we can create a different # version of this program that uses a queueing system instead. #use Data::Dumper; @ARGV < 2 && die "usage: run.pl log-file command-line arguments..."; #print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n"; $job_pick = 'all'; $max_jobs_run = -1; $jobstart = 1; $jobend = 1; $ignored_opts = ""; # These will be ignored. # First parse an option like JOB=1:4, and any # options that would normally be given to # queue.pl, which we will just discard. for (my $x = 1; $x <= 2; $x++) { # This for-loop is to # allow the JOB=1:n option to be interleaved with the # options to qsub. while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) { # parse any options that would normally go to qsub, but which will be ignored here. my $switch = shift @ARGV; if ($switch eq "-V") { $ignored_opts .= "-V "; } elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") { # we do support the option --max-jobs-run n, and its GridEngine form -tc n. # if the command appears multiple times uses the smallest option. if ( $max_jobs_run <= 0 ) { $max_jobs_run = shift @ARGV; } else { my $new_constraint = shift @ARGV; if ( ($new_constraint < $max_jobs_run) ) { $max_jobs_run = $new_constraint; } } if (! ($max_jobs_run > 0)) { die "run.pl: invalid option --max-jobs-run $max_jobs_run"; } } else { my $argument = shift @ARGV; if ($argument =~ m/^--/) { print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n"; } if ($switch eq "-sync" && $argument =~ m/^[yY]/) { $ignored_opts .= "-sync "; # Note: in the # corresponding code in queue.pl it says instead, just "$sync = 1;". } elsif ($switch eq "-pe") { # e.g. -pe smp 5 my $argument2 = shift @ARGV; $ignored_opts .= "$switch $argument $argument2 "; } elsif ($switch eq "--gpu") { $using_gpu = $argument; } elsif ($switch eq "--pick") { if($argument =~ m/^(all|failed|incomplete)$/) { $job_pick = $argument; } else { print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'" } } else { # Ignore option. $ignored_opts .= "$switch $argument "; } } } if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20 $jobname = $1; $jobstart = $2; $jobend = $3; if ($jobstart > $jobend) { die "run.pl: invalid job range $ARGV[0]"; } if ($jobstart <= 0) { die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility)."; } shift; } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1. $jobname = $1; $jobstart = $2; $jobend = $2; shift; } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n"; } } # Users found this message confusing so we are removing it. # if ($ignored_opts ne "") { # print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n"; # } if ($max_jobs_run == -1) { # If --max-jobs-run option not set, # then work out the number of processors if possible, # and set it based on that. $max_jobs_run = 0; if ($using_gpu) { if (open(P, "nvidia-smi -L |")) { $max_jobs_run++ while (

); close(P); } if ($max_jobs_run == 0) { $max_jobs_run = 1; print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n"; } } elsif (open(P, ") { if (m/^processor/) { $max_jobs_run++; } } if ($max_jobs_run == 0) { print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n"; $max_jobs_run = 10; # reasonable default. } close(P); } elsif (open(P, "sysctl -a |")) { # BSD/Darwin while (

) { if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4 $max_jobs_run = $1; last; } } close(P); if ($max_jobs_run == 0) { print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n"; $max_jobs_run = 10; # reasonable default. } } else { # allow at most 32 jobs at once, on non-UNIX systems; change this code # if you need to change this default. $max_jobs_run = 32; } # The just-computed value of $max_jobs_run is just the number of processors # (or our best guess); and if it happens that the number of jobs we need to # run is just slightly above $max_jobs_run, it will make sense to increase # $max_jobs_run to equal the number of jobs, so we don't have a small number # of leftover jobs. $num_jobs = $jobend - $jobstart + 1; if (!$using_gpu && $num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) { $max_jobs_run = $num_jobs; } } sub pick_or_exit { # pick_or_exit ( $logfile ) # Invoked before each job is started helps to run jobs selectively. # # Given the name of the output logfile decides whether the job must be # executed (by returning from the subroutine) or not (by terminating the # process calling exit) # # PRE: $job_pick is a global variable set by command line switch --pick # and indicates which class of jobs must be executed. # # 1) If a failed job is not executed the process exit code will indicate # failure, just as if the task was just executed and failed. # # 2) If a task is incomplete it will be executed. Incomplete may be either # a job whose log file does not contain the accounting notes in the end, # or a job whose log file does not exist. # # 3) If the $job_pick is set to 'all' (default behavior) a task will be # executed regardless of the result of previous attempts. # # This logic could have been implemented in the main execution loop # but a subroutine to preserve the current level of readability of # that part of the code. # # Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020 # if($job_pick eq 'all'){ return; # no need to bother with the previous log } open my $fh, "<", $_[0] or return; # job not executed yet my $log_line; my $cur_line; while ($cur_line = <$fh>) { if( $cur_line =~ m/# Ended \(code .*/ ) { $log_line = $cur_line; } } close $fh; if (! defined($log_line)){ return; # incomplete } if ( $log_line =~ m/# Ended \(code 0\).*/ ) { exit(0); # complete } elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){ if ($job_pick !~ m/^(failed|all)$/) { exit(1); # failed but not going to run } else { return; # failed } } elsif ( $log_line =~ m/.*\S.*/ ) { return; # incomplete jobs are always run } } $logfile = shift @ARGV; if (defined $jobname && $logfile !~ m/$jobname/ && $jobend > $jobstart) { print STDERR "run.pl: you are trying to run a parallel job but " . "you are putting the output into just one log file ($logfile)\n"; exit(1); } $cmd = ""; foreach $x (@ARGV) { if ($x =~ m/^\S+$/) { $cmd .= $x . " "; } elsif ($x =~ m:\":) { $cmd .= "'$x' "; } else { $cmd .= "\"$x\" "; } } #$Data::Dumper::Indent=0; $ret = 0; $numfail = 0; %active_pids=(); use POSIX ":sys_wait_h"; for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { if (scalar(keys %active_pids) >= $max_jobs_run) { # Lets wait for a change in any child's status # Then we have to work out which child finished $r = waitpid(-1, 0); $code = $?; if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen. if ( defined $active_pids{$r} ) { $jid=$active_pids{$r}; $fail[$jid]=$code; if ($code !=0) { $numfail++;} delete $active_pids{$r}; # print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n"; } else { die "run.pl: Cannot find the PID of the child process that just finished."; } # In theory we could do a non-blocking waitpid over all jobs running just # to find out if only one or more jobs finished during the previous waitpid() # However, we just omit this and will reap the next one in the next pass # through the for(;;) cycle } $childpid = fork(); if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; } if ($childpid == 0) { # We're in the child... this branch # executes the job and returns (possibly with an error status). if (defined $jobname) { $cmd =~ s/$jobname/$jobid/g; $logfile =~ s/$jobname/$jobid/g; } # exit if the job does not need to be executed pick_or_exit( $logfile ); system("mkdir -p `dirname $logfile` 2>/dev/null"); open(F, ">$logfile") || die "run.pl: Error opening log file $logfile"; print F "# " . $cmd . "\n"; print F "# Started at " . `date`; $starttime = `date +'%s'`; print F "#\n"; close(F); # Pipe into bash.. make sure we're not using any other shell. open(B, "|bash") || die "run.pl: Error opening shell command"; print B "( " . $cmd . ") 2>>$logfile >> $logfile"; close(B); # If there was an error, exit status is in $? $ret = $?; $lowbits = $ret & 127; $highbits = $ret >> 8; if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" } else { $return_str = "code $highbits"; } $endtime = `date +'%s'`; open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)"; $enddate = `date`; chop $enddate; print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n"; print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n"; close(F); exit($ret == 0 ? 0 : 1); } else { $pid[$jobid] = $childpid; $active_pids{$childpid} = $jobid; # print STDERR "Queued: " . Dumper(\%active_pids) . "\n"; } } # Now we have submitted all the jobs, lets wait until all the jobs finish foreach $child (keys %active_pids) { $jobid=$active_pids{$child}; $r = waitpid($pid[$jobid], 0); $code = $?; if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen. if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully } # Some sanity checks: # The $fail array should not contain undefined codes # The number of non-zeros in that array should be equal to $numfail # We cannot do foreach() here, as the JOB ids do not start at zero $failed_jids=0; for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { $job_return = $fail[$jobid]; if (not defined $job_return ) { # print Dumper(\@fail); die "run.pl: Sanity check failed: we have indication that some jobs are running " . "even after we waited for all jobs to finish" ; } if ($job_return != 0 ){ $failed_jids++;} } if ($failed_jids != $numfail) { die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)." } if ($numfail > 0) { $ret = 1; } if ($ret != 0) { $njobs = $jobend - $jobstart + 1; if ($njobs == 1) { if (defined $jobname) { $logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with # that job. } print STDERR "run.pl: job failed, log is in $logfile\n"; if ($logfile =~ m/JOB/) { print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script."; } } else { $logfile =~ s/$jobname/*/g; print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n"; } } exit ($ret); ================================================ FILE: tools/score.sh ================================================ #!/bin/bash min() { local a b a=$1 for b in "$@"; do if [ "${b}" -le "${a}" ]; then a="${b}" fi done echo "${a}" } # Set default values dset= exp_dir= scoring_opts= n_gpu=1 score_nj=16 ref_channel=0 use_pesq=false use_dnsmos=false dnsmos_use_gpu=true fs=16k scoring_protocol="STOI SDR SAR SIR SI_SNR" # Parse command line options . tools/parse_options.sh || exit 1 if [ ! ${fs} = 16k ] && ${use_dnsmos}; then echo "Warning: DNSMOS only supports 16k sampling rate." echo "--use_dnsmos will be set to false automatically." use_dnsmos=false fi # Set scoring options scoring_opts="" if ${use_dnsmos}; then # Set model path primary_model_path=DNSMOS/sig_bak_ovr.onnx p808_model_path=DNSMOS/model_v8.onnx if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then echo "==========================================" echo "Warning: DNSMOS model files are not found." echo "Trying to download them from the official repository." echo "If this takes too long," echo "please manually download the model files" echo "and put them in the DNSMOS directory." echo "==========================================" # creat directory for DNSMOS model files mkdir -p DNSMOS # download DNSMOS model files and save them to the directory wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/model_v8.onnx # check if the model files are downloaded successfully if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then echo "Error: DNSMOS model files are not downloaded successfully." exit 1 fi fi scoring_opts+="--dnsmos_mode local " scoring_opts+="--dnsmos_primary_model ${primary_model_path} " scoring_opts+="--dnsmos_p808_model ${p808_model_path} " if ${dnsmos_use_gpu}; then score_nj=$(min "${score_nj}" "${n_gpu}") scoring_opts+="--dnsmos_use_gpu ${dnsmos_use_gpu} " fi fi # Set directories and log directory _dir="${exp_dir}/scoring" _logdir="${_dir}/logdir" mkdir -p "${_logdir}" # 0. Check the inference file inf_scp=${exp_dir}/audio/spk1.scp if [ ! -s "${inf_scp}" ] || [ -z "$(cat "${inf_scp}")" ]; then echo "Error: ${inf_scp} does not exist or is empty!" exit 1 fi # 1. Split the key file key_file=${dset}/single.wav.scp split_scps="" _nj=$(min "${score_nj}" "$(wc <${key_file} -l)") for n in $(seq "${_nj}"); do split_scps+=" ${_logdir}/keys.${n}.scp" done # shellcheck disable=SC2086 ./tools/split_scp.pl "${key_file}" ${split_scps} _ref_scp="--ref_scp ${dset}/single.wav.scp " _inf_scp="--inf_scp ${exp_dir}/audio/spk1.scp " # 2. Submit scoring jobs echo "log: '${_logdir}/tse_scoring.*.log'" if ${use_dnsmos} && ${dnsmos_use_gpu}; then cmd="./tools/run.pl --gpu ${n_gpu}" else cmd="./tools/run.pl" fi # shellcheck disable=SC2086 ${cmd} JOB=1:"${_nj}" "${_logdir}"/tse_scoring.JOB.log \ python -m wesep.bin.score \ --key_file "${_logdir}"/keys.JOB.scp \ --output_dir "${_logdir}"/output.JOB \ ${_ref_scp} \ ${_inf_scp} \ --ref_channel ${ref_channel} \ --use_pesq ${use_pesq} \ --use_dnsmos ${use_dnsmos} \ --dnsmos_gpu_device JOB \ ${scoring_opts} # Check if PESQ is used if "${use_pesq}"; then if [ ${fs} = 16k ]; then scoring_protocol+=" PESQ_WB" else scoring_protocol+=" PESQ_NB" fi fi # Check if dnsmos is used if "${use_dnsmos}"; then scoring_protocol+=" BAK SIG OVRL P808_MOS" fi # Merge and sort result files for protocol in ${scoring_protocol} wav; do for i in $(seq "${_nj}"); do cat "${_logdir}/output.${i}/${protocol}_spk1" done | LC_ALL=C sort -k1 >"${_dir}/${protocol}_spk1" done # Calculate and save results for protocol in ${scoring_protocol}; do # shellcheck disable=SC2046 paste $(printf "%s/%s_spk1 " "${_dir}" "${protocol}") | awk 'BEGIN{sum=0} {n=0;score=0;for (i=2; i<=NF; i+=2){n+=1;score+=$i}; sum+=score/n} END{printf ("%.2f\n",sum/NR)}' >"${_dir}/result_${protocol,,}.txt" done # show the result ./tools/show_enh_score.sh "${_dir}/../.." > \ "${_dir}/../../RESULTS.md" ================================================ FILE: tools/show_enh_score.sh ================================================ #!/usr/bin/env bash mindepth=0 maxdepth=1 . tools/parse_options.sh if [ $# -gt 1 ]; then echo "Usage: $0 --mindepth 0 --maxdepth 1 [exp]" 1>&2 echo "" echo "Show the system environments and the evaluation results in Markdown format." echo 'The default of is "exp/".' exit 1 fi [ -f ./path.sh ] && . ./path.sh set -euo pipefail if [ $# -eq 1 ]; then exp=$(realpath "$1") else exp=exp fi cat < # RESULTS ## Environments - date: \`$(LC_ALL=C date)\` EOF cat </dev/null; then echo -e "\n## $(basename ${expdir})\n" [ -e "${expdir}"/config.yaml ] && grep ^config "${expdir}"/config.yaml metrics=() heading="\n|dataset|" sep="|---|" for type in pesq pesq_wb pesq_nb estoi stoi sar sdr sir si_snr ovrl sig bak p808_mos; do if ls "${expdir}"/*/scoring/result_${type}.txt &>/dev/null; then metrics+=("$type") heading+="${type^^}|" sep+="---|" fi done echo -e "${heading}\n${sep}" setnames=() for dirname in "${expdir}"/*/scoring/result_stoi.txt; do dset=$(echo $dirname | sed -e "s#${expdir}/\([^/]*\)/scoring/result_stoi.txt#\1#g") setnames+=("$dset") done for dset in "${setnames[@]}"; do line="|${dset}|" for ((i = 0; i < ${#metrics[@]}; i++)); do type=${metrics[$i]} if [ -f "${expdir}"/${dset}/scoring/result_${type}.txt ]; then score=$(head -n1 "${expdir}"/${dset}/scoring/result_${type}.txt) else score="" fi line+="${score}|" done echo $line done echo "" fi done < <(find ${exp} -mindepth ${mindepth} -maxdepth ${maxdepth} -type d) ================================================ FILE: tools/split_scp.pl ================================================ #!/usr/bin/env perl # Copyright 2010-2011 Microsoft Corporation # See ../../COPYING for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, # MERCHANTABLITY OR NON-INFRINGEMENT. # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # This program splits up any kind of .scp or archive-type file. # If there is no utt2spk option it will work on any text file and # will split it up with an approximately equal number of lines in # each but. # With the --utt2spk option it will work on anything that has the # utterance-id as the first entry on each line; the utt2spk file is # of the form "utterance speaker" (on each line). # It splits it into equal size chunks as far as it can. If you use the utt2spk # option it will make sure these chunks coincide with speaker boundaries. In # this case, if there are more chunks than speakers (and in some other # circumstances), some of the resulting chunks will be empty and it will print # an error message and exit with nonzero status. # You will normally call this like: # split_scp.pl scp scp.1 scp.2 scp.3 ... # or # split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... # Note that you can use this script to split the utt2spk file itself, # e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... # You can also call the scripts like: # split_scp.pl -j 3 0 scp scp.0 # [note: with this option, it assumes zero-based indexing of the split parts, # i.e. the second number must be 0 <= n < num-jobs.] use warnings; $num_jobs = 0; $job_id = 0; $utt2spk_file = ""; $one_based = 0; for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { if ($ARGV[0] eq "-j") { shift @ARGV; $num_jobs = shift @ARGV; $job_id = shift @ARGV; } if ($ARGV[0] =~ /--utt2spk=(.+)/) { $utt2spk_file=$1; shift; } if ($ARGV[0] eq '--one-based') { $one_based = 1; shift @ARGV; } } if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || $job_id - $one_based >= $num_jobs)) { die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . ($one_based ? " --one-based" : "") . "'\n" } $one_based and $job_id--; if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { die "Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; } $error = 0; $inscp = shift @ARGV; if ($num_jobs == 0) { # without -j option @OUTPUTS = @ARGV; } else { for ($j = 0; $j < $num_jobs; $j++) { if ($j == $job_id) { if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } else { push @OUTPUTS, "-"; } } else { push @OUTPUTS, "/dev/null"; } } } if ($utt2spk_file ne "") { # We have the --utt2spk option... open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; while(<$u_fh>) { @A = split; @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; ($u,$s) = @A; $utt2spk{$u} = $s; } close $u_fh; open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; @spkrs = (); while(<$i_fh>) { @A = split; if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } $u = $A[0]; $s = $utt2spk{$u}; defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; if(!defined $spk_count{$s}) { push @spkrs, $s; $spk_count{$s} = 0; $spk_data{$s} = []; # ref to new empty array. } $spk_count{$s}++; push @{$spk_data{$s}}, $_; } # Now split as equally as possible .. # First allocate spks to files by allocating an approximately # equal number of speakers. $numspks = @spkrs; # number of speakers. $numscps = @OUTPUTS; # number of output files. if ($numspks < $numscps) { die "$0: Refusing to split data because number of speakers $numspks " . "is less than the number of output .scp files $numscps\n"; } for($scpidx = 0; $scpidx < $numscps; $scpidx++) { $scparray[$scpidx] = []; # [] is array reference. } for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { $scpidx = int(($spkidx*$numscps) / $numspks); $spk = $spkrs[$spkidx]; push @{$scparray[$scpidx]}, $spk; $scpcount[$scpidx] += $spk_count{$spk}; } # Now will try to reassign beginning + ending speakers # to different scp's and see if it gets more balanced. # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. # We can show that if considering changing just 2 scp's, we minimize # this by minimizing the squared difference in sizes. This is # equivalent to minimizing the absolute difference in sizes. This # shows this method is bound to converge. $changed = 1; while($changed) { $changed = 0; for($scpidx = 0; $scpidx < $numscps; $scpidx++) { # First try to reassign ending spk of this scp. if($scpidx < $numscps-1) { $sz = @{$scparray[$scpidx]}; if($sz > 0) { $spk = $scparray[$scpidx]->[$sz-1]; $count = $spk_count{$spk}; $nutt1 = $scpcount[$scpidx]; $nutt2 = $scpcount[$scpidx+1]; if( abs( ($nutt2+$count) - ($nutt1-$count)) < abs($nutt2 - $nutt1)) { # Would decrease # size-diff by reassigning spk... $scpcount[$scpidx+1] += $count; $scpcount[$scpidx] -= $count; pop @{$scparray[$scpidx]}; unshift @{$scparray[$scpidx+1]}, $spk; $changed = 1; } } } if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { $spk = $scparray[$scpidx]->[0]; $count = $spk_count{$spk}; $nutt1 = $scpcount[$scpidx-1]; $nutt2 = $scpcount[$scpidx]; if( abs( ($nutt2-$count) - ($nutt1+$count)) < abs($nutt2 - $nutt1)) { # Would decrease # size-diff by reassigning spk... $scpcount[$scpidx-1] += $count; $scpcount[$scpidx] -= $count; shift @{$scparray[$scpidx]}; push @{$scparray[$scpidx-1]}, $spk; $changed = 1; } } } } # Now print out the files... for($scpidx = 0; $scpidx < $numscps; $scpidx++) { $scpfile = $OUTPUTS[$scpidx]; ($scpfile ne '-' ? open($f_fh, '>', $scpfile) : open($f_fh, '>&', \*STDOUT)) || die "$0: Could not open scp file $scpfile for writing: $!\n"; $count = 0; if(@{$scparray[$scpidx]} == 0) { print STDERR "$0: eError: split_scp.pl producing empty .scp file " . "$scpfile (too many splits and too few speakers?)\n"; $error = 1; } else { foreach $spk ( @{$scparray[$scpidx]} ) { print $f_fh @{$spk_data{$spk}}; $count += $spk_count{$spk}; } $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; } close($f_fh); } } else { # This block is the "normal" case where there is no --utt2spk # option and we just break into equal size chunks. open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; $numscps = @OUTPUTS; # size of array. @F = (); while(<$i_fh>) { push @F, $_; } $numlines = @F; if($numlines == 0) { print STDERR "$0: error: empty input scp file $inscp\n"; $error = 1; } $linesperscp = int( $numlines / $numscps); # the "whole part".. $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; $remainder = $numlines - ($linesperscp * $numscps); ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; # [just doing int() rounds down]. $n = 0; for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { $scpfile = $OUTPUTS[$scpidx]; ($scpfile ne '-' ? open($o_fh, '>', $scpfile) : open($o_fh, '>&', \*STDOUT)) || die "$0: Could not open scp file $scpfile for writing: $!\n"; for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { print $o_fh $F[$n++]; } close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; } $n == $numlines || die "$n != $numlines [code error]"; } exit ($error); ================================================ FILE: tools/test_dataset.py ================================================ from torch.utils.data import DataLoader from wesep.dataset.dataset import Dataset from wesep.dataset.dataset import tse_collate_fn from wesep.utils.file_utils import load_speaker_embeddings def test_premixed_dataset(): configs = { "shuffle": False, "shuffle_args": { "shuffle_size": 2500 }, "resample_rate": 16000, "chunk_len": 32000, } spk2embed_dict = load_speaker_embeddings("data/clean/test/embed.scp", "data/clean/test/single.utt2spk") dataset = Dataset( "shard", "data/clean/test/shard.list", configs=configs, spk2embed_dict=spk2embed_dict, whole_utt=False, ) return dataset def test_online_dataset(): # Implementation to test the online speaker mixing dataloader configs = { "shuffle": True, "resample_rate": 16000, "chunk_len": 64000, "num_speakers": 2, "online_mix": True, "reverb": False, } spk2embed_dict = load_speaker_embeddings("mydata/clean/test/embed.scp", "mydata/clean/test/utt2spk") dataset = Dataset( "shard", "mydata/clean/test/shard.list", configs=configs, spk2embed_dict=spk2embed_dict, whole_utt=False, ) return dataset if __name__ == "__main__": dataset = test_online_dataset() dataloader = DataLoader(dataset, batch_size=4, num_workers=1, collate_fn=tse_collate_fn) for i, batch in enumerate(dataloader): print( batch["wav_mix"].size(), batch["wav_targets"].size(), batch["spk_embeds"].size(), ) if i == 0: break ================================================ FILE: wesep/__init__.py ================================================ from wesep.cli.extractor import load_model # noqa from wesep.cli.extractor import load_model_local # noqa ================================================ FILE: wesep/bin/average_model.py ================================================ # Copyright (c) 2020 Mobvoi Inc (Di Wu) # 2021 Hongji Wang (jijijiang77@gmail.com) # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import glob import os.path import re import torch def get_args(): parser = argparse.ArgumentParser(description="average model") parser.add_argument("--dst_model", required=True, help="averaged model") parser.add_argument("--src_path", required=True, help="src model path for average") parser.add_argument("--num", default=5, type=int, help="nums for averaged model") parser.add_argument( "--min_epoch", default=0, type=int, help="min epoch used for averaging model", ) parser.add_argument( "--max_epoch", default=65536, # Big enough type=int, help="max epoch used for averaging model", ) parser.add_argument( "--mode", default="final", type=str, help="use last epochs for average or best epochs", ) parser.add_argument( "--epochs", default="1,2,3,4,5", type=str, help="use last epochs for average or best epochs", ) args = parser.parse_args() print(args) return args def main(): args = get_args() if args.mode == "final": path_list = glob.glob("{}/*[!avg][!final][!latest].pt".format( args.src_path)) path_list = sorted( path_list, key=lambda p: int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", p)[0]), ) path_list = path_list[-args.num:] else: epoch_indexes = list(args.epochs.split(",")) path_list = [ os.path.join(args.src_path, "checkpoint_" + x + ".pt") for x in epoch_indexes ] print(path_list) avg = None num = args.num assert num == len(path_list) for path in path_list: print("Processing {}".format(path)) states = torch.load(path, map_location=torch.device("cpu")) states = states["models"][0] if "models" in states else states if avg is None: avg = states else: for k in avg.keys(): avg[k] += states[k] # average for k in avg.keys(): if avg[k] is not None: # pytorch 1.6 use true_divide instead of /= avg[k] = torch.true_divide(avg[k], num) avg = {"models": [avg]} print("Saving to {}".format(args.dst_model)) torch.save(avg, args.dst_model) if __name__ == "__main__": main() ================================================ FILE: wesep/bin/export_jit.py ================================================ from __future__ import print_function import argparse import os import torch import yaml from wesep.models import get_model from wesep.utils.checkpoint import load_pretrained_model def get_args(): parser = argparse.ArgumentParser(description="export your script model") parser.add_argument("--config", required=True, help="config file") parser.add_argument("--checkpoint", required=True, help="checkpoint model") parser.add_argument("--output_model", required=True, help="output file") args = parser.parse_args() return args def main(): args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = "-1" with open(args.config, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) print(configs) model = get_model( configs["model"]["tse_model"])(**configs["model_args"]["tse_model"]) print(model) load_pretrained_model(model, args.checkpoint) model.eval() speaker_feat_dim = configs["dataset_args"]["fbank_args"].get( "num_mel_bins", 80) speaker_dummy_input = torch.ones(2, 300, speaker_feat_dim) mix_dummy_input = torch.ones(2, 81280) script_model = torch.jit.script(model, (mix_dummy_input, speaker_dummy_input)) script_model.save(args.output_model) print("Export model successfully, see {}".format(args.output_model)) if __name__ == "__main__": main() ================================================ FILE: wesep/bin/infer.py ================================================ from __future__ import print_function import os import time import fire import soundfile import torch from torch.utils.data import DataLoader from wesep.dataset.dataset import Dataset, tse_collate_fn_2spk from wesep.models import get_model from wesep.utils.checkpoint import load_pretrained_model from wesep.utils.file_utils import read_label_file, read_vec_scp_file from wesep.utils.score import cal_SISNRi from wesep.utils.utils import ( generate_enahnced_scp, get_logger, parse_config_or_kwargs, set_seed, ) os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["TORCH_USE_CUDA_DSA"] = "1" def infer(config="confs/conf.yaml", **kwargs): start = time.time() total_SISNR = 0 total_SISNRi = 0 total_cnt = 0 accept_cnt = 0 configs = parse_config_or_kwargs(config, **kwargs) sign_save_wav = configs.get( "save_wav", True) # Control if save the extracted speech as .wav rank = 0 set_seed(configs["seed"] + rank) gpu = configs["gpus"] device = (torch.device("cuda:{}".format(gpu)) if gpu >= 0 else torch.device("cpu")) sample_rate = configs.get("fs", None) if sample_rate is None or sample_rate == "16k": sample_rate = 16000 else: sample_rate = 8000 if 'spk_model_init' in configs['model_args']['tse_model']: configs['model_args']['tse_model']['spk_model_init'] = False model = get_model( configs["model"]["tse_model"])(**configs["model_args"]["tse_model"]) model_path = os.path.join(configs["checkpoint"]) load_pretrained_model(model, model_path) logger = get_logger(configs["exp_dir"], "infer.log") logger.info("Load checkpoint from {}".format(model_path)) save_audio_dir = os.path.join(configs["exp_dir"], "audio") if sign_save_wav: if not os.path.exists(save_audio_dir): try: os.makedirs(save_audio_dir) print(f"Directory {save_audio_dir} created successfully.") except OSError as e: print(f"Error creating directory {save_audio_dir}: {e}") else: print(f"Directory {save_audio_dir} already exists.") else: print("Do NOT save the results in wav.") model = model.to(device) model.eval() test_spk_embeds = configs.get("test_spk_embeds", None) test_spk1_embed_scp = configs["test_spk1_enroll"] test_spk2_embed_scp = configs["test_spk2_enroll"] joint_training = configs["model_args"]["tse_model"].get( "joint_training", None) if not joint_training and test_spk_embeds: test_spk2embed_dict = read_vec_scp_file(test_spk_embeds) else: test_spk2embed_dict = read_label_file(configs["test_spk2utt"]) test_spk1_embed = read_label_file(test_spk1_embed_scp) test_spk2_embed = read_label_file(test_spk2_embed_scp) lines = len(test_spk2embed_dict) test_dataset = Dataset( configs["data_type"], configs["test_data"], configs["dataset_args"], test_spk2embed_dict, test_spk1_embed, test_spk2_embed, state="test", joint_training=joint_training, whole_utt=configs.get("whole_utt", True), repeat_dataset=configs.get("repeat_dataset", False), ) test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=tse_collate_fn_2spk) test_iter = lines // 2 logger.info("test number: {}".format(test_iter)) with torch.no_grad(): for i, batch in enumerate(test_dataloader): features = batch["wav_mix"] targets = batch["wav_targets"] enroll = batch["spk_embeds"] spk = batch["spk"] key = batch["key"] features = features.float().to(device) # (B,T,F) targets = targets.float().to(device) enroll = enroll.float().to(device) outputs = model(features, enroll) if isinstance(outputs, (list, tuple)): outputs = outputs[0] if torch.min(outputs.max(dim=1).values) > 0: outputs = ((outputs / abs(outputs).max(dim=1, keepdim=True)[0] * 0.9).cpu().numpy()) else: outputs = outputs.cpu().numpy() if sign_save_wav: file1 = os.path.join( save_audio_dir, f"Utt{total_cnt + 1}-{key[0]}-T{spk[0]}.wav", ) soundfile.write(file1, outputs[0], sample_rate) file2 = os.path.join( save_audio_dir, f"Utt{total_cnt + 1}-{key[1]}-T{spk[1]}.wav", ) soundfile.write(file2, outputs[1], sample_rate) ref = targets.cpu().numpy() ests = outputs mix = features.cpu().numpy() if ests[0].size != ref[0].size: end = min(ests[0].size, ref[0].size, mix[0].size) ests_1 = ests[0][:end] ref_1 = ref[0][:end] mix_1 = mix[0][:end] SISNR1, delta1 = cal_SISNRi(ests_1, ref_1, mix_1) else: SISNR1, delta1 = cal_SISNRi(ests[0], ref[0], mix[0]) logger.info( "Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}" .format(total_cnt + 1, key[0], spk[0], SISNR1, delta1)) total_SISNR += SISNR1 total_SISNRi += delta1 total_cnt += 1 if delta1 > 1: accept_cnt += 1 if ests[1].size != ref[1].size: end = min(ests[1].size, ref[1].size, mix[1].size) ests_2 = ests[1][:end] ref_2 = ref[1][:end] mix_2 = mix[1][:end] SISNR2, delta2 = cal_SISNRi(ests_2, ref_2, mix_2) else: SISNR2, delta2 = cal_SISNRi(ests[1], ref[1], mix[1]) logger.info( "Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}" .format(total_cnt + 1, key[1], spk[1], SISNR2, delta2)) total_SISNR += SISNR2 total_SISNRi += delta2 total_cnt += 1 if delta2 > 1: accept_cnt += 1 # if (i + 1) == test_iter: # break end = time.time() # generate the scp file of the enhanced speech for scoring if sign_save_wav: generate_enahnced_scp(os.path.abspath(save_audio_dir), extension="wav") logger.info("Time Elapsed: {:.1f}s".format(end - start)) logger.info("Average SI-SNR: {:.2f}".format(total_SISNR / total_cnt)) logger.info("Average SI-SNRi: {:.2f}".format(total_SISNRi / total_cnt)) logger.info( "Acceptance rate of Utterances with SI-SDRi > 1 dB: {:.2f}".format( accept_cnt / total_cnt * 100)) if __name__ == "__main__": fire.Fire(infer) ================================================ FILE: wesep/bin/score.py ================================================ # ported from # https://github.com/espnet/espnet/blob/master/espnet2/bin/enh_scoring.py import argparse import logging import sys from pathlib import Path from typing import Dict, List, Union import numpy as np from mir_eval.separation import bss_eval_sources from pystoi import stoi from wesep.utils.datadir_writer import DatadirWriter from wesep.utils.file_utils import SoundScpReader from wesep.utils.score import cal_SISNR from wesep.utils.utils import ArgumentParser, get_commandline_args, str2bool def get_readers(scps: List[str], dtype: str): readers = [SoundScpReader(f, dtype=dtype) for f in scps] audio_format = "sound" return readers, audio_format def read_audio(reader, key, audio_format="sound"): if audio_format == "sound": return reader[key][1] else: raise ValueError(f"Unknown audio format: {audio_format}") def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, use_dnsmos: bool, dnsmos_args: Dict, use_pesq: bool, ): logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if use_dnsmos: if dnsmos_args["mode"] == "local": from wesep.utils.dnsmos import DNSMOS_local if not Path(dnsmos_args["primary_model"]).exists(): raise ValueError( f"The primary model {dnsmos_args['primary_model']} doesn't exist." " You can download the model from https://github.com/microsoft/" "DNS-Challenge/tree/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx") if not Path(dnsmos_args["p808_model"]).exists(): raise ValueError( f"The P808 model {dnsmos_args['p808_model']} doesn't exist." " You can download the model from https://github.com/microsoft/" "DNS-Challenge/tree/master/DNSMOS/DNSMOS/model_v8.onnx") dnsmos = DNSMOS_local( dnsmos_args["primary_model"], dnsmos_args["p808_model"], use_gpu=dnsmos_args["use_gpu"], convert_to_torch=dnsmos_args["convert_to_torch"], gpu_device=dnsmos_args["gpu_device"] - 1, ) logging.warning("Using local DNSMOS models for evaluation") elif dnsmos_args["mode"] == "web": from wesep.utils.dnsmos import DNSMOS_web if not dnsmos_args["auth_key"]: raise ValueError( "Please specify the authentication key for access to the Web-API. " "You can apply for the AUTH_KEY at https://github.com/microsoft/" "DNS-Challenge/blob/master/DNSMOS/README.md#to-use-the-web-api" ) dnsmos = DNSMOS_web(dnsmos_args["auth_key"]) logging.warning("Using the DNSMOS Web-API for evaluation") else: dnsmos = None if use_pesq: try: from pesq import PesqError, pesq logging.warning("Using the PESQ package for evaluation") except ImportError: raise ImportError( "Please install pesq and retry: pip install pesq") from None else: pesq = None assert len(ref_scp) == len(inf_scp), "len(ref_scp) != len(inf_scp)" num_spk = len(ref_scp) keys = [ line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") ] ref_readers, ref_audio_format = get_readers(ref_scp, dtype) inf_readers, inf_audio_format = get_readers(inf_scp, dtype) # get sample rate retval = ref_readers[0][keys[0]] if ref_audio_format == "kaldi_ark": sample_rate = ref_readers[0].rate elif ref_audio_format == "sound": sample_rate = retval[0] else: raise NotImplementedError(ref_audio_format) assert sample_rate is not None, (sample_rate, ref_audio_format) # check keys for inf_reader, ref_reader in zip(inf_readers, ref_readers): assert inf_reader.keys() == ref_reader.keys() with DatadirWriter(output_dir) as writer: for n, key in enumerate(keys): logging.info(f"[{n}] Scoring {key}") ref_audios = [ read_audio(ref_reader, key, audio_format=ref_audio_format) for ref_reader in ref_readers ] inf_audios = [ read_audio(inf_reader, key, audio_format=inf_audio_format) for inf_reader in inf_readers ] ref = np.array(ref_audios) inf = np.array(inf_audios) if ref.ndim > inf.ndim: # multi-channel reference and single-channel output ref = ref[..., ref_channel] elif ref.ndim < inf.ndim: # single-channel reference and multi-channel output inf = inf[..., ref_channel] elif ref.ndim == inf.ndim == 3: # multi-channel reference and output ref = ref[..., ref_channel] inf = inf[..., ref_channel] assert ref.shape == inf.shape, (ref.shape, inf.shape) sdr, sir, sar, perm = bss_eval_sources(ref, inf, compute_permutation=True) for i in range(num_spk): stoi_score = stoi(ref[i], inf[int(perm[i])], fs_sig=sample_rate) estoi_score = stoi( ref[i], inf[int(perm[i])], fs_sig=sample_rate, extended=True, ) si_snr_score = cal_SISNR( ref[i], inf[int(perm[i])], ) if dnsmos: dnsmos_score = dnsmos(inf[int(perm[i])], sample_rate) writer[f"OVRL_spk{i + 1}"][key] = str(dnsmos_score["OVRL"]) writer[f"SIG_spk{i + 1}"][key] = str(dnsmos_score["SIG"]) writer[f"BAK_spk{i + 1}"][key] = str(dnsmos_score["BAK"]) writer[f"P808_MOS_spk{i + 1}"][key] = str( dnsmos_score["P808_MOS"]) if pesq: if sample_rate == 8000: mode = "nb" elif sample_rate == 16000: mode = "wb" else: raise ValueError( "sample rate must be 8000 or 16000 for PESQ evaluation, " f"but got {sample_rate}") pesq_score = pesq( sample_rate, ref[i], inf[int(perm[i])], mode=mode, on_error=PesqError.RETURN_VALUES, ) if pesq_score == PesqError.NO_UTTERANCES_DETECTED: logging.warning( f"[PESQ] Error: No utterances detected for {key}. " "Skipping this utterance.") else: writer[f"PESQ_{mode.upper()}_spk{i + 1}"][key] = str( pesq_score) writer[f"STOI_spk{i + 1}"][key] = str(stoi_score * 100) # in percentage writer[f"ESTOI_spk{i + 1}"][key] = str(estoi_score * 100) writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score) writer[f"SDR_spk{i + 1}"][key] = str(sdr[i]) writer[f"SAR_spk{i + 1}"][key] = str(sar[i]) writer[f"SIR_spk{i + 1}"][key] = str(sir[i]) # save permutation assigned script file if i < len(ref_scp): if inf_audio_format == "sound": writer[f"wav_spk{i + 1}"][key] = inf_readers[ perm[i]].data[key] elif inf_audio_format == "kaldi_ark": # NOTE: SegmentsExtractor is not supported writer[f"wav_spk{i + 1}"][key] = inf_readers[ perm[i]].loader._dict[key] else: raise ValueError( f"Unknown audio format: {inf_audio_format}") def get_parser(): parser = ArgumentParser( description="Frontend inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Note(kamo): Use '_' instead of '-' as separator. # '-' is confusing if written in yaml. parser.add_argument( "--log_level", type=lambda x: x.upper(), default="INFO", choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), help="The verbose level of logging", ) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument( "--dtype", default="float32", choices=["float16", "float32", "float64"], help="Data type", ) group = parser.add_argument_group("Input data related") group.add_argument( "--ref_scp", type=str, required=True, action="append", ) group.add_argument( "--inf_scp", type=str, required=True, action="append", ) group.add_argument("--key_file", type=str) group.add_argument("--ref_channel", type=int, default=0) group = parser.add_argument_group("DNSMOS related") group.add_argument("--use_dnsmos", type=str2bool, default=False) group.add_argument( "--dnsmos_mode", type=str, choices=("local", "web"), default="local", help="Use local DNSMOS model or web API for DNSMOS calculation", ) group.add_argument( "--dnsmos_auth_key", type=str, default="", help="Required if dnsmsos_mode='web'", ) group.add_argument( "--dnsmos_use_gpu", type=str2bool, default=False, help="used when dnsmsos_mode='local'", ) group.add_argument( "--dnsmos_convert_to_torch", type=str2bool, default=False, help="used when dnsmsos_mode='local'", ) group.add_argument("--dnsmos_primary_model", type=str, default="./DNSMOS/sig_bak_ovr.onnx", help="Path to the primary DNSMOS model. " "Required if dnsmsos_mode='local'") group.add_argument( "--dnsmos_p808_model", type=str, default="./DNSMOS/model_v8.onnx", help="Path to the p808 model. Required if dnsmsos_mode='local'", ) group.add_argument("--dnsmos_gpu_device", type=int, default=None, help="gpu device to use for DNSMOS evaluation. " "Used when dnsmsos_mode='local'") group = parser.add_argument_group("PESQ related") group.add_argument( "--use_pesq", type=str2bool, default=False, help="Bebore setting this to True, please make sure that you or " "your institution have the license " "(check https://www.itu.int/rec/T-REC-P.862-200511-I!Amd2/en) to report PESQ", ) return parser def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() args = parser.parse_args(cmd) kwargs = vars(args) kwargs.pop("config", None) dnsmos_args = { "mode": kwargs.pop("dnsmos_mode"), "auth_key": kwargs.pop("dnsmos_auth_key"), "primary_model": kwargs.pop("dnsmos_primary_model"), "p808_model": kwargs.pop("dnsmos_p808_model"), "use_gpu": kwargs.pop("dnsmos_use_gpu"), "convert_to_torch": kwargs.pop("dnsmos_convert_to_torch"), "gpu_device": kwargs.pop("dnsmos_gpu_device"), } kwargs["dnsmos_args"] = dnsmos_args scoring(**kwargs) if __name__ == "__main__": main() ================================================ FILE: wesep/bin/train.py ================================================ # Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com) # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import os import re from pprint import pformat import fire import matplotlib.pyplot as plt import tableprint as tp import torch import torch.distributed as dist import yaml from torch.utils.data import DataLoader import wesep.utils.schedulers as schedulers from wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk from wesep.models import get_model from wesep.utils.checkpoint import ( load_checkpoint, load_pretrained_model, save_checkpoint, ) from wesep.utils.executor import Executor from wesep.utils.file_utils import ( load_speaker_embeddings, read_label_file, read_vec_scp_file, ) from wesep.utils.losses import parse_loss from wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger MAX_NUM_log_files = 100 # The maximum number of log-files to be kept logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR) def train(config="conf/config.yaml", **kwargs): """Trains a model on the given features and spk labels. :config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG VALUE :returns: None """ # print(kwargs) configs = parse_config_or_kwargs(config, **kwargs) checkpoint = configs.get("checkpoint", None) if checkpoint is not None: checkpoint = os.path.realpath(checkpoint) find_unused_parameters = configs.get("find_unused_parameters", False) # dist configs rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) gpu = int(configs["gpus"][rank]) torch.cuda.set_device(gpu) dist.init_process_group(backend="nccl") # Log rotation model_dir = os.path.join(configs["exp_dir"], "models") logger = setup_logger(rank, configs["exp_dir"], gpu, MAX_NUM_log_files) print("-------------------", dist.get_rank(), world_size) if world_size > 1: logger.info("training on multiple gpus, this gpu {}".format(gpu)) if rank == 0: logger.info("exp_dir is: {}".format(configs["exp_dir"])) logger.info("<== Passed Arguments ==>") # Print arguments into logs for line in pformat(configs).split("\n"): logger.info(line) # seed set_seed(configs["seed"] + rank) # loss criterion = configs.get("loss", None) if criterion: criterion = parse_loss(criterion) else: criterion = [ parse_loss("SISDR"), ] loss_posi = configs["loss_args"].get( "loss_posi", [[ 0, ]], ) loss_weight = configs["loss_args"].get( "loss_weight", [[ 1.0, ]], ) loss_args = (loss_posi, loss_weight) # embeds tr_spk_embeds = configs.get("train_spk_embeds", None) tr_single_utt2spk = configs["train_utt2spk"] joint_training = configs["model_args"]["tse_model"].get( "joint_training", False) multi_task = configs["model_args"]["tse_model"].get("multi_task", False) dict_spk = {} if not joint_training and tr_spk_embeds: tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds, tr_single_utt2spk) multi_task = None else: with open(configs["train_spk2utt"], "r") as f: tr_spk2embed_dict = json.load(f) if multi_task: for i, j in enumerate(tr_spk2embed_dict.keys( )): # Generate the dictionary for speakers in training set dict_spk[j] = i with open(tr_single_utt2spk, "r") as f: tr_lines = f.readlines() val_spk_embeds = configs.get("val_spk_embeds", None) val_spk1_enroll = configs["val_spk1_enroll"] val_spk2_enroll = configs["val_spk2_enroll"] if not joint_training and val_spk_embeds: val_spk2embed_dict = read_vec_scp_file(val_spk_embeds) else: val_spk2embed_dict = read_label_file(configs["val_spk2utt"]) val_lines = len(val_spk2embed_dict) val_spk1_embed = read_label_file(val_spk1_enroll) val_spk2_embed = read_label_file(val_spk2_enroll) # dataset and dataloader train_dataset = Dataset( configs["data_type"], configs["train_data"], configs["dataset_args"], tr_spk2embed_dict, None, None, state="train", joint_training=joint_training, dict_spk=dict_spk, whole_utt=configs.get("whole_utt", False), repeat_dataset=configs.get("repeat_dataset", True), noise_prob=configs["dataset_args"].get("noise_prob", 0), reverb_prob=configs["dataset_args"].get("reverb_prob", 0), noise_enroll_prob=configs["dataset_args"].get("noise_enroll_prob", 0), reverb_enroll_prob=configs["dataset_args"].get("reverb_enroll_prob", 0), specaug_enroll_prob=configs["dataset_args"].get( "specaug_enroll_prob", 0), online_mix=configs["dataset_args"].get("online_mix", False), noise_lmdb_file=configs["dataset_args"].get("noise_lmdb_file", None), ) val_dataset = Dataset(configs["data_type"], configs["val_data"], configs["dataset_args"], val_spk2embed_dict, val_spk1_embed, val_spk2_embed, state="val", joint_training=joint_training, whole_utt=configs.get("whole_utt", False), repeat_dataset=True, online_mix=False, noise_prob=0, reverb_prob=0, noise_enroll_prob=0, reverb_enroll_prob=0, specaug_enroll_prob=0) train_dataloader = DataLoader(train_dataset, **configs["dataloader_args"], collate_fn=tse_collate_fn) val_dataloader = DataLoader( val_dataset, **configs["dataloader_args"], collate_fn=tse_collate_fn_2spk, ) batch_size = configs["dataloader_args"]["batch_size"] if configs["dataset_args"].get("sample_num_per_epoch", 0) > 0: sample_num_per_epoch = configs["dataset_args"]["sample_num_per_epoch"] else: sample_num_per_epoch = len(tr_lines) // 2 epoch_iter = sample_num_per_epoch // world_size // batch_size val_iter = val_lines // 2 // world_size // batch_size if rank == 0: logger.info("<== Dataloaders ==>") logger.info("train dataloaders created") logger.info("epoch iteration number: {}".format(epoch_iter)) logger.info("val iteration number: {}".format(val_iter)) # model model_list = [] scheduler_list = [] optimizer_list = [] logger.info("<== Model ==>") model = get_model( configs["model"]["tse_model"])(**configs["model_args"]["tse_model"]) num_params = sum(param.numel() for param in model.parameters()) if rank == 0: logger.info("tse_model size: {:.2f} M".format(num_params / 1e6)) # print model for line in pformat(model).split("\n"): logger.info(line) # ddp_model model.cuda() ddp_model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=find_unused_parameters) device = torch.device("cuda") if rank == 0: logger.info("<== TSE Model Loss ==>") logger.info("loss criterion is: " + str(configs["loss"])) configs["optimizer_args"]["tse_model"]["lr"] = configs["scheduler_args"][ "tse_model"]["initial_lr"] optimizer = getattr(torch.optim, configs["optimizer"]["tse_model"])( ddp_model.parameters(), **configs["optimizer_args"]["tse_model"]) if rank == 0: logger.info("<== TSE Model Optimizer ==>") logger.info("optimizer is: " + configs["optimizer"]["tse_model"]) # scheduler configs["scheduler_args"]["tse_model"]["num_epochs"] = configs[ "num_epochs"] configs["scheduler_args"]["tse_model"]["epoch_iter"] = epoch_iter configs["scheduler_args"]["scale_ratio"] = 1.0 scheduler = getattr(schedulers, configs["scheduler"]["tse_model"])( optimizer, **configs["scheduler_args"]["tse_model"]) if rank == 0: logger.info("<== TSE Model Scheduler ==>") logger.info("scheduler is: " + configs["scheduler"]["tse_model"]) if configs["model_init"]["tse_model"] is not None: logger.info("Load initial model from {}".format( configs["model_init"]["tse_model"])) load_pretrained_model(ddp_model, configs["model_init"]["tse_model"]) elif checkpoint is None: logger.info("Train model from scratch ...") for c in criterion: c = c.to(device) # append to list model_list.append(ddp_model) optimizer_list.append(optimizer) scheduler_list.append(scheduler) scaler = torch.cuda.amp.GradScaler(enabled=configs["enable_amp"]) # If specify checkpoint, load some info from checkpoint. if checkpoint is not None: load_checkpoint(model_list, optimizer_list, scheduler_list, scaler, checkpoint) start_epoch = ( int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", checkpoint)[0]) + 1) logger.info("Load checkpoint: {}".format(checkpoint)) else: start_epoch = 1 logger.info("start_epoch: {}".format(start_epoch)) # save config.yaml if rank == 0: saved_config_path = os.path.join(configs["exp_dir"], "config.yaml") with open(saved_config_path, "w") as fout: data = yaml.dump(configs) fout.write(data) # training dist.barrier(device_ids=[gpu]) # synchronize here if rank == 0: logger.info("<========== Training process ==========>") header = ["Train/Val", "Epoch", "iter", "Loss", "LR"] for line in tp.header(header, width=10, style="grid").split("\n"): logger.info(line) dist.barrier(device_ids=[gpu]) # synchronize here executor = Executor() executor.step = 0 train_losses = [] val_losses = [] for epoch in range(start_epoch, configs["num_epochs"] + 1): train_dataset.set_epoch(epoch) # train_loss_com train_loss, _ = executor.train( train_dataloader, model_list, epoch_iter, optimizer_list, criterion, scheduler_list, scaler=scaler, epoch=epoch, logger=logger, enable_amp=configs["enable_amp"], clip_grad=configs["clip_grad"], log_batch_interval=configs["log_batch_interval"], device=device, se_loss_weight=loss_args, multi_task=multi_task, SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", 0), fbank_args=configs["dataset_args"].get('fbank_args', None), sample_rate=configs["dataset_args"]['resample_rate'], speaker_feat=configs["dataset_args"].get('speaker_feat', True) ) val_loss, _ = executor.cv( val_dataloader, model_list, val_iter, criterion, epoch=epoch, logger=logger, enable_amp=configs["enable_amp"], log_batch_interval=configs["log_batch_interval"], device=device, ) if rank == 0: logger.info("Epoch {} Train info train_loss {}".format( epoch, train_loss)) logger.info("Epoch {} Val info val_loss {}".format( epoch, val_loss)) train_losses.append(train_loss) val_losses.append(val_loss) best_loss = val_loss scheduler.best = best_loss # plot plt.figure() plt.title("Loss of Train and Validation") x = list(range(start_epoch, epoch + 1)) plt.plot(x, train_losses, "b-", label="Train Loss", linewidth=0.8) plt.plot(x, val_losses, "c-", label="Validation Loss", linewidth=0.8) plt.legend() plt.xlabel("Epoch") plt.ylabel("Loss") plt.xticks(range(start_epoch, epoch + 1, 1)) plt.savefig( f"{configs['exp_dir']}/{configs['model']['tse_model']}.png") plt.close() if rank == 0: if (epoch % configs["save_epoch_interval"] == 0 or epoch >= configs["num_epochs"] - configs["num_avg"]): save_checkpoint( model_list, optimizer_list, scheduler_list, scaler, os.path.join(model_dir, "checkpoint_{}.pt".format(epoch)), ) try: os.symlink( "checkpoint_{}.pt".format(epoch), os.path.join(model_dir, "latest_checkpoint.pt"), ) except FileExistsError: os.remove(os.path.join(model_dir, "latest_checkpoint.pt")) os.symlink( "checkpoint_{}.pt".format(epoch), os.path.join(model_dir, "latest_checkpoint.pt"), ) if rank == 0: os.symlink( "checkpoint_{}.pt".format(configs["num_epochs"]), os.path.join(model_dir, "final_checkpoint.pt"), ) logger.info(tp.bottom(len(header), width=10, style="grid")) if __name__ == "__main__": fire.Fire(train) ================================================ FILE: wesep/bin/train_gan.py ================================================ # Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com) # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import os import re from pprint import pformat import fire import matplotlib.pyplot as plt import tableprint as tp import torch import torch.distributed as dist import yaml from torch.utils.data import DataLoader import wesep.utils.schedulers as schedulers from wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk from wesep.models import get_model from wesep.utils.checkpoint import ( load_checkpoint, load_pretrained_model, save_checkpoint, ) from wesep.utils.executor_gan import ExecutorGAN from wesep.utils.file_utils import ( load_speaker_embeddings, read_label_file, read_vec_scp_file, ) from wesep.utils.losses import parse_loss from wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger MAX_NUM_log_files = 100 # The maximum number of log-files to be kept logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR) def train(config="conf/config.yaml", **kwargs): """Trains a model on the given features and spk labels. :config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG VALUE :returns: None """ configs = parse_config_or_kwargs(config, **kwargs) checkpoint = configs.get("checkpoint", None) if checkpoint is not None: checkpoint = os.path.realpath(checkpoint) find_unused_parameters = configs.get("find_unused_parameters", False) gan_loss_weight = configs.get("gan_loss_weight", 0.05) # dist configs rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) gpu = int(configs["gpus"][rank]) torch.cuda.set_device(gpu) dist.init_process_group(backend="nccl") # Log rotation model_dir = os.path.join(configs["exp_dir"], "models") logger = setup_logger(rank, configs["exp_dir"], gpu, MAX_NUM_log_files) print("-------------------", dist.get_rank(), world_size) if world_size > 1: logger.info("training on multiple gpus, this gpu {}".format(gpu)) if rank == 0: logger.info("exp_dir is: {}".format(configs["exp_dir"])) logger.info("<== Passed Arguments ==>") # Print arguments into logs for line in pformat(configs).split("\n"): logger.info(line) # seed set_seed(configs["seed"] + rank) # support multiple losses, e.g., criterion = [SISNR, CE] criterion = configs.get("loss", None) if criterion: criterion = parse_loss(criterion) else: criterion = [ parse_loss("SISNR"), ] # loss_posi is used to store the indices when the model has multiple outputs # loss_posi[i][j] stores the index of the output used for i-th criterion, # that is, output[loss_posi[i][j]] is used for the i-th criterion. loss_posi = configs["loss_args"].get( "loss_posi", [[ 0, ]], ) # loss_weight[i][j] stores the loss weight of output[loss_posi[i][j]] for the i-th criterion. # noqa loss_weight = configs["loss_args"].get( "loss_weight", [[ 1.0, ]], ) loss_args = (loss_posi, loss_weight) # embeds tr_spk_embeds = configs["train_spk_embeds"] tr_single_utt2spk = configs["train_utt2spk"] joint_training = configs["model_args"]["tse_model"].get( "joint_training", False) multi_task = configs["model_args"]["tse_model"].get("multi_task", False) # dict_spk: {spk_id: int_label} dict_spk = {} if not joint_training: tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds, tr_single_utt2spk) multi_task = False else: with open(configs["train_spk2utt"], "r") as f: tr_spk2embed_dict = json.load(f) # tr_spk2embed_dict: {spk_id: [[spk_id, wav_path], ...]} if multi_task: for i, j in enumerate(tr_spk2embed_dict.keys( )): # Generate the dictionary for speakers in training set dict_spk[j] = i with open(tr_single_utt2spk, "r") as f: tr_lines = f.readlines() val_spk_embeds = configs["val_spk_embeds"] val_spk1_enroll = configs["val_spk1_enroll"] val_spk2_enroll = configs["val_spk2_enroll"] if not joint_training: val_spk2embed_dict = read_vec_scp_file(val_spk_embeds) else: val_spk2embed_dict = read_label_file(configs["val_spk2utt"]) val_spk1_embed = read_label_file(val_spk1_enroll) val_spk2_embed = read_label_file(val_spk2_enroll) with open(val_spk_embeds, "r") as f: val_lines = f.readlines() # dataset and dataloader train_dataset = Dataset( configs["data_type"], configs["train_data"], configs["dataset_args"], tr_spk2embed_dict, None, None, state="train", joint_training=joint_training, dict_spk=dict_spk, whole_utt=configs.get("whole_utt", False), repeat_dataset=configs.get("repeat_dataset", True), reverb=configs["dataset_args"].get("reverb", False), noise=configs["dataset_args"].get("noise", False), noise_lmdb_file=configs["dataset_args"].get("noise_lmdb_file", None), online_mix=configs["dataset_args"].get("online_mix", False), ) val_dataset = Dataset( configs["data_type"], configs["val_data"], configs["dataset_args"], val_spk2embed_dict, val_spk1_embed, val_spk2_embed, state="val", joint_training=joint_training, whole_utt=configs.get("whole_utt", False), repeat_dataset=True, reverb=False, online_mix=False, ) train_dataloader = DataLoader(train_dataset, **configs["dataloader_args"], collate_fn=tse_collate_fn) val_dataloader = DataLoader( val_dataset, **configs["dataloader_args"], collate_fn=tse_collate_fn_2spk, ) batch_size = configs["dataloader_args"]["batch_size"] if configs["dataset_args"].get("sample_num_per_epoch", 0) > 0: sample_num_per_epoch = configs["dataset_args"]["sample_num_per_epoch"] else: sample_num_per_epoch = len(tr_lines) // 2 epoch_iter = sample_num_per_epoch // world_size // batch_size val_iter = len(val_lines) // 2 // world_size // batch_size if rank == 0: logger.info("<== Dataloaders ==>") logger.info("train dataloaders created") logger.info("epoch iteration number: {}".format(epoch_iter)) logger.info("val iteration number: {}".format(val_iter)) # model model_list = [] scheduler_list = [] optimizer_list = [] logger.info("<== Model ==>") model = get_model( configs["model"]["tse_model"])(**configs["model_args"]["tse_model"]) num_params = sum(param.numel() for param in model.parameters()) if rank == 0: logger.info("tse_model size: {}".format(num_params)) # print model for line in pformat(model).split("\n"): logger.info(line) # ddp_model model.cuda() ddp_model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=find_unused_parameters) device = torch.device("cuda") if rank == 0: logger.info("<== TSE Model Loss ==>") logger.info("loss criterion is: " + str(configs["loss"])) configs["optimizer_args"]["tse_model"]["lr"] = configs["scheduler_args"][ "tse_model"]["initial_lr"] optimizer = getattr(torch.optim, configs["optimizer"]["tse_model"])( ddp_model.parameters(), **configs["optimizer_args"]["tse_model"]) if rank == 0: logger.info("<== TSE Model Optimizer ==>") logger.info("optimizer is: " + configs["optimizer"]["tse_model"]) # scheduler configs["scheduler_args"]["tse_model"]["num_epochs"] = configs[ "num_epochs"] configs["scheduler_args"]["tse_model"]["epoch_iter"] = epoch_iter configs["scheduler_args"]["scale_ratio"] = 1.0 scheduler = getattr(schedulers, configs["scheduler"]["tse_model"])( optimizer, **configs["scheduler_args"]["tse_model"]) if rank == 0: logger.info("<== TSE Model Scheduler ==>") logger.info("scheduler is: " + configs["scheduler"]["tse_model"]) if configs["model_init"]["tse_model"] is not None: logger.info("Load initial model from {}".format( configs["model_init"]["tse_model"])) load_pretrained_model(ddp_model, configs["model_init"]["tse_model"]) elif checkpoint is None: logger.info("Train model from scratch ...") for c in criterion: c = c.to(device) # append to list model_list.append(ddp_model) optimizer_list.append(optimizer) scheduler_list.append(scheduler) scaler = torch.cuda.amp.GradScaler(enabled=configs["enable_amp"]) # discriminator discriminator = get_model(configs["model"]["discriminator"])( **configs["model_args"]["discriminator"]) num_params = sum(param.numel() for param in discriminator.parameters()) # optimizer configs["optimizer_args"]["discriminator"]["lr"] = configs[ "scheduler_args"]["discriminator"]["initial_lr"] # scheduler configs["scheduler_args"]["discriminator"]["num_epochs"] = configs[ "num_epochs"] configs["scheduler_args"]["discriminator"]["epoch_iter"] = epoch_iter configs["scheduler_args"]["discriminator"]["scale_ratio"] = 1.0 # ddp model discriminator.cuda() ddp_discriminator = torch.nn.parallel.DistributedDataParallel( discriminator, find_unused_parameters=find_unused_parameters) optimizer_d = getattr(torch.optim, configs["optimizer"]["discriminator"])( ddp_discriminator.parameters(), **configs["optimizer_args"]["discriminator"], ) scheduler_d = getattr(schedulers, configs["scheduler"]["discriminator"])( optimizer_d, **configs["scheduler_args"]["discriminator"]) # initialize discriminator if configs["model_init"]["discriminator"] is not None: logger.info("Load initial discriminator from {}".format( configs["model_init"]["discriminator"])) load_pretrained_model( ddp_discriminator, configs["model_init"]["discriminator"], type="discriminator", ) elif checkpoint is None: logger.info("Train discriminator from scratch ...") # If specify checkpoint, load some info from checkpoint. if checkpoint is not None: load_checkpoint(model_list, optimizer_list, scheduler_list, scaler, checkpoint) start_epoch = ( int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", checkpoint)[0]) + 1) logger.info("Load checkpoint: {}".format(checkpoint)) else: start_epoch = 1 model_list.append(ddp_discriminator) optimizer_list.append(optimizer_d) scheduler_list.append(scheduler_d) if rank == 0: logger.info("<== Discriminator Model ==>") logger.info("discriminator size: {}".format(num_params)) for line in pformat(discriminator).split("\n"): logger.info(line) logger.info("<== Discriminator Optimizer ==>") logger.info("optimizer is: " + configs["optimizer"]["discriminator"]) logger.info("<== Discriminator Scheduler ==>") logger.info("scheduler is: " + configs["scheduler"]["discriminator"]) # save config.yaml saved_config_path = os.path.join(configs["exp_dir"], "config.yaml") with open(saved_config_path, "w") as fout: data = yaml.dump(configs) fout.write(data) logger.info("start_epoch: {}".format(start_epoch)) # training dist.barrier(device_ids=[gpu]) # synchronize here if rank == 0: logger.info("<========== Training process ==========>") header = [ "Train/Val", "Epoch", "iter", "SE_Loss", "G_Loss", "D_Loss", "LR", ] for line in tp.header(header, width=10, style="grid").split("\n"): logger.info(line) dist.barrier(device_ids=[gpu]) # synchronize here executor = ExecutorGAN() executor.step = 0 train_losses = [] val_losses = [] train_d_losses = [] val_d_losses = [] for epoch in range(start_epoch, configs["num_epochs"] + 1): train_dataset.set_epoch(epoch) train_loss, train_d_loss = executor.train( train_dataloader, model_list, epoch_iter, optimizer_list, criterion, scheduler_list, scaler=scaler, epoch=epoch, logger=logger, enable_amp=configs["enable_amp"], clip_grad=configs["clip_grad"], log_batch_interval=configs["log_batch_interval"], device=device, se_loss_weight=loss_args, gan_loss_weight=gan_loss_weight, multi_task=multi_task, ) val_loss, val_d_loss = executor.cv( val_dataloader, model_list, val_iter, criterion, epoch=epoch, logger=logger, enable_amp=configs["enable_amp"], log_batch_interval=configs["log_batch_interval"], device=device, ) if rank == 0: logger.info( "Epoch {} Train info train_loss {}, train_d_loss {}".format( epoch, train_loss, train_d_loss)) logger.info("Epoch {} Val info val_loss {}, val_d_loss {}".format( epoch, val_loss, val_d_loss)) train_losses.append(train_loss) train_d_losses.append(train_d_loss) val_losses.append(val_loss) val_d_losses.append(val_d_loss) best_loss = val_loss scheduler.best = best_loss # plot plt.figure() plt.title("Loss of Train and Validation") x = list(range(start_epoch, epoch + 1)) plt.plot(x, train_losses, "b-", label="train_G_loss", linewidth=0.8) plt.plot(x, train_d_losses, "r-", label="train_D_loss", linewidth=0.8) plt.plot(x, val_losses, "c-", label="val_G_loss", linewidth=0.8) plt.plot(x, val_d_losses, "m-", label="val_D_loss", linewidth=0.8) plt.legend() plt.xlabel("Epoch") plt.ylabel("Loss") plt.xticks(range(start_epoch, epoch + 1, 1)) plt.savefig( f"{configs['exp_dir']}/{configs['model']['tse_model']}.png") plt.close() if rank == 0: if (epoch % configs["save_epoch_interval"] == 0 or epoch >= configs["num_epochs"] - configs["num_avg"]): save_checkpoint( model_list, optimizer_list, scheduler_list, scaler, os.path.join(model_dir, "checkpoint_{}.pt".format(epoch)), ) try: os.symlink( "checkpoint_{}.pt".format(epoch), os.path.join(model_dir, "latest_checkpoint.pt"), ) except FileExistsError: os.remove(os.path.join(model_dir, "latest_checkpoint.pt")) os.symlink( "checkpoint_{}.pt".format(epoch), os.path.join(model_dir, "latest_checkpoint.pt"), ) if rank == 0: os.symlink( "checkpoint_{}.pt".format(configs["num_epochs"]), os.path.join(model_dir, "final_checkpoint.pt"), ) logger.info(tp.bottom(len(header), width=10, style="grid")) if __name__ == "__main__": fire.Fire(train) ================================================ FILE: wesep/cli/__init__.py ================================================ ================================================ FILE: wesep/cli/extractor.py ================================================ import os import sys from silero_vad import load_silero_vad, get_speech_timestamps import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import yaml import soundfile from wesep.cli.hub import Hub from wesep.cli.utils import get_args from wesep.models import get_model from wesep.utils.checkpoint import load_pretrained_model from wesep.utils.utils import set_seed class Extractor: def __init__(self, model_dir: str): set_seed() config_path = os.path.join(model_dir, "config.yaml") model_path = os.path.join(model_dir, "avg_model.pt") with open(config_path, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if 'spk_model_init' in configs['model_args']['tse_model']: configs['model_args']['tse_model']['spk_model_init'] = False self.model = get_model(configs["model"]["tse_model"])( **configs["model_args"]["tse_model"] ) load_pretrained_model(self.model, model_path) self.model.eval() self.vad = load_silero_vad() self.table = {} self.resample_rate = configs["dataset_args"].get("resample_rate", 16000) self.apply_vad = False self.device = torch.device("cpu") self.wavform_norm = True self.output_norm = True self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False) self.joint_training = configs["model_args"]["tse_model"].get( "joint_training", False ) def set_wavform_norm(self, wavform_norm: bool): self.wavform_norm = wavform_norm def set_resample_rate(self, resample_rate: int): self.resample_rate = resample_rate def set_vad(self, apply_vad: bool): self.apply_vad = apply_vad def set_device(self, device: str): self.device = torch.device(device) self.model = self.model.to(self.device) def set_output_norm(self, output_norm: bool): self.output_norm = output_norm def compute_fbank( self, wavform, sample_rate=16000, num_mel_bins=80, frame_length=25, frame_shift=10, cmn=True, ): feat = kaldi.fbank( wavform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, sample_frequency=sample_rate, ) if cmn: feat = feat - torch.mean(feat, 0) return feat def extract_speech(self, audio_path: str, audio_path_2: str): pcm_mix, sample_rate_mix = torchaudio.load( audio_path, normalize=self.wavform_norm ) pcm_enroll, sample_rate_enroll = torchaudio.load( audio_path_2, normalize=self.wavform_norm ) return self.extract_speech_from_pcm(pcm_mix, sample_rate_mix, pcm_enroll, sample_rate_enroll) def extract_speech_from_pcm(self, pcm_mix: torch.Tensor, sample_rate_mix: int, pcm_enroll: torch.Tensor, sample_rate_enroll: int): if self.apply_vad: # TODO(Binbin Zhang): Refine the segments logic, here we just # suppose there is only silence at the start/end of the speech # Only do vad on the enrollment vad_sample_rate = 16000 wav = pcm_enroll if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) if sample_rate_enroll != vad_sample_rate: transform = torchaudio.transforms.Resample( orig_freq=sample_rate_enroll, new_freq=vad_sample_rate ) wav = transform(wav) segments = get_speech_timestamps(wav, self.vad, return_seconds=True) pcmTotal = torch.Tensor() if len(segments) > 0: # remove all the silence for segment in segments: start = int(segment["start"] * sample_rate_enroll) end = int(segment["end"] * sample_rate_enroll) pcmTemp = pcm_enroll[0, start:end] pcmTotal = torch.cat([pcmTotal, pcmTemp], 0) pcm_enroll = pcmTotal.unsqueeze(0) else: # all silence, nospeech return None pcm_mix = pcm_mix.to(torch.float) if sample_rate_mix != self.resample_rate: pcm_mix = torchaudio.transforms.Resample( orig_freq=sample_rate_mix, new_freq=self.resample_rate )(pcm_mix) pcm_enroll = pcm_enroll.to(torch.float) if sample_rate_enroll != self.resample_rate: pcm_enroll = torchaudio.transforms.Resample( orig_freq=sample_rate_enroll, new_freq=self.resample_rate )(pcm_enroll) if self.joint_training: if self.speaker_feat: feats = self.compute_fbank( pcm_enroll, sample_rate=self.resample_rate, cmn=True ) feats = feats.unsqueeze(0) else: feats = pcm_enroll feats = feats.to(self.device) pcm_mix = pcm_mix.to(self.device) with torch.no_grad(): outputs = self.model(pcm_mix, feats) outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs target_speech = outputs.to(torch.device("cpu")) if self.output_norm: target_speech = ( target_speech / abs(target_speech).max(dim=1, keepdim=True).values * 0.9 ) return target_speech else: return None def load_model(language: str) -> Extractor: model_path = Hub.get_model(language) return Extractor(model_path) def load_model_local(model_dir: str) -> Extractor: return Extractor(model_dir) def main(): args = get_args() if args.pretrain == "": if args.bsrnn: model = load_model("bsrnn") else: model = load_model(args.language) else: model = load_model_local(args.pretrain) model.set_resample_rate(args.resample_rate) model.set_vad(args.vad) model.set_device(args.device) model.set_output_norm(args.output_norm) if args.task == "extraction": speech = model.extract_speech(args.audio_file, args.audio_file2) if speech is not None: if args.normalize_output: speech = speech / abs(speech).max(dim=1, keepdim=True).values * 0.9 soundfile.write(args.output_file, speech[0], args.resample_rate) print("Succeed, see {}".format(args.output_file)) else: print("Fails to extract the target speech") else: print("Unsupported task {}".format(args.task)) sys.exit(-1) if __name__ == "__main__": main() ================================================ FILE: wesep/cli/hub.py ================================================ # Copyright (c) 2022 Mddct(hamddct@gmail.com) # 2023 Binbin Zhang(binbzha@qq.com) # 2024 Shuai Wang(wsstriving@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from pathlib import Path import tarfile import zipfile from urllib.request import urlretrieve import tqdm def download(url: str, dest: str, only_child=True): """download from url to dest""" assert os.path.exists(dest) print("Downloading {} to {}".format(url, dest)) def progress_hook(t): last_b = [0] def update_to(b=1, bsize=1, tsize=None): if tsize not in (None, -1): t.total = tsize displayed = t.update((b - last_b[0]) * bsize) last_b[0] = b return displayed return update_to # *.tar.gz name = url.split("?")[0].split("/")[-1] file_path = os.path.join(dest, name) with tqdm.tqdm( unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name) ) as t: urlretrieve( url, filename=file_path, reporthook=progress_hook(t), data=None ) t.total = t.n if name.endswith((".tar.gz", ".tar")): with tarfile.open(file_path) as f: if not only_child: f.extractall(dest) else: for tarinfo in f: if "/" not in tarinfo.name: continue name = os.path.basename(tarinfo.name) fileobj = f.extractfile(tarinfo) with open(os.path.join(dest, name), "wb") as writer: writer.write(fileobj.read()) elif name.endswith(".zip"): with zipfile.ZipFile(file_path, "r") as zip_ref: if not only_child: zip_ref.extractall(dest) else: for member in zip_ref.namelist(): member_path = os.path.relpath( member, start=os.path.commonpath(zip_ref.namelist()) ) print(member_path) if "/" not in member_path: continue name = os.path.basename(member_path) with zip_ref.open(member_path) as source, open( os.path.join(dest, name), "wb" ) as target: target.write(source.read()) class Hub(object): Assets = { "english": "bsrnn_ecapa_vox1.tar.gz", } # Hard coding of the URL ModelURLs = { "bsrnn_ecapa_vox1.tar.gz": ( "https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/" "resolve/master/bsrnn_ecapa_vox1.tar.gz" ), } def __init__(self) -> None: pass @staticmethod def get_model(lang: str) -> str: if lang not in Hub.Assets.keys(): print("ERROR: Unsupported lang {} !!!".format(lang)) sys.exit(1) # model = Hub.Assets[lang] model_name = Hub.Assets[lang] model_dir = os.path.join(Path.home(), ".wesep", lang) if not os.path.exists(model_dir): os.makedirs(model_dir) if set(["avg_model.pt", "config.yaml"]).issubset( set(os.listdir(model_dir)) ): return model_dir else: if model_name in Hub.ModelURLs: model_url = Hub.ModelURLs[model_name] download(model_url, model_dir) return model_dir else: print(f"ERROR: No URL found for model {model_name}") return None ================================================ FILE: wesep/cli/utils.py ================================================ import argparse def get_args(): parser = argparse.ArgumentParser(description="") parser.add_argument( "-t", "--task", choices=[ "extraction", ], default="extraction", help="task type", ) parser.add_argument( "-l", "--language", choices=[ # "chinese", "english", ], default="english", help="language type", ) parser.add_argument( "--bsrnn", action="store_true", help="whether to use the bsrnn model", ) parser.add_argument( "-p", "--pretrain", type=str, default="", help="model directory" ) parser.add_argument( "--device", type=str, default="cpu", help="device type (most commonly cpu or cuda," "but also potentially mps, xpu, xla or meta)" "and optional device ordinal for the device type.", ) parser.add_argument("--audio_file", help="mixture's audio file") parser.add_argument("--audio_file2", help="enroll's audio file") parser.add_argument( "--resample_rate", type=int, default=16000, help="resampling rate" ) parser.add_argument( "--vad", action="store_true", help="whether to do VAD or not" ) parser.add_argument( "--output_file", default='./extracted_speech.wav', help="extracted speech saved in .wav" ) parser.add_argument( "--output_norm", default=True, help="Control if normalize the output audio in .wav" ) args = parser.parse_args() return args ================================================ FILE: wesep/dataset/FRAM_RIR.py ================================================ # Author: Rongzhi Gu, Yi Luo # Copyright: Tencent AI Lab # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from torchaudio.functional import highpass_biquad from torchaudio.transforms import Resample # set random seed seed = 20231 np.random.seed(seed) torch.manual_seed(seed) def calc_cos(orientation_rad): """ cos_theta: tensor, [azimuth, elevation] with shape [..., 2] return: [..., 3] """ return torch.stack( [ torch.cos( orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])), torch.sin( orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])), torch.cos(orientation_rad[..., 1]), ], -1, ) def freq_invariant_decay_func(cos_theta, pattern="cardioid"): """ cos_theta: tensor Return: amplitude: tensor with same shape as cos_theta """ if pattern == "cardioid": return 0.5 + 0.5 * cos_theta elif pattern == "omni": return torch.ones_like(cos_theta) elif pattern == "bidirectional": return cos_theta elif pattern == "hyper_cardioid": return 0.25 + 0.75 * cos_theta elif pattern == "sub_cardioid": return 0.75 + 0.25 * cos_theta elif pattern == "half_omni": c = torch.clamp(cos_theta, 0) c[c > 0] = 1.0 return c else: raise NotImplementedError def freq_invariant_src_decay_func(mic_pos, src_pos, src_orientation_rad, pattern="cardioid"): """ mic_pos: [n_mic, 3] (tensor) src_pos: [n_src, 3] (tensor) src_orientation_rad: [n_src, 2] (tensor), elevation, azimuth Return: amplitude: [n_mic, n_src, n_image] """ # Steering vector of source(s) orV_src = calc_cos(src_orientation_rad).unsqueeze(0) # [nsrc, 3] # receiver to src vector rcv_to_src_vec = mic_pos.unsqueeze(1) - src_pos.unsqueeze( 0) # [n_mic, n_src, 3] cos_theta = (rcv_to_src_vec * orV_src).sum(-1) # [n_mic, n_src] cos_theta /= torch.sqrt(rcv_to_src_vec.pow(2).sum(-1)) cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1)) return freq_invariant_decay_func(cos_theta, pattern) def freq_invariant_mic_decay_func(mic_pos, img_pos, mic_orientation_rad, pattern="cardioid"): """ mic_pos: [n_mic, 3] (tensor) img_pos: [n_src, n_image, 3] (tensor) mic_orientation_rad: [n_mic, 2] (tensor), azimuth, elevation Return: amplitude: [n_mic, n_src, n_image] """ # Steering vector of source(s) orV_src = calc_cos(mic_orientation_rad) # [nmic, 3] orV_src = orV_src.view(-1, 1, 1, 3) # [n_mic, 1, 1, 3] # image to receiver vector # [1, n_src, n_image, 3] - [n_mic, 1, 1, 3] => [n_mic, n_src, n_image, 3] img_to_rcv_vec = img_pos.unsqueeze(0) - mic_pos.unsqueeze(1).unsqueeze(1) cos_theta = (img_to_rcv_vec * orV_src).sum(-1) # [n_mic, n_src, n_image] cos_theta /= torch.sqrt(img_to_rcv_vec.pow(2).sum(-1)) cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1)) return freq_invariant_decay_func(cos_theta, pattern) def FRAM_RIR( mic_pos, sr, T60, room_dim, src_pos, num_src=1, direct_range=(-6, 50), n_image=(1024, 4097), a=-2.0, b=2.0, tau=0.25, src_pattern="omni", src_orientation_rad=None, mic_pattern="omni", mic_orientation_rad=None, ): """Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR) # noqa Args: mic_pos: The microphone(s) position with respect to the room coordinates, # noqa with shape [num_mic, 3] (in meters). Room coordinate system must be defined in advance, # noqa with the constraint that the origin of the coordinate is on the floor(so positive z axis points up). # noqa sr: RIR sampling rate (Hz). T60: RT60 (second). room_dim: Room size with shape [3] (meters). src_pos: The source(s) position with respect to the room coordinate system, with shape [num_src, 3] (meters). # noqa num_src: Number of sources. Defaults to 1. direct_range: 2-element tuple, range of early reflection time (milliseconds, # noqa defined as the context around the direct path signal) of RIRs. # noqa Defaults to (-6, 50). n_image: 2-element tuple, minimum and maximum number of images to sample from. # noqa Defaults to (1024, 4097). a: controlling the random perturbation added to each virtual sound source. Defaults to -2.0. # noqa b: controlling the random perturbation added to each virtual sound source. Defaults to 2.0. # noqa tau: controlling the relationship between the distance and the number of reflections of each # noqa virtual sound source. Defaults to 0.25. src_pattern: Polar pattern for all of the sources. Defaults to "omni". src_orientation_rad: Array-like with shape [num_src, 2]. Orientation (rad) of all # noqa the sources, where the first column indicate azimuth and the # noqa second column indicate elevation. Defaults to None. # noqa mic_pattern: Polar pattern for all of the receivers. Defaults to "omni". mic_orientation_rad: Array-like with shape [num_mic, 2]. Orientation (rad) of all # noqa the microphones, where the first column indicate azimuth and # noqa the second column indicate elevation. Defaults to None. # noqa Returns: rir: RIR filters for all mic-source pairs, with shape [num_mic, num_src, rir_length]. # noqa early_rir: Early reflection (direct path) RIR filters for all mic-source pairs, # noqa with shape [num_mic, num_src, rir_length]. """ # sample image image = np.random.choice(range(n_image[0], n_image[1])) R = torch.tensor( 1.0 / (2 * (1.0 / room_dim[0] + 1.0 / room_dim[1] + 1.0 / room_dim[2]))) eps = np.finfo(np.float16).eps mic_position = torch.from_numpy(mic_pos) src_position = torch.from_numpy(src_pos) # [nsource, 3] n_mic = mic_position.shape[0] num_src = src_position.shape[0] # [nmic, nsource] direct_dist = ((mic_position.unsqueeze(1) - src_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt() # [nsource] nearest_dist, nearest_mic_idx = direct_dist.min(0) # [nsource, 3] nearest_mic_position = mic_position[nearest_mic_idx] ns = n_mic * num_src ratio = 64 sample_sr = sr * ratio velocity = 340.0 T60 = torch.tensor(T60) direct_idx = (torch.ceil(direct_dist * sample_sr / velocity).long().view( ns, )) rir_length = int(np.ceil(sample_sr * T60)) resample1 = Resample(sample_sr, sample_sr // int(np.sqrt(ratio))) resample2 = Resample(sample_sr // int(np.sqrt(ratio)), sr) reflect_coef = (1 - (1 - torch.exp(-0.16 * R / T60)).pow(2)).sqrt() dist_range = [ torch.linspace(1.0, velocity * T60 / nearest_dist[i] - 1, rir_length) for i in range(num_src) ] dist_prob = torch.linspace(0.0, 1.0, rir_length) dist_prob /= dist_prob.sum() dist_select_idx = dist_prob.multinomial(num_samples=int(image * num_src), replacement=True).view( num_src, image) dist_nearest_ratio = torch.stack( [dist_range[i][dist_select_idx[i]] for i in range(num_src)], 0) # apply different dist ratios to mirophones azm = torch.FloatTensor(num_src, image).uniform_(-np.pi, np.pi) ele = torch.FloatTensor(num_src, image).uniform_(-np.pi / 2, np.pi / 2) # [nsource, nimage, 3] unit_3d = torch.stack( [ torch.sin(ele) * torch.cos(azm), torch.sin(ele) * torch.sin(azm), torch.cos(ele), ], -1, ) # [nsource] x [nsource, T] x [nsource, nimage, 3] => [nsource, nimage, 3] image2nearest_dist = nearest_dist.view( -1, 1, 1) * dist_nearest_ratio.unsqueeze(-1) image_position = (nearest_mic_position.unsqueeze(1) + image2nearest_dist * unit_3d) # [nmic, nsource, nimage] dist = ((mic_position.view(-1, 1, 1, 3) - image_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt() # reflection perturbation reflect_max = (torch.log10(velocity * T60) - 3) / torch.log10(reflect_coef) reflect_ratio = (dist / (velocity * T60)) * (reflect_max.view(1, -1, 1) - 1) + 1 reflect_pertub = torch.FloatTensor(num_src, image).uniform_( a, b) * dist_nearest_ratio.pow(tau) reflect_ratio = torch.maximum(reflect_ratio + reflect_pertub.unsqueeze(0), torch.ones(1)) # [nmic, nsource, 1 + nimage] dist = torch.cat([direct_dist.unsqueeze(2), dist], 2) reflect_ratio = torch.cat([torch.zeros(n_mic, num_src, 1), reflect_ratio], 2) delta_idx = (torch.minimum( torch.ceil(dist * sample_sr / velocity), torch.ones(1) * rir_length - 1, ).long().view(ns, -1)) delta_decay = reflect_coef.pow(reflect_ratio) / dist ################################# # source orientation simulation # ################################# if src_pattern != "omni": # randomly sample each image's relative orientation with respect to the original source # noqa # equivalent to a random decay corresponds to the source's orientation pattern decay # noqa img_orientation_rad = torch.FloatTensor(num_src, image, 2).uniform_(-np.pi, np.pi) img_cos_theta = torch.cos(img_orientation_rad[..., 0]) * torch.cos( img_orientation_rad[..., 1]) # [nsource, nimage] img_orientation_decay = freq_invariant_decay_func( img_cos_theta, pattern=src_pattern) # [nsource, nimage] # direct path orientation should use the provided parameter if src_orientation_rad is None: # assume random orientation if not given src_orientation_azi = torch.FloatTensor(num_src).uniform_( -np.pi, np.pi) src_orientation_ele = torch.FloatTensor(num_src).uniform_( -np.pi, np.pi) src_orientation_rad = torch.stack( [src_orientation_azi, src_orientation_ele], -1) else: src_orientation_rad = torch.from_numpy( src_orientation_rad) # [nsource, 2] src_orientation_decay = freq_invariant_src_decay_func( mic_position, src_position, src_orientation_rad, pattern=src_pattern, ) # [nmic, nsource] # apply decay delta_decay[:, :, 0] *= src_orientation_decay delta_decay[:, :, 1:] *= img_orientation_decay.unsqueeze(0) if mic_pattern != "omni": # mic orientation simulation # # when not given, assume that all mics facing up (positive z axis) if mic_orientation_rad is None: mic_orientation_rad = torch.stack( [torch.zeros(n_mic), torch.zeros(n_mic)], -1) # [nmic, 2] else: mic_orientation_rad = torch.from_numpy(mic_orientation_rad) all_src_img_pos = torch.cat( (src_position.unsqueeze(1), image_position), 1) # [nsource, nimage+1, 3] mic_orientation_decay = freq_invariant_mic_decay_func( mic_position, all_src_img_pos, mic_orientation_rad, pattern=mic_pattern, ) # [nmic, nsource, nimage+1] # apply decay delta_decay *= mic_orientation_decay rir = torch.zeros(ns, rir_length) delta_decay = delta_decay.view(ns, -1) for i in range(ns): remainder_idx = delta_idx[i] valid_mask = np.ones(len(remainder_idx)) while np.sum(valid_mask) > 0: valid_remainder_idx, unique_remainder_idx = np.unique( remainder_idx, return_index=True) rir[i][valid_remainder_idx] += ( delta_decay[i][unique_remainder_idx] * valid_mask[unique_remainder_idx]) valid_mask[unique_remainder_idx] = 0 remainder_idx[unique_remainder_idx] = 0 direct_mask = torch.zeros(ns, rir_length).float() for i in range(ns): direct_mask[ i, max(direct_idx[i] + sample_sr * direct_range[0] // 1000, 0 ):min(direct_idx[i] + sample_sr * direct_range[1] // 1000, rir_length), ] = 1.0 rir_direct = rir * direct_mask all_rir = torch.stack([rir, rir_direct], 1).view(ns * 2, -1) rir_downsample = resample1(all_rir) rir_hp = highpass_biquad(rir_downsample, sample_sr // int(np.sqrt(ratio)), 80.0) rir = resample2(rir_hp).float().view(n_mic, num_src, 2, -1) return rir[:, :, 0].data.numpy(), rir[:, :, 1].data.numpy() def sample_mic_arch(n_mic, mic_spacing=None, bounding_box=None): if mic_spacing is None: mic_spacing = [0.02, 0.10] if bounding_box is None: bounding_box = [0.08, 0.12, 0] sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1) if sample_n_mic == 1: mic_arch = np.array([[0, 0, 0]]) else: mic_arch = [] while len(mic_arch) < sample_n_mic: this_mic_pos = np.random.uniform(np.array([0, 0, 0]), np.array(bounding_box)) if len(mic_arch) != 0: ok = True for other_mic_pos in mic_arch: this_mic_spacing = np.linalg.norm(this_mic_pos - other_mic_pos) if (this_mic_spacing < mic_spacing[0] or this_mic_spacing > mic_spacing[1]): ok = False break if ok: mic_arch.append(this_mic_pos) else: mic_arch.append(this_mic_pos) mic_arch = np.stack(mic_arch, 0) # [nmic, 3] return mic_arch def sample_src_pos( room_dim, num_src, array_pos, min_mic_dis=0.5, max_mic_dis=5, min_dis_wall=None, ): if min_dis_wall is None: min_dis_wall = [0.5, 0.5, 0.5] # random sample the source positon src_pos = [] while len(src_pos) < num_src: pos = np.random.uniform(np.array(min_dis_wall), np.array(room_dim) - np.array(min_dis_wall)) dis = np.linalg.norm(pos - np.array(array_pos)) if dis >= min_mic_dis and dis <= max_mic_dis: src_pos.append(pos) return np.stack(src_pos, 0) def sample_mic_array_pos(mic_arch, room_dim, min_dis_wall=None): """ Generate the microphone array position according to the given microphone architecture (geometry) # noqa :param mic_arch: np.array with shape [n_mic, 3] the relative 3D coordinate to the array_pos in (m) e.g., 2-mic linear array [[-0.1, 0, 0], [0.1, 0, 0]]; e.g., 4-mic circular array [[0, 0.035, 0], [0.035, 0, 0], [0, -0.035, 0], [-0.035, 0, 0]] # noqa :param min_dis_wall: minimum distance from the wall in (m) :return mic_pos: microphone array position in (m) with shape [n_mic, 3] array_pos: array CENTER / REFERENCE position in (m) with shape [1, 3] """ def rotate(angle, valuex, valuey): rotate_x = valuex * np.cos(angle) + valuey * np.sin(angle) # [nmic] rotate_y = valuey * np.cos(angle) - valuex * np.sin(angle) return np.stack( [rotate_x, rotate_y, np.zeros_like(rotate_x)], -1) # [nmic, 3] if min_dis_wall is None: min_dis_wall = [0.5, 0.5, 0.5] if isinstance(mic_arch, dict): # ADHOC ARRAY n_mic, mic_spacing, bounding_box = ( mic_arch["n_mic"], mic_arch["spacing"], mic_arch["bounding_box"], ) sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1) if sample_n_mic == 1: mic_arch = np.array([[0, 0, 0]]) else: mic_arch = [ np.random.uniform(np.array([0, 0, 0]), np.array(bounding_box)) ] while len(mic_arch) < sample_n_mic: this_mic_pos = np.random.uniform(np.array([0, 0, 0]), np.array(bounding_box)) ok = True for other_mic_pos in mic_arch: this_mic_spacing = np.linalg.norm(this_mic_pos - other_mic_pos) if (this_mic_spacing < mic_spacing[0] or this_mic_spacing > mic_spacing[1]): ok = False break if ok: mic_arch.append(this_mic_pos) mic_arch = np.stack(mic_arch, 0) # [nmic, 3] else: mic_arch = np.array(mic_arch) mic_array_center = np.mean(mic_arch, 0, keepdims=True) # [1, 3] max_radius = max(np.linalg.norm(mic_arch - mic_array_center, axis=-1)) array_pos = np.random.uniform( np.array(min_dis_wall) + max_radius, np.array(room_dim) - np.array(min_dis_wall) - max_radius, ).reshape(1, 3) mic_pos = array_pos + mic_arch # assume the array is always horizontal rotate_azm = np.random.uniform(-np.pi, np.pi) mic_pos = array_pos + rotate(rotate_azm, mic_arch[:, 0], mic_arch[:, 1]) # [n_mic, 3] return mic_pos, array_pos def sample_a_config(simu_config): room_config = simu_config["min_max_room"] rt60_config = simu_config["rt60"] mic_dist_config = simu_config["mic_dist"] num_src = simu_config["num_src"] room_dim = np.random.uniform(np.array(room_config[0]), np.array(room_config[1])) rt60 = np.random.uniform(rt60_config[0], rt60_config[1]) sr = simu_config["sr"] if ("array_pos" not in simu_config.keys()): # mic_arch must be given in this case mic_arch = simu_config["mic_arch"] mic_pos, array_pos = sample_mic_array_pos(mic_arch, room_dim) else: array_pos = simu_config["array_pos"] if "src_pos" not in simu_config.keys(): src_pos = sample_src_pos( room_dim, num_src, array_pos, min_mic_dis=mic_dist_config[0], max_mic_dis=mic_dist_config[1], ) else: src_pos = np.array(simu_config["src_pos"]) return mic_pos, sr, rt60, room_dim, src_pos, array_pos # === single-channel FRA-RIR === def single_channel(simu_config): mic_arch = {"n_mic": [1, 1], "spacing": None, "bounding_box": None} simu_config["mic_arch"] = mic_arch mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config( simu_config) rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos, array_pos) # with shape [1, n_src, rir_len] return rir, rir_direct # === multi-channel (fixed) === def multi_channel_array(simu_config): mic_arch = [[-0.05, 0, 0], [0.05, 0, 0]] simu_config["mic_arch"] = mic_arch mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config( simu_config) rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos) # with shape [n_mic, n_src, rir_len] return rir, rir_direct # === multi-channel (adhoc) === def multi_channel_adhoc(simu_config): mic_arch = { "n_mic": [1, 3], "spacing": [0.02, 0.05], "bounding_box": [0.5, 1.0, 0], # x, y, z } simu_config["mic_arch"] = mic_arch mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config( simu_config) rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos) # with shape [sample_n_mic, n_src, rir_len] return rir, rir_direct def multi_channel_src_orientation(): """ ========================= → y axis | | | *1 *2 | | | | ↑ | | | | *3 *4 | | | ========================= ↓ x axis """ sr = 16000 rt60 = 0.6 room_dim = [8, 8, 3] src_pos = np.array([[4, 4, 1.5]]) # middle of the room mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5], [6, 6, 1.5]] # mic 1, 2 ) # mic 3, 4 src_pattern = "sub_cardioid" src_orientation_rad = (np.array([180, 90]) / 180.0 * np.pi ) # facing *front* (negative x axis) rir, rir_direct = FRAM_RIR( mic_pos, sr, rt60, room_dim=room_dim, src_pos=src_pos, src_pattern=src_pattern, src_orientation_rad=src_orientation_rad, ) return rir, rir_direct def multi_channel_mic_orientation(): """ ========================= → y axis | | | ↑1 ↓2 | | | | o | | | | ↑3 ↓4 | | | ========================= ↓ x axis """ sr = 16000 rt60 = 0.6 room_dim = [8, 8, 3] src_pos = np.array([[4, 4, 1.5]]) # middle of the room mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5], [6, 6, 1.5]] # mic 1, 2 ) # mic 3, 4 mic_pattern = "sub_cardioid" mic_orientation_rad = ( np.array([ [180, 90], [0, 90], # mic 1 (negative x axis), 2 (positive x axis) [180, 90], [0, 90], ]) / 180.0 * np.pi) # mic 3 (negative x axis), 4 (positive x axis) rir, rir_direct = FRAM_RIR( mic_pos, sr, rt60, room_dim=room_dim, src_pos=src_pos, mic_pattern=mic_pattern, mic_orientation_rad=mic_orientation_rad, ) return rir, rir_direct ================================================ FILE: wesep/dataset/dataset.py ================================================ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # 2023 Shuai Wang (wsstriving@gmail.com) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import torch import torch.distributed as dist import torch.nn.functional as tf from torch.utils.data import IterableDataset import wesep.dataset.processor as processor from wesep.utils.file_utils import read_lists class Processor(IterableDataset): def __init__(self, source, f, *args, **kw): assert callable(f) self.source = source self.f = f self.args = args self.kw = kw def set_epoch(self, epoch): self.source.set_epoch(epoch) def __iter__(self): """Return an iterator over the source dataset processed by the given processor. """ assert self.source is not None assert callable(self.f) return self.f(iter(self.source), *self.args, **self.kw) def apply(self, f): assert callable(f) return Processor(self, f, *self.args, **self.kw) class DistributedSampler: def __init__(self, shuffle=True, partition=True): self.epoch = -1 self.update() self.shuffle = shuffle self.partition = partition def update(self): assert dist.is_available() if dist.is_initialized(): self.rank = dist.get_rank() self.world_size = dist.get_world_size() else: self.rank = 0 self.world_size = 1 worker_info = torch.utils.data.get_worker_info() if worker_info is None: self.worker_id = 0 self.num_workers = 1 else: self.worker_id = worker_info.id self.num_workers = worker_info.num_workers return dict( rank=self.rank, world_size=self.world_size, worker_id=self.worker_id, num_workers=self.num_workers, ) def set_epoch(self, epoch): self.epoch = epoch def sample(self, data): """Sample data according to rank/world_size/num_workers Args: data(List): input data list Returns: List: data list after sample """ data = list(range(len(data))) if len(data) <= self.num_workers: if self.shuffle: random.Random(self.epoch).shuffle(data) else: if self.partition: if self.shuffle: random.Random(self.epoch).shuffle(data) data = data[self.rank::self.world_size] data = data[self.worker_id::self.num_workers] return data class DataList(IterableDataset): def __init__(self, lists, shuffle=True, partition=True, repeat_dataset=False): self.lists = lists self.repeat_dataset = repeat_dataset self.sampler = DistributedSampler(shuffle, partition) def set_epoch(self, epoch): self.sampler.set_epoch(epoch) def __iter__(self): sampler_info = self.sampler.update() indexes = self.sampler.sample(self.lists) if not self.repeat_dataset: for index in indexes: data = dict(src=self.lists[index]) data.update(sampler_info) yield data else: indexes_len = len(indexes) counter = 0 while True: index = indexes[counter % indexes_len] counter += 1 data = dict(src=self.lists[index]) data.update(sampler_info) yield data def tse_collate_fn_2spk(batch, mode="min"): # Warning: hard-coded for 2 speakers, will be deprecated in the future, # use tse_collate_fn instead new_batch = {} wav_mix = [] wav_targets = [] spk_embeds = [] spk = [] key = [] spk_label = [] length_spk_embeds = [] for s in batch: wav_mix.append(s["wav_mix"]) wav_targets.append(s["wav_spk1"]) spk.append(s["spk1"]) key.append(s["key"]) spk_embeds.append(torch.from_numpy(s["embed_spk1"].copy())) length_spk_embeds.append(spk_embeds[-1].shape[1]) if "spk1_label" in s.keys(): spk_label.append(s["spk1_label"]) wav_mix.append(s["wav_mix"]) wav_targets.append(s["wav_spk2"]) spk.append(s["spk2"]) key.append(s["key"]) spk_embeds.append(torch.from_numpy(s["embed_spk2"].copy())) length_spk_embeds.append(spk_embeds[-1].shape[1]) if "spk2_label" in s.keys(): spk_label.append(s["spk2_label"]) if not (len(set(length_spk_embeds)) == 1): if mode == "max": max_len = max(length_spk_embeds) for i in range(len(length_spk_embeds)): if len(spk_embeds[i].shape) == 2: spk_embeds[i] = tf.pad( spk_embeds[i], (0, max_len - length_spk_embeds[i]), "constant", 0, ) elif len(spk_embeds[i].shape) == 3: spk_embeds[i] = tf.pad( spk_embeds[i], (0, 0, 0, max_len - length_spk_embeds[i]), "constant", 0, ) if mode == "min": min_len = min(length_spk_embeds) for i in range(len(length_spk_embeds)): if len(spk_embeds[i].shape) == 2: spk_embeds[i] = spk_embeds[i][:, :min_len] elif len(spk_embeds[i].shape) == 3: spk_embeds[i] = spk_embeds[i][:, :min_len, :] new_batch["wav_mix"] = torch.concat(wav_mix) new_batch["wav_targets"] = torch.concat(wav_targets) new_batch["spk_embeds"] = torch.concat(spk_embeds) new_batch["length_spk_embeds"] = length_spk_embeds new_batch["spk"] = spk new_batch["key"] = key new_batch["spk_label"] = torch.as_tensor(spk_label) return new_batch def tse_collate_fn(batch, mode="min"): # This is a more generalizable implementation for target speaker extraction # Support arbitrary number of speakers new_batch = {} wav_mix = [] wav_targets = [] spk_embeds = [] spk = [] key = [] spk_label = [] length_spk_embeds = [] for s in batch: for i in range(s["num_speaker"]): wav_mix.append(s["wav_mix"]) wav_targets.append(s["wav_spk{}".format(i + 1)]) spk.append(s["spk{}".format(i + 1)]) key.append(s["key"]) spk_embeds.append( torch.from_numpy(s["embed_spk{}".format(i + 1)].copy())) length_spk_embeds.append(spk_embeds[-1].shape[1]) if "spk{}_label".format(i + 1) in s.keys(): spk_label.append(s["spk{}_label".format(i + 1)]) if not (len(set(length_spk_embeds)) == 1): if mode == "max": max_len = max(length_spk_embeds) for i in range(len(length_spk_embeds)): if len(spk_embeds[i].shape) == 2: spk_embeds[i] = tf.pad( spk_embeds[i], (0, max_len - length_spk_embeds[i]), "constant", 0, ) elif len(spk_embeds[i].shape) == 3: spk_embeds[i] = tf.pad( spk_embeds[i], (0, 0, 0, max_len - length_spk_embeds[i]), "constant", 0, ) if mode == "min": min_len = min(length_spk_embeds) for i in range(len(length_spk_embeds)): if len(spk_embeds[i].shape) == 2: spk_embeds[i] = spk_embeds[i][:, :min_len] elif len(spk_embeds[i].shape) == 3: spk_embeds[i] = spk_embeds[i][:, :min_len, :] new_batch["wav_mix"] = torch.concat(wav_mix) new_batch["wav_targets"] = torch.concat(wav_targets) new_batch["spk_embeds"] = torch.concat(spk_embeds) new_batch["length_spk_embeds"] = ( length_spk_embeds # Not used, but maybe needed when using the enrollment utterance # noqa ) new_batch["spk"] = spk new_batch["key"] = key new_batch["spk_label"] = torch.as_tensor(spk_label) return new_batch def Dataset( data_type, data_list_file, configs, spk2embed_dict=None, spk1_embed=None, spk2_embed=None, state="train", joint_training=False, dict_spk=None, whole_utt=False, repeat_dataset=False, noise_prob=0, reverb_prob=0, noise_enroll_prob=0, reverb_enroll_prob=0, specaug_enroll_prob=0, noise_lmdb_file=None, online_mix=False, ): """Construct dataset from arguments We have two shuffle stage in the Dataset. The first is global shuffle at shards tar/raw/feat file level. The second is local shuffle at training samples level. Args: :param spk2_embed: :param online_mix: :param spk1_embed: :param data_type(str): shard/raw/feat :param data_list_file: data list file :param configs: dataset configs :param noise_prob:probility to add noise on mixture :param reverb_prob:probility to add reverb on mixture :param noise_enroll_prob:probility to add noise on enrollment speech :param reverb_enroll_prob:probility to add reverb on enrollment speech :param specaug_enroll_prob: probility to apply SpecAug on fbank of enrollment speech # noqa :param noise_lmdb_file: noise data source lmdb file :param whole_utt: use whole utt or random chunk :param repeat_dataset: """ assert data_type in ["shard", "raw"] lists = read_lists(data_list_file) shuffle = configs.get("shuffle", False) # Global shuffle dataset = DataList(lists, shuffle=shuffle, repeat_dataset=repeat_dataset) if data_type == "shard": dataset = Processor(dataset, processor.url_opener) if not online_mix: dataset = Processor(dataset, processor.tar_file_and_group) else: dataset = Processor(dataset, processor.tar_file_and_group_single_spk) else: dataset = Processor(dataset, processor.parse_raw) if configs.get("filter_len", False) and state == "train": # Filter the data with unwanted length filter_conf = configs.get("filter_args", {}) dataset = Processor(dataset, processor.filter_len, **filter_conf) # Local shuffle if shuffle and not online_mix: dataset = Processor(dataset, processor.shuffle, **configs["shuffle_args"]) # resample resample_rate = configs.get("resample_rate", 16000) dataset = Processor(dataset, processor.resample, resample_rate) if not whole_utt: # random chunk chunk_len = configs.get("chunk_len", resample_rate * 3) dataset = Processor(dataset, processor.random_chunk, chunk_len) if online_mix: dataset = Processor( dataset, processor.mix_speakers, configs.get("num_speakers", 2), configs.get("online_buffer_size", 1000), ) if reverb_prob > 0: dataset = Processor(dataset, processor.add_reverb, reverb_prob) dataset = Processor( dataset, processor.snr_mixer, configs.get("use_random_snr", False), ) if noise_prob > 0: assert noise_lmdb_file is not None dataset = Processor(dataset, processor.add_noise, noise_lmdb_file, noise_prob) speaker_feat = configs.get("speaker_feat", False) if state == "train": if not joint_training: dataset = Processor(dataset, processor.sample_spk_embedding, spk2embed_dict) else: dataset = Processor(dataset, processor.sample_enrollment, spk2embed_dict, dict_spk) if reverb_enroll_prob > 0: dataset = Processor(dataset, processor.add_reverb_on_enroll, reverb_enroll_prob) if noise_enroll_prob > 0: assert noise_lmdb_file is not None dataset = Processor( dataset, processor.add_noise_on_enroll, noise_lmdb_file, noise_enroll_prob, ) if speaker_feat: dataset = Processor(dataset, processor.compute_fbank, **configs["fbank_args"]) dataset = Processor(dataset, processor.apply_cmvn) if specaug_enroll_prob > 0: dataset = Processor(dataset, processor.spec_aug, prob=specaug_enroll_prob) else: if not joint_training: dataset = Processor( dataset, processor.sample_fix_spk_embedding, spk2embed_dict, spk1_embed, spk2_embed, ) else: dataset = Processor( dataset, processor.sample_fix_spk_enrollment, spk2embed_dict, spk1_embed, spk2_embed, dict_spk, ) if speaker_feat: dataset = Processor(dataset, processor.compute_fbank, **configs["fbank_args"]) dataset = Processor(dataset, processor.apply_cmvn) return dataset ================================================ FILE: wesep/dataset/lmdb_data.py ================================================ # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pickle import random import lmdb class LmdbData: def __init__(self, lmdb_file): self.db = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False) with self.db.begin(write=False) as txn: obj = txn.get(b"__keys__") assert obj is not None self.keys = pickle.loads(obj) assert isinstance(self.keys, list) def random_one(self): assert len(self.keys) > 0 index = random.randint(0, len(self.keys) - 1) key = self.keys[index] with self.db.begin(write=False) as txn: value = txn.get(key.encode()) assert value is not None return key, value def __del__(self): self.db.close() if __name__ == "__main__": import sys db = LmdbData(sys.argv[1]) key, _ = db.random_one() print(key) key, _ = db.random_one() print(key) ================================================ FILE: wesep/dataset/processor.py ================================================ import io import json import logging import random import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import librosa import numpy as np import soundfile as sf import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from scipy import signal from wesep.dataset.FRAM_RIR import single_channel as RIR_sim from wesep.dataset.lmdb_data import LmdbData AUDIO_FORMAT_SETS = {"flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"} # set the simulation configuration simu_config = { "min_max_room": [[3, 3, 2.5], [10, 6, 4]], "rt60": [0.1, 0.7], "sr": 16000, "mic_dist": [0.2, 5.0], "num_src": 1, } def url_opener(data): """Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert "src" in sample # TODO(Binbin Zhang): support HTTP url = sample["src"] try: pr = urlparse(url) # local file if pr.scheme == "" or pr.scheme == "file": stream = open(url, "rb") # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f"wget -q -O - {url}" process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream) yield sample except Exception as ex: logging.warning("Failed to open {}".format(url)) def tar_file_and_group(data): """Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, mix_wav, spk1_wav, spk2_wav, ..., sample_rate}] """ for sample in data: assert "stream" in sample stream = tarfile.open(fileobj=sample["stream"], mode="r:*") # TODO: The mode need to be validated # In order to be compatible with the torch 2.x version, # the file reading method here does not use streaming. prev_prefix = None example = {} num_speakers = 0 valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind(".") assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prev_prefix not in prefix: example["key"] = prev_prefix if valid: example["num_speaker"] = num_speakers num_speakers = 0 yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if "spk" in postfix: example[postfix] = ( file_obj.read().decode("utf8").strip()) num_speakers += 1 elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) if prefix[-5:-1] == "_spk": example["wav" + prefix[-5:]] = waveform prefix = prefix[:-5] else: example["wav_mix"] = waveform example["sample_rate"] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning("error to parse {}".format(name)) prev_prefix = prefix if prev_prefix is not None: example["key"] = prev_prefix example["num_speaker"] = num_speakers num_speakers = 0 yield example stream.close() if "process" in sample: sample["process"].communicate() sample["stream"].close() def tar_file_and_group_single_spk(data): """Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, spk, sample_rate}] """ for sample in data: assert "stream" in sample stream = tarfile.open(fileobj=sample["stream"], mode="r|*") # Only support pytorch version <2.0 prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind(".") assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example["key"] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix in ["spk"]: example[postfix] = ( file_obj.read().decode("utf8").strip()) elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example["wav"] = waveform example["sample_rate"] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning("error to parse {}".format(name)) prev_prefix = prefix if prev_prefix is not None: example["key"] = prev_prefix yield example stream.close() if "process" in sample: sample["process"].communicate() sample["stream"].close() def parse_raw_single_spk(data): """Parse key/wav/spk from json line Args: data: Iterable[str], str is a json line has key/wav/spk Returns: Iterable[{key, wav, spk, sample_rate}] """ for sample in data: assert "src" in sample json_line = sample["src"] obj = json.loads(json_line) assert "key" in obj assert "wav" in obj assert "spk" in obj key = obj["key"] wav_file = obj["wav"] spk = obj["spk"] try: waveform, sample_rate = torchaudio.load(wav_file) example = dict(key=key, spk=spk, wav=waveform, sample_rate=sample_rate) yield example except Exception as ex: logging.warning("Failed to read {}".format(wav_file)) def mix_speakers(data, num_speaker=2, shuffle_size=1000): """Dynamic mixing speakers when loading data, shuffle is not needed if this function is used Args: :param data: Iterable[{key, wavs, spks}] :param num_speaker: :param use_random_snr: :param shuffle_size: Returns: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) for x in buf: cur_spk = x["spk"] example = { "key": x["key"], "wav_spk1": x["wav"], "spk1": x["spk"], "sample_rate": x["sample_rate"], } key = "mix_" + x["key"] interference_idx = 1 while interference_idx < num_speaker: interference = random.choice(buf) while interference["spk"] == cur_spk: interference = random.choice(buf) key = key + "_" + interference["key"] interference_idx += 1 example["wav_spk" + str(interference_idx)] = interference["wav"] example["spk" + str(interference_idx)] = interference["spk"] example["key"] = key example["num_speaker"] = num_speaker yield example buf = [] # The samples left over random.shuffle(buf) for x in buf: cur_spk = x["spk"] example = { "key": x["key"], "wav_spk1": x["wav"], "spk1": x["spk"], "sample_rate": x["sample_rate"], } key = "mix_" + x["key"] interference_idx = 1 while interference_idx < num_speaker: interference = random.choice(buf) while interference["spk"] == cur_spk: interference = random.choice(buf) key = key + "_" + interference["key"] interference_idx += 1 example["wav_spk" + str(interference_idx)] = interference["wav"] example["spk" + str(interference_idx)] = interference["spk"] example["key"] = key example["num_speaker"] = num_speaker yield example def snr_mixer(data, use_random_snr: bool = False): """Dynamic mixing speakers when loading data, shuffle is not needed if this function is used. # noqa Args: data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] use_random_snr (bool, optional): Whether use random SNR to mix speeches. Defaults to False. # noqa Returns: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] """ for sample in data: assert "num_speaker" in sample.keys() if "wav_spk1_reverb" in sample.keys(): suffix = "_reverb" else: suffix = "" num_speaker = sample["num_speaker"] wavs_to_mix = [sample["wav_spk1" + suffix]] target_energy = torch.sum(wavs_to_mix[0]**2, dim=-1, keepdim=True) for i in range(1, num_speaker): interference = sample[f"wav_spk{i + 1}" + suffix] if use_random_snr: snr = random.uniform(-10, 10) else: snr = 0 energy = torch.sum(interference**2, dim=-1, keepdim=True) interference *= torch.sqrt(target_energy / energy) * 10**(snr / 20) wavs_to_mix.append(interference) wavs_to_mix = torch.stack(wavs_to_mix) sample["wav_mix"] = torch.sum(wavs_to_mix, 0) max_amp = max( torch.abs(sample["wav_mix"]).max().item(), *[x.item() for x in torch.abs(wavs_to_mix).max(dim=-1)[0]], ) if max_amp != 0: mix_scaling = 1 / max_amp else: mix_scaling = 1 sample["wav_mix"] = sample["wav_mix"] * mix_scaling for i in range(0, num_speaker): sample[f"wav_spk{i + 1}" + suffix] *= mix_scaling yield sample def shuffle(data, shuffle_size=2500): """Local shuffle the data Args: data: Iterable[{key, wavs, spks}] shuffle_size: buffer size for shuffle Returns: Iterable[{key, wavs, spks}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) for x in buf: yield x buf = [] # The sample left over random.shuffle(buf) for x in buf: yield x def spk_to_id(data, spk2id): """Parse spk id Args: data: Iterable[{key, wav/feat, spk}] spk2id: Dict[str, int] Returns: Iterable[{key, wav/feat, label}] """ for sample in data: assert "spk" in sample if sample["spk"] in spk2id: label = spk2id[sample["spk"]] else: label = -1 sample["label"] = label yield sample def resample(data, resample_rate=16000): """Resample data. Inplace operation. Args: data: Iterable[{key, wavs, spks, sample_rate}] resample_rate: target resample rate Returns: Iterable[{key, wavs, spks, sample_rate}] """ for sample in data: assert "sample_rate" in sample sample_rate = sample["sample_rate"] if sample_rate != resample_rate: all_keys = list(sample.keys()) sample["sample_rate"] = resample_rate for key in all_keys: if "wav" in key: waveform = sample[key] sample[key] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) yield sample def sample_spk_embedding(data, spk_embeds): """sample reference speaker embeddings for the target speaker Args: data: Iterable[{key, wav, label, sample_rate}] spk_embeds: dict which stores all potential embeddings for the speaker Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("spk"): sample["embed_" + key] = random.choice(spk_embeds[sample[key]]) yield sample def sample_fix_spk_embedding(data, spk2embed_dict, spk1_embed, spk2_embed): """sample reference speaker embeddings for the target speaker Args: data: Iterable[{key, wav, label, sample_rate}] spk_embeds: dict which stores all potential embeddings for the speaker Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("spk"): if key == "spk1": sample["embed_" + key] = spk2embed_dict[spk1_embed[sample["key"]]] else: sample["embed_" + key] = spk2embed_dict[spk2_embed[sample["key"]]] yield sample def sample_enrollment(data, spk_embeds, dict_spk): """sample reference speech for the target speaker Args: data: Iterable[{key, wav, label, sample_rate}] spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker # noqa dict_spk: dict of speakers in the enrollment sets [Order: spkID] Returns: Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment), # noqa spk_lable(when multi-task training)}] """ for sample in data: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("spk"): enrollment, _ = sf.read( random.choice(spk_embeds[sample[key]])[1]) sample["embed_" + key] = np.expand_dims(enrollment, axis=0) if dict_spk: sample[key + "_label"] = dict_spk[sample[key]] yield sample def sample_fix_spk_enrollment(data, spk2embed_dict, spk1_embed, spk2_embed, dict_spk=None): """sample reference speaker embeddings for the target speaker Args: data: Iterable[{key, wav, label, sample_rate}] spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker # noqa dict_spk: dict of speakers in the enrollment sets [Order: spkID] Returns: Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment), # noqa spk_lable(when multi-task training)}] """ for sample in data: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("spk"): if key == "spk1": enrollment, _ = sf.read( spk2embed_dict[spk1_embed[sample["key"]]]) else: enrollment, _ = sf.read( spk2embed_dict[spk2_embed[sample["key"]]]) sample["embed_" + key] = np.expand_dims(enrollment, axis=0) if dict_spk: sample[key + "_label"] = dict_spk[sample[key]] yield sample def compute_fbank(data, num_mel_bins=80, frame_length=25, frame_shift=10, dither=1.0): """Extract fbank Args: data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa Returns: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa """ for sample in data: assert "sample_rate" in sample sample_rate = sample["sample_rate"] all_keys = list(sample.keys()) for key in all_keys: if key.startswith("embed"): waveform = torch.from_numpy(sample[key]) waveform = waveform * (1 << 15) mat = kaldi.fbank( waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, sample_frequency=sample_rate, window_type="hamming", use_energy=False, ) sample[key] = mat yield sample def apply_cmvn(data, norm_mean=True, norm_var=False): """Apply CMVN Args: data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa Returns: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa """ for sample in data: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("embed"): mat = sample[key] if norm_mean: mat = mat - torch.mean(mat, dim=0) if norm_var: mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8) mat = mat.unsqueeze(0) sample[key] = mat.detach().numpy() yield sample def get_random_chunk(data_list, chunk_len): """Get random chunk Args: data_list: [torch.Tensor: 1XT] (random len) chunk_len: chunk length Returns: [torch.Tensor] (exactly chunk_len) """ # Assert all entries in the list share the same length assert False not in [len(i) == len(data_list[0]) for i in data_list] data_list = [data[0] for data in data_list] data_len = len(data_list[0]) # random chunk if data_len >= chunk_len: chunk_start = random.randint(0, data_len - chunk_len) for i in range(len(data_list)): temp_data = data_list[i][chunk_start:chunk_start + chunk_len] while torch.equal(temp_data, torch.zeros_like(temp_data)): chunk_start = random.randint(0, data_len - chunk_len) temp_data = data_list[i][chunk_start:chunk_start + chunk_len] data_list[i] = temp_data # re-clone the data to avoid memory leakage if type(data_list[i]) == torch.Tensor: data_list[i] = data_list[i].clone() else: # np.array data_list[i] = data_list[i].copy() else: # padding repeat_factor = chunk_len // data_len + 1 for i in range(len(data_list)): if type(data_list[i]) == torch.Tensor: data_list[i] = data_list[i].repeat(repeat_factor) else: # np.array data_list[i] = np.tile(data_list[i], repeat_factor) data_list[i] = data_list[i][:chunk_len] data_list = [data.unsqueeze(0) for data in data_list] return data_list def filter_len( data, min_num_seconds=1, max_num_seconds=1000, ): """Filter the utterance with very short duration and random chunk the utterance with very long duration. Args: data: Iterable[{key, wav, label, sample_rate}] min_num_seconds: minimum number of seconds of wav file max_num_seconds: maximum number of seconds of wav file Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: assert "key" in sample assert "sample_rate" in sample assert "wav" in sample sample_rate = sample["sample_rate"] wav = sample["wav"] min_len = min_num_seconds * sample_rate max_len = max_num_seconds * sample_rate if wav.size(1) < min_len: continue elif wav.size(1) > max_len: wav = get_random_chunk([wav], max_len)[0] sample["wav"] = wav yield sample def random_chunk(data, chunk_len): """Random chunk the data into chunk_len Args: data: Iterable[{key, wav/feat, label}] chunk_len: chunk length for each sample Returns: Iterable[{key, wav/feat, label}] """ for sample in data: assert "key" in sample wav_keys = [key for key in list(sample.keys()) if "wav" in key] wav_data_list = [sample[key] for key in wav_keys] wav_data_list = get_random_chunk(wav_data_list, chunk_len) sample.update(zip(wav_keys, wav_data_list)) yield sample def fix_chunk(data, chunk_len): """Random chunk the data into chunk_len Args: data: Iterable[{key, wav/feat, label}] chunk_len: chunk length for each sample Returns: Iterable[{key, wav/feat, label}] """ for sample in data: assert "key" in sample all_keys = list(sample.keys()) for key in all_keys: if key.startswith("wav"): sample[key] = sample[key][:, :chunk_len] yield sample def add_noise( data, noise_lmdb_file, noise_prob: float = 0.0, noise_db_low: int = -5, noise_db_high: int = 25, single_channel: bool = True, ): """Add noise to mixture Args: data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] noise_lmdb_file: noise LMDB data source. noise_db_low (int, optional): SNR lower bound. Defaults to -5. noise_db_high (int, optional): SNR upper bound. Defaults to 25. single_channel (bool, optional): Whether to force the noise file to be single channel. # noqa Defaults to True. Returns: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}] # noqa """ noise_source = LmdbData(noise_lmdb_file) for sample in data: if noise_prob > random.random(): assert "sample_rate" in sample.keys() tgt_fs = sample["sample_rate"] speech = sample["wav_mix"].numpy() # [1, nsamples] nsamples = speech.shape[1] power = (speech**2).mean() noise_key, noise_data = noise_source.random_one() if noise_key.startswith( "speech"): # using interference speech as additive noise snr_range = [10, 30] else: snr_range = [noise_db_low, noise_db_high] noise_db = np.random.uniform(snr_range[0], snr_range[1]) with sf.SoundFile(io.BytesIO(noise_data)) as f: fs = f.samplerate if tgt_fs and fs != tgt_fs: nsamples_ = int(nsamples / tgt_fs * fs) + 1 else: nsamples_ = nsamples if f.frames == nsamples_: noise = f.read(dtype=np.float64, always_2d=True) elif f.frames < nsamples_: offset = np.random.randint(0, nsamples_ - f.frames) # noise: (Time, Nmic) noise = f.read(dtype=np.float64, always_2d=True) # Repeat noise noise = np.pad( noise, [(offset, nsamples_ - f.frames - offset), (0, 0)], mode="wrap", ) else: offset = np.random.randint(0, f.frames - nsamples_) f.seek(offset) # noise: (Time, Nmic) noise = f.read(nsamples_, dtype=np.float64, always_2d=True) if len(noise) != nsamples_: raise RuntimeError( f"Something wrong: {noise_lmdb_file}") if single_channel: num_ch = noise.shape[1] chs = [np.random.randint(num_ch)] noise = noise[:, chs] # noise: (Nmic, Time) noise = noise.T if tgt_fs and fs != tgt_fs: logging.warning( f"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)" # noqa ) noise = librosa.resample(noise, orig_sr=fs, target_sr=tgt_fs, res_type="kaiser_fast") if noise.shape[1] < nsamples: noise = np.pad( noise, [(0, 0), (0, nsamples - noise.shape[1])], mode="wrap", ) else: noise = noise[:, :nsamples] noise_power = (noise**2).mean() scale = (10**(-noise_db / 20) * np.sqrt(power) / np.sqrt(max(noise_power, 1e-10))) scaled_noise = scale * noise speech = speech + scaled_noise sample["wav_mix"] = torch.from_numpy(speech) sample["noise"] = torch.from_numpy(scaled_noise) sample["snr"] = noise_db yield sample def add_reverb(data, reverb_prob=0): """ Args: data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] Returns: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] Note: This function is implemented with reference to Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR) https://arxiv.org/pdf/2304.08052 This function is only used when online_mixing. """ for sample in data: assert "num_speaker" in sample.keys() assert "sample_rate" in sample.keys() simu_config["num_src"] = sample["num_speaker"] simu_config["sr"] = sample["sample_rate"] rirs, _ = RIR_sim(simu_config) # [n_mic, nsource, nsamples] rirs = rirs[0] # [nsource, nsamples] for i in range(sample["num_speaker"]): if reverb_prob > random.random(): # [1, audio_len], currently only support single-channel audio audio = sample[f"wav_spk{i + 1}"].numpy() rir = rirs[i:i + 1, :] # [1, nsamples] rir_audio = signal.convolve( audio, rir, mode="full")[:, :audio.shape[1]] # [1, audio_len] max_scale = np.max(np.abs(rir_audio)) out_audio = rir_audio / max_scale * 0.9 # Note: Here, we do not replace the dry audio with the reverberant audio, # noqa # which means we hope the model to perform dereverberation and # TSE simultaneously. sample[f"wav_spk{i + 1}"] = torch.from_numpy(out_audio) yield sample def add_noise_on_enroll( data, noise_lmdb_file, noise_enroll_prob: float = 0.0, noise_db_low: int = 0, noise_db_high: int = 25, single_channel: bool = True, ): """Add noise to mixture Args: data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] noise_lmdb_file: noise LMDB data source. noise_db_low (int, optional): SNR lower bound. Defaults to 0. noise_db_high (int, optional): SNR upper bound. Defaults to 25. single_channel (bool, optional): Whether to force the noise file to be single channel. # noqa Defaults to True. Returns: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}] # noqa """ noise_source = LmdbData(noise_lmdb_file) for sample in data: assert "sample_rate" in sample.keys() tgt_fs = sample["sample_rate"] all_keys = list(sample.keys()) for key in all_keys: if key.startswith("spk") and "label" not in key: if noise_enroll_prob > random.random(): speech = sample["embed_" + key] nsamples = speech.shape[1] power = (speech**2).mean() noise_key, noise_data = noise_source.random_one() if noise_key.startswith( "speech" ): # using interference speech as additive noise snr_range = [10, 30] else: snr_range = [noise_db_low, noise_db_high] noise_db = np.random.uniform(snr_range[0], snr_range[1]) _, noise_data = noise_source.random_one() with sf.SoundFile(io.BytesIO(noise_data)) as f: fs = f.samplerate if tgt_fs and fs != tgt_fs: nsamples_ = int(nsamples / tgt_fs * fs) + 1 else: nsamples_ = nsamples if f.frames == nsamples_: noise = f.read(dtype=np.float64, always_2d=True) elif f.frames < nsamples_: offset = np.random.randint(0, nsamples_ - f.frames) # noise: (Time, Nmic) noise = f.read(dtype=np.float64, always_2d=True) # Repeat noise noise = np.pad( noise, [ (offset, nsamples_ - f.frames - offset), (0, 0), ], mode="wrap", ) else: offset = np.random.randint(0, f.frames - nsamples_) f.seek(offset) # noise: (Time, Nmic) noise = f.read(nsamples_, dtype=np.float64, always_2d=True) if len(noise) != nsamples_: raise RuntimeError( f"Something wrong: {noise_lmdb_file}") if single_channel: num_ch = noise.shape[1] chs = [np.random.randint(num_ch)] noise = noise[:, chs] # noise: (Nmic, Time) noise = noise.T if tgt_fs and fs != tgt_fs: logging.warning( f"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)" # noqa ) noise = librosa.resample( noise, orig_sr=fs, target_sr=tgt_fs, res_type="kaiser_fast", ) if noise.shape[1] < nsamples: noise = np.pad( noise, [(0, 0), (0, nsamples - noise.shape[1])], mode="wrap", ) else: noise = noise[:, :nsamples] noise_power = (noise**2).mean() scale = (10**(-noise_db / 20) * np.sqrt(power) / np.sqrt(max(noise_power, 1e-10))) scaled_noise = scale * noise speech = speech + scaled_noise sample["embed_" + key] = speech yield sample def add_reverb_on_enroll(data, reverb_enroll_prob=0): """ Args: data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] Returns: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}] """ for sample in data: assert "num_speaker" in sample.keys() assert "sample_rate" in sample.keys() for i in range(sample["num_speaker"]): simu_config["sr"] = sample["sample_rate"] simu_config["num_src"] = 1 rirs, _ = RIR_sim(simu_config) # [n_mic, nsource, nsamples] rirs = rirs[0] # [nsource, nsamples] if reverb_enroll_prob > random.random(): # [1, audio_len], currently only support single-channel audio audio = sample[f"embed_spk{i+1}"] # rir = rirs[i : i + 1, :] # [1, nsamples] rir = rirs rir_audio = signal.convolve( audio, rir, mode="full")[:, :audio.shape[1]] # [1, audio_len] max_scale = np.max(np.abs(rir_audio)) out_audio = rir_audio / max_scale * 0.9 # Note: Here, we do not replace the dry audio with the reverberant audio, # noqa # which means we hope the model to perform dereverberation and # TSE simultaneously. sample[f"embed_spk{i+1}"] = out_audio yield sample def spec_aug(data, num_t_mask=1, num_f_mask=1, max_t=10, max_f=8, prob=0): """Do spec augmentation Inplace operation Args: data: Iterable[{key, feat, label}] num_t_mask: number of time mask to apply num_f_mask: number of freq mask to apply max_t: max width of time mask max_f: max width of freq mask prob: prob of spec_aug Returns Iterable[{key, feat, label}] """ for sample in data: if random.random() < prob: all_keys = list(sample.keys()) for key in all_keys: if key.startswith("embed"): y = sample[key] max_frames = y.shape[1] max_freq = y.shape[2] # time mask for i in range(num_t_mask): start = random.randint(0, max_frames - 1) length = random.randint(1, max_t) end = min(max_frames, start + length) y[:, start:end, :] = 0 # freq mask for i in range(num_f_mask): start = random.randint(0, max_freq - 1) length = random.randint(1, max_f) end = min(max_freq, start + length) y[:, :, start:end] = 0 sample[key] = y yield sample ================================================ FILE: wesep/dataset/vad.py ================================================ import numpy as np import soundfile as sf class VoiceActivityDetection: def __init__(self, wave): self.wave = wave def segmentation(self, overlap, slice_len): frequency = 16000 signal = self.wave self.seg_len = len(signal) / frequency self.slice_len = slice_len overlap = 2 slices = np.arange(0, self.seg_len, slice_len - overlap, dtype=np.intc) # print(slices) audio_slices = [] for start, end in zip(slices[:-1], slices[1:]): start_audio = start * frequency end_audio = (end + overlap) * frequency audio_slice = signal[int(start_audio):int(end_audio)] # print(len(audio_slice)) audio_slices.append(audio_slice) # wavfile.write('slices{}.wav'.format(start), 16000, audio_slice) # print(len(audio_slices)) return audio_slices def calc_energy(self, audio): # for a in enumerate(audio): # if (a == 0.0): # a = 0.00001 # print(np.sum(np.sum(audio**2))) energy = audio / np.sum(np.sum(audio**2) + 1e-8) * 1e2 # print(len(audio)) return energy def select(self): audio_slices = self.segmentation(overlap=1, slice_len=4) energies = [] for audio in audio_slices: chunk_len = len(audio) / 10 chunk_slice = np.arange(0, len(audio) + chunk_len, chunk_len, dtype=np.intc) for start, end in zip(chunk_slice[:-1], chunk_slice[1:]): energy = self.calc_energy(audio[start:end]) # print(energy) for i, _ in enumerate(energy): if (energy[i]) == 0: energy[i] = 0.00001 # print(energy[i]) energies.append(sum(energy)) # print(energies) threshold = np.quantile(energies, 0.25) print(threshold) if threshold < 0.0001: threshold = 0.0001 fin_audios = [] i = 0 for audio in audio_slices: chunk_len = len(audio) / 10 chunk_slice = np.arange(0, len(audio) + chunk_len, chunk_len, dtype=np.intc) count = 0 for start, end in zip(chunk_slice[:-1], chunk_slice[1:]): energy = self.calc_energy(audio[start:end]) # if 50% enenrgy > threshold # print(energy) print(sum(i >= threshold for i in energy)) if sum(i >= threshold for i in energy) >= chunk_len // 2: count += 1 # save seg # print(count) if count >= 5: sf.write("output{}.wav".format(i), audio, 16000) if len(audio) < self.slice_len * 16000: # print(self.slice_len*16000-len(audio)) audio = np.concatenate( [audio, np.zeros(self.slice_len * 16000 - len(audio))]) fin_audios.append(audio) i += 1 if len(fin_audios) == 0: fin_audios.append(np.zeros(self.slice_len * 16000)) return fin_audios ================================================ FILE: wesep/models/__init__.py ================================================ import wesep.models.bsrnn as bsrnn import wesep.models.convtasnet as convtasnet import wesep.models.dpccn as dpccn import wesep.models.tfgridnet as tfgridnet import wesep.modules.metric_gan.discriminator as discriminator import wesep.models.bsrnn_multi_optim as bsrnn_multi import wesep.models.bsrnn_feats as bsrnn_feats def get_model(model_name: str): if model_name.startswith("ConvTasNet"): return getattr(convtasnet, model_name) elif model_name.startswith("BSRNN_Multi"): return getattr(bsrnn_multi, model_name) elif model_name.startswith("BSRNN_Feats"): return getattr(bsrnn_feats, model_name) elif model_name.startswith("BSRNN"): return getattr(bsrnn, model_name) elif model_name.startswith("DPCCN"): return getattr(dpccn, model_name) elif model_name.startswith("TFGridNet"): return getattr(tfgridnet, model_name) elif model_name.startswith("CMGAN"): return getattr(discriminator, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1) ================================================ FILE: wesep/models/bsrnn.py ================================================ from __future__ import print_function from typing import Optional import numpy as np import torch import torch.nn as nn import torchaudio from wespeaker.models.speaker_model import get_speaker_model from wesep.modules.common.speaker import PreEmphasis from wesep.modules.common.speaker import SpeakerFuseLayer from wesep.modules.common.speaker import SpeakerTransform class ResRNN(nn.Module): def __init__(self, input_size, hidden_size, bidirectional=True): super(ResRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.eps = torch.finfo(torch.float32).eps self.norm = nn.GroupNorm(1, input_size, self.eps) self.rnn = nn.LSTM( input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional, ) # linear projection layer self.proj = nn.Linear(hidden_size * 2, input_size) # hidden_size = feature_dim * 2 def forward(self, input): # input shape: batch, dim, seq rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous()) rnn_output = self.proj(rnn_output.contiguous().view( -1, rnn_output.shape[2])).view(input.shape[0], input.shape[2], input.shape[1]) return input + rnn_output.transpose(1, 2).contiguous() """ TODO : attach the speaker embedding to each input Input shape:(B,feature_dim + spk_emb_dim , T) """ class BSNet(nn.Module): def __init__(self, in_channel, nband=7, bidirectional=True): super(BSNet, self).__init__() self.nband = nband self.feature_dim = in_channel // nband self.band_rnn = ResRNN(self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional) self.band_comm = ResRNN(self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional) def forward(self, input, dummy: Optional[torch.Tensor] = None): # input shape: B, nband*N, T B, N, T = input.shape band_output = self.band_rnn( input.view(B * self.nband, self.feature_dim, -1)).view(B, self.nband, -1, T) # band comm band_output = (band_output.permute(0, 3, 2, 1).contiguous().view( B * T, -1, self.nband)) output = (self.band_comm(band_output).view( B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous()) return output.view(B, N, T) class FuseSeparation(nn.Module): def __init__( self, nband=7, num_repeat=6, feature_dim=128, spk_emb_dim=256, spk_fuse_type="concat", multi_fuse=True, ): """ :param nband : len(self.band_width) """ super(FuseSeparation, self).__init__() self.multi_fuse = multi_fuse self.nband = nband self.feature_dim = feature_dim self.separation = nn.ModuleList([]) if self.multi_fuse: for _ in range(num_repeat): self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type, )) self.separation.append(BSNet(nband * feature_dim, nband)) else: self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type, )) for _ in range(num_repeat): self.separation.append(BSNet(nband * feature_dim, nband)) def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)): """ x: [B, nband, feature_dim, T] out: [B, nband, feature_dim, T] """ batch_size = x.shape[0] if self.multi_fuse: for i, sep_func in enumerate(self.separation): x = sep_func(x, spk_embedding) if i % 2 == 0: x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) else: x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) else: x = self.separation[0](x, spk_embedding) x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) for idx, sep in enumerate(self.separation): if idx > 0: x = sep(x, spk_embedding) x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) return x class BSRNN(nn.Module): # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, # use_bidirectional=True def __init__( self, spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, use_spk_transform=True, use_bidirectional=True, spk_fuse_type="concat", multi_fuse=True, joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ): super(BSRNN, self).__init__() self.sr = sr self.win = win self.stride = stride self.group = self.win // 2 self.enc_dim = self.win // 2 + 1 self.feature_dim = feature_dim self.eps = torch.finfo(torch.float32).eps self.spk_emb_dim = spk_emb_dim self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task # 0-1k (100 hop), 1k-4k (250 hop), # 4k-8k (500 hop), 8k-16k (1k hop), # 16k-20k (2k hop), 20k-inf # 0-8k (1k hop), 8k-16k (2k hop), 16k bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim)) bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim)) bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim)) bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim)) # add up to 8k self.band_width = [bandwidth_100] * 15 self.band_width += [bandwidth_200] * 10 self.band_width += [bandwidth_500] * 5 self.band_width += [bandwidth_2k] * 1 self.band_width.append(self.enc_dim - int(np.sum(self.band_width))) self.nband = len(self.band_width) if use_spk_transform: self.spk_transform = SpeakerTransform() else: self.spk_transform = nn.Identity() if joint_training: self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if not spk_feat: if feat_type == "consistent": self.preEmphasis = PreEmphasis() self.spk_encoder = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=win, win_length=win, hop_length=stride, f_min=20, window_fn=torch.hamming_window, n_mels=spk_args["feat_dim"], ) else: self.preEmphasis = nn.Identity() self.spk_encoder = nn.Identity() if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) else: self.pred_linear = nn.Identity() self.BN = nn.ModuleList([]) for i in range(self.nband): self.BN.append( nn.Sequential( nn.GroupNorm(1, self.band_width[i] * 2, self.eps), nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1), )) self.separator = FuseSeparation( nband=self.nband, num_repeat=num_repeat, feature_dim=feature_dim, spk_emb_dim=spk_emb_dim, spk_fuse_type=spk_fuse_type, multi_fuse=multi_fuse, ) # self.proj = nn.Linear(hidden_size*2, input_size) self.mask = nn.ModuleList([]) for i in range(self.nband): self.mask.append( nn.Sequential( nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1), )) def pad_input(self, input, window, stride): """ Zero-padding input according to window/stride size. """ batch_size, nsample = input.shape # pad the signals at the end for matching the window/stride size rest = window - (stride + nsample % window) % window if rest > 0: pad = torch.zeros(batch_size, rest).type(input.type()) input = torch.cat([input, pad], 1) pad_aux = torch.zeros(batch_size, stride).type(input.type()) input = torch.cat([pad_aux, input, pad_aux], 1) return input, rest def forward(self, input, embeddings): # input shape: (B, C, T) wav_input = input spk_emb_input = embeddings batch_size, nsample = wav_input.shape nch = 1 # frequency-domain separation spec = torch.stft( wav_input, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(wav_input.device).type( wav_input.type()), return_complex=True, ) # concat real and imag, split to subbands spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T subband_spec = [] subband_mix_spec = [] band_idx = 0 for i in range(len(self.band_width)): subband_spec.append(spec_RI[:, :, band_idx:band_idx + self.band_width[i]].contiguous()) subband_mix_spec.append(spec[:, band_idx:band_idx + self.band_width[i]]) # B*nch, BW, T band_idx += self.band_width[i] # normalization and bottleneck subband_feature = [] for i, bn_func in enumerate(self.BN): subband_feature.append( bn_func(subband_spec[i].view(batch_size * nch, self.band_width[i] * 2, -1))) subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T # print(subband_feature.size(), spk_emb_input.size()) predict_speaker_lable = torch.tensor(0.0).to( spk_emb_input.device) # dummy if self.joint_training: if not self.spk_feat: if self.feat_type == "consistent": with torch.no_grad(): spk_emb_input = self.preEmphasis(spk_emb_input) spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 spk_emb_input = spk_emb_input.log() spk_emb_input = spk_emb_input - torch.mean( spk_emb_input, dim=-1, keepdim=True) spk_emb_input = spk_emb_input.permute(0, 2, 1) tmp_spk_emb_input = self.spk_model(spk_emb_input) if isinstance(tmp_spk_emb_input, tuple): spk_emb_input = tmp_spk_emb_input[-1] else: spk_emb_input = tmp_spk_emb_input predict_speaker_lable = self.pred_linear(spk_emb_input) spk_embedding = self.spk_transform(spk_emb_input) spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch)) sep_subband_spec = [] for i, mask_func in enumerate(self.mask): this_output = mask_func(sep_output[:, i]).view( batch_size * nch, 2, 2, self.band_width[i], -1) this_mask = this_output[:, 0] * torch.sigmoid( this_output[:, 1]) # B*nch, 2, K, BW, T this_mask_real = this_mask[:, 0] # B*nch, K, BW, T this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T est_spec_real = (subband_mix_spec[i].real * this_mask_real - subband_mix_spec[i].imag * this_mask_imag ) # B*nch, BW, T est_spec_imag = (subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[i].imag * this_mask_real ) # B*nch, BW, T sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag)) est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T output = torch.istft( est_spec.view(batch_size * nch, self.enc_dim, -1), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(wav_input.device).type( wav_input.type()), length=nsample, ) output = output.view(batch_size, nch, -1) s = torch.squeeze(output, dim=1) return s, predict_speaker_lable if __name__ == "__main__": from thop import profile, clever_format model = BSRNN( spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, spk_fuse_type="additive", ) s = 0 for param in model.parameters(): s += np.product(param.size()) print("# of parameters: " + str(s / 1024.0 / 1024.0)) x = torch.randn(4, 32000) spk_embeddings = torch.randn(4, 256) output = model(x, spk_embeddings) print(output.shape) macs, params = profile(model, inputs=(x, spk_embeddings)) macs, params = clever_format([macs, params], "%.3f") print(macs, params) ================================================ FILE: wesep/models/bsrnn_feats.py ================================================ from __future__ import print_function from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from wespeaker.models.speaker_model import get_speaker_model from wesep.modules.common.speaker import PreEmphasis from wesep.modules.common.speaker import SpeakerFuseLayer from wesep.modules.common.speaker import SpeakerTransform from wesep.utils.funcs import compute_fbank, apply_cmvn class ResRNN(nn.Module): def __init__(self, input_size, hidden_size, bidirectional=True): super(ResRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.eps = torch.finfo(torch.float32).eps self.norm = nn.GroupNorm(1, input_size, self.eps) self.rnn = nn.LSTM( input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional, ) # linear projection layer self.proj = nn.Linear(hidden_size * 2, input_size) # hidden_size = feature_dim * 2 def forward(self, input): # input shape: batch, dim, seq rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous()) rnn_output = self.proj(rnn_output.contiguous().view( -1, rnn_output.shape[2])).view(input.shape[0], input.shape[2], input.shape[1]) return input + rnn_output.transpose(1, 2).contiguous() """ TODO : attach the speaker embedding to each input Input shape:(B,feature_dim + spk_emb_dim , T) """ class BSNet(nn.Module): def __init__(self, in_channel, nband=7, bidirectional=True): super(BSNet, self).__init__() self.nband = nband self.feature_dim = in_channel // nband self.band_rnn = ResRNN(self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional) self.band_comm = ResRNN(self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional) def forward(self, input, dummy: Optional[torch.Tensor] = None): # input shape: B, nband*N, T B, N, T = input.shape band_output = self.band_rnn( input.view(B * self.nband, self.feature_dim, -1)).view(B, self.nband, -1, T) # band comm band_output = (band_output.permute(0, 3, 2, 1).contiguous().view( B * T, -1, self.nband)) output = (self.band_comm(band_output).view( B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous()) return output.view(B, N, T) class CrossAtt(nn.Module): def __init__(self, embed_dim, num_heads, *args, **kwargs): super(CrossAtt, self).__init__() self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, *args, **kwargs) def forward(self, query, key, value): if query.dim() == 4: spk_embeddings = [] for i in range(query.shape[1]): x = query[:, i, :, :].squeeze(dim=1) # (batch, feature, time) x, _ = self.multihead_attn(x.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)) spk_embeddings.append(x.transpose(1, 2)) spk_embeddings = torch.stack(spk_embeddings, dim=1) elif query.dim() == 3: x, _ = self.multihead_attn(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)) spk_embeddings = x.transpose(1, 2) return spk_embeddings class FuseSeparation(nn.Module): def __init__( self, nband=7, num_repeat=6, feature_dim=128, spk_emb_dim=256, spk_fuse_type="concat", multi_fuse=True, ): """ :param nband : len(self.band_width) """ super(FuseSeparation, self).__init__() self.spk_fuse_type = spk_fuse_type self.multi_fuse = multi_fuse self.nband = nband self.feature_dim = feature_dim self.attenFuse = nn.ModuleList([]) if spk_fuse_type and spk_fuse_type.startswith("cross_"): spk_emb_frame_dim = 512 # Ecapa_TDNN spk_emb_dim = feature_dim self.attenFuse.append(nn.Linear(spk_emb_frame_dim, feature_dim)) self.attenFuse.append(CrossAtt(embed_dim=feature_dim, num_heads=2, batch_first=True)) self.separation = nn.ModuleList([]) if self.multi_fuse and self.spk_fuse_type: for _ in range(num_repeat): self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type.removeprefix("cross_"), )) self.separation.append(BSNet(nband * feature_dim, nband)) else: if self.spk_fuse_type: self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type.removeprefix("cross_"), )) for _ in range(num_repeat): self.separation.append(BSNet(nband * feature_dim, nband)) def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)): """ x: [B, nband, feature_dim, T] out: [B, nband, feature_dim, T] """ batch_size = x.shape[0] if self.spk_fuse_type and self.spk_fuse_type.startswith('cross_'): spk_embedding = spk_embedding.transpose(1, 2) spk_embedding = self.attenFuse[0](spk_embedding) spk_embedding = spk_embedding.transpose(1, 2) spk_embedding = self.attenFuse[1](x, spk_embedding, spk_embedding) if self.multi_fuse and self.spk_fuse_type: for i, sep_func in enumerate(self.separation): x = sep_func(x, spk_embedding) if i % 2 == 0: x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) else: x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) if self.spk_fuse_type.startswith('cross_'): spk_embedding = spk_embedding.transpose(1, 2) spk_embedding = self.attenFuse[0](spk_embedding) spk_embedding = spk_embedding.transpose(1, 2) spk_embedding = self.attenFuse[1](x, spk_embedding, spk_embedding) else: idx_start = -1 if self.spk_fuse_type: x = self.separation[0](x, spk_embedding) idx_start += 1 x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) for idx, sep in enumerate(self.separation): if idx > idx_start: x = sep(x, spk_embedding) x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) return x class BSRNN_Feats(nn.Module): # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, # use_bidirectional=True def __init__( self, spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, use_spk_transform=False, use_bidirectional=True, spectral_feat=False, spk_fuse_type="concat", multi_fuse=False, joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ): super(BSRNN_Feats, self).__init__() self.sr = sr self.win = win self.stride = stride self.group = self.win // 2 self.enc_dim = self.win // 2 + 1 self.feature_dim = feature_dim self.eps = torch.finfo(torch.float32).eps self.spk_emb_dim = spk_emb_dim self.spk_fuse_type = spk_fuse_type self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task # 0-1k (100 hop), 1k-4k (250 hop), # 4k-8k (500 hop), 8k-16k (1k hop), # 16k-20k (2k hop), 20k-inf # 0-8k (1k hop), 8k-16k (2k hop), 16k bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim)) bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim)) bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim)) bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim)) # add up to 8k self.band_width = [bandwidth_100] * 15 self.band_width += [bandwidth_200] * 10 self.band_width += [bandwidth_500] * 5 self.band_width += [bandwidth_2k] * 1 self.band_width.append(self.enc_dim - int(np.sum(self.band_width))) self.nband = len(self.band_width) if use_spk_transform: self.spk_transform = SpeakerTransform() else: self.spk_transform = nn.Identity() if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb'): self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if not spk_feat: if feat_type == "consistent": self.preEmphasis = PreEmphasis() self.spk_encoder = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=win, win_length=win, hop_length=stride, f_min=20, window_fn=torch.hamming_window, n_mels=spk_args["feat_dim"], ) else: self.preEmphasis = nn.Identity() self.spk_encoder = nn.Identity() if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) else: self.pred_linear = nn.Identity() spec_map = 2 if spectral_feat: spec_map += 1 self.spectral_feat = spectral_feat self.spec_map = spec_map self.BN = nn.ModuleList([]) for i in range(self.nband): self.BN.append( nn.Sequential( nn.GroupNorm(1, self.band_width[i] * spec_map, self.eps), nn.Conv1d(self.band_width[i] * spec_map, self.feature_dim, 1), )) self.separator = FuseSeparation( nband=self.nband, num_repeat=num_repeat, feature_dim=feature_dim, spk_emb_dim=spk_emb_dim, spk_fuse_type=spk_fuse_type, multi_fuse=multi_fuse, ) self.mask = nn.ModuleList([]) for i in range(self.nband): self.mask.append( nn.Sequential( nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1), )) def pad_input(self, input, window, stride): """ Zero-padding input according to window/stride size. """ batch_size, nsample = input.shape # pad the signals at the end for matching the window/stride size rest = window - (stride + nsample % window) % window if rest > 0: pad = torch.zeros(batch_size, rest).type(input.type()) input = torch.cat([input, pad], 1) pad_aux = torch.zeros(batch_size, stride).type(input.type()) input = torch.cat([pad_aux, input, pad_aux], 1) return input, rest def forward(self, input, embeddings): # input shape: (B, C, T) wav_input = input spk_emb_input = embeddings batch_size, nsample = wav_input.shape nch = 1 # frequency-domain separation spec = torch.stft( wav_input, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(wav_input.device).type( wav_input.type()), return_complex=True, ) spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T # Calculate the spectral level feature if self.spectral_feat: aux_c = torch.stft( spk_emb_input, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(spk_emb_input.device).type( spk_emb_input.type()), return_complex=True, ) if self.spectral_feat == 'tfmap_spec': mix_mag_ori = torch.abs(spec) enroll_mag = torch.abs(aux_c) mix_mag = F.normalize(mix_mag_ori, p=2, dim=1) enroll_mag = F.normalize(enroll_mag, p=2, dim=1) mix_mag = mix_mag.permute(0, 2, 1).contiguous() att_scores = torch.matmul(mix_mag, enroll_mag) att_weights = F.softmax(att_scores, dim=-1) enroll_mag = enroll_mag.permute(0, 2, 1).contiguous() tf_map = torch.matmul(att_weights, enroll_mag) tf_map = tf_map.permute(0, 2, 1).contiguous() tf_map = tf_map / tf_map.norm(dim=1, keepdim=True) # Recover the energy of estimated tfmap feature tf_map = ( torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True) * tf_map ) # Another kind of nomalization for tf_map feature # tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True) spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1) if self.spectral_feat == 'tfmap_emb': # Only Ecapa-TDNN. with torch.no_grad(): signal_dim = wav_input.dim() extended_shape = ( [1] * (3 - signal_dim) + list(wav_input.size()) ) pad = int(self.win // 2) mix_emb = F.pad( wav_input.view(extended_shape), [pad, pad], mode="reflect" ) mix_emb = mix_emb.view(mix_emb.shape[-signal_dim:]) signal_dim = spk_emb_input.dim() extended_shape = ( [1] * (3 - signal_dim) + list(spk_emb_input.size()) ) pad = int(self.win // 2) spk_emb = F.pad( spk_emb_input.view(extended_shape), [pad, pad], mode="reflect" ) spk_emb = spk_emb.view(spk_emb.shape[-signal_dim:]) spk_emb = compute_fbank( spk_emb, frame_length=self.win * 1e3 / self.sr, frame_shift=self.stride * 1e3 / self.sr, dither=0.0, sample_rate=self.sr ) mix_emb = compute_fbank( mix_emb, frame_length=self.win * 1e3 / self.sr, frame_shift=self.stride * 1e3 / self.sr, dither=0.0, sample_rate=self.sr ) mix_emb = apply_cmvn(mix_emb) spk_emb = apply_cmvn(spk_emb) spk_emb = self.spk_model(spk_emb) if isinstance(spk_emb, tuple): spk_emb_frame = spk_emb[0] else: spk_emb_frame = spk_emb mix_emb = self.spk_model(mix_emb) if isinstance(mix_emb, tuple): mix_emb_frame = mix_emb[0] else: mix_emb_frame = mix_emb mix_emb_frame_ = F.normalize(mix_emb_frame, p=2, dim=1) spk_emb_frame_ = F.normalize(spk_emb_frame, p=2, dim=1) mix_emb_frame_ = mix_emb_frame_.transpose(1, 2) att_scores = torch.matmul(mix_emb_frame_, spk_emb_frame_) att_weights = F.softmax(att_scores, dim=-1) mix_mag_ori = torch.abs(spec) enroll_mag = torch.abs(aux_c) enroll_mag = enroll_mag.transpose(1, 2) # enroll_mag = F.normalize(enroll_mag, p=2, dim=1) tf_map = torch.matmul(att_weights, enroll_mag) tf_map = tf_map.transpose(1, 2) tf_map = tf_map / tf_map.norm(dim=1, keepdim=True) # Recover the energy of estimated tfmap feature tf_map = ( torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True) * tf_map ) # Another kind of nomalization for tf_map feature # tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True) spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1) # concat real and imag, split to subbands subband_spec = [] subband_mix_spec = [] band_idx = 0 for i in range(len(self.band_width)): subband_spec.append(spec_RI[:, :, band_idx:band_idx + self.band_width[i]].contiguous()) subband_mix_spec.append(spec[:, band_idx:band_idx + self.band_width[i]]) # B*nch, BW, T band_idx += self.band_width[i] # normalization and bottleneck subband_feature = [] for i, bn_func in enumerate(self.BN): subband_feature.append( bn_func(subband_spec[i].view(batch_size * nch, self.band_width[i] * self.spec_map, -1))) subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T # print(subband_feature.size(), spk_emb_input.size()) predict_speaker_lable = torch.tensor(0.0).to( spk_emb_input.device) # dummy if ( (self.spectral_feat and self.spectral_feat == "tfmap_emb") and (self.spk_fuse_type and self.spk_fuse_type.startswith("cross_")) ): spk_emb_input = spk_emb_frame elif self.joint_training and self.spk_fuse_type: if not self.spk_feat: if self.feat_type == "consistent": with torch.no_grad(): spk_emb_input = self.preEmphasis(spk_emb_input) spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 spk_emb_input = spk_emb_input.log() spk_emb_input = spk_emb_input - torch.mean( spk_emb_input, dim=-1, keepdim=True) spk_emb_input = spk_emb_input.permute(0, 2, 1) if self.spk_fuse_type and self.spk_fuse_type.startswith("cross_"): tmp_spk_emb_input = self.spk_model._get_frame_level_feat( spk_emb_input) else: tmp_spk_emb_input = self.spk_model(spk_emb_input) if isinstance(tmp_spk_emb_input, tuple): spk_emb_input = tmp_spk_emb_input[-1] else: spk_emb_input = tmp_spk_emb_input predict_speaker_lable = self.pred_linear(spk_emb_input) spk_embedding = self.spk_transform(spk_emb_input) if self.spk_fuse_type and not self.spk_fuse_type.startswith("cross_"): spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch)) sep_subband_spec = [] for i, mask_func in enumerate(self.mask): this_output = mask_func(sep_output[:, i]).view( batch_size * nch, 2, 2, self.band_width[i], -1) this_mask = this_output[:, 0] * torch.sigmoid( this_output[:, 1]) # B*nch, 2, K, BW, T this_mask_real = this_mask[:, 0] # B*nch, K, BW, T this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T est_spec_real = (subband_mix_spec[i].real * this_mask_real - subband_mix_spec[i].imag * this_mask_imag ) # B*nch, BW, T est_spec_imag = (subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[i].imag * this_mask_real ) # B*nch, BW, T sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag)) est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T output = torch.istft( est_spec.view(batch_size * nch, self.enc_dim, -1), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(wav_input.device).type( wav_input.type()), length=nsample, ) output = output.view(batch_size, nch, -1) s = torch.squeeze(output, dim=1) return s, predict_speaker_lable if __name__ == "__main__": from thop import profile, clever_format model = BSRNN_Feats( spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, spectral_feat='tfmap_emb', spk_fuse_type='cross_multiply', spk_model="ECAPA_TDNN_GLOB_c512", spk_args={ "embed_dim": 192, "feat_dim": 80, "pooling_func": "ASTP", } ) s = 0 for param in model.parameters(): s += np.product(param.size()) print("# of parameters: " + str(s / 1024.0 / 1024.0)) x = torch.randn(4, 32000) spk_embeddings = torch.randn(4, 16000) output = model(x, spk_embeddings) print(output[0].shape) macs, params = profile(model, inputs=(x, spk_embeddings)) macs, params = clever_format([macs, params], "%.3f") print(macs, params) ================================================ FILE: wesep/models/bsrnn_multi_optim.py ================================================ from __future__ import print_function from typing import Optional import numpy as np import torch import torch.nn as nn import torchaudio from wespeaker.models.speaker_model import get_speaker_model from wesep.modules.common.speaker import PreEmphasis from wesep.modules.common.speaker import SpeakerFuseLayer from wesep.modules.common.speaker import SpeakerTransform class ResRNN(nn.Module): def __init__(self, input_size, hidden_size, bidirectional=True): super(ResRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.eps = torch.finfo(torch.float32).eps self.norm = nn.GroupNorm(1, input_size, self.eps) self.rnn = nn.LSTM( input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional, ) # linear projection layer self.proj = nn.Linear( hidden_size * 2, input_size ) # hidden_size = feature_dim * 2 def forward(self, input): # input shape: batch, dim, seq rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous()) rnn_output = self.proj( rnn_output.contiguous().view(-1, rnn_output.shape[2]) ).view(input.shape[0], input.shape[2], input.shape[1]) return input + rnn_output.transpose(1, 2).contiguous() """ TODO : attach the speaker embedding to each input Input shape:(B,feature_dim + spk_emb_dim , T) """ class BSNet(nn.Module): def __init__(self, in_channel, nband=7, bidirectional=True): super(BSNet, self).__init__() self.nband = nband self.feature_dim = in_channel // nband self.band_rnn = ResRNN( self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional ) self.band_comm = ResRNN( self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional ) def forward(self, input, dummy: Optional[torch.Tensor] = None): # input shape: B, nband*N, T B, N, T = input.shape band_output = self.band_rnn( input.view(B * self.nband, self.feature_dim, -1) ).view(B, self.nband, -1, T) # band comm band_output = ( band_output.permute(0, 3, 2, 1).contiguous().view(B * T, -1, self.nband) ) output = ( self.band_comm(band_output) .view(B, T, -1, self.nband) .permute(0, 3, 2, 1) .contiguous() ) return output.view(B, N, T) class FuseSeparation(nn.Module): def __init__( self, nband=7, num_repeat=6, feature_dim=128, spk_emb_dim=256, spk_fuse_type="concat", multi_fuse=True, ): """ :param nband : len(self.band_width) """ super(FuseSeparation, self).__init__() self.multi_fuse = multi_fuse self.nband = nband self.feature_dim = feature_dim self.separation = nn.ModuleList([]) if self.multi_fuse: for _ in range(num_repeat): self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type, ) ) self.separation.append(BSNet(nband * feature_dim, nband)) else: self.separation.append( SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type, ) ) for _ in range(num_repeat): self.separation.append(BSNet(nband * feature_dim, nband)) def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)): """ x: [B, nband, feature_dim, T] out: [B, nband, feature_dim, T] """ batch_size = x.shape[0] if self.multi_fuse: for i, sep_func in enumerate(self.separation): x = sep_func(x, spk_embedding) if i % 2 == 0: x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) else: x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) else: x = self.separation[0](x, spk_embedding) x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) for idx, sep in enumerate(self.separation): if idx > 0: x = sep(x, spk_embedding) x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) return x class BSRNN_Multi(nn.Module): # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, # use_bidirectional=True def __init__( self, spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, use_spk_transform=True, use_bidirectional=True, spk_fuse_type="concat", multi_fuse=True, joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ): super(BSRNN_Multi, self).__init__() self.sr = sr self.win = win self.stride = stride self.group = self.win // 2 self.enc_dim = self.win // 2 + 1 self.feature_dim = feature_dim self.eps = torch.finfo(torch.float32).eps self.spk_emb_dim = spk_emb_dim self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task # 0-1k (100 hop), 1k-4k (250 hop), # 4k-8k (500 hop), 8k-16k (1k hop), # 16k-20k (2k hop), 20k-inf # 0-8k (1k hop), 8k-16k (2k hop), 16k bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim)) bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim)) bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim)) bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim)) # add up to 8k self.band_width = [bandwidth_100] * 15 self.band_width += [bandwidth_200] * 10 self.band_width += [bandwidth_500] * 5 self.band_width += [bandwidth_2k] * 1 self.band_width.append(self.enc_dim - int(np.sum(self.band_width))) self.nband = len(self.band_width) if use_spk_transform: self.spk_transform = SpeakerTransform() else: self.spk_transform = nn.Identity() if joint_training: self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if not spk_feat: if feat_type == "consistent": self.preEmphasis = PreEmphasis() self.spk_encoder = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=win, win_length=win, hop_length=stride, f_min=20, window_fn=torch.hamming_window, n_mels=spk_args["feat_dim"], ) else: self.preEmphasis = nn.Identity() self.spk_encoder = nn.Identity() if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) else: self.pred_linear = nn.Identity() self.BN = nn.ModuleList([]) for i in range(self.nband): self.BN.append( nn.Sequential( nn.GroupNorm(1, self.band_width[i] * 2, self.eps), nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1), ) ) self.separator = FuseSeparation( nband=self.nband, num_repeat=num_repeat, feature_dim=feature_dim, spk_emb_dim=spk_emb_dim, spk_fuse_type=spk_fuse_type, multi_fuse=multi_fuse, ) # self.proj = nn.Linear(hidden_size*2, input_size) self.mask = nn.ModuleList([]) for i in range(self.nband): self.mask.append( nn.Sequential( nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1), nn.Tanh(), nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1), ) ) def pad_input(self, input, window, stride): """ Zero-padding input according to window/stride size. """ batch_size, nsample = input.shape # pad the signals at the end for matching the window/stride size rest = window - (stride + nsample % window) % window if rest > 0: pad = torch.zeros(batch_size, rest).type(input.type()) input = torch.cat([input, pad], 1) pad_aux = torch.zeros(batch_size, stride).type(input.type()) input = torch.cat([pad_aux, input, pad_aux], 1) return input, rest def forward(self, input, embeddings): # input shape: (B, C, T) wav_input = input spk_emb_input = embeddings batch_size, nsample = wav_input.shape nch = 1 # frequency-domain separation spec = torch.stft( wav_input, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win) .to(wav_input.device) .type(wav_input.type()), return_complex=True, ) # concat real and imag, split to subbands spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T subband_spec = [] subband_mix_spec = [] band_idx = 0 for i in range(len(self.band_width)): subband_spec.append( spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous() ) subband_mix_spec.append( spec[:, band_idx : band_idx + self.band_width[i]] ) # B*nch, BW, T band_idx += self.band_width[i] # normalization and bottleneck subband_feature = [] for i, bn_func in enumerate(self.BN): subband_feature.append( bn_func( subband_spec[i].view(batch_size * nch, self.band_width[i] * 2, -1) ) ) subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T # print(subband_feature.size(), spk_emb_input.size()) predict_speaker_lable = torch.tensor(0.0).to(spk_emb_input.device) # dummy if self.joint_training: if not self.spk_feat: if self.feat_type == "consistent": with torch.no_grad(): spk_emb_input = self.preEmphasis(spk_emb_input) spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 spk_emb_input = spk_emb_input.log() spk_emb_input = spk_emb_input - torch.mean( spk_emb_input, dim=-1, keepdim=True ) spk_emb_input = spk_emb_input.permute(0, 2, 1) tmp_spk_emb_input = self.spk_model(spk_emb_input) if isinstance(tmp_spk_emb_input, tuple): spk_emb_input = tmp_spk_emb_input[-1] else: spk_emb_input = tmp_spk_emb_input predict_speaker_lable = self.pred_linear(spk_emb_input) spk_embedding = self.spk_transform(spk_emb_input) spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch)) sep_subband_spec = [] for i, mask_func in enumerate(self.mask): this_output = mask_func(sep_output[:, i]).view( batch_size * nch, 2, 2, self.band_width[i], -1 ) this_mask = this_output[:, 0] * torch.sigmoid( this_output[:, 1] ) # B*nch, 2, K, BW, T this_mask_real = this_mask[:, 0] # B*nch, K, BW, T this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T est_spec_real = ( subband_mix_spec[i].real * this_mask_real - subband_mix_spec[i].imag * this_mask_imag ) # B*nch, BW, T est_spec_imag = ( subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[i].imag * this_mask_real ) # B*nch, BW, T sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag)) est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T output = torch.istft( est_spec.view(batch_size * nch, self.enc_dim, -1), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win) .to(wav_input.device) .type(wav_input.type()), length=nsample, ) output = output.view(batch_size, nch, -1) s = torch.squeeze(output, dim=1) if torch.is_grad_enabled(): self_embedding = s.detach() self_predict_speaker_lable = torch.tensor(0.0).to( self_embedding.device ) # dummy if self.joint_training: if self.feat_type == "consistent": with torch.no_grad(): self_embedding = self.preEmphasis(self_embedding) self_embedding = self.spk_encoder(self_embedding) + 1e-8 self_embedding = self_embedding.log() self_embedding = self_embedding - torch.mean( self_embedding, dim=-1, keepdim=True ) self_embedding = self_embedding.permute(0, 2, 1) self_tmp_spk_emb_input = self.spk_model(self_embedding) if isinstance(self_tmp_spk_emb_input, tuple): self_spk_emb_input = self_tmp_spk_emb_input[-1] else: self_spk_emb_input = self_tmp_spk_emb_input self_predict_speaker_lable = self.pred_linear(self_spk_emb_input) self_spk_embedding = self.spk_transform(self_spk_emb_input) self_spk_embedding = self_spk_embedding.unsqueeze(1).unsqueeze(3) self_sep_output = self.separator( subband_feature, self_spk_embedding, torch.tensor(nch) ) self_sep_subband_spec = [] for i, mask_func in enumerate(self.mask): this_output = mask_func(self_sep_output[:, i]).view( batch_size * nch, 2, 2, self.band_width[i], -1 ) this_mask = this_output[:, 0] * torch.sigmoid( this_output[:, 1] ) # B*nch, 2, K, BW, T this_mask_real = this_mask[:, 0] # B*nch, K, BW, T this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T est_spec_real = ( subband_mix_spec[i].real * this_mask_real - subband_mix_spec[i].imag * this_mask_imag ) # B*nch, BW, T est_spec_imag = ( subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[i].imag * this_mask_real ) # B*nch, BW, T self_sep_subband_spec.append( torch.complex(est_spec_real, est_spec_imag) ) self_est_spec = torch.cat(self_sep_subband_spec, 1) # B*nch, F, T self_output = torch.istft( self_est_spec.view(batch_size * nch, self.enc_dim, -1), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win) .to(wav_input.device) .type(wav_input.type()), length=nsample, ) self_output = self_output.view(batch_size, nch, -1) self_s = torch.squeeze(self_output, dim=1) return s, self_s, predict_speaker_lable, self_predict_speaker_lable return s, predict_speaker_lable if __name__ == "__main__": from thop import profile, clever_format model = BSRNN_Multi( spk_emb_dim=256, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, spk_fuse_type="additive", ) s = 0 for param in model.parameters(): s += np.product(param.size()) print("# of parameters: " + str(s / 1024.0 / 1024.0)) x = torch.randn(4, 32000) spk_embeddings = torch.randn(4, 256) output = model(x, spk_embeddings) print(output.shape) macs, params = profile(model, inputs=(x, spk_embeddings)) macs, params = clever_format([macs, params], "%.3f") print(macs, params) ================================================ FILE: wesep/models/convtasnet.py ================================================ import torch import torch.nn as nn from wesep.modules.common import select_norm from wesep.modules.common.speaker import SpeakerTransform from wesep.modules.tasnet import DeepEncoder, DeepDecoder from wesep.modules.tasnet import MultiEncoder, MultiDecoder from wesep.modules.tasnet import FuseSeparation from wesep.modules.tasnet.convs import Conv1D, ConvTrans1D from wesep.modules.tasnet.speaker import ResNet4SpExplus from wespeaker.models.speaker_model import get_speaker_model class ConvTasNet(nn.Module): def __init__( self, N=512, L=16, B=128, H=512, P=3, X=8, R=3, spk_emb_dim=256, norm="gLN", activate="relu", causal=False, skip_con=False, spk_fuse_type="concatConv", # "concat", "additive", "multiply", "FiLM", "None", # ("concatConv" only for convtasnet) multi_fuse=True, use_spk_transform=True, encoder_type="Multi", # 'Multi', 'Deep', None decoder_type="Multi", joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ): """ :param N: Number of filters in autoencoder :param L: Length of the filters (in samples) :param B: Number of channels in bottleneck and the residual paths :param H: Number of channels in convolutional blocks :param P: Kernel size in convolutional blocks :param X: Number of convolutional blocks in each repeat :param R: Number of repeats :param norm: :param activate: :param causal: :param skip_con: :param spk_fuse_type: concat/addition/FiLM :param use_spk_transform: :param use_deep_enc: :param use_deep_dec: """ super(ConvTasNet, self).__init__() self.encoder_type = encoder_type self.decoder_type = decoder_type # n x 1 x T => n x N x T if encoder_type == "Multi": self.encoder = MultiEncoder( in_channels=1, middle_channels=N, out_channels=B, kernel_size=L, stride=L // 2, ) elif encoder_type == "Deep": self.encoder = DeepEncoder(1, N, L, stride=L // 2) self.LayerN_S = select_norm(norm, N) self.BottleN_S = Conv1D(N, B, 1) else: self.encoder = nn.Sequential( Conv1D(1, N, L, stride=L // 2, padding=0), nn.ReLU()) self.LayerN_S = select_norm(norm, N) self.BottleN_S = Conv1D(N, B, 1) self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task if joint_training: if not self.spk_feat: if self.feat_type == "consistent": self.spk_model = ResNet4SpExplus( in_channel=N, C_embedding=spk_emb_dim ) # The speaker model is fixed for SpEx+ currently else: self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if self.spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) if not use_spk_transform: self.spk_transform = nn.Identity() else: self.spk_transform = SpeakerTransform() # Separation block # n x B x T => n x B x T self.separation = FuseSeparation( R, X, B, H, P, norm=norm, causal=causal, skip_con=skip_con, C_embedding=spk_emb_dim, spk_fuse_type=spk_fuse_type, multi_fuse=multi_fuse, ) # n x N x T => n x 1 x L if decoder_type == "Multi": self.decoder = MultiDecoder( in_channels=B, middle_channels=N, out_channels=1, kernel_size=L, stride=L // 2, ) elif decoder_type == "Deep": self.decoder = DeepDecoder(N, L, stride=L // 2) self.gen_masks = Conv1D(B, N, 1) else: self.decoder = ConvTrans1D(N, 1, L, stride=L // 2) self.gen_masks = Conv1D(B, N, 1) # activation function active_f = { "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "softmax": nn.Softmax(dim=0), } # self.activation_type = activate self.activation = active_f[activate] def forward(self, x, embeddings): if x.dim() >= 3: raise RuntimeError( "{} accept 1/2D tensor as input, but got {:d}".format( self.__name__, x.dim())) if x.dim() == 1: x = torch.unsqueeze(x, 0) # x: n x 1 x L => n x N x T if self.encoder_type == "Multi": e, w1, w2, w3 = self.encoder(x) x = e # replace x with e, for asymmetric encoder-decoder else: x = self.encoder(x) e = self.LayerN_S(x) e = self.BottleN_S( e) # Embedding fuse after dimension changed fro N to B if (self.joint_training): # Only support sharing Encoder and ResNet in SpEx+ currently # Speaker Encoder if not self.spk_feat and self.feat_type == "consistent": if self.encoder_type == "Multi": _, aux_w1, aux_w2, aux_w3 = self.encoder(embeddings) embeddings = torch.cat([aux_w1, aux_w2, aux_w3], 1) else: aux_x = self.encoder(embeddings) aux_e = self.LayerN_S(aux_x) embeddings = self.BottleN_S(aux_e) embeddings = self.spk_model(embeddings) if isinstance(embeddings, tuple): embeddings = embeddings[-1] if self.multi_task: predict_speaker_lable = self.pred_linear(embeddings) spk_embeds = self.spk_transform(embeddings.unsqueeze(-1)) e = self.separation(e, spk_embeds) # decoder part n x L if self.decoder_type == "Multi": s = self.decoder( e, w1, w2, w3, actLayer=self.activation) # s is a tuple by using multiDecoder else: # n x B x L => n x N x L m = self.gen_masks(e) # n x N x L m = self.activation(m) x = x * m s = self.decoder(x) if self.joint_training and self.multi_task: if not isinstance(s, list): s = [ s, ] s.append(predict_speaker_lable) return s # s: N x Len Or List(N x Len,x3/x4) def check_parameters(net): """ Returns module parameters. Mb """ parameters = sum(param.numel() for param in net.parameters()) return parameters / 10**6 def test_convtasnet(): x = torch.randn(4, 32000) spk_embeddings = torch.randn(4, 256) net = ConvTasNet(use_spk_transform=False, spk_fuse_type="FiLM") s = net(x, spk_embeddings) print(str(check_parameters(net)) + " Mb") print(s[1].shape) if __name__ == "__main__": test_convtasnet() ================================================ FILE: wesep/models/dpccn.py ================================================ import torch import torch.nn as nn import torchaudio from wespeaker.models.speaker_model import get_speaker_model from wesep.modules.common.speaker import PreEmphasis from wesep.modules.common.speaker import SpeakerFuseLayer from wesep.modules.common.speaker import SpeakerTransform from wesep.modules.dpccn.convs import Conv2dBlock from wesep.modules.dpccn.convs import ConvTrans2dBlock from wesep.modules.dpccn.convs import DenseBlock from wesep.modules.dpccn.convs import TCNBlock class DPCCN(nn.Module): def __init__( self, win=512, stride=128, spk_emb_dim=256, sr=16000, use_spk_transform=False, spk_fuse_type="multiply", feature_dim=257, kernel_size=(3, 3), stride1=(1, 1), stride2=(1, 2), paddings=(1, 1), output_padding=(0, 0), tcn_dims=384, tcn_blocks=10, tcn_layers=2, causal=False, pool_size=(4, 8, 16, 32), multi_fuse=False, joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ) -> None: super(DPCCN, self).__init__() self.win_len = win self.hop_size = stride self.spk_emb_dim = spk_emb_dim self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task self.conv2d = nn.Conv2d(2, 16, kernel_size, stride1, paddings) self.encoder = self._build_encoder(kernel_size=kernel_size, stride=stride2, padding=paddings) if use_spk_transform: self.spk_transform = SpeakerTransform() else: self.spk_transform = nn.Identity() if joint_training: self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if not spk_feat: if feat_type == "consistent": self.preEmphasis = PreEmphasis() self.spk_encoder = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=win, win_length=win, hop_length=stride, f_min=20, window_fn=torch.hamming_window, n_mels=spk_args["feat_dim"], ) else: self.preEmphasis = nn.Identity() self.spk_encoder = nn.Identity() if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) else: self.pred_linear = nn.Identity() self.spk_fuse = SpeakerFuseLayer( embed_dim=self.spk_emb_dim, feat_dim=feature_dim, fuse_type=spk_fuse_type, ) self.tcn_layers = self._build_tcn_layers( tcn_layers, tcn_blocks, in_dims=tcn_dims, out_dims=tcn_dims, causal=causal, ) self.decoder = self._build_decoder( kernel_size=kernel_size, stride=stride2, padding=paddings, output_padding=output_padding, ) self.avg_pool = self._build_avg_pool(pool_size) self.avg_proj = nn.Conv2d(64, 32, 1, 1) self.deconv2d = nn.ConvTranspose2d(32, 2, kernel_size, stride1, paddings) def _build_encoder(self, **enc_kargs): """ Build encoder layers """ encoder = nn.ModuleList() encoder.append(DenseBlock(16, 16, "enc")) for i in range(4): encoder.append( nn.Sequential( Conv2dBlock(in_dims=16 if i == 0 else 32, out_dims=32, **enc_kargs), DenseBlock(32, 32, "enc"), )) encoder.append(Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs)) encoder.append(Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs)) encoder.append(Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs)) return encoder def _build_decoder(self, **dec_kargs): """ Build decoder layers """ decoder = nn.ModuleList() decoder.append( ConvTrans2dBlock(in_dims=384 * 2, out_dims=128, **dec_kargs)) decoder.append( ConvTrans2dBlock(in_dims=128 * 2, out_dims=64, **dec_kargs)) decoder.append( ConvTrans2dBlock(in_dims=64 * 2, out_dims=32, **dec_kargs)) for i in range(4): decoder.append( nn.Sequential( DenseBlock(32, 64, "dec"), ConvTrans2dBlock(in_dims=64, out_dims=32 if i != 3 else 16, **dec_kargs), )) decoder.append(DenseBlock(16, 32, "dec")) return decoder def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs): """ Build TCN blocks in each repeat (layer) """ blocks = [ TCNBlock(**tcn_kargs, dilation=(2**b)) for b in range(tcn_blocks) ] return nn.Sequential(*blocks) def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs): """ Build TCN layers """ layers = [ self._build_tcn_blocks(tcn_blocks, **tcn_kargs) for _ in range(tcn_layers) ] return nn.Sequential(*layers) def _build_avg_pool(self, pool_size): """ Build avg pooling layers """ avg_pool = nn.ModuleList() for sz in pool_size: avg_pool.append( nn.Sequential(nn.AvgPool2d(sz), nn.Conv2d(32, 8, 1, 1))) return avg_pool def forward(self, input, aux): wav_input = input spk_emb_input = aux batch_size, nsample = wav_input.shape # frequency-domain separation spec = torch.stft( wav_input, n_fft=self.win_len, hop_length=self.hop_size, window=torch.hann_window(self.win_len).to(wav_input.device).type( wav_input.type()), return_complex=True, ) # concat real and imag, split to subbands spec_RI = torch.stack([spec.real, spec.imag], 1) # spec = torch.einsum("hijk->hikj", spec_RI) # batchsize, 2, T, F spec = torch.transpose(spec_RI, 2, 3) # batchsize, 2, T, F out = self.conv2d(spec) out_list = [] out = self.encoder[0](out) predict_speaker_lable = torch.tensor(0.0).to( spk_emb_input.device) # dummy if self.joint_training: if not self.spk_feat: if self.feat_type == "consistent": with torch.no_grad(): spk_emb_input = self.preEmphasis(spk_emb_input) spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 spk_emb_input = spk_emb_input.log() spk_emb_input = spk_emb_input - torch.mean( spk_emb_input, dim=-1, keepdim=True) spk_emb_input = spk_emb_input.permute(0, 2, 1) tmp_spk_emb_input = self.spk_model(spk_emb_input) if isinstance(tmp_spk_emb_input, tuple): spk_emb_input = tmp_spk_emb_input[-1] else: spk_emb_input = tmp_spk_emb_input predict_speaker_lable = self.pred_linear(spk_emb_input) spk_embedding = self.spk_transform(spk_emb_input) spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) out = self.spk_fuse(out.transpose(2, 3), spk_embedding).transpose(2, 3) out_list.append(out) for _, enc in enumerate(self.encoder[1:]): out = enc(out) out_list.append(out) B, N, T, F = out.shape out = out.reshape(B, N, T * F) out = self.tcn_layers(out) out = out.reshape(B, N, T, F) out_list = out_list[::-1] for idx, dec in enumerate(self.decoder): out = dec(torch.cat([out_list[idx], out], 1)) # Pyramidal pooling B, N, T, F = out.shape upsample = nn.Upsample(size=(T, F), mode="bilinear") pool_list = [] for avg in self.avg_pool: pool_list.append(upsample(avg(out))) out = torch.cat([out, *pool_list], 1) out = self.avg_proj(out) out = self.deconv2d(out) est_spec = torch.transpose(out, 2, 3) # (batchsize, 2, F, T) B, N, F, T = est_spec.shape est_spec = torch.chunk(est_spec, 2, 1) # [(B, 1, F, T), (B, 1, F, T)]) est_spec = torch.complex(est_spec[0], est_spec[1]) output = torch.istft( est_spec.reshape(B, -1, T), n_fft=self.win_len, hop_length=self.hop_size, window=torch.hann_window(self.win_len).to(wav_input.device).type( wav_input.type()), length=nsample, ) return output, predict_speaker_lable if __name__ == "__main__": import numpy as np model = DPCCN() s = 0 for param in model.parameters(): s += np.product(param.size()) print("# of parameters: " + str(s / 1024.0 / 1024.0)) mix = torch.randn(4, 32000) aux = torch.randn(4, 256) est = model(mix, aux) print(est.size()) ================================================ FILE: wesep/models/sep_model.py ================================================ import wesep.models.bsrnn as bsrnn import wesep.models.convtasnet as convtasnet import wesep.models.dpccn as dpccn import wesep.models.tfgridnet as tfgridnet def get_model(model_name: str): if model_name.startswith("ConvTasNet"): return getattr(convtasnet, model_name) elif model_name.startswith("BSRNN"): return getattr(bsrnn, model_name) elif model_name.startswith("DPCNN"): return getattr(dpccn, model_name) elif model_name.startswith("TFGridNet"): return getattr(tfgridnet, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1) if __name__ == "__main__": print(get_model("ConvTasNet")) ================================================ FILE: wesep/models/tfgridnet.py ================================================ # The implementation is based on: # https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torchaudio from packaging.version import parse as V from wespeaker.models.speaker_model import get_speaker_model from wesep.modules.common.speaker import PreEmphasis from wesep.modules.common.speaker import SpeakerFuseLayer, SpeakerTransform from wesep.modules.tfgridnet.gridnet_block import GridNetBlock is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") class TFGridNet(nn.Module): """Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up the code by vectorizing multiple heads in self-attention, and better dealing with Deconv1D in each intra- and inter-block when emb_ks == emb_hs. Reference: [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, "TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation", in TASLP, 2023. [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, "TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural Speaker Separation", in ICASSP, 2023. NOTES: As outlined in the Reference, this model works best when trained with variance normalized mixture input and target, e.g., with mixture of shape [batch, samples, microphones], you normalize it by dividing with torch.std(mixture, (1, 2)). You must do the same for the target signals. It is encouraged to do so when not using scale-invariant loss functions such as SI-SDR. Specifically, use: std_ = std(mix) mix = mix / std_ tgt = tgt / std_ Args: n_srcs: number of output sources/speakers. n_fft: stft window size. stride: stft stride. window: stft window type choose between 'hamming', 'hanning' or None. n_imics: num of channels (only fixed-array geometry supported). n_layers: number of TFGridNetV2 blocks. lstm_hidden_units: number of hidden units in LSTM. attn_n_head: number of heads in self-attention attn_approx_qk_dim: approximate dim of frame-level key/value tensors emb_dim: embedding dimension emb_ks: kernel size for unfolding and deconv1D emb_hs: hop size for unfolding and deconv1D activation: activation function to use in the whole TFGridNetV2 model, you can use any torch supported activation e.g. 'relu' or 'elu'. eps: small epsilon for normalization layers. spk_emb_dim: the dimension of target speaker embeddings. use_spk_transform: whether use networks to transfer the speaker embeds. spk_fuse_type: the fusion method of speaker embeddings. """ def __init__( self, n_srcs=1, sr=16000, n_fft=128, stride=64, window="hann", n_imics=1, n_layers=6, lstm_hidden_units=192, attn_n_head=4, attn_approx_qk_dim=512, emb_dim=48, emb_ks=4, emb_hs=1, activation="prelu", eps=1.0e-5, spk_emb_dim=256, use_spk_transform=False, spk_fuse_type="multiply", joint_training=True, multi_task=False, spksInTrain=251, spk_model=None, spk_model_init=None, spk_model_freeze=False, spk_args=None, spk_feat=False, feat_type="consistent", ): super().__init__() self.n_srcs = n_srcs self.n_fft = n_fft self.stride = stride self.window = window self.n_imics = n_imics self.n_layers = n_layers self.spk_emb_dim = spk_emb_dim self.joint_training = joint_training self.spk_feat = spk_feat self.feat_type = feat_type self.spk_model_freeze = spk_model_freeze self.multi_task = multi_task assert n_fft % 2 == 0 n_freqs = n_fft // 2 + 1 if use_spk_transform: self.spk_transform = SpeakerTransform() else: self.spk_transform = nn.Identity() if joint_training: self.spk_model = get_speaker_model(spk_model)(**spk_args) if spk_model_init: pretrained_model = torch.load(spk_model_init) state = self.spk_model.state_dict() for key in state.keys(): if key in pretrained_model.keys(): state[key] = pretrained_model[key] # print(key) else: print("not %s loaded" % key) self.spk_model.load_state_dict(state) if spk_model_freeze: for param in self.spk_model.parameters(): param.requires_grad = False if not spk_feat: if feat_type == "consistent": self.preEmphasis = PreEmphasis() self.spk_encoder = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=n_fft, win_length=n_fft, hop_length=stride, f_min=20, window_fn=torch.hamming_window, n_mels=spk_args["feat_dim"], ) else: self.preEmphasis = nn.Identity() self.spk_encoder = nn.Identity() if multi_task: self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) else: self.pred_linear = nn.Identity() self.spk_fuse = SpeakerFuseLayer( embed_dim=spk_emb_dim, feat_dim=n_freqs, fuse_type=spk_fuse_type, ) t_ksize = 3 ks, padding = (t_ksize, 3), (t_ksize // 2, 1) self.conv = nn.Sequential( nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding), nn.GroupNorm(1, emb_dim, eps=eps), ) self.blocks = nn.ModuleList([]) for _ in range(n_layers): self.blocks.append( GridNetBlock( emb_dim, emb_ks, emb_hs, n_freqs, lstm_hidden_units, n_head=attn_n_head, approx_qk_dim=attn_approx_qk_dim, activation=activation, eps=eps, )) self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, ks, padding=padding) def forward( self, input: torch.Tensor, embeddings: torch.Tensor, ) -> torch.Tensor: """Forward. Args: input (torch.Tensor): batched multi-channel audio tensor with M audio channels and N samples [B, N, M] embeddings (torch.Tensor): batched target speaker embeddings [B, D] Returns: enhanced (List[Union(torch.Tensor)]): [(B, T), ...] list of len n_srcs of mono audio tensors with T samples. """ batch_size, n_samples = input.shape[0], input.shape[1] spk_emb_input = embeddings if self.n_imics == 1: assert len(input.shape) == 2 input = input[..., None] # [B, N, M] mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1] input = input / mix_std_ # RMS normalization input = input.transpose(1, 2).reshape( -1, input.size(1)) # [B, N, M] -> [B*M, N] window_func = getattr(torch, f"{self.window}_window") window = window_func(self.n_fft, dtype=input.dtype, device=input.device) batch = torch.stft( input, n_fft=self.n_fft, win_length=self.n_fft, hop_length=self.stride, window=window, return_complex=True, onesided=True, ) # [B, F, T] batch = batch.transpose(1, 2) # [B, T, F] batch0 = batch.view(batch_size, -1, batch.size(1), batch.size(2)) # [B, M, T, F] # ilens = torch.full((batch_size,), n_samples, dtype=torch.long) batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F] n_batch, _, n_frames, n_freqs = batch.shape batch = self.conv(batch) # [B, -1, T, F] predict_speaker_label = torch.tensor(0.0).to( spk_emb_input.device) # dummy if self.joint_training: if not self.spk_feat: if self.feat_type == "consistent": with torch.no_grad(): spk_emb_input = self.preEmphasis(spk_emb_input) spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 spk_emb_input = spk_emb_input.log() spk_emb_input = spk_emb_input - torch.mean( spk_emb_input, dim=-1, keepdim=True) spk_emb_input = spk_emb_input.permute(0, 2, 1) tmp_spk_emb_input = self.spk_model(spk_emb_input) if isinstance(tmp_spk_emb_input, tuple): spk_emb_input = tmp_spk_emb_input[-1] else: spk_emb_input = tmp_spk_emb_input predict_speaker_label = self.pred_linear(spk_emb_input) spk_embedding = self.spk_transform(spk_emb_input) # [B, D] spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) # [B, 1, D, 1] for ii in range(self.n_layers): batch = torch.transpose( self.spk_fuse(batch.transpose(2, 3), spk_embedding), 2, 3) # [B, -1, T, F] batch = self.blocks[ii](batch) # [B, -1, T, F] batch = self.deconv(batch) # [B, n_srcs*2, T, F] batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) assert is_torch_1_9_plus, "Require torch 1.9.0+." batch = torch.complex(batch[:, :, 0], batch[:, :, 1]) batch = torch.istft( torch.transpose(batch.view(-1, n_frames, n_freqs), 1, 2), n_fft=self.n_fft, hop_length=self.stride, win_length=self.n_fft, window=window, onesided=True, length=n_samples, return_complex=False, ) # [B, n_srcs] batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples) batch = batch * mix_std_ # reverse the RMS normalization # batch = [batch[:, src] for src in range(self.num_spk)] batch = batch.squeeze(1) return batch, predict_speaker_label @property def num_spk(self): return self.n_srcs @staticmethod def pad2(input_tensor, target_len): input_tensor = torch.nn.functional.pad( input_tensor, (0, target_len - input_tensor.shape[-1])) return input_tensor ================================================ FILE: wesep/modules/__init__.py ================================================ ================================================ FILE: wesep/modules/common/__init__.py ================================================ from wesep.modules.common.norm import ChannelWiseLayerNorm # noqa from wesep.modules.common.norm import FiLM # noqa from wesep.modules.common.norm import GlobalChannelLayerNorm # noqa from wesep.modules.common.norm import select_norm # noqa ================================================ FILE: wesep/modules/common/norm.py ================================================ import numbers import torch import torch.nn as nn class GlobalChannelLayerNorm(nn.Module): """ Calculate Global Layer Normalization dim: (int or list or torch.Size) – input shape from an expected input of size eps: a value added to the denominator for numerical stability. elementwise_affine: a boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). """ def __init__(self, dim, eps=1e-05, elementwise_affine=True): super(GlobalChannelLayerNorm, self).__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(self.dim, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x): # x = N x C x L # N x 1 x 1 # cln: mean,var N x 1 x L # gln: mean,var N x 1 x 1 if x.dim() != 3: raise RuntimeError("{} accept 3D tensor as input".format( self.__name__)) mean = torch.mean(x, (1, 2), keepdim=True) var = torch.mean((x - mean)**2, (1, 2), keepdim=True) # N x C x L if self.elementwise_affine: x = (self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias) else: x = (x - mean) / torch.sqrt(var + self.eps) return x class ChannelWiseLayerNorm(nn.LayerNorm): """ Channel wise layer normalization """ def __init__(self, *args, **kwargs): super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) def forward(self, x): """ x: N x C x T """ x = torch.transpose(x, 1, 2) x = super().forward(x) x = torch.transpose(x, 1, 2) return x def select_norm(norm, dim): """ Build normalize layer LN cost more memory than BN """ if norm not in ["cLN", "gLN", "BN"]: raise RuntimeError("Unsupported normalize layer: {}".format(norm)) if norm == "cLN": return ChannelWiseLayerNorm(dim, elementwise_affine=True) elif norm == "BN": return nn.BatchNorm1d(dim) else: return GlobalChannelLayerNorm(dim, elementwise_affine=True) class FiLM(nn.Module): """Feature-wise Linear Modulation (FiLM) layer https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160 # noqa """ def __init__(self, feat_size, embed_size, num_film_layers=1, layer_norm=False): super(FiLM, self).__init__() self.feat_size = feat_size self.embed_size = embed_size self.num_film_layers = num_film_layers self.layer_norm = nn.LayerNorm(embed_size) if layer_norm else None gamma_fcs, beta_fcs = [], [] for i in range(num_film_layers): if i == 0: gamma_fcs.append(nn.Linear(embed_size, feat_size)) beta_fcs.append(nn.Linear(embed_size, feat_size)) else: gamma_fcs.append(nn.Linear(feat_size, feat_size)) beta_fcs.append(nn.Linear(feat_size, feat_size)) self.gamma_fcs = nn.ModuleList(gamma_fcs) self.beta_fcs = nn.ModuleList(beta_fcs) self.init_weights() def init_weights(self): for i in range(self.num_film_layers): nn.init.zeros_(self.gamma_fcs[i].weight) nn.init.zeros_(self.gamma_fcs[i].bias) nn.init.zeros_(self.beta_fcs[i].weight) nn.init.zeros_(self.beta_fcs[i].bias) def forward(self, embed, x): gamma, beta = None, None for i in range(len(self.gamma_fcs)): if i == 0: gamma = self.gamma_fcs[i](embed) beta = self.beta_fcs[i](embed) else: gamma = self.gamma_fcs[i](gamma) beta = self.beta_fcs[i](beta) if len(gamma.shape) < len(x.shape): gamma = gamma.unsqueeze(-1).expand_as(x) beta = beta.unsqueeze(-1).expand_as(x) else: gamma = gamma.expand_as(x) beta = beta.expand_as(x) # print(gamma.size(), beta.size()) x = (1 + gamma) * x + beta if self.layer_norm is not None: x = self.layer_norm(x) return x class ConditionalLayerNorm(nn.Module): """ https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160 """ def __init__(self, normalized_shape, embed_dim, modulate_bias=False, eps=1e-5): super(ConditionalLayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape, ) self.normalized_shape = tuple(normalized_shape) self.embed_dim = embed_dim self.eps = eps self.weight = nn.Parameter(torch.empty(*normalized_shape)) self.bias = nn.Parameter(torch.empty(*normalized_shape)) assert len(normalized_shape) == 1 self.ln_weight_modulation = FiLM(normalized_shape[0], embed_dim) self.modulate_bias = modulate_bias if self.modulate_bias: self.ln_bias_modulation = FiLM(normalized_shape[0], embed_dim) else: self.ln_bias_modulation = None self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, input, embed): mean = torch.mean(input, -1, keepdim=True) var = torch.var(input, -1, unbiased=False, keepdim=True) weight = self.ln_weight_modulation( embed, self.weight.expand(embed.size(0), -1)) if self.ln_bias_modulation is None: bias = self.bias else: bias = self.ln_bias_modulation(embed, self.bias.expand(embed.size(0), -1)) res = (input - mean) / torch.sqrt(var + self.eps) * weight + bias return res def extra_repr(self): return "{normalized_shape}, {embed_dim}, \ modulate_bias={modulate_bias}, eps={eps}".format(**self.__dict__) ================================================ FILE: wesep/modules/common/speaker.py ================================================ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from wesep.modules.common import FiLM class PreEmphasis(torch.nn.Module): def __init__(self, coef: float = 0.97): super().__init__() self.coef = coef self.register_buffer( "flipped_filter", torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), ) def forward(self, input: torch.tensor) -> torch.tensor: input = input.unsqueeze(1) input = F.pad(input, (1, 0), "reflect") return F.conv1d(input, self.flipped_filter).squeeze(1) class SpeakerTransform(nn.Module): def __init__(self, embed_dim=256, num_layers=3, hid_dim=128): """ Transform the pretrained speaker embeddings, keep the dimension :param embed_dim: :param num_layers: :param hid_dim: :return: """ super(SpeakerTransform, self).__init__() self.transforms = [] self.transforms.append(nn.Conv1d(embed_dim, hid_dim, 1)) for _ in range(num_layers - 2): self.transforms.append(nn.Conv1d(hid_dim, hid_dim, 1)) self.transforms.append(nn.Tanh()) self.transforms.append(nn.Conv1d(hid_dim, embed_dim, 1)) self.transforms = nn.Sequential(*self.transforms) def forward(self, x): if len(x.size()) == 2: return self.transforms(x.unsqueeze(-1)).squeeze(-1) else: return self.transforms(x) class LinearLayer(nn.Module): def __init__(self, in_features, out_features, bias=True): super(LinearLayer, self).__init__() self.linear = nn.Linear(in_features, out_features, bias) def forward(self, x, dummy: Optional[torch.Tensor] = None): return self.linear(x) class SpeakerFuseLayer(nn.Module): def __init__(self, embed_dim=256, feat_dim=512, fuse_type="concat"): super(SpeakerFuseLayer, self).__init__() assert fuse_type in ["concat", "additive", "multiply", "FiLM", "None"] self.fuse_type = fuse_type if fuse_type == "concat": self.fc = LinearLayer(embed_dim + feat_dim, feat_dim) elif fuse_type == "additive": self.fc = LinearLayer(embed_dim, feat_dim) elif fuse_type == "multiply": self.fc = LinearLayer(embed_dim, feat_dim) elif fuse_type == "FiLM": self.fc = FiLM(feat_dim, embed_dim) else: raise ValueError("Fuse type not defined.") def forward(self, x, embed): """ :param x: batch x dimension x length :param embed: batch x dimension x 1 :return: """ if self.fuse_type == "concat": # For Conv if len(x.size()) == 3: embed_t = embed.expand(-1, -1, x.size(2)) y = torch.cat([x, embed_t], 1) y = torch.transpose(y, 1, 2) x = torch.transpose(self.fc(y), 1, 2) else: # len(x.size() == 4 embed_t = embed.expand(-1, x.size(1), -1, x.size(3)) y = torch.cat([x, embed_t], 2) y = torch.transpose(y, 2, 3) x = torch.transpose(self.fc(y), 2, 3).contiguous() # print(x.size()) elif self.fuse_type == "additive": if len(x.size()) == 3: embed_t = embed.expand(-1, -1, x.size(2)) embed_t = torch.transpose(embed_t, 1, 2) x = x + torch.transpose(self.fc(embed_t), 1, 2) else: # len(x.size() == 4 embed_t = embed.expand(-1, x.size(1), -1, x.size(3)) embed_t = torch.transpose(embed_t, 2, 3) x = x + torch.transpose(self.fc(embed_t), 2, 3) elif self.fuse_type == "multiply": if len(x.size()) == 3: embed_t = embed.expand(-1, -1, x.size(2)) embed_t = torch.transpose(embed_t, 1, 2) x = x * torch.transpose(self.fc(embed_t), 1, 2) else: # len(x.size() == 4 embed_t = embed.expand(-1, x.size(1), -1, x.size(3)) embed_t = torch.transpose(embed_t, 2, 3) x = x * torch.transpose(self.fc(embed_t), 2, 3) else: embed = embed.squeeze(-1) x = self.fc(embed, x) return x def test_speaker_fuse(): st = SpeakerTransform(embed_dim=256, num_layers=3, hid_dim=128) sfl = SpeakerFuseLayer(fuse_type="multiply") embeds = torch.rand(4, 256) encoder_output = torch.rand(4, 512, 1000) print(embeds.size()) embeds = st(embeds) print(embeds.size()) output = sfl(encoder_output, embeds) print(output.size()) if __name__ == "__main__": test_speaker_fuse() ================================================ FILE: wesep/modules/dpccn/__init__.py ================================================ ================================================ FILE: wesep/modules/dpccn/convs.py ================================================ from typing import Tuple import torch import torch.nn as nn class Conv1D(nn.Conv1d): """ 1D conv in ConvTasNet """ def __init__(self, *args, **kwargs): super(Conv1D, self).__init__(*args, **kwargs) def forward(self, x, squeeze=False): """ x: N x L or N x C x L """ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 2/3D tensor as input".format( self.__name__)) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if squeeze: x = torch.squeeze(x) return x class Conv2dBlock(nn.Module): def __init__( self, in_dims: int = 16, out_dims: int = 32, kernel_size: Tuple[int] = (3, 3), stride: Tuple[int] = (1, 1), padding: Tuple[int] = (1, 1), ) -> None: super(Conv2dBlock, self).__init__() self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride, padding) self.elu = nn.ELU() self.norm = nn.InstanceNorm2d(out_dims) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv2d(x) x = self.elu(x) return self.norm(x) class ConvTrans2dBlock(nn.Module): def __init__( self, in_dims: int = 32, out_dims: int = 16, kernel_size: Tuple[int] = (3, 3), stride: Tuple[int] = (1, 2), padding: Tuple[int] = (1, 0), output_padding: Tuple[int] = (0, 0), ) -> None: super(ConvTrans2dBlock, self).__init__() self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size, stride, padding, output_padding) self.elu = nn.ELU() self.norm = nn.InstanceNorm2d(out_dims) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.convtrans2d(x) x = self.elu(x) return self.norm(x) class DenseBlock(nn.Module): def __init__(self, in_dims, out_dims, mode="enc", **kargs): super(DenseBlock, self).__init__() if mode not in ["enc", "dec"]: raise RuntimeError("The mode option must be 'enc' or 'dec'!") n = 1 if mode == "enc" else 2 self.conv1 = Conv2dBlock(in_dims=in_dims * n, out_dims=in_dims, **kargs) self.conv2 = Conv2dBlock(in_dims=in_dims * (n + 1), out_dims=in_dims, **kargs) self.conv3 = Conv2dBlock(in_dims=in_dims * (n + 2), out_dims=in_dims, **kargs) self.conv4 = Conv2dBlock(in_dims=in_dims * (n + 3), out_dims=in_dims, **kargs) self.conv5 = Conv2dBlock(in_dims=in_dims * (n + 4), out_dims=out_dims, **kargs) def forward(self, x: torch.Tensor) -> torch.Tensor: y1 = self.conv1(x) y2 = self.conv2(torch.cat([x, y1], 1)) y3 = self.conv3(torch.cat([x, y1, y2], 1)) y4 = self.conv4(torch.cat([x, y1, y2, y3], 1)) y5 = self.conv5(torch.cat([x, y1, y2, y3, y4], 1)) return y5 class TCNBlock(nn.Module): """ TCN block: IN - ELU - Conv1D - IN - ELU - Conv1D """ def __init__( self, in_dims: int = 384, out_dims: int = 384, kernel_size: int = 3, dilation: int = 1, causal: bool = False, ) -> None: super(TCNBlock, self).__init__() self.norm1 = nn.InstanceNorm1d(in_dims) self.elu1 = nn.ELU() dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else (dilation * (kernel_size - 1))) # dilated conv self.dconv1 = nn.Conv1d( in_dims, out_dims, kernel_size, padding=dconv_pad, dilation=dilation, groups=in_dims, bias=True, ) self.norm2 = nn.InstanceNorm1d(in_dims) self.elu2 = nn.ELU() self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True) # different padding way self.causal = causal self.dconv_pad = dconv_pad def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.elu1(self.norm1(x)) y = self.dconv1(y) if self.causal: y = y[:, :, :-self.dconv_pad] y = self.elu2(self.norm2(y)) y = self.dconv2(y) x = x + y return x ================================================ FILE: wesep/modules/metric_gan/__init__.py ================================================ ================================================ FILE: wesep/modules/metric_gan/discriminator.py ================================================ import torch import torch.nn as nn # utility functions/classes used in the implementation of discriminators. class LearnableSigmoid(nn.Module): def __init__(self, in_features, beta=1): super().__init__() self.beta = beta self.slope = nn.Parameter(torch.ones(in_features)) self.slope.requiresGrad = True def forward(self, x): return self.beta * torch.sigmoid(self.slope * x) # discriminators class CMGAN_Discriminator(nn.Module): def __init__( self, n_fft=400, hop=100, in_channels=2, hid_chans=16, ksz=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, num_conv_blocks=4, num_linear_layers=2, ): """discriminator used in CMGAN (Interspeech 2022) paper: https://arxiv.org/pdf/2203.15149.pdf code: https://github.com/ruizhecao96/CMGAN Args: n_fft (int, optional): the windows length of stft. Defaults to 400. hop (int, optional): the hop length of stft. Defaults to 100. in_channels (int, optional): num of input channels. Defaults to 2. hid_chans (int, optional): num of hidden channels. Defaults to 16. ksz (tuple, optional): kernel size. Defaults to (4, 4). stride (tuple, optional): stride. Defaults to (2, 2). padding (tuple, optional): padding. Defaults to (1, 1). bias (bool, optional): bias. Defaults to False. num_conv_blocks (int, optional): num of conv blocks. Defaults to 4. num_linear_layers (int, optional): num of linear layers. Defaults to 2. """ super(CMGAN_Discriminator, self).__init__() assert num_conv_blocks >= num_linear_layers self.n_fft = n_fft self.hop = hop self.num_conv_blocks = num_conv_blocks self.num_linear_layers = num_linear_layers self.conv = nn.ModuleList([]) in_chans = in_channels out_chans = hid_chans for i in range(num_conv_blocks): self.conv.append( nn.Sequential( nn.utils.spectral_norm( nn.Conv2d( in_chans, out_chans, ksz, stride, padding, bias=bias, )), nn.InstanceNorm2d(out_chans, affine=True), nn.PReLU(out_chans), )) in_chans = out_chans out_chans = hid_chans * (2**(i + 1)) self.pooling = nn.Sequential( nn.AdaptiveMaxPool2d(1), nn.Flatten(), ) self.fc = nn.ModuleList([]) for i in range(num_linear_layers - 1): self.fc.append( nn.Sequential( nn.utils.spectral_norm( nn.Linear( hid_chans * (2**(num_conv_blocks - 1 - i)), hid_chans * (2**(num_conv_blocks - 2 - i)), )), nn.Dropout(0.3), nn.PReLU(hid_chans * (2**(num_conv_blocks - 2 - i))), )) self.fc.append( nn.Sequential( nn.utils.spectral_norm( nn.Linear( hid_chans * (2**(num_conv_blocks - num_linear_layers)), 1, )), LearnableSigmoid(1), )) def forward(self, ref_wav, est_wav): """ Args: ref_wav (torch.Tensor): the reference signal. [B, T] est_wav (torch.Tensor): the estimated signal. [B, T] Return: estimated_scores (torch.Tensor): estimated scores, [B] """ ref_spec = torch.stft( ref_wav, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(ref_wav.device).type( ref_wav.type()), return_complex=True, ).transpose(-1, -2) est_spec = torch.stft( est_wav, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(est_wav.device).type( est_wav.type()), return_complex=True, ).transpose(-1, -2) # input shape: (B, 2, T, F) input = torch.stack((abs(ref_spec), abs(est_spec)), dim=1) for i in range(self.num_conv_blocks): input = self.conv[i](input) input = self.pooling(input) for i in range(self.num_linear_layers): input = self.fc[i](input) return input if __name__ == "__main__": # functions used to test discriminators def test_CMGAN_Discriminator(): B, T = 2, 16000 ref_spec = torch.randn(B, T) est_spec = torch.randn(B, T) D = CMGAN_Discriminator() metric = D(ref_spec, est_spec).detach() print(f"estimated metric score is {metric}") test_CMGAN_Discriminator() ================================================ FILE: wesep/modules/tasnet/__init__.py ================================================ from wesep.modules.tasnet.decoder import DeepDecoder # noqa from wesep.modules.tasnet.decoder import MultiDecoder # noqa from wesep.modules.tasnet.encoder import DeepEncoder # noqa from wesep.modules.tasnet.encoder import MultiEncoder # noqa from wesep.modules.tasnet.separation import Separation, FuseSeparation # noqa from wesep.modules.tasnet.speaker import ResNet4SpExplus # noqa ================================================ FILE: wesep/modules/tasnet/convs.py ================================================ import torch import torch.nn as nn from wesep.modules.common import select_norm # from wesep.modules.common.spkadapt import SpeakerFuseLayer class Conv1D(nn.Conv1d): def __init__(self, *args, **kwargs): super(Conv1D, self).__init__(*args, **kwargs) def forward(self, x, squeeze=False): # x: N x C x L if x.dim() not in [2, 3]: raise RuntimeError("{} accept 2/3D tensor as input".format( self.__name__)) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if squeeze: x = torch.squeeze(x) return x class ConvTrans1D(nn.ConvTranspose1d): def __init__(self, *args, **kwargs): super(ConvTrans1D, self).__init__(*args, **kwargs) def forward(self, x, squeeze=False): """ x: N x L or N x C x L """ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 2/3D tensor as input".format( self.__name__)) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if squeeze: x = torch.squeeze(x) return x class Conv1DBlock(nn.Module): """ Consider only residual links """ def __init__( self, in_channels=256, out_channels=512, kernel_size=3, dilation=1, norm="gln", causal=False, skip_con=True, ): super(Conv1DBlock, self).__init__() # conv 1 x 1 self.conv1x1 = Conv1D(in_channels, out_channels, 1) self.PReLU_1 = nn.PReLU() self.norm_1 = select_norm(norm, out_channels) # not causal don't need to padding, causal need to pad+1 = kernel_size self.pad = ((dilation * (kernel_size - 1)) // 2 if not causal else (dilation * (kernel_size - 1))) # depthwise convolution # TODO: This is not depthwise seperable convolution self.dwconv = Conv1D( out_channels, out_channels, kernel_size, groups=out_channels, padding=self.pad, dilation=dilation, ) self.PReLU_2 = nn.PReLU() self.norm_2 = select_norm(norm, out_channels) if skip_con: self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True) self.Output = nn.Conv1d(out_channels, in_channels, 1, bias=True) self.causal = causal self.skip_con = skip_con def forward(self, x): # x: N x C x L # N x O_C x L c = self.conv1x1(x) # N x O_C x L c = self.PReLU_1(c) c = self.norm_1(c) # causal: N x O_C x (L+pad) # noncausal: N x O_C x L c = self.dwconv(c) if self.causal: c = c[:, :, :-self.pad] c = self.PReLU_2(c) c = self.norm_2(c) # N x O_C x L if self.skip_con: Sc = self.Sc_conv(c) c = self.Output(c) return Sc, c + x c = self.Output(c) return x + c class Conv1DBlock4Fuse(nn.Module): """ 1D convolutional block: Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv """ def __init__( self, in_channels=256, spk_embed_dim=100, conv_channels=512, kernel_size=3, dilation=1, norm="cLN", causal=False, ): super(Conv1DBlock4Fuse, self).__init__() # 1x1 conv self.conv1x1 = Conv1D(in_channels + spk_embed_dim, conv_channels, 1) self.prelu1 = nn.PReLU() self.lnorm1 = select_norm(norm, conv_channels) dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else (dilation * (kernel_size - 1))) # depthwise conv self.dconv = nn.Conv1d( conv_channels, conv_channels, kernel_size, groups=conv_channels, padding=dconv_pad, dilation=dilation, bias=True, ) self.prelu2 = nn.PReLU() self.lnorm2 = select_norm(norm, conv_channels) # 1x1 conv cross channel self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) # different padding way self.causal = causal self.dconv_pad = dconv_pad def forward(self, x, aux): T = x.shape[-1] aux = aux.repeat(1, 1, T) y = torch.cat([x, aux], 1) y = self.conv1x1(y) y = self.lnorm1(self.prelu1(y)) y = self.dconv(y) if self.causal: y = y[:, :, :-self.dconv_pad] y = self.lnorm2(self.prelu2(y)) y = self.sconv(y) x = x + y return x ================================================ FILE: wesep/modules/tasnet/decoder.py ================================================ import torch import torch.nn as nn from wesep.modules.tasnet.convs import Conv1D, ConvTrans1D class DeepDecoder(nn.Module): def __init__(self, N, kernel_size=16, stride=16 // 2): super(DeepDecoder, self).__init__() self.sequential = nn.Sequential( nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=8, padding=8), nn.PReLU(), nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=4, padding=4), nn.PReLU(), nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=2, padding=2), nn.PReLU(), nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=1, padding=1), nn.PReLU(), nn.ConvTranspose1d(N, 1, kernel_size=kernel_size, stride=stride, bias=True), ) def forward(self, x): """ x: N x L or N x C x L """ x = self.sequential(x) if torch.squeeze(x).dim() == 1: x = torch.squeeze(x, dim=1) else: x = torch.squeeze(x) return x class MultiDecoder(nn.Module): def __init__(self, in_channels, middle_channels, out_channels, kernel_size, stride): super(MultiDecoder, self).__init__() B = in_channels N = middle_channels L = kernel_size # n x B x T => n x 2N x T self.mask1 = Conv1D(B, N, 1) self.mask2 = Conv1D(B, N, 1) self.mask3 = Conv1D(B, N, 1) # using ConvTrans1D: n x N x T => n x 1 x To # To = (T - 1) * L // 2 + L self.decoder_1d_1 = ConvTrans1D(N, out_channels, kernel_size=L, stride=stride, bias=True) self.decoder_1d_2 = ConvTrans1D(N, out_channels, kernel_size=80, stride=stride, bias=True) self.decoder_1d_3 = ConvTrans1D(N, out_channels, kernel_size=160, stride=stride, bias=True) def forward(self, x, w1, w2, w3, actLayer): """ x: N x L or N x C x L """ m1 = actLayer(self.mask1(x)) m2 = actLayer(self.mask2(x)) m3 = actLayer(self.mask3(x)) s1 = w1 * m1 s2 = w2 * m2 s3 = w3 * m3 est1 = self.decoder_1d_1(s1, squeeze=True) xlen = est1.shape[-1] if est1.dim() > 1: est2 = self.decoder_1d_2(s2, squeeze=True)[:, :xlen] est3 = self.decoder_1d_3(s3, squeeze=True)[:, :xlen] else: est1 = est1.unsqueeze(0) est2 = self.decoder_1d_2(s2, squeeze=True).unsqueeze(0)[:, :xlen] est3 = self.decoder_1d_3(s3, squeeze=True).unsqueeze(0)[:, :xlen] s = [est1, est2, est3] return s ================================================ FILE: wesep/modules/tasnet/encoder.py ================================================ import torch as th import torch.nn as nn import torch.nn.functional as F from wesep.modules.common import select_norm from wesep.modules.tasnet.convs import Conv1D class DeepEncoder(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride): super(DeepEncoder, self).__init__() self.sequential = nn.Sequential( Conv1D(in_channels, out_channels, kernel_size, stride=stride), Conv1D( out_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding=1, ), nn.PReLU(), Conv1D( out_channels, out_channels, kernel_size=3, stride=1, dilation=2, padding=2, ), nn.PReLU(), Conv1D( out_channels, out_channels, kernel_size=3, stride=1, dilation=4, padding=4, ), nn.PReLU(), Conv1D( out_channels, out_channels, kernel_size=3, stride=1, dilation=8, padding=8, ), nn.PReLU(), ) def forward(self, x): """ :param x: [B, T] :return: out: [B, N, T] """ x = self.sequential(x) return x class MultiEncoder(nn.Module): def __init__(self, in_channels, middle_channels, out_channels, kernel_size, stride): super(MultiEncoder, self).__init__() self.L1 = kernel_size self.L2 = 80 self.L3 = 160 self.encoder_1d_short = Conv1D(in_channels, middle_channels, self.L1, stride=stride, padding=0) self.encoder_1d_middle = Conv1D(in_channels, middle_channels, self.L2, stride=stride, padding=0) self.encoder_1d_long = Conv1D(in_channels, middle_channels, self.L3, stride=stride, padding=0) # keep T not change # T = int((xlen - L) / (L // 2)) + 1 # before repeat blocks, always cLN self.ln = select_norm( "cLN", 3 * middle_channels) # ChannelWiseLayerNorm(3 * middle_channels) # n x N x T => n x B x T self.proj = Conv1D(3 * middle_channels, out_channels, 1) def forward(self, x): """ :param x: [B, T] :return: out: [B, N, T] """ w1 = F.relu(self.encoder_1d_short(x)) T = w1.shape[-1] xlen1 = x.shape[-1] xlen2 = (T - 1) * (self.L1 // 2) + self.L2 xlen3 = (T - 1) * (self.L1 // 2) + self.L3 w2 = F.relu( self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) w3 = F.relu( self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) # n x 3N x T x = self.ln(th.cat([w1, w2, w3], 1)) # n x B x T x = self.proj(x) return x, w1, w2, w3 ================================================ FILE: wesep/modules/tasnet/separation.py ================================================ import torch.nn as nn from wesep.modules.common import select_norm from wesep.modules.common.speaker import SpeakerFuseLayer from wesep.modules.tasnet.convs import Conv1DBlock, Conv1DBlock4Fuse class Separation(nn.Module): def __init__( self, R, X, B, H, P, norm="gLN", causal=False, skip_con=True, start_dilation=0, ): """ Args :param R: Number of repeats :param X: Number of convolutional blocks in each repeat :param B: Number of channels in bottleneck and the residual paths :param H: Number of channels in convolutional blocks :param P: Kernel size in convolutional blocks :param norm: The type of normalization(gln, cln, bn) :param causal: Two choice(causal or noncausal) :param skip_con: Whether to use skip connection """ super(Separation, self).__init__() self.separation = nn.ModuleList([]) for _ in range(R): for x in range(start_dilation, X): self.separation.append( Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con)) self.skip_con = skip_con def forward(self, x): """ x: [B, N, L] out: [B, N, L] """ if self.skip_con: skip_connection = 0 for i in range(len(self.separation)): skip, out = self.separation[i](x) skip_connection = skip_connection + skip x = out return skip_connection else: for i in range(len(self.separation)): out = self.separation[i](x) x = out return x class FuseSeparation(nn.Module): def __init__( self, R, X, B, H, P, norm="gLN", causal=False, skip_con=False, C_embedding=256, spk_fuse_type="concatConv", multi_fuse=True, ): """ :param R: Number of repeats :param X: Number of convolutional blocks in each repeat :param B: Number of channels in bottleneck and the residual paths :param H: Number of channels in convolutional blocks :param P: Kernel size in convolutional blocks :param norm: The type of normalization(gln, cln, bn) :param causal: Two choice(causal or noncausal) :param skip_con: Whether to use skip connection """ super(FuseSeparation, self).__init__() self.multi_fuse = multi_fuse self.spk_fuse_type = spk_fuse_type self.separation = nn.ModuleList([]) if self.multi_fuse: for _ in range(R): if spk_fuse_type == "concatConv": self.separation.append( Conv1DBlock4Fuse( spk_embed_dim=C_embedding, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1, )) self.separation.append( Separation( 1, X, B, H, P, norm=norm, causal=causal, skip_con=skip_con, start_dilation=1, )) else: self.separation.append( SpeakerFuseLayer( embed_dim=C_embedding, feat_dim=B, fuse_type=spk_fuse_type, )) self.separation.append(nn.PReLU()) self.separation.append(select_norm(norm, B)) self.separation.append( Separation( 1, X, B, H, P, norm=norm, causal=causal, skip_con=skip_con, )) else: if spk_fuse_type == "concatConv": self.separation.append( Conv1DBlock4Fuse( spk_embed_dim=C_embedding, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1, )) else: self.separation.append( SpeakerFuseLayer( embed_dim=C_embedding, feat_dim=B, fuse_type=spk_fuse_type, )) self.separation.append(nn.PReLU()) self.separation.append(select_norm(norm, B)) self.separation = Separation(R, X, B, H, P, norm=norm, causal=causal, skip_con=skip_con) def forward(self, x, spk_embedding): """ x: [B, N, L] out: [B, N, L] """ if self.multi_fuse: if self.spk_fuse_type == "concatConv": round_num = 2 else: round_num = 4 for i in range(len(self.separation)): if i % round_num == 0: x = self.separation[i](x, spk_embedding) else: x = self.separation[i](x) else: x = self.separation[0](x, spk_embedding) for i in range(1, len(self.separation)): x = self.separation[i](x) return x ================================================ FILE: wesep/modules/tasnet/separator.py ================================================ import torch.nn as nn from wesep.modules.tasnet.convs import Conv1DBlock class Separation(nn.Module): """ R Number of repeats X Number of convolutional blocks in each repeat B Number of channels in bottleneck and the residual paths H Number of channels in convolutional blocks P Kernel size in convolutional blocks norm The type of normalization(gln, cl, bn) causal Two choice(causal or noncausal) skip_con Whether to use skip connection """ def __init__(self, R, X, B, H, P, norm="gln", causal=False, skip_con=True): super(Separation, self).__init__() self.separation = nn.ModuleList([]) for _ in range(R): for x in range(X): self.separation.append( Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con)) self.skip_con = skip_con def forward(self, x): """ x: [B, N, L] out: [B, N, L] """ if self.skip_con: skip_connection = 0 for i in range(len(self.separation)): skip, out = self.separation[i](x) skip_connection = skip_connection + skip x = out return skip_connection else: for i in range(len(self.separation)): out = self.separation[i](x) x = out return x ================================================ FILE: wesep/modules/tasnet/speaker.py ================================================ import torch.nn as nn from wesep.modules.common.norm import ChannelWiseLayerNorm from wesep.modules.tasnet.convs import Conv1D class ResBlock(nn.Module): """ ref to https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py and https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py """ def __init__(self, in_dims, out_dims): super().__init__() self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False) self.batch_norm1 = nn.BatchNorm1d(out_dims) self.batch_norm2 = nn.BatchNorm1d(out_dims) self.prelu1 = nn.PReLU() self.prelu2 = nn.PReLU() self.mp = nn.MaxPool1d(3) if in_dims != out_dims: self.downsample = True self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) else: self.downsample = False def forward(self, x): residual = x x = self.conv1(x) x = self.batch_norm1(x) x = self.prelu1(x) x = self.conv2(x) x = self.batch_norm2(x) if self.downsample: residual = self.conv_downsample(residual) x = x + residual x = self.prelu2(x) return self.mp(x) class ResNet4SpExplus(nn.Module): def __init__(self, in_channel=256, C_embedding=256): super().__init__() self.aux_enc3 = nn.Sequential( ChannelWiseLayerNorm(3 * in_channel), Conv1D(3 * 256, 256, 1), ResBlock(256, 256), ResBlock(256, 512), ResBlock(512, 512), Conv1D(512, C_embedding, 1), ) def forward(self, x): aux = self.aux_enc3(x) aux = aux.mean(dim=-1) return aux ================================================ FILE: wesep/modules/tfgridnet/__init__.py ================================================ ================================================ FILE: wesep/modules/tfgridnet/gridnet_block.py ================================================ # The implementation is based on: # https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from torch.nn.parameter import Parameter from wesep.utils.utils import get_layer class GridNetBlock(nn.Module): def __getitem__(self, key): return getattr(self, key) def __init__( self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, n_head=4, approx_qk_dim=512, activation="prelu", eps=1e-5, ): super().__init__() assert activation == "prelu" in_channels = emb_dim * emb_ks self.intra_norm = nn.LayerNorm(emb_dim, eps=eps) self.intra_rnn = nn.LSTM( in_channels, hidden_channels, 1, batch_first=True, bidirectional=True, ) if emb_ks == emb_hs: self.intra_linear = nn.Linear(hidden_channels * 2, in_channels) else: self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) self.inter_norm = nn.LayerNorm(emb_dim, eps=eps) self.inter_rnn = nn.LSTM( in_channels, hidden_channels, 1, batch_first=True, bidirectional=True, ) if emb_ks == emb_hs: self.inter_linear = nn.Linear(hidden_channels * 2, in_channels) else: self.inter_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate assert emb_dim % n_head == 0 self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1)) self.add_module( "attn_norm_Q", AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), ) self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1)) self.add_module( "attn_norm_K", AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), ) self.add_module("attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)) self.add_module( "attn_norm_V", AllHeadPReLULayerNormalization4DCF( (n_head, emb_dim // n_head, n_freqs), eps=eps), ) self.add_module( "attn_concat_proj", nn.Sequential( nn.Conv2d(emb_dim, emb_dim, 1), get_layer(activation)(), LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), ), ) self.emb_dim = emb_dim self.emb_ks = emb_ks self.emb_hs = emb_hs self.n_head = n_head def forward(self, x): """GridNetBlock Forward. Args: x: [B, C, T, Q] out: [B, C, T, Q] """ B, C, old_T, old_Q = x.shape olp = self.emb_ks - self.emb_hs T = math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks Q = math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C] x = F.pad( x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C] # intra RNN input_ = x intra_rnn = self.intra_norm(input_) # [B, T, Q, C] if self.emb_ks == self.emb_hs: intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) # [BT, Q//I, I*C] intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, Q//I, H] intra_rnn = self.intra_linear(intra_rnn) # [BT, Q//I, I*C] intra_rnn = intra_rnn.view([B, T, Q, C]) else: intra_rnn = intra_rnn.view([B * T, Q, C]) # [BT, Q, C] intra_rnn = intra_rnn.transpose(1, 2) # [BT, C, Q] intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*I, -1] intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I] intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H] intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1] intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q] intra_rnn = intra_rnn.view([B, T, C, Q]) intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C] intra_rnn = intra_rnn + input_ # [B, T, Q, C] intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C] # inter RNN input_ = intra_rnn inter_rnn = self.inter_norm(input_) # [B, Q, T, C] if self.emb_ks == self.emb_hs: inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) # [BQ, T//I, I*C] inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, T//I, H] inter_rnn = self.inter_linear(inter_rnn) # [BQ, T//I, I*C] inter_rnn = inter_rnn.view([B, Q, T, C]) else: inter_rnn = inter_rnn.view(B * Q, T, C) # [BQ, T, C] inter_rnn = inter_rnn.transpose(1, 2) # [BQ, C, T] inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BQ, C*I, -1] inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I] inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, -1, H] inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1] inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T] inter_rnn = inter_rnn.view([B, Q, C, T]) inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C] inter_rnn = inter_rnn + input_ # [B, Q, T, C] inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q] inter_rnn = inter_rnn[..., olp:olp + old_T, olp:olp + old_Q] batch = inter_rnn Q = self["attn_norm_Q"]( self["attn_conv_Q"](batch)) # [B, n_head, C, T, Q] K = self["attn_norm_K"]( self["attn_conv_K"](batch)) # [B, n_head, C, T, Q] V = self["attn_norm_V"]( self["attn_conv_V"](batch)) # [B, n_head, C, T, Q] Q = Q.view(-1, *Q.shape[2:]) # [B*n_head, C, T, Q] K = K.view(-1, *K.shape[2:]) # [B*n_head, C, T, Q] V = V.view(-1, *V.shape[2:]) # [B*n_head, C, T, Q] Q = Q.transpose(1, 2) Q = Q.flatten(start_dim=2) # [B', T, C*Q] K = K.transpose(2, 3) K = K.contiguous().view([B * self.n_head, -1, old_T]) # [B', C*Q, T] V = V.transpose(1, 2) # [B', T, C, Q] old_shape = V.shape V = V.flatten(start_dim=2) # [B', T, C*Q] emb_dim = Q.shape[-1] attn_mat = torch.matmul(Q, K) / (emb_dim**0.5) # [B', T, T] attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T] V = torch.matmul(attn_mat, V) # [B', T, C*Q] V = V.reshape(old_shape) # [B', T, C, Q] V = V.transpose(1, 2) # [B', C, T, Q] emb_dim = V.shape[1] batch = V.contiguous().view([B, self.n_head * emb_dim, old_T, old_Q]) # [B, C, T, Q]) batch = self["attn_concat_proj"](batch) # [B, C, T, Q]) out = batch + inter_rnn return out class LayerNormalization4DCF(nn.Module): def __init__(self, input_dimension, eps=1e-5): super().__init__() assert len(input_dimension) == 2 param_size = [1, input_dimension[0], 1, input_dimension[1]] self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) init.ones_(self.gamma) init.zeros_(self.beta) self.eps = eps def forward(self, x): if x.ndim == 4: stat_dim = (1, 3) else: raise ValueError( "Expect x to have 4 dimensions, but got {}".format(x.ndim)) mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1] std_ = torch.sqrt( x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F] x_hat = ((x - mu_) / std_) * self.gamma + self.beta return x_hat class AllHeadPReLULayerNormalization4DCF(nn.Module): def __init__(self, input_dimension, eps=1e-5): super().__init__() assert len(input_dimension) == 3 H, E, n_freqs = input_dimension param_size = [1, H, E, 1, n_freqs] self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) init.ones_(self.gamma) init.zeros_(self.beta) self.act = nn.PReLU(num_parameters=H, init=0.25) self.eps = eps self.H = H self.E = E self.n_freqs = n_freqs def forward(self, x): assert x.ndim == 4 B, _, T, _ = x.shape x = x.view([B, self.H, self.E, T, self.n_freqs]) x = self.act(x) # [B,H,E,T,F] stat_dim = (2, 4) mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1] std_ = torch.sqrt( x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,H,1,T,1] x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F] return x ================================================ FILE: wesep/utils/abs_loss.py ================================================ from abc import ABC, abstractmethod import torch EPS = torch.finfo(torch.get_default_dtype()).eps class AbsEnhLoss(torch.nn.Module, ABC): """Base class for all Enhancement loss modules.""" # the name will be the key that appears in the reporter @property def name(self) -> str: return NotImplementedError # This property specifies whether the criterion will only # be evaluated during the inference stage @property def only_for_test(self) -> bool: return False @abstractmethod def forward( self, ref, inf, ) -> torch.Tensor: # the return tensor should be shape of (batch) raise NotImplementedError ================================================ FILE: wesep/utils/checkpoint.py ================================================ from typing import List, Optional import torch from wesep.utils.schedulers import BaseClass def load_pretrained_model(model: torch.nn.Module, path: str, type: str = "generator"): assert type in ["generator", "discriminator"] states = torch.load( path, map_location="cpu", ) if type == "generator": state = states["models"][0] else: assert len(states["models"]) == 2 state = states["models"][1] if isinstance(model, torch.nn.DataParallel): model.module.load_state_dict(state) elif isinstance(model, torch.nn.parallel.DistributedDataParallel): model.module.load_state_dict(state) else: model.load_state_dict(state) def load_checkpoint( models: List[torch.nn.Module], optimizers: List[torch.optim.Optimizer], schedulers: List[BaseClass], scaler: Optional[torch.cuda.amp.GradScaler], path: str, only_model: bool = False, mode: str = "all", ): assert mode in ["all", "generator", "discriminator"] states = torch.load( path, map_location="cpu", ) if mode == "generator": model_state, optimizer_state, scheduler_state = ( [states["models"][0]], [states["optimizers"][0]], [states["schedulers"][0]], ) elif mode == "discriminator": model_state, optimizer_state, scheduler_state = ( [states["models"][1]], [states["optimizers"][1]], [states["schedulers"][1]], ) else: model_state, optimizer_state, scheduler_state = ( states["models"], states["optimizers"], states["schedulers"], ) for model, state in zip(models, model_state): if isinstance(model, torch.nn.DataParallel): model.module.load_state_dict(state, strict=False) elif isinstance(model, torch.nn.parallel.DistributedDataParallel): model.module.load_state_dict(state, strict=False) else: model.load_state_dict(state, strict=False) if not only_model: for optimizer, state in zip(optimizers, optimizer_state): optimizer.load_state_dict(state) for scheduler, state in zip(schedulers, scheduler_state): if scheduler is not None: scheduler.load_state_dict(state) if scaler is not None: if states["scaler"] is not None: scaler.load_state_dict(states["scaler"]) def save_checkpoint( models: List[torch.nn.Module], optimizers: List[torch.optim.Optimizer], schedulers: List[BaseClass], scaler: Optional[torch.cuda.amp.GradScaler], path: str, ): if isinstance(models[0], torch.nn.DataParallel): state_dict = [model.module.state_dict() for model in models] elif isinstance(models[0], torch.nn.parallel.DistributedDataParallel): state_dict = [model.module.state_dict() for model in models] else: state_dict = [model.state_dict() for model in models] torch.save( { "models": state_dict, "optimizers": [o.state_dict() for o in optimizers], "schedulers": [s.state_dict() if s is not None else None for s in schedulers], "scaler": scaler.state_dict() if scaler is not None else None, }, path, ) ================================================ FILE: wesep/utils/datadir_writer.py ================================================ import warnings from pathlib import Path from typing import Union # ported from # https://github.com/espnet/espnet/blob/master/espnet2/fileio/datadir_writer.py class DatadirWriter: """Writer class to create kaldi like data directory. Examples: >>> with DatadirWriter("output") as writer: ... # output/sub.txt is created here ... subwriter = writer["sub.txt"] ... # Write "uttidA some/where/a.wav" ... subwriter["uttidA"] = "some/where/a.wav" ... subwriter["uttidB"] = "some/where/b.wav" """ def __init__(self, p: Union[Path, str]): self.path = Path(p) self.chilidren = {} self.fd = None self.has_children = False self.keys = set() def __enter__(self): return self def __getitem__(self, key: str) -> "DatadirWriter": if self.fd is not None: raise RuntimeError("This writer points out a file") if key not in self.chilidren: w = DatadirWriter((self.path / key)) self.chilidren[key] = w self.has_children = True retval = self.chilidren[key] return retval def __setitem__(self, key: str, value: str): if self.has_children: raise RuntimeError("This writer points out a directory") if key in self.keys: warnings.warn(f"Duplicated: {key}", stacklevel=1) if self.fd is None: self.path.parent.mkdir(parents=True, exist_ok=True) self.fd = self.path.open("w", encoding="utf-8") self.keys.add(key) self.fd.write(f"{key} {value}\n") def __exit__(self, exc_type, exc_val, exc_tb): self.close() def close(self): if self.has_children: prev_child = None for child in self.chilidren.values(): child.close() if prev_child is not None and prev_child.keys != child.keys: warnings.warn( f"Ids are mismatching between " f"{prev_child.path} and {child.path}", stacklevel=1) prev_child = child elif self.fd is not None: self.fd.close() ================================================ FILE: wesep/utils/dnsmos.py ================================================ import json import math import librosa import numpy as np import requests import torch import torchaudio SAMPLING_RATE = 16000 INPUT_LENGTH = 9.01 # URL for the web service SCORING_URI_DNSMOS = "https://dnsmos.azurewebsites.net/score" SCORING_URI_DNSMOS_P835 = ( "https://dnsmos.azurewebsites.net/v1/dnsmosp835/score") def poly1d(coefficients, use_numpy=False): if use_numpy: return np.poly1d(coefficients) coefficients = tuple(reversed(coefficients)) def func(p): return sum(coef * p**i for i, coef in enumerate(coefficients)) return func class DNSMOS_web: # ported from # https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos.py def __init__(self, auth_key): self.auth_key = auth_key def __call__(self, aud, input_fs, fname="", method="p808"): if input_fs != SAMPLING_RATE: audio = librosa.resample(aud, orig_sr=input_fs, target_sr=SAMPLING_RATE) else: audio = aud # Set the content type headers = {"Content-Type": "application/json"} # If authentication is enabled, set the authorization header headers["Authorization"] = f"Basic {self.auth_key}" fname = fname + ".wav" if fname else "audio.wav" data = {"data": audio.tolist(), "filename": fname} input_data = json.dumps(data) # Make the request and display the response if method == "p808": u = SCORING_URI_DNSMOS else: u = SCORING_URI_DNSMOS_P835 resp = requests.post(u, data=input_data, headers=headers) score_dict = resp.json() return score_dict class DNSMOS_local: # ported from # https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos_local.py def __init__( self, primary_model_path, p808_model_path, use_gpu=False, convert_to_torch=False, gpu_device=None, ): self.convert_to_torch = convert_to_torch self.use_gpu = use_gpu self.gpu_device = gpu_device if convert_to_torch: try: from onnx2torch import convert except ModuleNotFoundError: raise RuntimeError( "Please install onnx2torch manually and retry!") from None if primary_model_path is not None: self.primary_model = convert(primary_model_path).eval() self.p808_model = convert(p808_model_path).eval() self.spectrogram = torchaudio.transforms.Spectrogram( n_fft=321, hop_length=160, pad_mode="constant") self.to_db = torchaudio.transforms.AmplitudeToDB("power", top_db=80.0) if use_gpu: if gpu_device is not None: torch.cuda.set_device(gpu_device) if primary_model_path is not None: self.primary_model = self.primary_model.cuda() self.p808_model = self.p808_model.cuda() self.spectrogram = self.spectrogram.cuda() else: try: import onnxruntime as ort except ModuleNotFoundError: raise RuntimeError( "Please install onnxruntime manually and retry!") from None prvd = ("CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider") if primary_model_path is not None: self.onnx_sess = ort.InferenceSession(primary_model_path, providers=[prvd]) self.p808_onnx_sess = ort.InferenceSession(p808_model_path, providers=[prvd]) if self.gpu_device is not None: self.onnx_sess.set_providers([prvd], [{ "device_id": gpu_device }]) self.p808_onnx_sess.set_providers( [prvd], [{ "device_id": gpu_device }]) def audio_melspec( self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True, ): if self.convert_to_torch: specgram = self.spectrogram(audio) fb = torch.as_tensor( librosa.filters.mel(sr=sr, n_fft=frame_size + 1, n_mels=n_mels).T, dtype=audio.dtype, device=audio.device, ) mel_spec = torch.matmul(specgram.transpose(-1, -2), fb).transpose(-1, -2) if to_db: self.to_db.db_multiplier = math.log10( max(self.to_db.amin, torch.max(mel_spec))) mel_spec = (self.to_db(mel_spec) + 40) / 40 else: mel_spec = librosa.feature.melspectrogram( y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels, ) if to_db: mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 return mel_spec.T def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): flag = not self.convert_to_torch if is_personalized_MOS: p_ovr = poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046], flag) p_sig = poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726], flag) p_bak = poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132], flag) else: p_ovr = poly1d([-0.06766283, 1.11546468, 0.04602535], flag) p_sig = poly1d([-0.08397278, 1.22083953, 0.0052439], flag) p_bak = poly1d([-0.13166888, 1.60915514, -0.39604546], flag) sig_poly = p_sig(sig) bak_poly = p_bak(bak) ovr_poly = p_ovr(ovr) return sig_poly, bak_poly, ovr_poly def __call__(self, aud, input_fs, is_personalized_MOS=False): if self.convert_to_torch: if self.use_gpu: if self.gpu_device is not None: device = f"cuda:{self.gpu_device}" else: device = "cuda" else: device = "cpu" if isinstance(aud, torch.Tensor): aud = aud.to(device=device) else: aud = torch.as_tensor(aud, dtype=torch.float32, device=device) else: aud = (aud.cpu().detach().numpy() if isinstance(aud, torch.Tensor) else aud) if input_fs != SAMPLING_RATE: if self.convert_to_torch: audio = torch.as_tensor( librosa.resample( aud.detach().cpu().numpy(), orig_sr=input_fs, target_sr=SAMPLING_RATE, ), dtype=aud.dtype, device=aud.device, ) else: audio = librosa.resample(aud, orig_sr=input_fs, target_sr=SAMPLING_RATE) else: audio = aud len_samples = int(INPUT_LENGTH * SAMPLING_RATE) while len(audio) < len_samples: if self.convert_to_torch: audio = torch.cat((audio, audio)) else: audio = np.append(audio, audio) num_hops = int(np.floor(len(audio) / SAMPLING_RATE) - INPUT_LENGTH) + 1 hop_len_samples = SAMPLING_RATE predicted_mos_sig_seg_raw = [] predicted_mos_bak_seg_raw = [] predicted_mos_ovr_seg_raw = [] predicted_mos_sig_seg = [] predicted_mos_bak_seg = [] predicted_mos_ovr_seg = [] predicted_p808_mos = [] for idx in range(num_hops): audio_seg = audio[int(idx * hop_len_samples):int((idx + INPUT_LENGTH) * hop_len_samples)] if len(audio_seg) < len_samples: continue if self.convert_to_torch: input_features = audio_seg.float()[None, :] p808_input_features = self.audio_melspec( audio=audio_seg[:-160]).float()[None, :, :] p808_mos = self.p808_model(p808_input_features) mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.primary_model( input_features)[0] else: input_features = np.array(audio_seg).astype("float32")[ np.newaxis, :] p808_input_features = np.array( self.audio_melspec(audio=audio_seg[:-160])).astype( "float32")[np.newaxis, :, :] p808_mos = self.p808_onnx_sess.run( None, {"input_1": p808_input_features})[0][0][0] mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run( None, {"input_1": input_features})[0][0] mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS) predicted_mos_sig_seg_raw.append(mos_sig_raw) predicted_mos_bak_seg_raw.append(mos_bak_raw) predicted_mos_ovr_seg_raw.append(mos_ovr_raw) predicted_mos_sig_seg.append(mos_sig) predicted_mos_bak_seg.append(mos_bak) predicted_mos_ovr_seg.append(mos_ovr) predicted_p808_mos.append(p808_mos) to_array = torch.stack if self.convert_to_torch else np.array return { "OVRL_raw": to_array(predicted_mos_ovr_seg_raw).mean(), "SIG_raw": to_array(predicted_mos_sig_seg_raw).mean(), "BAK_raw": to_array(predicted_mos_bak_seg_raw).mean(), "OVRL": to_array(predicted_mos_ovr_seg).mean(), "SIG": to_array(predicted_mos_sig_seg).mean(), "BAK": to_array(predicted_mos_bak_seg).mean(), "P808_MOS": to_array(predicted_p808_mos).mean(), } ================================================ FILE: wesep/utils/executor.py ================================================ # Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com) # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext import tableprint as tp # if your python version < 3.7 use the below one import torch from wesep.utils.funcs import clip_gradients, compute_fbank, apply_cmvn import random class Executor: def __init__(self): self.step = 0 def train( self, dataloader, models, epoch_iter, optimizers, criterion, schedulers, scaler, epoch, enable_amp, logger, clip_grad=5.0, log_batch_interval=100, device=torch.device("cuda"), se_loss_weight=1.0, multi_task=False, SSA_enroll_prob=0, fbank_args=None, sample_rate=16000, speaker_feat=True ): """Train one epoch""" model = models[0] optimizer = optimizers[0] scheduler = schedulers[0] model.train() log_interval = log_batch_interval accum_grad = 1 losses = [] if isinstance(model, torch.nn.parallel.DistributedDataParallel): model_context = model.join else: model_context = nullcontext with model_context(): for i, batch in enumerate(dataloader): features = batch["wav_mix"] targets = batch["wav_targets"] # embeddings when not joint training, enrollment wavforms # when joint training enroll = batch["spk_embeds"] # spk_lable is an empty list when not joint training # and multi-task spk_label = batch["spk_label"] cur_iter = (epoch - 1) * epoch_iter + i scheduler.step(cur_iter) features = features.float().to(device) # (B,T,F) targets = targets.float().to(device) enroll = enroll.float().to(device) spk_label = spk_label.to(device) with torch.cuda.amp.autocast(enabled=enable_amp): if SSA_enroll_prob > 0: if SSA_enroll_prob > random.random(): with torch.no_grad(): outputs = model(features, enroll) est_speech = outputs[0] self_fbank = est_speech if fbank_args is not None and speaker_feat: self_fbank = compute_fbank( est_speech, **fbank_args, sample_rate=sample_rate) self_fbank = apply_cmvn(self_fbank) outputs = model(features, self_fbank) else: outputs = model(features, enroll) else: outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs] loss = 0 for ii in range(len(criterion)): # se_loss_weight: ([position in outputs[0], [1]], # [weights:[1.0], [0.5]]) for ji in range(len(se_loss_weight[0][ii])): if (multi_task and criterion[ii].__class__.__name__ == "CrossEntropyLoss"): loss += se_loss_weight[1][ii][ji] * ( criterion[ii]( outputs[se_loss_weight[0][ii][ji]], spk_label, ).mean() / accum_grad) continue loss += se_loss_weight[1][ii][ji] * (criterion[ii]( outputs[se_loss_weight[0][ii][ji]], targets).mean() / accum_grad) losses.append(loss.item()) total_loss_avg = sum(losses) / len(losses) # updata the model optimizer.zero_grad() # scaler does nothing here if enable_amp=False scaler.scale(loss).backward() scaler.unscale_(optimizer) clip_gradients(model, clip_grad) scaler.step(optimizer) scaler.update() if (i + 1) % log_interval == 0: logger.info( tp.row( ( "TRAIN", epoch, i + 1, total_loss_avg * accum_grad, optimizer.param_groups[0]["lr"], ), width=10, style="grid", )) if (i + 1) == epoch_iter: break total_loss_avg = sum(losses) / len(losses) return total_loss_avg, 0 def cv( self, dataloader, models, val_iter, criterion, epoch, enable_amp, logger, log_batch_interval=100, device=torch.device("cuda"), ): """Cross validation on""" model = models[0] model.eval() log_interval = log_batch_interval losses = [] with torch.no_grad(): for i, batch in enumerate(dataloader): features = batch["wav_mix"] targets = batch["wav_targets"] enroll = batch["spk_embeds"] features = features.float().to(device) # (B,T,F) targets = targets.float().to(device) enroll = enroll.float().to(device) with torch.cuda.amp.autocast(enabled=enable_amp): outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs] # By default, the first loss is used as the indicator # of the validation set. loss = criterion[0](outputs[0], targets).mean() losses.append(loss.item()) total_loss_avg = sum(losses) / len(losses) if (i + 1) % log_interval == 0: logger.info( tp.row( ("VAL", epoch, i + 1, total_loss_avg, "-"), width=10, style="grid", )) if (i + 1) == val_iter: break return total_loss_avg, 0 ================================================ FILE: wesep/utils/executor_gan.py ================================================ # Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com) # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext import tableprint as tp # if your python version < 3.7 use the below one import torch import torch.nn.functional as F from wesep.utils.funcs import clip_gradients from wesep.utils.score import batch_evaluation, cal_PESQ_norm class ExecutorGAN: def __init__(self): self.step = 0 def train( self, dataloader, models, epoch_iter, optimizers, criterion, schedulers, scaler, epoch, enable_amp, logger, clip_grad=5.0, log_batch_interval=100, device=torch.device("cuda"), se_loss_weight=0.95, gan_loss_weight=0.05, multi_task=False, ): """Train one epoch""" assert (len(models) == len(optimizers) == len(schedulers) == 2), "Currently only support one discriminator" model, discriminator = models optimizer, optimizer_dis = optimizers scheduler, scheduler_dis = schedulers model.train() discriminator.train() log_interval = log_batch_interval accum_grad = 1 losses = [] se_losses = [] dis_losses = [] if isinstance(model, torch.nn.parallel.DistributedDataParallel): model_context = model.join else: model_context = nullcontext with model_context(): for i, batch in enumerate(dataloader): features = batch["wav_mix"] targets = batch["wav_targets"] # embeddings when when not joint training, enrollment # wavforms when joint training enroll = batch["spk_embeds"] # spk_lable is an empty list when not joint training # and multi-task spk_label = batch["spk_label"] one_labels = torch.ones(features.size(0)) cur_iter = (epoch - 1) * epoch_iter + i scheduler.step(cur_iter) scheduler_dis.step(cur_iter) features = features.float().to(device) targets = targets.float().to(device) enroll = enroll.float().to(device) spk_label = spk_label.to(device) one_labels = one_labels.float().to(device) # calculate discriminator loss with torch.cuda.amp.autocast(enabled=enable_amp): outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs] # outputs is a list of tensors, each tensor has shape # (Batch, samples) if multi_task: # remove the predicted spk_label from the outputs list enhanced_wavs = torch.stack(outputs[:-1], dim=0) else: # enhanced_wavs: [N, Batch, samples], N is the number # of output of the model enhanced_wavs = torch.stack(outputs, dim=0) d_loss = self._calculate_discriminator_loss( discriminator, targets, enhanced_wavs.detach(), features.detach(), ) dis_losses.append(d_loss.item()) total_dis_loss_avg = sum(dis_losses) / len(dis_losses) # updata discriminator optimizer_dis.zero_grad() # scaler does nothing here if enable_amp=False scaler.scale(d_loss).backward() scaler.unscale_(optimizer_dis) clip_gradients(discriminator, clip_grad) scaler.step(optimizer_dis) scaler.update() # calculate generator loss with torch.cuda.amp.autocast(enabled=enable_amp): se_loss = 0 for ii in range(len(criterion)): # se_loss_weight[0]: 2-D array,loss_posi; # se_loss_weight[1]: 2-D array,loss_weight. for ji in range(len(se_loss_weight[0][ii])): if multi_task and ii == (len(criterion) - 1): se_loss += se_loss_weight[1][ii][ji] * ( criterion[ii]( outputs[se_loss_weight[0][ii][ji]], spk_label, ).mean() / accum_grad) continue se_loss += se_loss_weight[1][ii][ji] * ( criterion[ii] (outputs[se_loss_weight[0][ii][ji]], targets).mean() / accum_grad) gan_loss = 0 len_output = (len(outputs) - 1 if multi_task else len(outputs)) for j in range(len_output): enhanced_fake_metric = discriminator( targets, outputs[j]) gan_loss += F.mse_loss( enhanced_fake_metric.flatten(), one_labels, ) g_loss = se_loss + gan_loss_weight * gan_loss losses.append(g_loss.item()) se_losses.append(se_loss.item()) total_loss_avg = sum(losses) / len(losses) total_se_loss_avg = sum(se_losses) / len(se_losses) # updata the generator optimizer.zero_grad() # scaler does nothing here if enable_amp=False scaler.scale(g_loss).backward() scaler.unscale_(optimizer) clip_gradients(model, clip_grad) scaler.step(optimizer) scaler.update() if (i + 1) % log_interval == 0: logger.info( tp.row( ( "TRAIN", epoch, i + 1, total_se_loss_avg, total_loss_avg * accum_grad, total_dis_loss_avg * accum_grad, optimizer.param_groups[0]["lr"], ), width=10, style="grid", )) if (i + 1) == epoch_iter: break total_loss_avg = sum(losses) / len(losses) total_dis_loss_avg = sum(dis_losses) / len(dis_losses) return total_loss_avg, total_dis_loss_avg def cv( self, dataloader, models, val_iter, criterion, epoch, enable_amp, logger, log_batch_interval=100, device=torch.device("cuda"), ): """Cross validation on""" assert len(models) == 2, "Currently only support one discriminator" model, discriminator = models model.eval() discriminator.eval() log_interval = log_batch_interval losses = [] se_losses = [] dis_losses = [] with torch.no_grad(): for i, batch in enumerate(dataloader): features = batch["wav_mix"] targets = batch["wav_targets"] enroll = batch["spk_embeds"] one_labels = torch.ones(features.size(0)) features = features.float().to(device) # (B,T,F) targets = targets.float().to(device) enroll = enroll.float().to(device) one_labels = one_labels.float().to(device) with torch.cuda.amp.autocast(enabled=enable_amp): outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs] # calculate discriminator loss d_loss = self._calculate_discriminator_loss( discriminator, targets, outputs[0].unsqueeze(0), features, ) dis_losses.append(d_loss.item()) total_dis_loss_avg = sum(dis_losses) / len(dis_losses) # calculate generator loss with torch.cuda.amp.autocast(enabled=enable_amp): se_loss = criterion[0](outputs[0], targets).mean() enhanced_fake_metric = discriminator(targets, outputs[0]) gan_loss = F.mse_loss( enhanced_fake_metric.flatten(), one_labels, ) g_loss = se_loss + gan_loss losses.append(g_loss.item()) se_losses.append(se_loss.item()) total_loss_avg = sum(losses) / len(losses) total_se_loss_avg = sum(se_losses) / len(se_losses) if (i + 1) % log_interval == 0: logger.info( tp.row( ( "VAL", epoch, i + 1, total_se_loss_avg, total_loss_avg, total_dis_loss_avg, "-", ), width=10, style="grid", )) if (i + 1) == val_iter: break return total_loss_avg, total_dis_loss_avg def mse_loss(self, output, target): return F.mse_loss(output.flatten(), target) def _calculate_discriminator_loss( self, discriminator, clean_wavs, enhanced_wavs, noisy_wavs, ): """Calculate the discriminator loss Args: discriminator (torch.nn.Module): the discriminator model clean_wavs (torch.Tensor): the clean waveforms, [Batch, samples] enhanced_wavs (torch.Tensor): the predicted waveforms, [N, Batch, samples] noisy_wavs (torch.Tensor): the noisy waveforms, [Batch, samples] Returns: torch.Tensor: the discriminator loss """ def calculate_mse_loss(output, target): if target is not None: target = torch.FloatTensor(target).to(device) return self.mse_loss(output, target) return 0 device = clean_wavs.device one_labels = torch.ones(clean_wavs.size(0)).float().to(device) noisy_fake_metric = discriminator(clean_wavs, noisy_wavs) clean_fake_metric = discriminator(clean_wavs, clean_wavs) audio_ref = clean_wavs.detach().cpu().numpy() audio_noisy = noisy_wavs.detach().cpu().numpy() noisy_real_metric = batch_evaluation(cal_PESQ_norm, audio_noisy, audio_ref, parallel=False) loss_d_clean = self.mse_loss(clean_fake_metric, one_labels) loss_d_noisy = calculate_mse_loss(noisy_fake_metric, noisy_real_metric) d_loss = loss_d_clean + loss_d_noisy # unbind enhanced_wavs to get a list of tensors, # each tensor has shape (Batch, samples) enhanced_wavs = torch.unbind(enhanced_wavs, dim=0) for enhanced_wav in enhanced_wavs: enhanced_fake_metric = discriminator(clean_wavs, enhanced_wav) audio_est = enhanced_wav.detach().cpu().numpy() enhanced_real_metric = batch_evaluation(cal_PESQ_norm, audio_est, audio_ref, parallel=False) loss_d_enhanced = calculate_mse_loss(enhanced_fake_metric, enhanced_real_metric) d_loss += loss_d_enhanced return d_loss ================================================ FILE: wesep/utils/file_utils.py ================================================ import collections import math from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import kaldiio import numpy as np import soundfile def read_lists(list_file): """list_file: only 1 column""" lists = [] with open(list_file, "r", encoding="utf8") as fin: for line in fin: lists.append(line.strip()) return lists def read_vec_scp_file(scp_file): """ Read the pre-extracted kaldi-format speaker embeddings. :param scp_file: path to xvector.scp :return: dict {wav_name: embedding} """ samples_dict = {} for key, vec in kaldiio.load_scp_sequential(scp_file): if len(vec.shape) == 1: vec = np.expand_dims(vec, 0) samples_dict[key] = vec return samples_dict def norm_embeddings(embeddings, kaldi_style=True): """ Norm embeddings to unit length :param embeddings: input embeddings :param kaldi_style: if true, the norm should be embedding dimension :return: """ scale = math.sqrt(embeddings.shape[-1]) if kaldi_style else 1.0 if len(embeddings.shape) == 2: return (scale * embeddings.transpose() / np.linalg.norm(embeddings, axis=1)).transpose() elif len(embeddings.shape) == 1: return scale * embeddings / np.linalg.norm(embeddings) def read_label_file(label_file): """ Read the utt2spk file :param label_file: the path to utt2spk :return: dict {wav_name: spk_id} """ labels_dict = {} with open(label_file, "r") as fin: for line in fin: tokens = line.strip().split() labels_dict[tokens[0]] = tokens[1] return labels_dict def load_speaker_embeddings(scp_file, utt2spk_file): """ :param scp_file: :param utt2spk_file: :return: {spk1: [emb1, emb2 ...], spk2: [emb1, emb2...]} """ samples_dict = read_vec_scp_file(scp_file) labels_dict = read_label_file(utt2spk_file) spk2embeds = {} for key, vec in samples_dict.items(): if len(vec.shape) == 1: vec = np.expand_dims(vec, 0) label = labels_dict[key] if label in spk2embeds.keys(): spk2embeds[label].append(vec) else: spk2embeds[label] = [vec] return spk2embeds # ported from # https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py def read_2columns_text(path: Union[Path, str]) -> Dict[str, str]: """Read a text file having 2 columns as dict object. Examples: wav.scp: key1 /some/path/a.wav key2 /some/path/b.wav >>> read_2columns_text('wav.scp') {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} """ data = {} with Path(path).open("r", encoding="utf-8") as f: for linenum, line in enumerate(f, 1): sps = line.rstrip().split(maxsplit=1) if len(sps) == 1: k, v = sps[0], "" else: k, v = sps if k in data: raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") data[k] = v return data # ported from # https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py def read_multi_columns_text( path: Union[Path, str], return_unsplit: bool = False ) -> Tuple[Dict[str, List[str]], Optional[Dict[str, str]]]: """Read a text file having 2 or more columns as dict object. Examples: wav.scp: key1 /some/path/a1.wav /some/path/a2.wav key2 /some/path/b1.wav /some/path/b2.wav /some/path/b3.wav key3 /some/path/c1.wav ... >>> read_multi_columns_text('wav.scp') {'key1': ['/some/path/a1.wav', '/some/path/a2.wav'], 'key2': ['/some/path/b1.wav', '/some/path/b2.wav', '/some/path/b3.wav'], 'key3': ['/some/path/c1.wav']} """ data = {} if return_unsplit: unsplit_data = {} else: unsplit_data = None with Path(path).open("r", encoding="utf-8") as f: for linenum, line in enumerate(f, 1): sps = line.rstrip().split(maxsplit=1) if len(sps) == 1: k, v = sps[0], "" else: k, v = sps if k in data: raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") data[k] = v.split() if v != "" else [""] if return_unsplit: unsplit_data[k] = v return data, unsplit_data # ported from # https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py def soundfile_read( wavs: Union[str, List[str]], dtype=None, always_2d: bool = False, concat_axis: int = 1, start: int = 0, end: int = None, return_subtype: bool = False, ) -> Tuple[np.array, int]: if isinstance(wavs, str): wavs = [wavs] arrays = [] subtypes = [] prev_rate = None prev_wav = None for wav in wavs: with soundfile.SoundFile(wav) as f: f.seek(start) if end is not None: frames = end - start else: frames = -1 if dtype == "float16": array = f.read( frames, dtype="float32", always_2d=always_2d, ).astype(dtype) else: array = f.read(frames, dtype=dtype, always_2d=always_2d) rate = f.samplerate subtype = f.subtype subtypes.append(subtype) if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1: # array: (Time, Channel) array = array[:, None] if prev_wav is not None: if prev_rate != rate: raise RuntimeError( f"{prev_wav} and {wav} have mismatched sampling rate: " f"{prev_rate} != {rate}") dim1 = arrays[0].shape[1 - concat_axis] dim2 = array.shape[1 - concat_axis] if dim1 != dim2: raise RuntimeError( "Shapes must match with " f"{1 - concat_axis} axis, but gut {dim1} and {dim2}") prev_rate = rate prev_wav = wav arrays.append(array) if len(arrays) == 1: array = arrays[0] else: array = np.concatenate(arrays, axis=concat_axis) if return_subtype: return array, rate, subtypes else: return array, rate # ported from # https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py class SoundScpReader(collections.abc.Mapping): """Reader class for 'wav.scp'. Examples: wav.scp is a text file that looks like the following: key1 /some/path/a.wav key2 /some/path/b.wav key3 /some/path/c.wav key4 /some/path/d.wav ... >>> reader = SoundScpReader('wav.scp') >>> rate, array = reader['key1'] If multi_columns=True is given and multiple files are given in one line with space delimiter, and the output array are concatenated along channel direction key1 /some/path/a.wav /some/path/a2.wav key2 /some/path/b.wav /some/path/b2.wav ... >>> reader = SoundScpReader('wav.scp', multi_columns=True) >>> rate, array = reader['key1'] In the above case, a.wav and a2.wav are concatenated. Note that even if multi_columns=True is given, SoundScpReader still supports a normal wav.scp, i.e., a wav file is given per line, but this option is disable by default because dict[str, list[str]] object is needed to be kept, but it increases the required amount of memory. """ def __init__( self, fname, dtype=None, always_2d: bool = False, multi_columns: bool = False, concat_axis=1, ): self.fname = fname self.dtype = dtype self.always_2d = always_2d if multi_columns: self.data, _ = read_multi_columns_text(fname) else: self.data = read_2columns_text(fname) self.multi_columns = multi_columns self.concat_axis = concat_axis def __getitem__(self, key) -> Tuple[int, np.ndarray]: wavs = self.data[key] array, rate = soundfile_read( wavs, dtype=self.dtype, always_2d=self.always_2d, concat_axis=self.concat_axis, ) # Returned as scipy.io.wavread's order return rate, array def get_path(self, key): return self.data[key] def __contains__(self, item): return item def __len__(self): return len(self.data) def __iter__(self): return iter(self.data) def keys(self): return self.data.keys() ================================================ FILE: wesep/utils/funcs.py ================================================ # Created on 2018/12 # Author: Kaituo XU import math import torch import torchaudio.compliance.kaldi as kaldi def overlap_and_add(signal, frame_step): """Reconstructs a signal from a framed representation. Adds potentially overlapping frames of a signal with shape `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. The resulting tensor has shape `[..., output_size]` where output_size = (frames - 1) * frame_step + frame_length Args: signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. Returns: A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. output_size = (frames - 1) * frame_step + frame_length Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/ contrib/signal/python/ops/reconstruction_ops.py """ outer_dimensions = signal.size()[:-2] frames, frame_length = signal.size()[-2:] subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor subframe_step = frame_step // subframe_length subframes_per_frame = frame_length // subframe_length output_size = frame_step * (frames - 1) + frame_length output_subframes = output_size // subframe_length subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) frame = signal.new_tensor(frame).long() # signal may in GPU or CPU frame = frame.contiguous().view(-1) result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) result.index_add_(-2, frame, subframe_signal) result = result.view(*outer_dimensions, -1) return result def remove_pad(inputs, inputs_lengths): """ Args: inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size inputs_lengths: torch.Tensor, [B] Returns: results: a list containing B items, each item is [C, T], T varies """ results = [] dim = inputs.dim() if dim == 3: C = inputs.size(1) for input, length in zip(inputs, inputs_lengths): if dim == 3: # [B, C, T] results.append(input[:, :length].view(C, -1).cpu().numpy()) elif dim == 2: # [B, T] results.append(input[:length].view(-1).cpu().numpy()) return results def clip_gradients(model, clip): norms = [] for _, p in model.named_parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) norms.append(param_norm.item()) clip_coef = clip / (param_norm + 1e-6) if clip_coef < 1: p.grad.data.mul_(clip_coef) return norms def compute_fbank( data, num_mel_bins=80, frame_length=25, frame_shift=10, dither=1.0, sample_rate=16000, ): """Extract fbank""" fbank_list = [] for index_ in range(data.shape[0]): waveform = data[index_, :].unsqueeze(0) waveform = waveform * (1 << 15) mat = kaldi.fbank( waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, sample_frequency=sample_rate, window_type="hamming", use_energy=False, ) fbank_list.append(mat.unsqueeze(0)) np_fbank = torch.cat(fbank_list, 0) return np_fbank def apply_cmvn(data, norm_mean=True, norm_var=False): """Apply CMVN Args: data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] Returns: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] """ mat_list = [] for index_ in range(data.shape[0]): mat = data[index_, :, :] if norm_mean: mat = mat - torch.mean(mat, dim=0) if norm_var: mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8) mat = mat.unsqueeze(0) mat_list.append(mat) np_mat = torch.cat(mat_list, 0) return np_mat if __name__ == "__main__": torch.manual_seed(123) M, C, K, N = 2, 2, 3, 4 frame_step = 2 signal = torch.randint(5, (M, C, K, N)) result = overlap_and_add(signal, frame_step) print(signal) print(result) ================================================ FILE: wesep/utils/losses.py ================================================ import auraloss import torch.nn as nn import torchmetrics.audio as audio_metrics from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio """Get a loss function with its name from the configuration file.""" valid_losses = {} torch_losses = { "L1": nn.L1Loss(), "L2": nn.MSELoss(), "CE": nn.CrossEntropyLoss(), } torchmetrics_losses = { # Not tested "PIT": audio_metrics.PermutationInvariantTraining( scale_invariant_signal_noise_ratio), } auraloss_losses = { "STFT": auraloss.freq.STFTLoss(), "MultiResolutionSTFT": auraloss.freq.MultiResolutionSTFTLoss(), "SISDR": auraloss.time.SISDRLoss(), "SISNR": auraloss.time.SISDRLoss(), "SNR": auraloss.time.SNRLoss(), } valid_losses.update(torch_losses) valid_losses.update(auraloss_losses) valid_losses.update(torchmetrics_losses) def parse_loss(loss): loss_functions = [] if not isinstance(loss, list): loss = [loss] for i in range(len(loss)): loss_name = loss[i] loss_functions.append(valid_losses.get(loss_name)) return loss_functions ================================================ FILE: wesep/utils/schedulers.py ================================================ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) # 2021 Zhengyang Chen (chenzhengyang117@gmail.com) # 2022 Hongji Wang (jijijiang77@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math class MarginScheduler: def __init__( self, model, epoch_iter, increase_start_epoch, fix_start_epoch, initial_margin, final_margin, update_margin, increase_type="exp", ): """ The margin is fixed as initial_margin before increase_start_epoch, between increase_start_epoch and fix_start_epoch, the margin is exponentially increasing from initial_margin to final_margin after fix_start_epoch, the margin is fixed as final_margin. """ self.model = model self.increase_start_iter = (increase_start_epoch - 1) * epoch_iter self.fix_start_iter = (fix_start_epoch - 1) * epoch_iter self.initial_margin = initial_margin self.final_margin = final_margin self.increase_type = increase_type self.fix_already = False self.current_iter = 0 self.update_margin = update_margin and hasattr(self.model.projection, "update") self.increase_iter = self.fix_start_iter - self.increase_start_iter self.init_margin() def init_margin(self): if hasattr(self.model.projection, "update"): self.model.projection.update(margin=self.initial_margin) def get_increase_margin(self): initial_val = 1.0 final_val = 1e-3 current_iter = self.current_iter - self.increase_start_iter if self.increase_type == "exp": # exponentially increase the margin ratio = (1.0 - math.exp( (current_iter / self.increase_iter) * math.log(final_val / (initial_val + 1e-6))) * initial_val) else: # linearly increase the margin ratio = 1.0 * current_iter / self.increase_iter return (self.initial_margin + (self.final_margin - self.initial_margin) * ratio) def step(self, current_iter=None): if not self.update_margin or self.fix_already: return if current_iter is not None: self.current_iter = current_iter if self.current_iter >= self.fix_start_iter: self.fix_already = True if hasattr(self.model.projection, "update"): self.model.projection.update(margin=self.final_margin) elif self.current_iter >= self.increase_start_iter: if hasattr(self.model.projection, "update"): self.model.projection.update(margin=self.get_increase_margin()) self.current_iter += 1 def get_margin(self): try: margin = self.model.projection.margin except Exception: margin = 0.0 return margin class BaseClass: """ Base Class for learning rate scheduler """ def __init__( self, optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch=6, scale_ratio=1.0, warm_from_zero=False, ): """ warm_up_epoch: the first warm_up_epoch is the multiprocess warm-up stage scale_ratio: multiplied to the current lr in the multiprocess training process """ self.optimizer = optimizer self.max_iter = num_epochs * epoch_iter self.initial_lr = initial_lr self.final_lr = final_lr self.scale_ratio = scale_ratio self.current_iter = 0 self.warm_up_iter = warm_up_epoch * epoch_iter self.warm_from_zero = warm_from_zero def get_multi_process_coeff(self): lr_coeff = 1.0 * self.scale_ratio if self.current_iter < self.warm_up_iter: if self.warm_from_zero: lr_coeff = (self.scale_ratio * self.current_iter / self.warm_up_iter) elif self.scale_ratio > 1: lr_coeff = (self.scale_ratio - 1) * self.current_iter / self.warm_up_iter + 1.0 return lr_coeff def get_current_lr(self): """ This function should be implemented in the child class """ return 0.0 def get_lr(self): return self.optimizer.param_groups[0]["lr"] def set_lr(self): current_lr = self.get_current_lr() for param_group in self.optimizer.param_groups: param_group["lr"] = current_lr def step(self, current_iter=None): if current_iter is not None: self.current_iter = current_iter self.set_lr() self.current_iter += 1 def step_return_lr(self, current_iter=None): if current_iter is not None: self.current_iter = current_iter current_lr = self.get_current_lr() self.current_iter += 1 return current_lr def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return { key: value for key, value in self.__dict__.items() if key != "optimizer" } def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) class ExponentialDecrease(BaseClass): def __init__( self, optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch=6, scale_ratio=1.0, warm_from_zero=False, ): super().__init__( optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch, scale_ratio, warm_from_zero, ) def get_current_lr(self): lr_coeff = self.get_multi_process_coeff() current_lr = (lr_coeff * self.initial_lr * math.exp( (self.current_iter / self.max_iter) * math.log(self.final_lr / self.initial_lr))) return current_lr class TriAngular2(BaseClass): """ The implementation of https://arxiv.org/pdf/1506.01186.pdf """ def __init__( self, optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch=6, scale_ratio=1.0, cycle_step=2, reduce_lr_diff_ratio=0.5, ): super().__init__( optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch, scale_ratio, ) self.reduce_lr_diff_ratio = reduce_lr_diff_ratio self.cycle_iter = cycle_step * epoch_iter self.step_size = self.cycle_iter // 2 self.max_lr = initial_lr self.min_lr = final_lr self.gap = self.max_lr - self.min_lr def get_current_lr(self): lr_coeff = self.get_multi_process_coeff() point = self.current_iter % self.cycle_iter cycle_index = self.current_iter // self.cycle_iter self.max_lr = (self.min_lr + self.gap * self.reduce_lr_diff_ratio**cycle_index) if point <= self.step_size: current_lr = (self.min_lr + (self.max_lr - self.min_lr) * point / self.step_size) else: current_lr = (self.max_lr - (self.max_lr - self.min_lr) * (point - self.step_size) / self.step_size) current_lr = lr_coeff * current_lr return current_lr def show_lr_curve(scheduler): import matplotlib.pyplot as plt lr_list = [] for current_lr in range(0, scheduler.max_iter): lr_list.append(scheduler.step_return_lr(current_lr)) data_index = list(range(1, len(lr_list) + 1)) plt.plot(data_index, lr_list, "-o", markersize=1) plt.legend(loc="best") plt.xlabel("Iteration") plt.ylabel("LR") plt.show() if __name__ == "__main__": optimizer = None num_epochs = 6 epoch_iter = 500 initial_lr = 0.6 final_lr = 0.1 warm_up_epoch = 2 scale_ratio = 4 scheduler = ExponentialDecrease( optimizer, num_epochs, epoch_iter, initial_lr, final_lr, warm_up_epoch, scale_ratio, ) # scheduler = TriAngular2(optimizer, # num_epochs, # epoch_iter, # initial_lr, # final_lr, # warm_up_epoch, # scale_ratio, # cycle_step=2, # reduce_lr_diff_ratio=0.5) show_lr_curve(scheduler) ================================================ FILE: wesep/utils/score.py ================================================ import numpy as np from joblib import Parallel, delayed from pesq import pesq from pystoi.stoi import stoi def cal_SISNR(est, ref, eps=1e-8): """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) Args: est: separated signal, numpy.ndarray, [T] ref: reference signal, numpy.ndarray, [T] Returns: SISNR """ assert len(est) == len(ref) est_zm = est - np.mean(est) ref_zm = ref - np.mean(ref) t = np.sum(est_zm * ref_zm) * ref_zm / (np.linalg.norm(ref_zm)**2 + eps) return 20 * np.log10(eps + np.linalg.norm(t) / (np.linalg.norm(est_zm - t) + eps)) def cal_SISNRi(est, ref, mix, eps=1e-8): """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) Args: est: separated signal, numpy.ndarray, [T] ref: reference signal, numpy.ndarray, [T] Returns: SISNR """ assert len(est) == len(ref) == len(mix) sisnr1 = cal_SISNR(est, ref) sisnr2 = cal_SISNR(mix, ref) return sisnr1, sisnr1 - sisnr2 def cal_PESQ(est, ref): assert len(est) == len(ref) mode = "wb" p = pesq(16000, ref, est, mode) return p def cal_PESQ_norm(est, ref): assert len(est) == len(ref) mode = "wb" try: # normalize PESQ to (0, 1) p = (pesq(16000, ref, est, mode) + 0.5) / 5 except Exception: # error can happen due to silent estimated signal p = None return p def cal_PESQi(est, ref, mix): """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) Args: est: separated signal, numpy.ndarray, [T] ref: reference signal, numpy.ndarray, [T] Returns: SISNR """ assert len(est) == len(ref) == len(mix) pesq1 = cal_PESQ(est, ref) pesq2 = cal_PESQ(mix, ref) return pesq1, pesq1 - pesq2 def cal_STOI(est, ref): assert len(est) == len(ref) p = stoi(ref, est, 16000) return p def cal_STOIi(est, ref, mix): """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) Args: est: separated signal, numpy.ndarray, [T] ref: reference signal, numpy.ndarray, [T] Returns: SISNR """ assert len(est) == len(ref) == len(mix) stoi1 = cal_STOI(est, ref) stoi2 = cal_STOI(mix, ref) return stoi1, stoi1 - stoi2 def batch_evaluation(metric, est, ref, lengths=None, parallel=False, n_jobs=8): """Calculate specified evaluation metrics in batches Args: metric (Callable): the function to calculate metric est (np.ndarray): separated signal, numpy.ndarray, [B, T] ref (np.ndarray): reference signal, numpy.ndarray, [B, T] lengths (np.ndarray, optional): specify the length of each signal. Defaults to None. parallel (bool, optional): whether to calculate metric in parallel. Default to False. n_jobs (int, optional): number of jobs, used when `parallel` is True. Defaults to 8. Returns: scores (np.ndarray): batched metrics, [B] """ assert callable(metric) if lengths is not None: assert ((0 < lengths) & (lengths <= 1)).all() lengths = (lengths * est.size(1)).round().int().cpu() est = [p[:length].cpu() for p, length in zip(est, lengths)] ref = [t[:length].cpu() for t, length in zip(ref, lengths)] if parallel: while True: try: scores = Parallel(n_jobs=n_jobs, timeout=30)(delayed(metric)(p, t) for p, t in zip(est, ref)) break except Exception as e: print(e) print("Evaluation timeout...... (will try again)") else: scores = [] for p, t in zip(est, ref): score = metric(p, t) scores.append(score) if None in scores: return None return np.array(scores) ================================================ FILE: wesep/utils/signal.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scipy.signal import get_window def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): """ Return window coefficient """ def sqrthann(win_len): return get_window("hann", win_len, fftbins=True)**0.5 if win_type == "None" or win_type is None: window = np.ones(win_len) elif win_type == "sqrthann": window = sqrthann(win_len) else: window = get_window(win_type, win_len, fftbins=True) # **0.5 N = fft_len fourier_basis = np.fft.rfft(np.eye(N))[:win_len] real_kernel = np.real(fourier_basis) imag_kernel = np.imag(fourier_basis) kernel = np.concatenate([real_kernel, imag_kernel], 1).T if invers: kernel = np.linalg.pinv(kernel).T kernel = kernel * window kernel = kernel[:, None, :] return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( window[None, :, None].astype(np.float32)) class ConvSTFT(nn.Module): def __init__( self, win_len, win_inc, fft_len=None, win_type="hamming", feature_type="real", ): super(ConvSTFT, self).__init__() if fft_len is None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) else: self.fft_len = fft_len kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) self.register_buffer("weight", kernel) self.feature_type = feature_type self.stride = win_inc self.win_len = win_len self.dim = self.fft_len def forward(self, inputs): if inputs.dim() == 2: inputs = torch.unsqueeze(inputs, 1) inputs = F.pad( inputs, [self.win_len - self.stride, self.win_len - self.stride]) outputs = F.conv1d(inputs, self.weight, stride=self.stride) if self.feature_type == "complex": return outputs else: dim = self.dim // 2 + 1 real = outputs[:, :dim, :] imag = outputs[:, dim:, :] mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase class ConviSTFT(nn.Module): def __init__( self, win_len, win_inc, fft_len=None, win_type="hamming", feature_type="real", ): super(ConviSTFT, self).__init__() if fft_len is None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) else: self.fft_len = fft_len kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) self.register_buffer("weight", kernel) self.feature_type = feature_type self.win_type = win_type self.win_len = win_len self.stride = win_inc self.stride = win_inc self.dim = self.fft_len self.register_buffer("window", window) self.register_buffer("enframe", torch.eye(win_len)[:, None, :]) def forward(self, inputs, phase=None): """ inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) phase: [B, N//2+1, T] (if not none) """ if phase is not None: real = inputs * torch.cos(phase) imag = inputs * torch.sin(phase) inputs = torch.cat([real, imag], 1) outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) # this is from torch-stft: https://github.com/pseeth/torch-stft t = self.window.repeat(1, 1, inputs.size(-1))**2 coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) outputs = outputs / (coff + 1e-8) # outputs = torch.where(coff == 0, outputs, outputs/coff) outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)] return outputs ================================================ FILE: wesep/utils/utils.py ================================================ # Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import difflib import logging import os import random import shutil import sys from distutils.util import strtobool from pathlib import Path import numpy as np import torch import torch.distributed as dist import yaml def str2bool(value: str) -> bool: return bool(strtobool(value)) def get_logger(outdir, fname): formatter = logging.Formatter( "[ %(levelname)s : %(asctime)s ] - %(message)s") logging.basicConfig( level=logging.DEBUG, format="[ %(levelname)s : %(asctime)s ] - %(message)s", ) logger = logging.getLogger("Pyobj, f") # Dump log to file fh = logging.FileHandler(os.path.join(outdir, fname)) fh.setFormatter(formatter) logger.addHandler(fh) return logger def setup_logger(rank, exp_dir, device_ids, MAX_NUM_LOG_FILES: int = 100): model_dir = os.path.join(exp_dir, "models") file_name = "train.log" if rank == 0: os.makedirs(model_dir, exist_ok=True) for i in range(MAX_NUM_LOG_FILES - 1, -1, -1): if i == 0: p = Path(os.path.join(exp_dir, file_name)) pn = p.parent / (p.stem + ".1" + p.suffix) else: _p = Path(os.path.join(exp_dir, file_name)) p = _p.parent / (_p.stem + f".{i}" + _p.suffix) pn = _p.parent / (_p.stem + f".{i + 1}" + _p.suffix) if p.exists(): if i == MAX_NUM_LOG_FILES - 1: p.unlink() else: shutil.move(p, pn) dist.barrier(device_ids=[device_ids]) # let the rank 0 mkdir first return get_logger(exp_dir, file_name) def parse_config_or_kwargs(config_file, **kwargs): """parse_config_or_kwargs :param config_file: Config file that has parameters, yaml format :param **kwargs: Other alternative parameters or overwrites for conf """ with open(config_file) as con_read: yaml_config = yaml.load(con_read, Loader=yaml.FullLoader) # values from conf file are all possible params help_str = "Valid Parameters are:\n" help_str += "\n".join(list(yaml_config.keys())) # passed kwargs will override yaml conf # for key in kwargs.keys(): # assert key in yaml_config, "Parameter {} invalid!\n".format(key) # add the path of config file to dict if "config" not in kwargs: kwargs["config"] = config_file return dict(yaml_config, **kwargs) def validate_path(dir_name): """Create the directory if it doesn't exist :param dir_name :return: None """ dir_name = os.path.dirname(dir_name) # get the path if not os.path.exists(dir_name) and (dir_name != ""): os.makedirs(dir_name) def set_seed(seed=42): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def generate_enahnced_scp(directory: str, extension: str = "wav"): source_dir = Path(directory) spk_scp = source_dir.joinpath("spk1.scp") audio_list = [] for file_path in source_dir.rglob(f"*.{extension}"): audio_list.append(file_path) with open(spk_scp, "w") as f: for audio in audio_list: path = str(audio.resolve()) ori_filename = audio.stem spk1_id = ori_filename.split("-")[1] # spk2_id = ori_filename.split("_")[1].split("-")[0] curr_spk = ori_filename.split("T")[1] prefix = "s1" if curr_spk == spk1_id else "s2" f_dash_index = ori_filename.find("-") l_dash_index = ori_filename.rfind("-") filename = ori_filename[f_dash_index + 1:l_dash_index] final_filename = prefix + "/" + filename + ".wav" line = final_filename + " " + path f.write(line + "\n") def get_commandline_args(): # ported from # https://github.com/espnet/espnet/blob/master/espnet/utils/cli_utils.py extra_chars = [ " ", ";", "&", "(", ")", "|", "^", "<", ">", "?", "*", "[", "]", "$", "`", '"', "\\", "!", "{", "}", ] # Escape the extra characters for shell argv = [(arg.replace("'", "'\\''") if all( char not in arg for char in extra_chars) else "'" + arg.replace("'", "'\\''") + "'") for arg in sys.argv] return sys.executable + " " + " ".join(argv) # ported from # https://github.com/espnet/espnet/blob/master/espnet2/utils/config_argparse.py class ArgumentParser(argparse.ArgumentParser): """Simple implementation of ArgumentParser supporting config file This class is originated from https://github.com/bw2/ConfigArgParse, but this class is lack of some features that it has. - Not supporting multiple config files - Automatically adding "--config" as an option. - Not supporting any formats other than yaml - Not checking argument type """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument("--config", help="Give config file in yaml format") def parse_known_args(self, args=None, namespace=None): # Once parsing for setting from "--config" _args, _ = super().parse_known_args(args, namespace) if _args.config is not None: if not Path(_args.config).exists(): self.error(f"No such file: {_args.config}") with open(_args.config, "r", encoding="utf-8") as f: d = yaml.safe_load(f) if not isinstance(d, dict): self.error("Config file has non dict value: {_args.config}") for key in d: for action in self._actions: if key == action.dest: break else: self.error( f"unrecognized arguments: {key} (from {_args.config})") # NOTE(kamo): Ignore "--config" from a config file # NOTE(kamo): Unlike "configargparse", this module doesn't # check type. i.e. We can set any type value # regardless of argument type. self.set_defaults(**d) return super().parse_known_args(args, namespace) def get_layer(l_name, library=torch.nn): """Return layer object handler from library e.g. from torch.nn E.g. if l_name=="elu", returns torch.nn.ELU. Args: l_name (string): Case insensitive name for layer in library (e.g. .'elu'). library (module): Name of library/module where to search for object handler with l_name e.g. "torch.nn". Returns: layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) """ all_torch_layers = list(dir(torch.nn)) match = [x for x in all_torch_layers if l_name.lower() == x.lower()] if len(match) == 0: close_matches = difflib.get_close_matches( l_name, [x.lower() for x in all_torch_layers]) raise NotImplementedError( "Layer with name {} not found in {}.\n Closest matches: {}".format( l_name, str(library), close_matches)) elif len(match) > 1: close_matches = difflib.get_close_matches( l_name, [x.lower() for x in all_torch_layers]) raise NotImplementedError( "Multiple matchs for layer with name {} not found in {}.\n " "All matches: {}".format(l_name, str(library), close_matches)) else: # valid layer_handler = getattr(library, match[0]) return layer_handler # def spk2id(utt_spk_list): # _, spk_list = zip(*utt_spk_list) # spk_list = sorted(list(set(spk_list))) # remove overlap and sort # spk2id_dict = {} # spk_list.sort() # for i, spk in enumerate(spk_list): # spk2id_dict[spk] = i # return spk2id_dict