[
  {
    "path": ".clang-format",
    "content": "---\nLanguage:        Cpp\n# BasedOnStyle:  Google\nAccessModifierOffset: -1\nAlignAfterOpenBracket: Align\nAlignConsecutiveAssignments: false\nAlignConsecutiveDeclarations: false\nAlignEscapedNewlinesLeft: true\nAlignOperands:   true\nAlignTrailingComments: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowShortBlocksOnASingleLine: false\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortFunctionsOnASingleLine: All\nAllowShortIfStatementsOnASingleLine: true\nAllowShortLoopsOnASingleLine: true\nAlwaysBreakAfterDefinitionReturnType: None\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: true\nAlwaysBreakTemplateDeclarations: true\nBinPackArguments: true\nBinPackParameters: true\nBraceWrapping:\n  AfterClass:      false\n  AfterControlStatement: false\n  AfterEnum:       false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  IndentBraces:    false\nBreakBeforeBinaryOperators: None\nBreakBeforeBraces: Attach\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializersBeforeComma: false\nBreakAfterJavaFieldAnnotations: false\nBreakStringLiterals: true\nColumnLimit:     80\nCommentPragmas:  '^ IWYU pragma:'\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDisableFormat:   false\nExperimentalAutoDetectBinPacking: false\nForEachMacros:   [ foreach, Q_FOREACH, BOOST_FOREACH ]\nIncludeCategories:\n  - Regex:           '^<.*\\.h>'\n    Priority:        1\n  - Regex:           '^<.*'\n    Priority:        2\n  - Regex:           '.*'\n    Priority:        3\nIncludeIsMainRegex: '([-_](test|unittest))?$'\nIndentCaseLabels: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nJavaScriptQuotes: Leave\nJavaScriptWrapImports: true\nKeepEmptyLinesAtTheStartOfBlocks: false\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBlockIndentWidth: 2\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: false\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakString: 1000\nPenaltyExcessCharacter: 1000000\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Left\nReflowComments:  true\nSortIncludes:    true\nSpaceAfterCStyleCast: false\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeParens: ControlStatements\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles:  false\nSpacesInContainerLiterals: true\nSpacesInCStyleCastParentheses: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard:        Auto\nTabWidth:        8\nUseTab:          Never\n...\n"
  },
  {
    "path": ".flake8",
    "content": "[flake8]\nselect = B,C,E,F,P,T4,W,B9\nmax-line-length = 80\nmax-doc-length = 80\n# C408 ignored because we like the dict keyword argument syntax\n# E501 is not flexible enough, we're using B950 instead\nignore =\n    E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,\n    # shebang has extra meaning in fbcode lints, so I think it's not worth trying\n    # to line this up with executable bit\n    EXE001,\n    # these ignores are from flake8-bugbear; please fix!\n    B007,B008,B905,\n    # these ignores are from flake8-comprehensions; please fix!\n    C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415\nexclude =\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: Lint\n\non:\n  push:\n    branches:\n    - main\n  pull_request:\n\njobs:\n  quick-checks:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Fetch Wenet\n        uses: actions/checkout@v1\n      - name: Checkout PR tip\n        run: |\n          set -eux\n          if [[ \"${{ github.event_name }}\" == \"pull_request\" ]]; then\n            # We are on a PR, so actions/checkout leaves us on a merge commit.\n            # Check out the actual tip of the branch.\n            git checkout ${{ github.event.pull_request.head.sha }}\n          fi\n          echo ::set-output name=commit_sha::$(git rev-parse HEAD)\n        id: get_pr_tip\n      - name: Ensure no tabs\n        run: |\n          (! git grep -I -l $'\\t' -- . ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo \"The above files have tabs; please convert them to spaces\"; false))\n      - name: Ensure no trailing whitespace\n        run: |\n          (! git grep -I -n $' $' -- . ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo \"The above files have trailing whitespace; please remove them\"; false))\n\n  flake8-py3:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Setup Python\n        uses: actions/setup-python@v1\n        with:\n          python-version: 3.9\n          architecture: x64\n      - name: Fetch Wenet\n        uses: actions/checkout@v1\n      - name: Checkout PR tip\n        run: |\n          set -eux\n          if [[ \"${{ github.event_name }}\" == \"pull_request\" ]]; then\n            # We are on a PR, so actions/checkout leaves us on a merge commit.\n            # Check out the actual tip of the branch.\n            git checkout ${{ github.event.pull_request.head.sha }}\n          fi\n          echo ::set-output name=commit_sha::$(git rev-parse HEAD)\n        id: get_pr_tip\n      - name: Run flake8\n        run: |\n          set -eux\n          pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0\n          flake8 --version\n          flake8\n          if [ $? != 0 ]; then exit 1; fi\n\n  cpplint:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Setup Python\n        uses: actions/setup-python@v1\n        with:\n          python-version: 3.x\n          architecture: x64\n      - name: Fetch Wenet\n        uses: actions/checkout@v1\n      - name: Checkout PR tip\n        run: |\n          set -eux\n          if [[ \"${{ github.event_name }}\" == \"pull_request\" ]]; then\n            # We are on a PR, so actions/checkout leaves us on a merge commit.\n            # Check out the actual tip of the branch.\n            git checkout ${{ github.event.pull_request.head.sha }}\n          fi\n          echo ::set-output name=commit_sha::$(git rev-parse HEAD)\n        id: get_pr_tip\n      - name: Run cpplint\n        run: |\n          set -eux\n          pip install cpplint==1.6.1\n          cpplint --version\n          cpplint --recursive .\n          if [ $? != 0 ]; then exit 1; fi\n\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n*.egg-info\n\n# Visual Studio Code files\n.vscode\n.vs\n\n# PyCharm files\n.idea\nvenv\n\n# Eclipse Project settings\n*.*project\n.settings\n\n# Sublime Text settings\n*.sublime-workspace\n*.sublime-project\n\n# Editor temporaries\n*.swn\n*.swo\n*.swp\n*.swm\n*~\n\n# IPython notebook checkpoints\n.ipynb_checkpoints\n\n# macOS dir files\n.DS_Store\n\nexp\ndata\nraw_wav\ntensorboard\n**/*build*\nwespeaker_models\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n    - id: trailing-whitespace\n  - repo: https://github.com/pre-commit/mirrors-yapf\n    rev: 'v0.32.0'\n    hooks:\n    - id: yapf\n  - repo: https://github.com/pycqa/flake8\n    rev: '3.8.2'\n    hooks:\n    - id: flake8\n  - repo: https://github.com/pre-commit/mirrors-clang-format\n    rev: 'v17.0.6'\n    hooks:\n    - id: clang-format\n  - repo: https://github.com/cpplint/cpplint\n    rev: '1.6.1'\n    hooks:\n    - id: cpplint\n"
  },
  {
    "path": "CPPLINT.cfg",
    "content": "root=runtime\nfilter=-build/c++11\n"
  },
  {
    "path": "README.md",
    "content": "# Wesep\r\n\r\n> We aim to build a toolkit focusing on front-end processing in the cocktail party set up, including target speaker extraction and ~~speech separation (Future work)~~\r\n\r\n\r\n### Install for development & deployment\r\n* Clone this repo\r\n``` sh\r\nhttps://github.com/wenet-e2e/wesep.git\r\n```\r\n\r\n* Create conda env: pytorch version >= 1.12.0 is required !!!\r\n``` sh\r\nconda create -n wesep python=3.9\r\nconda activate wesep\r\nconda install pytorch=1.12.1 torchaudio=0.12.1 cudatoolkit=11.3 -c pytorch -c conda-forge\r\npip install -r requirements.txt\r\npre-commit install  # for clean and tidy code\r\n```\r\n\r\n## The Target Speaker Extraction Task\r\n\r\n> Target speaker extraction (TSE) focuses on isolating the speech of a specific target speaker from overlapped multi-talker speech, which is a typical setup in the cocktail party problem.\r\nWeSep is featured with flexible target speaker modeling, scalable data management, effective on-the-fly data simulation, structured recipes and deployment support.\r\n\r\n<img src=\"resources/tse.png\" width=\"600px\">\r\n\r\n## Features (To Do List)\r\n\r\n- [x] On the fly data simulation\r\n  - [x] Dynamic Mixture simulation\r\n  - [x] Dynamic Reverb simulation\r\n  - [x] Dynamic Noise simulation\r\n- [x] Support time- and frequency- domain models\r\n    - Time-domain\r\n        - [x] conv-tasnet based models\r\n            - [x] Spex+\r\n    - Frequency domain\r\n        - [x] pBSRNN\r\n        - [x] pDPCCN\r\n        - [x] tf-gridnet (Extremely slow, need double check)\r\n- [ ] Training Criteria\r\n    - [x] SISNR loss\r\n    - [x] GAN loss  (Need further investigation)\r\n- [ ] Datasets\r\n  - [x] Libri2Mix (Illustration for pre-mixed speech)\r\n  - [x] VoxCeleb (Illustration for online training)\r\n  - [ ] WSJ0-2Mix\r\n- [ ] Speaker Embedding\r\n  - [x] Wespeaker Intergration\r\n  - [x] Joint Learned Speaker Embedding\r\n  - [x] Different fusion methods\r\n- [ ] Pretrained models\r\n- [ ] CLI Usage\r\n- [x] Runtime\r\n\r\n## Data Pipe Design\r\n\r\nFollowing Wenet and Wespeaker, WeSep organizes the data processing modules as a pipeline of a set of different processors. The following figure shows such a pipeline with essential processors.\r\n\r\n<img src=\"resources/datapipe.png\" width=\"800px\">\r\n\r\n## Discussion\r\n\r\nFor Chinese users, you can scan the QR code on the left to join our group directly. If it has expired, please scan the personal Wechat QR code on the right.\r\n\r\n|<img src='resources/Wechat_group.jpg' style=\" width: 200px; height: 300px;\">|<img src='resources/Wechat.jpg' style=\" width: 200px; height: 300px;\">|\r\n| ---- | ---- |\r\n\r\n\r\n\r\n## Citations\r\nIf you find wespeaker useful, please cite it as\r\n\r\n```bibtex\r\n@inproceedings{wang24fa_interspeech,\r\n  title     = {WeSep: A Scalable and Flexible Toolkit Towards Generalizable Target Speaker Extraction},\r\n  author    = {Shuai Wang and Ke Zhang and Shaoxiong Lin and Junjie Li and Xuefei Wang and Meng Ge and Jianwei Yu and Yanmin Qian and Haizhou Li},\r\n  year      = {2024},\r\n  booktitle = {Interspeech 2024},\r\n  pages     = {4273--4277},\r\n  doi       = {10.21437/Interspeech.2024-1840},\r\n}\r\n```\r\n"
  },
  {
    "path": "examples/librimix/tse/README.md",
    "content": "# Libri2Mix Recipe\n\n\n## Goal of this recipe\nThis recipe aims to illustrate how to use WeSep to perform the target speaker extraction task on a pre-defined training set such as Libri2Mix, the mixtures have been prepared on the disk. If you want to check the online data processing and scale to larger training set, please check the voxceleb1 recipe.\n\n## Difference of V1 and V2\nThe difference between v1 and v2 lies in the approach to speaker modeling.\n\n- v1 outlines a process where WeSpeaker is used to extract embeddings beforehand, which are then saved to disk and sampled during training. This setup allows you to use other toolkits or existing speaker embeddings freely.\n- v2 demonstrates a more integrated approach with the WeSpeaker toolkit (recommended). You can choose to either fix the speaker encoder or train it jointly with the separation backbone.\n\n"
  },
  {
    "path": "examples/librimix/tse/v1/README.md",
    "content": "## Tutorial on LibriMix\n\nIf you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback.\n\nNOTE: WE DON'T RECOMMEND THIS VERSION, IT'S JUST FOR ILLUSTRATING HOW TO USE YOUR OWN EXTRACTOR\n\nYOU NEED TO INSTALL WESPEAKER FIRST, CHECK `https://github.com/wenet-e2e/wespeaker` FOR THE INSTRUCTION\n\n\n### First Experiment\n\nWe provide a recipe `examples/librimix/tse/v1/run.sh` on LibriMix data.\n\nThe recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.\n\n```bash\ncd examples/librimix/tse/v1\nbash run.sh --stage 1 --stop_stage 1\nbash run.sh --stage 2 --stop_stage 2\nbash run.sh --stage 3 --stop_stage 3\nbash run.sh --stage 4 --stop_stage 4\nbash run.sh --stage 5 --stop_stage 5\nbash run.sh --stage 6 --stop_stage 6\n```\n\nYou could also just run the whole script\n```bash\nbash run.sh --stage 1 --stop_stage 6\n```\n\n------\n\n### Stage 1: Prepare Training Data\n\nPrior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`.\n\nAs the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`:\n\n+ `fs`: the sample rate of the dataset, valid options are `16k` and `8k`.\n+ `min_max`: the mode of mixtures, valiad options are `min` and `max`.\n+ `noise_type`: the type of mixture, valiad options are `clean` and `both`.\n\nIn our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows:\n\n``` bash\nfs=16k\nmin_max=min\nnoise_type=\"clean\"\nLibri2Mix_dir=/path/to/Libri2Mix\n```\n\nAfter configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location. However, you can customize the `data` variable in `enh.sh` to save these files in any desired location.\n\n```bash\ndata=data # you can change this to any directory\n```\n\nIn this stage, `local/prepare_data.sh`accomplishes three tasks:\n\n1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`, each containing the following files:\n\n    + `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id`\n\n        ```text\n        s1/103-1240-0003_1235-135887-0017.wav 103\n        s1/103-1240-0004_4195-186237-0003.wav 103\n        ...\n        ```\n\n    + `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`.\n\n        ```\n        103-1240-0003_1235-135887-0017 103 1235\n        103-1240-0004_4195-186237-0003 103 4195\n        ...\n        ```\n\n    + `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path`\n\n        ```\n        s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav\n        s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav\n        ...\n        ```\n\n    + `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`.\n\n        ```\n        103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav\n        103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav\n        ...\n        ```\n\n2. Prepare the speaker embeddings using wespeaker pretrained models. This step will generate two files in the `dev`, `test`, and `train_100` directories respectively:\n\n    + `embed.ark`: Kaldi ark file that stores the speaker embeddings.\n\n    + `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path`\n\n        ```\n        s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569\n        s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715\n        ...\n        ```\n\n3. Prepare LibriMix target-speaker enroll signal. This step will generate four files in the `dev` and `test` directories respectively:\n\n    + `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`.\n\n        ```\n        4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020\n        4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014\n        ...\n        ```\n\n    + `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.\n\n        ```\n        1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav\n        1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav\n        ...\n        ```\n\n    + `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.\n\n        ```\n        1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav\n        1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav\n        ...\n        ```\n\n    + `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`.\n\n        ```\n        \"652\": [[\"652-129742-0010\", \"/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav\"],\n        ...,\n        [\"652-129742-0000\", \"/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav\"]],\n        ...\n        ```\nAt the end of this stage, the directory structure of `data` should look like this:\n\n```\ndata/\n|__ clean/ # the noise_type you chose\n    |__ dev/\n    |   |__ embed.ark\n    |   |__ embed.scp\n    |   |__ mixture2enrollment\n    |   |__ single.utt2spk\n    |   |__ single.wav.scp\n    |   |__ spk1.enroll\n    |   |__ spk2.enroll\n    |   |__ spk2enroll.json # empty\n    |   |__ utt2spk\n    |   |__ wav.scp\n    |\n    |__ test/ # the same as dev/\n    |\n    |__ train_100/\n        |__ embed.ark\n        |__ embed.scp\n        |__ single.utt2spk\n        |__ single.wav.scp\n        |__ utt2spk\n        |__ wav.scp\n```\n\n------\n\n### Stage 2: Convert Data Format\n\nThis stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the  `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet.\n\nThis stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively:\n\n+ `shards/`: this directory stores the compressed packets (tar) files.\n\n    ```bash\n    ls shards\n    shards_000000000.tar  shards_000000001.tar  shards_000000002.tar ...\n    ```\n\n+ `shard.list`: each line records the path to the corresponding tar file.\n\n    ```\n    data/clean/dev/shards/shards_000000000.tar\n    data/clean/dev/shards/shards_000000001.tar\n    data/clean/dev/shards/shards_000000002.tar\n    ...\n    ```\n\nAt the end of this stage, the directory structure of `data` should look like this:\n\n```\ndata/\n|__ clean/ # the noise_type you chose\n    |__ dev/\n    |   |__ embed.ark, embed.scp, ... # files generated by Stage 1\n    |   |__ shard.list\n    |   |__ shards/\n    |       |__ shards_000000000.tar\n    |       |__ shards_000000001.tar\n    |       |__ shards_000000002.tar\n    |\n    |__ test/ # the same as dev/\n    |\n    |__ train_100/\n        |__ embed.ark, embed.scp, ... # files generated by Stage 1\n        |__ shard.list\n        |__ shards/\n            |__ shards_000000000.tar\n            |__ ...\n            |__ shards_000000013.tar\n```\n\n------\n\n### Stage 3: Neural Networking Training\n\nYou can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information:\n\n+ **overall training process related**\n\n    ```yaml\n    seed: 42\n    exp_dir: exp/BSRNN\n    enable_amp: false\n    gpus: '0,1'\n    log_batch_interval: 100\n    save_epoch_interval: 1\n    joint_training: false\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `seed`: specify a random seed.\n    + `exp_dir`: specify the experiment directory.\n    + `enable_amp`: whether enable automatic mixed precision.\n    + `gpus`: specify the visible GPUs during training.\n    + `log_batch_interval`: specify after how many batch iterations to record in the log.\n    + `save_epoch_interval`: specify after how many batch epoches to save a checkpoint.\n    + `joint_training`: specify whether the model for extracting speaker embeddings is jointly trained with the TSE model. Defaluts to `false`.\n\n+ **dataset and dataloader realted**\n\n    ```yaml\n    dataset_args:\n      resample_rate: 16000\n      sample_num_per_epoch: 0\n      shuffle: true\n      shuffle_args:\n        shuffle_size: 2500\n      whole_utt: false\n      chunk_len: 48000\n      online_mix: false\n      data_type: \"shard\"\n      train_data: \"data/clean/train_100/shard.list\"\n      train_spk_embeds: \"data/clean/train_100/embed.scp\"\n      train_utt2spk: \"data/clean/train_100/single.utt2spk\"\n      train_spk2utt: \"data/clean/train_100/spk2enroll.json\"\n      val_data: \"data/clean/dev/shard.list\"\n      val_spk_embeds: \"data/clean/dev/embed.scp\"\n      val_utt2spk: \"data/clean/dev/single.utt2spk\"\n      val_spk1_enroll: \"data/clean/dev/spk1.enroll\"\n      val_spk2_enroll: \"data/clean/dev/spk2.enroll\"\n      val_spk2utt: \"data/clean/dev/single.wav.scp\"\n\n\n    dataloader_args:\n      batch_size: 16  # A800\n      drop_last: true\n      num_workers: 6\n      pin_memory: false\n      prefetch_factor: 6\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`.\n    + `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set.\n    + `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`.\n    + `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`.\n    + `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`.\n    + `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`.\n    + `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`.\n    + `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.\n    + `train_data`: File containing paths to the training set files.\n    + `train_spk_embeds`: File containing paths to the speaker embeddings of training set.\n    + `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk`  generated in Stage 1).\n    + `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`.\n    + `val_data`: File containing paths to the validation set files.\n    + `val_spk_embeds`: Similiar to `train_spk_embeds`.\n    + `val_utt2spk`: Similiar to `train_utt2spk`.\n    + `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space.\n    + `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space.\n    + `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp`  generated in Stage 1).\n        + We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`.\n        + Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples.\n    + `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario.\n    + `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.\n    + `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.\n    + `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them.\n    + `prefetch_factor`: number of batches loaded in advance by each worker.\n\n+ **loss function related**\n\n    ```yaml\n    loss: SISDR\n    loss_args: { }\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `loss`: the loss function used for training.\n    + `loss_args`: the required arguments for the loss function.\n\n    In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in  `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`.\n\n    Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter.\n\n+ **neural network structure related**\n\n    ```yaml\n    model:\n      tse_model: BSRNN\n    model_args:\n      tse_model:\n        sr: 16000\n        win: 512\n        stride: 128\n        feature_dim: 128\n        num_repeat: 6\n        spk_emb_dim: 256\n        spk_fuse_type: 'multiply'\n        use_spk_transform: False\n\n    model_init:\n      tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt\n      discriminator: null\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `model`: specify the neural network used for training.\n    + `model_args`: specify model-specific parameters.\n    + `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`.\n\n+ **model optimization related**\n\n    ```yaml\n    num_epochs: 150\n    clip_grad: 5.0\n\n    optimizer:\n      tse_model: Adam\n    optimizer_args:\n      tse_model:\n        lr: 0.001\n        weight_decay: 0.0001\n\n    scheduler:\n      tse_model: ExponentialDecrease\n    scheduler_args:\n      tse_model:\n        final_lr: 2.5e-05\n        initial_lr: 0.001\n        warm_from_zero: false\n        warm_up_epoch: 0\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `num_epochs`: total number of training epochs.\n    + `clip_grad`: set the threshold for gradient clipping.\n    + `optimizer`: set the optimizer.\n    + `optimizer_args`: the required arguments for optimizer.\n    + `scheduler`: set the scheduler.\n    + `scheduler_args`: the required arguments for scheduler.\n\n+ **others**\n\n    ```yaml\n    num_avg: 2\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `num_avg`: numbers for averaged model.\n\nTo avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file:\n\n```bash\n  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \\\n    ${train_script} --config confs/config.yaml \\\n    --gpus \"[0]\" \\\n```\n\nAt the end of this stage, an experiment directory will be created in the current directory, containing the following files:\n\n```\n${exp_dir}/\n|__ train.log\n|__ config.yaml\n|__ models/\n    |__ checkpoint_1.pt\n    |__ ...\n    |__ checkpoint_150.pt\n    |__ final_checkpoint.pt -> checkpoint_150.pt\n    |__ latest_checkpoint.pt -> checkpoint_150.pt\n```\n\n------\n\n### Stage 4: Extract Speech Using the Trained Model\n\nAfter training is complete, you can execute stage 4 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script:\n\n+ `config`: the configuration file used in Stage 3.\n+ `fs`: the sample rate of the audio data.\n+ `gpus`: the index of the visible GPU.\n+ `exp_dir`: the experiment directory.\n+ `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.\n+ `test_data`: similiar to `train_data`.\n+ `test_spk_embeds`: similiar to `train_spk_embeds`.\n+ `test_spk1_enroll`: similiar to `dev_spk1_enroll`.\n+ `test_spk2_enroll`: similiar to `dev_spk2_enroll`.\n+ `test_spk2utt`: similiar to `dev_spk2utt`.\n+ `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech.\n\nAt the end of this stage, the structure of  the experiment directory should look like this:\n\n```\n${exp_dir}/\n|__ train.log\n|__ config.yaml\n|__ models/\n|__ infer.log\n|__ audio/\n    |__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path`\n    |__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav\n    |__ ...\n    |__ Utt999-61-70968-0003_2830-3980-0008-T61.wav\n```\n\n------\n\n### Stage 5: Scoring\n\nIn this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment.\n\nAt the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this:\n\n```\nexp/BSRNN/\n|__ ${exp_dir}\n|    |__ train.log, ... # files and directories generated in Stage 4\n|    |__ scoring/\n|\n|__ RESULTS.md\n```\n\n------\n\n### Stage 6: Apply Model Average\n\nIn this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`:\n\n+ `dst_model`: the path to save the averaged model.\n+ `src_path`: source models path for average.\n+ `num`: number of source models for the averaged model.\n+ `mode`: the mode for model averaging. Validate options are `final` and `best`.\n    + `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames.\n    + `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models.\n+ `epochs`: this parameter only takes effect when `mode`  is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models.\n"
  },
  {
    "path": "examples/librimix/tse/v1/confs/bsrnn.yaml",
    "content": "dataloader_args:\n  batch_size: 16  # A800: 16\n  drop_last: true\n  num_workers: 6\n  pin_memory: false\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n\nenable_amp: false\nexp_dir: exp/BSRNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\nmodel:\n  tse_model: BSRNN\nmodel_args:\n  tse_model:\n    sr: 16000\n    win: 512\n    stride: 128\n    feature_dim: 128\n    num_repeat: 6\n    spk_emb_dim: 256\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    multi_fuse: False        # Multi speaker fuse with seperation modules\n    joint_training: False\n\nmodel_init:\n  tse_model: null\n  discriminator: null\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v1/confs/dpcc_init_gan.yaml",
    "content": "use_metric_loss: true\n\ndataloader_args:\n  batch_size: 4\n  drop_last: true\n  num_workers: 4\n  pin_memory: false\n  prefetch_factor: 4\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n\nenable_amp: false\nexp_dir: exp/DPCNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISNR\nloss_args: { }\ngan_loss_weight: 0.05\n\nmodel:\n  tse_model: DPCCN\n  discriminator: CMGAN_Discriminator\nmodel_args:\n  tse_model:\n    win: 512\n    stride: 128\n    feature_dim: 257\n    tcn_blocks: 10\n    tcn_layers: 2\n    spk_emb_dim: 256\n    causal: False\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n  discriminator: {}\n\nmodel_init:\n  tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt\n  discriminator: null\nnum_avg: 5\nnum_epochs: 50\n\noptimizer:\n  tse_model: Adam\n  discriminator: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.0001\n    weight_decay: 0.0001\n  discriminator:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 3.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\n  discriminator: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.0001\n    warm_from_zero: false\n    warm_up_epoch: 0\n  discriminator:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v1/confs/dpccn.yaml",
    "content": "dataloader_args:\n  batch_size: 4\n  drop_last: true\n  num_workers: 4\n  pin_memory: false\n  prefetch_factor: 4\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n\nenable_amp: false\nexp_dir: exp/DPCNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\nmodel:\n  tse_model: DPCCN\nmodel_args:\n  tse_model:\n    win: 512\n    stride: 128\n    feature_dim: 257\n    tcn_blocks: 10\n    tcn_layers: 2\n    spk_emb_dim: 256\n    causal: False\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    joint_training: False\n\nmodel_init:\n  tse_model: null\n  discriminator: null\nnum_avg: 5\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 3.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v1/confs/tfgridnet.yaml",
    "content": "dataloader_args:\n  batch_size: 4\n  drop_last: true\n  num_workers: 4\n  pin_memory: false\n  prefetch_factor: 4\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 16000\n\nenable_amp: false\nexp_dir: exp/TFGridNet\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SI_SNR\nloss_args: { }\n\nmodel:\n  tse_model: TFGridNet\nmodel_args:\n  tse_model:\n    n_srcs: 1\n    n_fft: 128\n    stride: 64\n    window: \"hann\"\n    n_imics: 1\n    n_layers: 6\n    lstm_hidden_units: 192\n    attn_n_head: 4\n    attn_approx_qk_dim: 512\n    emb_dim: 128\n    emb_ks: 1\n    emb_hs: 1\n    activation: \"prelu\"\n    eps: 1.0e-5\n    spk_emb_dim: 256\n    use_spk_transform: False\n    spk_fuse_type: \"multiply\"\n    joint_training: False\n\n\nmodel_init:\n  tse_model: null\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v1/local/prepare_data.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nmix_data_path='/Data/Libri2Mix/wav16k/min/'\n\ndata=data\nnoise_type=clean\nnum_spk=2\n\n. tools/parse_options.sh || exit 1\n\ndata=$(realpath ${data})\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare the meta files for the datasets\"\n\n  for dataset in dev test train-100; do\n    echo \"Preparing files for\" $dataset\n\n    # Prepare the meta data for the mixed data\n    dataset_path=$mix_data_path/$dataset/mix_${noise_type}\n    mkdir -p \"${data}\"/$noise_type/${dataset}\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print $NF}' |\n      awk -v path=\"${dataset_path}\" '{print $1 , path \"/\" $1 , path \"/../s1/\" $1 , path \"/../s2/\" $1}' |\n      sed 's#.wav##' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/wav.scp |\n      awk -F[_-] '{print $0, $1,$4}' >\"${data}\"/$noise_type/${dataset}/utt2spk\n\n    # Prepare the meta data for single speakers\n    dataset_path=$mix_data_path/$dataset/s1\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s1/\" $NF, $0}' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/single.wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's1' |\n      awk -F[-_/] '{print $0, $2}' >\"${data}\"/$noise_type/${dataset}/single.utt2spk\n\n    dataset_path=$mix_data_path/$dataset/s2\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s2/\" $NF, $0}' | sort -k1,1 >>\"${data}\"/$noise_type/${dataset}/single.wav.scp\n\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's2' |\n      awk -F[-_/] '{print $0, $5}' >>\"${data}\"/$noise_type/${dataset}/single.utt2spk\n  done\nfi\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"Prepare the speaker embeddings using wespeaker pretrained models\"\n  mkdir wespeaker_resnet34\n  wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.zip\n  unzip voxceleb_resnet34_LM.zip -d wespeaker_resnet34\n  mv wespeaker_resnet34/voxceleb_resnet34_LM.yaml wespeaker_resnet34/config.yaml\n  mv wespeaker_resnet34/voxceleb_resnet34_LM.pt wespeaker_resnet34/avg_model.pt\n  for dataset in dev test train-100; do\n    mkdir -p \"${data}\"/$noise_type/${dataset}\n    echo \"Preparing files for\" $dataset\n    wespeaker --task embedding_kaldi \\\n              --wav_scp \"${data}\"/$noise_type/${dataset}/single.wav.scp \\\n              --output_file \"${data}\"/$noise_type/${dataset}/embed \\\n              -p wespeaker_resnet34 \\\n              --device cuda:0 # GPU idx\n  done\nfi\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  echo \"stage 3: Prepare LibriMix target-speaker enroll signal\"\n\n  for dset in dev test train-100; do\n    python local/prepare_spk2enroll_librispeech.py \\\n      \"${mix_data_path}/${dset}\" \\\n      --is_librimix True \\\n      --outfile \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --audio_format wav\n  done\n\n  for dset in dev test; do\n    if [ $num_spk -eq 2 ]; then\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    else\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    fi\n\n    output_file=\"${data}/${noise_type}/${dset}/mixture2enrollment\"\n    wget -O \"$output_file\" \"$url\"\n  done\n\n  for dset in dev test; do\n    python local/prepare_librimix_enroll.py \\\n      \"${data}\"/$noise_type/${dset}/wav.scp \\\n      \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --mix2enroll \"${data}/${noise_type}/${dset}/mixture2enrollment\" \\\n      --num_spk ${num_spk} \\\n      --train False \\\n      --output_dir \"${data}\"/${noise_type}/${dset} \\\n      --outfile_prefix \"spk\"\n  done\nfi\n"
  },
  {
    "path": "examples/librimix/tse/v1/local/prepare_librimix_enroll.py",
    "content": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.utils.utils import str2bool\n\n\ndef prepare_librimix_enroll(wav_scp,\n                            spk2utts,\n                            output_dir,\n                            num_spk=2,\n                            train=True,\n                            prefix=\"enroll_spk\"):\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    with Path(spk2utts).open(\"r\", encoding=\"utf-8\") as f:\n        # {spkID: [(uid1, path1), (uid2, path2), ...]}\n        spk2utt = json.load(f)\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            uttIDs = mixtureID.split(\"_\")\n            for spk in range(num_spk):\n                uttID = uttIDs[spk]\n                spkID = uttID.split(\"-\")[0]\n                if train:\n                    # For training, we choose the auxiliary signal on the fly.\n                    # Here we use the pattern f\"*{uttID} {spkID}\".\n                    writer[f\"{prefix}{spk + 1}.enroll\"][\n                        mixtureID] = f\"*{uttID} {spkID}\"\n                else:\n                    enrollID = random.choice(spk2utt[spkID])[1]\n                    while enrollID == uttID and len(spk2utt[spkID]) > 1:\n                        enrollID = random.choice(spk2utt[spkID])[1]\n                    writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = enrollID\n\n\ndef prepare_librimix_enroll_v2(wav_scp,\n                               map_mix2enroll,\n                               output_dir,\n                               num_spk=2,\n                               prefix=\"spk\"):\n    # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    mix2enroll = {}\n    with open(map_mix2enroll) as f:\n        for line in f:\n            mix_id, utt_id, enroll_id = line.strip().split()\n            sid = mix_id.split(\"_\").index(utt_id) + 1\n            mix2enroll[mix_id, f\"s{sid}\"] = enroll_id\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            for spk in range(num_spk):\n                enroll_id = mix2enroll[mixtureID, f\"s{spk + 1}\"]\n                writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = (enroll_id +\n                                                                  \".wav\")\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"wav_scp\",\n        type=str,\n        help=\"Path to the wav.scp file\",\n    )\n    parser.add_argument(\"spk2utts\",\n                        type=str,\n                        help=\"Path to the json file containing mapping \"\n                        \"from speaker ID to utterances\")\n    parser.add_argument(\n        \"--num_spk\",\n        type=int,\n        default=2,\n        choices=(2, 3),\n        help=\"Number of speakers in each mixture sample\",\n    )\n    parser.add_argument(\n        \"--train\",\n        type=str2bool,\n        default=True,\n        help=\"Whether is the training set or not\",\n    )\n    parser.add_argument(\n        \"--mix2enroll\",\n        type=str,\n        default=None,\n        help=\"Path to the downloaded map_mixture2enrollment file. \"\n        \"If `train` is False, this value is required.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"Random seed\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        required=True,\n        help=\"Path to the directory for storing output files\",\n    )\n    parser.add_argument(\n        \"--outfile_prefix\",\n        type=str,\n        default=\"spk\",\n        help=\"Prefix of the output files\",\n    )\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n\n    if args.train:\n        prepare_librimix_enroll(\n            args.wav_scp,\n            args.spk2utts,\n            args.output_dir,\n            num_spk=args.num_spk,\n            train=args.train,\n            prefix=args.outfile_prefix,\n        )\n    else:\n        prepare_librimix_enroll_v2(\n            args.wav_scp,\n            args.mix2enroll,\n            args.output_dir,\n            num_spk=args.num_spk,\n            prefix=args.outfile_prefix,\n        )\n"
  },
  {
    "path": "examples/librimix/tse/v1/local/prepare_spk2enroll_librispeech.py",
    "content": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.utils import str2bool\n\n\ndef get_spk2utt(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in Path(path).rglob(\"*.{}\".format(audio_format)):\n            readerID = audio.parent.parent.stem\n            uid = audio.stem\n            assert uid.split(\"-\")[0] == readerID, audio\n            spk2utt[readerID].append((uid, str(audio)))\n\n    return spk2utt\n\n\ndef get_spk2utt_librimix(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in chain(\n                Path(path).rglob(\"s1/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s2/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s3/*.{}\".format(audio_format)),\n        ):\n            spk_idx = int(audio.parent.stem[1:]) - 1\n            mix_uid = audio.stem\n            uid = mix_uid.split(\"_\")[spk_idx]\n            sid = uid.split(\"-\")[0]\n            spk2utt[sid].append((uid, str(audio)))\n\n    return spk2utt\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"audio_paths\",\n        type=str,\n        nargs=\"+\",\n        help=\"Paths to Librispeech subsets\",\n    )\n    parser.add_argument(\n        \"--is_librimix\",\n        type=str2bool,\n        default=False,\n        help=\"Whether the provided audio_paths points to LibriMix data\",\n    )\n    parser.add_argument(\n        \"--outfile\",\n        type=str,\n        default=\"spk2utt_tse.json\",\n        help=\"Path to the output spk2utt json file\",\n    )\n    parser.add_argument(\"--audio_format\", type=str, default=\"flac\")\n    args = parser.parse_args()\n\n    if args.is_librimix:\n        # use clean sources from LibriMix as enrollment\n        spk2utt = get_spk2utt_librimix(args.audio_paths,\n                                       audio_format=args.audio_format)\n    else:\n        # use Librispeech as enrollment\n        spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)\n    outfile = Path(args.outfile)\n    outfile.parent.mkdir(parents=True, exist_ok=True)\n    with outfile.open(\"w\", encoding=\"utf-8\") as f:\n        json.dump(spk2utt, f, indent=4)\n"
  },
  {
    "path": "examples/librimix/tse/v1/path.sh",
    "content": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHONIOENCODING=UTF-8\nexport PYTHONPATH=../../../../:$PYTHONPATH\n"
  },
  {
    "path": "examples/librimix/tse/v1/run.sh",
    "content": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\n# General configuration\nstage=-1\nstop_stage=1\n\n# Data preparation related\ndata=data\nfs=16k\nmin_max=min\nnoise_type=\"clean\"\ndata_type=\"shard\" # shard/raw\nLibri2Mix_dir=/YourPath/librimix/Libri2Mix\nmix_data_path=\"${Libri2Mix_dir}/wav${fs}/${min_max}\"\n\n# Training related\ngpus=\"[0,1]\"\nuse_gan_loss=false\nconfig=confs/bsrnn.yaml\nexp_dir=exp/BSRNN/resnet34-pre_extract-multiply_fuse\n\nif [ -z \"${config}\" ] && [ -f \"${exp_dir}/config.yaml\" ]; then\n  config=\"${exp_dir}/config.yaml\"\nfi\n\n# TSE model initialization related\ncheckpoint=\nif [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/latest_checkpoint.pt\" ]; then\n  checkpoint=\"${exp_dir}/models/latest_checkpoint.pt\"\nfi\n\n# Inferencing and scoring related\nuse_pesq=true\nuse_dnsmos=true\ndnsmos_use_gpu=true\n\n# Model average related\nnum_avg=2\n\n. tools/parse_options.sh || exit 1\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare datasets ...\"\n  ./local/prepare_data.sh --mix_data_path ${mix_data_path} \\\n    --data ${data} \\\n    --noise_type ${noise_type} \\\n    --stage 2 \\\n    --stop-stage 2\nfi\n\ndata=${data}/${noise_type}\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"Covert train and test data to ${data_type}...\"\n  for dset in train-100 dev test; do\n    #  for dset in train-360; do\n    python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \\\n      --num_threads 16 \\\n      --prefix shards \\\n      --shuffle \\\n      ${data}/$dset/wav.scp ${data}/$dset/utt2spk \\\n      ${data}/$dset/shards ${data}/$dset/shard.list\n  done\nfi\n\nnum_gpus=$(echo $gpus | awk -F ',' '{print NF}')\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  echo \"Start training ...\"\n  if ${use_gan_loss}; then\n    train_script=wesep/bin/train_gan.py\n  else\n    train_script=wesep/bin/train.py\n  fi\n  export OMP_NUM_THREADS=8\n  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \\\n    ${train_script} --config $config \\\n    --exp_dir ${exp_dir} \\\n    --gpus $gpus \\\n    --num_avg ${num_avg} \\\n    --data_type \"${data_type}\" \\\n    --train_data ${data}/train-100/${data_type}.list \\\n    --train_spk_embeds ${data}/train-100/embed.scp \\\n    --train_utt2spk ${data}/train-100/single.utt2spk \\\n    --train_spk2utt ${data}/train-100/spk2enroll.json \\\n    --val_data ${data}/dev/${data_type}.list \\\n    --val_spk_embeds ${data}/dev/embed.scp \\\n    --val_utt2spk ${data}/dev/single.utt2spk \\\n    --val_spk1_enroll ${data}/dev/spk1.enroll \\\n    --val_spk2_enroll ${data}/dev/spk2.enroll \\\n    --val_spk2utt ${data}/dev/single.wav.scp \\\n    ${checkpoint:+--checkpoint $checkpoint}\nfi\n\n# shellcheck disable=SC2215\nif [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then\n  echo \"Start inferencing ...\"\n  python wesep/bin/infer.py --config $config \\\n    --fs ${fs} \\\n    --gpus 0 \\\n    --exp_dir ${exp_dir} \\\n    --data_type \"${data_type}\" \\\n    --test_data ${data}/test/${data_type}.list \\\n    --test_spk_embeds ${data}/test/embed.scp \\\n    --test_spk1_enroll ${data}/test/spk1.enroll \\\n    --test_spk2_enroll ${data}/test/spk2.enroll \\\n    --test_spk2utt ${data}/test/single.wav.scp \\\n    ${checkpoint:+--checkpoint $checkpoint}\nfi\n\nif [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then\n  echo \"Start scoring ...\"\n  ./tools/score.sh --dset \"${data}/test\" \\\n    --exp_dir \"${exp_dir}\" \\\n    --fs ${fs} \\\n    --use_pesq \"${use_pesq}\" \\\n    --use_dnsmos \"${use_dnsmos}\" \\\n    --dnsmos_use_gpu \"${dnsmos_use_gpu}\" \\\n    --n_gpu \"${num_gpus}\"\nfi\n\nif [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then\n  echo \"Do model average ...\"\n  avg_model=$exp_dir/models/avg_best_model.pt\n  python wesep/bin/average_model.py \\\n    --dst_model $avg_model \\\n    --src_path $exp_dir/models \\\n    --num ${num_avg} \\\n    --mode best \\\n    --epochs \"138,141\"\nfi\n"
  },
  {
    "path": "examples/librimix/tse/v2/README.md",
    "content": "## Tutorial on LibriMix\n\nIf you meet any problems when going through this tutorial, please feel free to ask in github issues. Thanks for any kind of feedback.\n\n\n### First Experiment\n\nWe provide a recipe `examples/librimix/tse/v2/run.sh` on LibriMix data.\n\nThe recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.\n\n```bash\ncd examples/librimix/tse/v2\nbash run.sh --stage 1 --stop_stage 1\nbash run.sh --stage 2 --stop_stage 2\nbash run.sh --stage 3 --stop_stage 3\nbash run.sh --stage 4 --stop_stage 4\nbash run.sh --stage 5 --stop_stage 5\nbash run.sh --stage 6 --stop_stage 6\n```\n\nYou could also just run the whole script\n```bash\nbash run.sh --stage 1 --stop_stage 6\n```\n\n------\n\n### Stage 1: Prepare Training Data\n\nPrior to executing this phase, we assume that you have locally stored or can access the LibriMix dataset and you should assign the data path to `Libri2Mix_dir`.\n\nAs the LibriMix dataset is available in multiple versions, each determined by factors like the number of sources in the mixtures and the sampling rate, you can choose the desired version by adjusting the following variables in `run.sh`:\n\n+ `fs`: the sample rate of the dataset, valid options are `16k` and `8k`.\n+ `min_max`: the mode of mixtures, valiad options are `min` and `max`.\n+ `noise_type`: the type of mixture, valiad options are `clean` and `both`.\n\nIn our recipe, we opt for the Libri2Mix data with a sampling rate of 16kHz, in 'min' mode, and without noise, thus configuring as follows:\n\n``` bash\nfs=16k\nmin_max=min\nnoise_type=\"clean\"\nLibri2Mix_dir=/path/to/Libri2Mix\n```\n\nAfter configuring the desired dataset version, running the script for the first phase will generate the prepared data files. By default, these files are stored in the `data` directory in the current location.\n\n```bash\ndata=data # you can change this to any directory\n```\n\nIn this stage, `local/prepare_data.sh`accomplishes three tasks (Main differences with v1 version):\n\n1. Organizes the original Libri2Mix dataset into three directoies `dev`, `test` and `train_100`/`train_360`, each containing the following files:\n\n    + `single.utt2spk`: each line records two space-separated columns: `clean_wav_id` and `speaker_id`\n\n        ```text\n        s1/103-1240-0003_1235-135887-0017.wav 103\n        s1/103-1240-0004_4195-186237-0003.wav 103\n        ...\n        ```\n\n    + `utt2spk`: each line records three space-separated columns: `mixture_wav_id`, `speaker1_id` and `speaker2_id`.\n\n        ```\n        103-1240-0003_1235-135887-0017 103 1235\n        103-1240-0004_4195-186237-0003 103 4195\n        ...\n        ```\n\n    + `single.wav.scp`: each line records two space-separated columns: `clean_wav_id` and `clean_wav_path`\n\n        ```\n        s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0003_1235-135887-0017.wav\n        s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/s1/103-1240-0004_4195-186237-0003.wav\n        ...\n        ```\n\n    + `wav.scp`: each line records four space-separated columns: `mixture_wav_id`, `mixtrue_wav_path`, `clean_wav1_path` and `clean_wav2_path`.\n\n        ```\n        103-1240-0003_1235-135887-0017 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0003_1235-135887-0017.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0003_1235-135887-0017.wav\n        103-1240-0004_4195-186237-0003 /Data/Libri2Mix/wav16k/min/train-100/mix_clean/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s1/103-1240-0004_4195-186237-0003.wav /Data/Libri2Mix/wav16k/min/train-100/mix_clean/../s2/103-1240-0004_4195-186237-0003.wav\n        ...\n        ```\n\n2. Prepare LibriMix target-speaker enroll signal. This step will generate one `json` file in the `dev`, `test` and `train_100`/`train_360` directories, and additional three files in the `dev` and `test` directories respectively:\n\n    + `spk2enroll.json`: A JSON file, where the format of the stored key-value pairs is `{spk_id: [[spk_id_with_prefix_or_suffix, wav_path], ...]}`.\n\n        ```\n        \"652\": [[\"652-129742-0010\", \"/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0010_3081-166546-0071.wav\"],\n        ...,\n        [\"652-129742-0000\", \"/Data/Libri2Mix/wav16k/min/dev/s1/652-129742-0000_1993-147966-0004.wav\"]],\n        ...\n        ```\n\n    + `mixture2enrollment`: each line records three space-separated columns: `mixture_wav_id`, `clean_wav_id` and `enrollment_wav_id`.\n\n        ```\n        4077-13754-0001_5142-33396-0065 4077-13754-0001 s1/4077-13754-0004_5142-36377-0020\n        4077-13754-0001_5142-33396-0065 5142-33396-0065 s1/5142-36377-0003_1320-122612-0014\n        ...\n        ```\n\n    + `spk1.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.\n\n        ```\n        1272-128104-0000_2035-147961-0014 s1/1272-135031-0015_2277-149896-0006.wav\n        1272-128104-0003_2035-147961-0016 s1/1272-135031-0013_1988-147956-0016.wav\n        ...\n        ```\n\n    + `spk2.enroll`: each line records two space-separated columns: `mixture_wav_id` and `enrollment_wav_id`.\n\n        ```\n        1272-128104-0000_2035-147961-0014 s1/2035-152373-0009_3000-15664-0016.wav\n        1272-128104-0003_2035-147961-0016 s2/6313-66129-0013_2035-152373-0012.wav\n        ...\n        ```\n\nAt the end of this stage, the directory structure of `data` should look like this:\n\n```\ndata/\n|__ clean/ # the noise_type you chose\n    |__ dev/\n    |   |__ mixture2enrollment\n    |   |__ single.utt2spk\n    |   |__ single.wav.scp\n    |   |__ spk1.enroll\n    |   |__ spk2.enroll\n    |   |__ spk2enroll.json\n    |   |__ utt2spk\n    |   |__ wav.scp\n    |\n    |__ test/ # the same as dev/\n    |\n    |__ train_100/\n        |__ single.utt2spk\n        |__ single.wav.scp\n        |__ spk2enroll.json\n        |__ utt2spk\n        |__ wav.scp\n```\n\n3. Download the speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker for training the TSE model with pretrained speaker encoder. The models will be unzipped into `wespeaker_models/`.\nFind more speaker models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md.\n\n4. Prepare the speaker embeddings using wespeaker pretrained models. (Not needed, and comment off in v2 version by default.)\n\nThis step will generate two files in the `dev`, `test`, and `train_100` directories respectively:\n\n    + `embed.ark`: Kaldi ark file that stores the speaker embeddings.\n\n    + `embed.scp`: each line records two space-separated columns: `clean_wav_id` and `spk_embed_path`\n\n        ```\n        s1/103-1240-0003_1235-135887-0017.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:1450569\n        s1/103-1240-0004_4195-186237-0003.wav workspace/wesep/examples/librimix/tse/v1/data/clean/train-100/embed.ark:10622715\n        ...\n        ```\n\n------\n\n### Stage 2: Convert Data Format\n\nThis stage involves transforming the data into `shard` format, which is better suited for large datasets. Its core idea is to make the audio and labels of multiple small data(such as 1000 pieces), into compressed packets (tar) and read them based on the IterableDataset of Pytorch. For a detailed explanation of the  `shard` format, please refer to the [documentation](https://github.com/wenet-e2e/wenet/blob/main/docs/UIO.md) available in Wenet.\n\nThis stage will generate a subdirectory and a file in the `dev`, `test`, and `train_100` directories respectively:\n\n+ `shards/`: this directory stores the compressed packets (tar) files.\n\n    ```bash\n    ls shards\n    shards_000000000.tar  shards_000000001.tar  shards_000000002.tar ...\n    ```\n\n+ `shard.list`: each line records the path to the corresponding tar file.\n\n    ```\n    data/clean/dev/shards/shards_000000000.tar\n    data/clean/dev/shards/shards_000000001.tar\n    data/clean/dev/shards/shards_000000002.tar\n    ...\n    ```\n\nAt the end of this stage, the directory structure of `data` should look like this:\n\n```\ndata/\n|__ clean/ # the noise_type you chose\n    |__ dev/\n    |   |__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1\n    |   |__ shard.list\n    |   |__ shards/\n    |       |__ shards_000000000.tar\n    |       |__ shards_000000001.tar\n    |       |__ shards_000000002.tar\n    |\n    |__ test/ # the same as dev/\n    |\n    |__ train_100/\n        |__ single.utt2spk, single.wav.scp, ... # files generated by Stage 1\n        |__ shard.list\n        |__ shards/\n          |__ shards_000000000.tar\n          |__ ...\n          |__ shards_000000013.tar\n```\n\n------\n\n### Stage 3: Neural Networking Training\n\nYou can configure network training related parameters through the configuration file. We provide some ready-to-use configuration files in the recipe. If you wish to write your own configuration files or understand the meaning of certain parameters in the configuration files, you can refer to the following information:\n\n+ **overall training process related**\n\n    ```yaml\n    seed: 42\n    exp_dir: exp/BSRNN\n    enable_amp: false\n    gpus: '0,1'\n    log_batch_interval: 100\n    save_epoch_interval: 1\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `seed`: specify a random seed.\n    + `exp_dir`: specify the experiment directory.\n    + `enable_amp`: whether enable automatic mixed precision.\n    + `gpus`: specify the visible GPUs during training.\n    + `log_batch_interval`: specify after how many batch iterations to record in the log.\n    + `save_epoch_interval`: specify after how many batch epoches to save a checkpoint.\n\n+ **dataset and dataloader realted**\n\n    ```yaml\n    dataset_args:\n      resample_rate: 16000\n      sample_num_per_epoch: 0\n      shuffle: true\n      shuffle_args:\n        shuffle_size: 2500\n      whole_utt: false\n      chunk_len: 48000\n      online_mix: false\n      speaker_feat: &speaker_feat true\n      fbank_args:\n        num_mel_bins: 80\n        frame_shift: 10\n        frame_length: 25\n        dither: 1.0\n\n      \"Usually you don't need to manually write the data part of the configuration into the config file, it will be automatically generated.\"\n      data_type: \"shard\"\n      train_data: \"data/clean/train_100/shard.list\"\n      train_utt2spk: \"data/clean/train_100/single.utt2spk\"\n      train_spk2utt: \"data/clean/train_100/spk2enroll.json\"\n      val_data: \"data/clean/dev/shard.list\"\n      val_utt2spk: \"data/clean/dev/single.utt2spk\"\n      val_spk1_enroll: \"data/clean/dev/spk1.enroll\"\n      val_spk2_enroll: \"data/clean/dev/spk2.enroll\"\n      val_spk2utt: \"data/clean/dev/single.wav.scp\"\n\n\n    dataloader_args:\n      batch_size: 12  # A800\n      drop_last: true\n      num_workers: 6\n      pin_memory: false\n      prefetch_factor: 6\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `resample_rate`: All audio in the dataset will be resampled to this specified sample rate. Defaults to `16000`.\n    + `sample_num_per_epoch`: Specifies how many samples from the full training set will be iterated over in each epoch during training. The default is `0`, which means iterating over the entire training set.\n    + `shuffle`: Whether to perform *global* shuffle, i.e., shuffling at shards tar/raw/feat file level. Defaults to `true`.\n    + `shuffle_size`: Parameters related to *local* shuffle. Local shuffle maintains a buffer, and shuffling is only performed when the number of data items in the buffer reaches the s`shuffle_size`. Defaults to `2500`.\n    + `whole_utt`: Whether the network input and training target are the entire audio segment. Defaults to `false`.\n    + `chunk_len`: This parameter only takes effect when `whole_utt` is set to `false`. It indicates the length of the segment to be extracted from the complete audio as the network input and training target. Defaults to `48000`.\n    + `online_mix`: Whether dynamic mixing speakers when loading data, `shuffle` will not take effect if this parameter is set to `true`. Defaults to `false`.\n    + `speaker_feat`: Whether transform the enrollment from waveform to fbank. Recommended setting to `true`. Defaults to `false`.\n    + `num_mel_bins`: The parameter of fbank. The feature dimension of the fbank. Defaults to `80`.\n    + `frame_shift`: The parameter of fbank. The time of frame shift in `ms`. Defaults to `10`.\n    + `frame_length`: The parameter of fbank. The frame length in `ms`. Defaults to `25`.\n    + `dither`: The parameter of fbank. Whether add noise to fbank feature. Defaults to `1.0`.\n    + `data_type`: Specify the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.\n    + `train_data`: File containing paths to the training set files.\n    + `train_utt2spk`: Each line of the file specified by this parameter consists of `clean_wav_id` and `speaker_id`, separated by a space(e.g. `single.utt2spk`  generated in Stage 1).\n    + `train_spk2utt`: The file specified by this parameter is only used when the `joint_training` parameter is set to `true`. Each line of the file contains `speaker_id` and `enrollment_wav_id`.\n    + `val_data`: File containing paths to the validation set files.\n    + `val_utt2spk`: Similiar to `train_utt2spk`.\n    + `val_spk1_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker1_enrollment_wav_id`, separated by a space.\n    + `val_spk2_enroll`: Each line of the file specified by this parameter consists of `mixtrue_wav_id` and `speaker2_enrollment_wav_id`, separated by a space.\n    + `val_spk2utt`: Each line of the file specified by this parameter consists of `clean_wav_id` and `clean_wav_path`, separated by a space(e.g. `single.wav.scp`  generated in Stage 1).\n        + We have denoted this parameter as `val_spk2utt`, but it is actually assigned the `single.wav.scp` file as its value. This might be perplexing for users familiar with file formats in Kaldi or ESPnet, where the `spk2utt` file typically consists of lines containing `spk_id` and `wav_id`, whereas the `wav.scp` file's lines contain `wav_id` and `wav_path`.\n        + Nevertheless, upon closer examination of its role in subsequent procedures, it becomes evident that it is indeed employed to create a dictionary mapping speaker IDs to audio samples.\n    + `batch_size`: how many samples per batch to load. Please note that the batch size mentioned here refers to the **batch size per GPU**. So, if you are training on two GPUs within a single node and set the batch size to 16, it is equivalent to setting the batch size to 32 in a single-GPU, single-node scenario.\n    + `drop_last`: set to `true` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `false` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.\n    + `num_workers`: how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.\n    + `pin_memory`: If `true`, the data loader will copy Tensors into device/CUDA pinned memory before returning them.\n    + `prefetch_factor`: number of batches loaded in advance by each worker.\n\n+ **loss function related**\n\n    ```yaml\n    loss: SISDR\n    loss_args: { }\n\n    ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set.\n    loss: [SISDR, CE]\n    loss_args:\n      loss_posi: [[0],[1]]\n      loss_weight: [[1.0],[1.0]]\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `loss`: the loss function used for training.\n    + `loss_args`: the required arguments for the loss function.\n    + `loss_posi`: Select which outputs from the TSE model the loss function works on.\n    + `loss_weight`: The weight of loss calculated from corresponding loss function.\n\n    In addition to some common loss functions, we also support the use of GAN loss. You can enable this feature by setting `use_gan_loss` to `true` in  `run.sh`. Once enabled, the TSE model serves as the generator, and another convolutional neural network acts as the discriminator, engaging in adversarial training. The final loss of the TSE model is a combination of the losses specified in the configuration file and the GAN loss. By default, the weight for the former is set to` 0.95`, while the latter is set to `0.05`.\n\n    Due to the compatibility with GAN loss, the parameters mentioned below often differentiate between `tse_model` and `discriminator` under a single parameter. In such cases, we no longer provide separate explanations for each parameter.\n\n+ **neural network structure related**\n\n    ```yaml\n    model:\n      tse_model: BSRNN\n    model_args:\n      tse_model:\n        sr: 16000\n        win: 512\n        stride: 128\n        feature_dim: 128\n        num_repeat: 6\n        spk_emb_dim: 256\n        spk_fuse_type: 'multiply'\n        use_spk_transform: False\n        multi_fuse: False\n        joint_training: True       ### You should always set this para to `True` when using v2 version.\n        spk_model: ResNet34\n          spk_model_init: None\n          spk_args: None\n        spk_emb_dim: 256\n        spk_model_freeze: False\n        spk_feat: *speaker_feat\n        feat_type: \"consistent\"\n        multi_task: False\n        spksInTrain: 251\n\n    model_init:\n      tse_model: exp/BSRNN/no_spk_transform-multiply_fuse/models/latest_checkpoint.pt\n      discriminator: null\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `model`: specify the neural network used for training.\n    + `model_args`: specify model-specific parameters.\n    + `spk_fuse_type`: specify the fusion method of the speaker embedding. Support `concat`, `additive`, `multiply` and `FiLM`.\n    + `multi_fuse`: whether fuse the speaker embedding multiple times.\n    + `joint_training`: specify whether the speaker encoder for extracting speaker embeddings is jointly trained with the TSE model. Always set this to `true`. Do NOT use it to control if training with pretrained speaker encoders. Defaluts to `false`.\n    + `spk_model`: specify the speaker model. Supports most speaker models in wespeaker: https://github.com/wenet-e2e/wespeaker/tree/master.\n    + `spk_model_init`: the path of the pre-trained speaker model. Find more pretrained models in https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md. Set `None` for training the speaker model with TSE model from scratch.\n    + `spk_args`: specify speaker model-specific parameters.\n    + `spk_emb_dim`: the feature dimension of speaker embedding extracted from the speaker encoder.\n    + `spk_model_freeze`: whether freeze the weights in speaker encoder. Set `True` when using pretrained speaker encoder.\n    + `spk_feat`: Use the defined parameters in `dataset_args` to determine whether to perform feature extraction of enrollment within the model.\n    + `feat_type`: specify the type of enrollment's feature, when `spk_feat` is `False`.\n    + `multi_task`: whether use such as `CE` loss function for jointly training the speaker encoder. This parameter needs to be coordinated with the `loss`.\n    + `spksInTrain`: specify the speaker number in the training dataset. wsj0-2mix: 101, Libri2mix-100: 251, Libri2mix-360:921.\n    + `model_init`: whether to initialize the model with an existing checkpoint. Use `null` for no initialization. If you want to initialize, provide the checkpoint path. Defaults to `null`.\n\n\n+ **model optimization related**\n\n    ```yaml\n    num_epochs: 150\n    clip_grad: 5.0\n\n    optimizer:\n      tse_model: Adam\n    optimizer_args:\n      tse_model:\n        lr: 0.001\n        weight_decay: 0.0001\n\n    scheduler:\n      tse_model: ExponentialDecrease\n    scheduler_args:\n      tse_model:\n        final_lr: 2.5e-05\n        initial_lr: 0.001\n        warm_from_zero: false\n        warm_up_epoch: 0\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `num_epochs`: total number of training epochs.\n    + `clip_grad`: set the threshold for gradient clipping.\n    + `optimizer`: set the optimizer.\n    + `optimizer_args`: the required arguments for optimizer. Not used in currently version. The learning rate and scheduler are determined by `scheduler_args`.\n    + `scheduler`: set the scheduler.\n    + `scheduler_args`: the required arguments for scheduler.\n\n+ **others**\n\n    ```yaml\n    num_avg: 2\n    ```\n\n    Explanations for some of the parameters mentioned above:\n\n    + `num_avg`: numbers for averaged model.\n\nTo avoid frequent changes to the configuration file, we support **overwriting values in the configuration file** directly within `run.sh`. For example, running the following command in `run.sh` will overwrite the visible GPU from `'0,1'` to ``'0'`` in the above configuration file:\n\n```bash\n  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \\\n    ${train_script} --config confs/config.yaml \\\n    --gpus \"[0]\" \\\n```\n\nAt the end of this stage, an experiment directory will be created in the current directory, containing the following files:\n\n```\n${exp_dir}/\n|__ train.log\n|__ config.yaml\n|__ models/\n  |__ checkpoint_1.pt\n  |__ ...\n  |__ checkpoint_150.pt\n  |__ final_checkpoint.pt -> checkpoint_150.pt\n  |__ latest_checkpoint.pt -> checkpoint_150.pt\n```\n\n------\n\n### Stage 4: Apply Model Average\n\nIn this stage, we perform model averaging, and you need to specify the following parameters in `run.sh`:\n\n+ `dst_model`: the path to save the averaged model.\n+ `src_path`: source models path for average.\n+ `num`: number of source models for the averaged model.\n+ `mode`: the mode for model averaging. Validate options are `final` and `best`.\n    + `final`: filters and sorts the latest PyTorch model files in the source directory. Averages the states of the last `num` models based on a numerical sorting of their filenames.\n    + `best`: directly uses user-specified epochs to select specific model checkpoint files. Averages the states of these selected models.\n+ `epochs`: this parameter only takes effect when `mode`  is set to `best` and is used to specify the epoch index of the checkpoint that will be used as source models.\n\n------\n\n### Stage 5: Extract Speech Using the Trained Model\n\nAfter training is complete, you can execute stage 5 to extract the target speaker's speech using the trained model. In this stage, it mainly calls `wesep/bin/infer.py`, and you need to provide the following parameters for this script:\n\n+ `config`: the configuration file used in Stage 3.\n+ `fs`: the sample rate of the audio data.\n+ `gpus`: the index of the visible GPU.\n+ `exp_dir`: the experiment directory.\n+ `data_type`: the type of dataset, with valid options being `shard` and `raw`. Defaults to `shard`.\n+ `test_data`: similiar to `train_data`.\n+ `test_spk1_enroll`: similiar to `dev_spk1_enroll`.\n+ `test_spk2_enroll`: similiar to `dev_spk2_enroll`.\n+ `test_spk2utt`: similiar to `dev_spk2utt`.\n+ `save_wav`: control if save the extracted speech in `exp_dir/audio`.\n+ `checkpoint`: the path to the checkpoint used for extracting the target speaker's speech.\n\nAt the end of this stage, the structure of  the experiment directory should look like this:\n\n```\n${exp_dir}/\n|__ train.log\n|__ config.yaml\n|__ models/\n|__ infer.log\n|__ audio/\n  |__ spk1.scp # each line records two space-separated columns: `target_wav_id` and `target_wav_path`\n  |__ Utt1001-4992-41806-0008_6930-75918-0015-T4992.wav\n  |__ ...\n  |__ Utt999-61-70968-0003_2830-3980-0008-T61.wav\n```\n\n------\n\n### Stage 6: Scoring\n\nIn this stage, we evaluate the quality of the generated speech using common objective metrics. The default metrics include **STOI**, **SDR**, **SAR**, **SIR**, and **SI_SNR**. In addition to these metrics, you can also include **PESQ** and **DNS_MOS** by setting the values of `use_pesq` and `use_dnsmos` to `true`. Please be aware that DNS_MOS is exclusively supported for audio samples with a **16 kHz** sampling rate. For audio with different sampling rates, refrain from employing DNS_MOS for assessment.\n\nAt the end of this stage, a markdown file `RESULTS.md` will be created under `exp` directory, the directory structure of `exp` should look like this:\n\n```\nexp/BSRNN/\n|__ ${exp_dir}\n|  |__ train.log, ... # files and directories generated in Stage 5\n|  |__ scoring/\n|\n|__ RESULTS.md\n```\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/bsrnn.yaml",
    "content": "dataloader_args:\n  batch_size: 8 #RTX2080 1, V100: 8, A800: 16\n  drop_last: true\n  num_workers: 6\n  pin_memory: true\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: &sr 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  speaker_feat: &speaker_feat True\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n  # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589\n  # only Single-optimization method is supported here.\n  # if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml\n  SSA_enroll_prob:  0  # prob to add SSA on enrollment speech\n\nenable_amp: false\nexp_dir: exp/BSRNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\n# loss: [SISDR, CE]               ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set\n# loss_args:\n#   loss_posi: [[0],[1]]\n#   loss_weight: [[1.0],[1.0]]\n\nmodel:\n  tse_model: BSRNN\nmodel_args:\n  tse_model:\n    sr: *sr\n    win: 512\n    stride: 128\n    feature_dim: 128\n    num_repeat: 6\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True  # Always set True, use \"spk_model_freeze\" to control if use pre-trained speaker encoders\n    ####### ResNet    The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md\n    spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152\n    spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt\n    spk_args:\n      feat_dim: 80\n      embed_dim: &embed_dim 256\n      pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n      two_emb_layer: False\n    ####### Ecapa_TDNN\n    # spk_model: ECAPA_TDNN_GLOB_c512\n    # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    # spk_args:\n    #   embed_dim: &embed_dim 192\n    #   feat_dim: 80\n    #   pooling_func: ASTP\n    ####### CAMPPlus\n    # spk_model: CAMPPlus\n    # spk_model_init: False\n    # spk_args:\n    #   feat_dim: 80\n    #   embed_dim: &embed_dim 192\n    #   pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: False    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\n# find_unused_parameters: True\n\nmodel_init:\n  tse_model: null\n  discriminator: null\n  spk_model: null\n\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001  # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/bsrnn_feats.yaml",
    "content": "dataloader_args:\n  batch_size: 4 #RTX2080 1, V100: 4, A800: 12\n  drop_last: true\n  num_workers: 6\n  pin_memory: true\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: &sr 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  speaker_feat: &speaker_feat False\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n  # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589\n  # only Single-optimization method is supported here.\n  # if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml\n  SSA_enroll_prob:  0  # prob to add SSA on enrollment speech\n\nenable_amp: false\nexp_dir: exp/BSRNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\n# loss: [SISDR, CE]               ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set\n# loss_args:\n#   loss_posi: [[0],[1]]\n#   loss_weight: [[1.0],[1.0]]\n\nmodel:\n  tse_model: BSRNN_Feats\nmodel_args:\n  tse_model:\n    sr: *sr\n    win: 512\n    stride: 128\n    feature_dim: 128\n    num_repeat: 6\n    spectral_feat: 'tfmap_emb'  # 'tfmap_spec' 'tfmap_emb' False\n    spk_fuse_type: 'cross_multiply'    #'cross_multiply' 'multiply' False\n    use_spk_transform: False\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True     # Always set True, use \"spk_model_freeze\" to control if use pre-trained speaker encoders\n    #################################################################\n    ###### Ecapa_TDNN\n    spk_model: ECAPA_TDNN_GLOB_c512\n    spk_model_init: ./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    spk_args:\n      embed_dim: &embed_dim 192\n      feat_dim: 80\n      pooling_func: ASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: True    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   # if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251          # wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\n# find_unused_parameters: True\n\nmodel_init:\n  tse_model: null\n  discriminator: null\n  spk_model: null\n\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001  # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml",
    "content": "dataloader_args:\n  batch_size: 8 #RTX2080 1, V100: 8, A800: 16\n  drop_last: true\n  num_workers: 6\n  pin_memory: true\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: &sr 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  speaker_feat: &speaker_feat False\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n\n\nenable_amp: false\nexp_dir: exp/BSRNN\ngpus: '0,1'\nlog_batch_interval: 100\n\n#Please refer to  our SLT paper https://www.arxiv.org/abs/2409.09589\n# to check our parameter settings.\nloss: SISDR\nloss_args:\n  loss_posi: [[0,1]]\n  loss_weight: [[0.4,0.6]]\n\n#if you wanna use CE loss, multi_task needs to be set True\n# loss: [SISDR, CE]               ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set\n# loss_args:\n#   loss_posi: [[0,1],[2,3]]\n#   loss_weight: [[0.36,0.54],[0.04,0.06]]\n\nmodel:\n  tse_model: BSRNN_Multi\nmodel_args:\n  tse_model:\n    sr: *sr\n    win: 512\n    stride: 128\n    feature_dim: 128\n    num_repeat: 6\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True  # Always set True, use \"spk_model_freeze\" to control if use pre-trained speaker encoders\n    ####### ResNet    The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md\n    spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152\n    spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt\n    spk_args:\n      feat_dim: 80\n      embed_dim: &embed_dim 256\n      pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n      two_emb_layer: False\n    ####### Ecapa_TDNN\n    # spk_model: ECAPA_TDNN_GLOB_c512\n    # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    # spk_args:\n    #   embed_dim: &embed_dim 192\n    #   feat_dim: 80\n    #   pooling_func: ASTP\n    ####### CAMPPlus\n    # spk_model: CAMPPlus\n    # spk_model_init: False\n    # spk_args:\n    #   feat_dim: 80\n    #   embed_dim: &embed_dim 192\n    #   pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: False    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\n# find_unused_parameters: True\n\nmodel_init:\n  tse_model: null\n  discriminator: null\n  spk_model: null\n\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001  # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/dpcc_init_gan.yaml",
    "content": "use_metric_loss: true\n\ndataloader_args:\n  batch_size: 4\n  drop_last: true\n  num_workers: 4\n  pin_memory: false\n  prefetch_factor: 4\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n\nenable_amp: false\nexp_dir: exp/DPCNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISNR\nloss_args: { }\ngan_loss_weight: 0.05\n\nmodel:\n  tse_model: DPCCN\n  discriminator: CMGAN_Discriminator\nmodel_args:\n  tse_model:\n    win: 512\n    stride: 128\n    feature_dim: 257\n    tcn_blocks: 10\n    tcn_layers: 2\n    spk_emb_dim: 256\n    causal: False\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n  discriminator: {}\n\nmodel_init:\n  tse_model: exp/DPCCN/no_spk_transform-multiply_fuse/models/final_model.pt\n  discriminator: null\nnum_avg: 5\nnum_epochs: 50\n\noptimizer:\n  tse_model: Adam\n  discriminator: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.0001\n    weight_decay: 0.0001\n  discriminator:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 3.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\n  discriminator: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.0001\n    warm_from_zero: false\n    warm_up_epoch: 0\n  discriminator:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/dpccn.yaml",
    "content": "dataloader_args:\n  batch_size: 6\n  drop_last: true\n  num_workers: 6\n  pin_memory: false\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  speaker_feat: &speaker_feat True\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n\nenable_amp: false\nexp_dir: exp/DPCNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\n# loss: [SISDR, CE]               ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set\n# loss_args:\n#   loss_posi: [[0],[1]]\n#   loss_weight: [[1.0],[1.0]]\n\nmodel:\n  tse_model: DPCCN\nmodel_args:\n  tse_model:\n    win: 512\n    stride: 128\n    feature_dim: 257\n    tcn_blocks: 10\n    tcn_layers: 2\n    causal: False\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True      # Always set True, use \"spk_model_freeze\" to control if use pre-trained speaker encoders\n    ####### ResNet    The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md\n    spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152\n    spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt\n    spk_args:\n      feat_dim: 80\n      embed_dim: &embed_dim 256\n      pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n      two_emb_layer: False\n    ####### Ecapa_TDNN\n    # spk_model: ECAPA_TDNN_GLOB_c512\n    # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    # spk_args:\n    #   embed_dim: &embed_dim 192\n    #   feat_dim: 80\n    #   pooling_func: ASTP\n    ####### CAMPPlus\n    # spk_model: CAMPPlus\n    # spk_model_init: False\n    # spk_args:\n    #   feat_dim: 80\n    #   embed_dim: &embed_dim 192\n    #   pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: False    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\nmodel_init:\n  tse_model: null\n  discriminator: null\nnum_avg: 5\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\n\nclip_grad: 3.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/spexplus.yaml",
    "content": "dataloader_args:\n  batch_size: 8       #A800: 8\n  drop_last: true\n  num_workers: 4\n  pin_memory: true\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n\nenable_amp: false\nexp_dir: exp/SpExplus/\ngpus: ['0']\nlog_batch_interval: 100\n\n# joint_training: True\nloss: [SISDR, CE]               ###SI_SNR, SDR, sisnr, CE\nloss_args:\n  loss_posi: [[0,1,2],[3]]\n  loss_weight: [[0.8,0.1,0.1],[0.5]]\n\n\nmodel:\n  tse_model: ConvTasNet\nmodel_args:\n  tse_model:\n    B: 256\n    H: 512\n    L: 20\n    N: 256\n    P: 3\n    R: 4\n    X: 8\n    spk_emb_dim: 256\n    activate: relu\n    causal: false\n    norm: gLN\n    skip_con: False\n    spk_fuse_type: concatConv # \"concat\", \"additive\", \"multiply\", \"FiLM\", \"None\", (\"concatConv\" only for convtasnet)\n    use_spk_transform: False\n    multi_fuse: True        # Multi speaker fuse with seperation modules\n    encoder_type: Multi # Multi, Deep, False\n    decoder_type: Multi # Multi, Deep, False\n    joint_training: True\n    multi_task: True\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\nmodel_init:\n  tse_model: null\n  discriminator: null\nnum_avg: 5\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001  # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\nsave_epoch_interval: 5\nclip_grad: 5.0 # False\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: False\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/confs/tfgridnet.yaml",
    "content": "dataloader_args:\n  batch_size: 1\n  drop_last: true\n  num_workers: 6\n  pin_memory: false\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: &sr 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  chunk_len: 48000\n  speaker_feat: &speaker_feat True\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n  noise_lmdb_file: './data/musan/lmdb'\n  noise_prob: 0 # prob to add noise aug per sample\n  specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech\n  reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech\n  noise_enroll_prob: 0 # prob to add noise aug on enrollment speech\n\nenable_amp: false\nexp_dir: exp/TFGridNet\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\n# loss: [SISDR, CE]               ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set\n# loss_args:\n#   loss_posi: [[0],[1]]\n#   loss_weight: [[1.0],[1.0]]\n\nmodel:\n  tse_model: TFGridNet\nmodel_args:\n  tse_model:\n    n_srcs: 1\n    sr: *sr\n    n_fft: 128\n    stride: 64\n    window: \"hann\"\n    n_imics: 1\n    n_layers: 6\n    lstm_hidden_units: 192\n    attn_n_head: 4\n    attn_approx_qk_dim: 512\n    emb_dim: 128\n    emb_ks: 1\n    emb_hs: 1\n    activation: \"prelu\"\n    eps: 1.0e-5\n    use_spk_transform: False\n    spk_fuse_type: \"multiply\"\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True      # Always set True, use \"spk_model_freeze\" to control if use pre-trained speaker encoders\n    ####### ResNet    The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md\n    spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152\n    spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt\n    spk_args:\n      feat_dim: 80\n      embed_dim: &embed_dim 256\n      pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n      two_emb_layer: False\n    ####### Ecapa_TDNN\n    # spk_model: ECAPA_TDNN_GLOB_c512\n    # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    # spk_args:\n    #   embed_dim: &embed_dim 192\n    #   feat_dim: 80\n    #   pooling_func: ASTP\n    ####### CAMPPlus\n    # spk_model: CAMPPlus\n    # spk_model_init: False\n    # spk_args:\n    #   feat_dim: 80\n    #   embed_dim: &embed_dim 192\n    #   pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: False    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\n\nmodel_init:\n  tse_model: null\nnum_avg: 5\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/librimix/tse/v2/local/prepare_data.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nmix_data_path='./Libri2Mix/wav16k/min/'\n\ndata=data\nnoise_type=clean\nnum_spk=2\n\n. tools/parse_options.sh || exit 1\n\ndata=$(realpath ${data})\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare the meta files for the datasets\"\n\n  for dataset in dev test train-100; do\n  # for dataset in train-360; do\n    echo \"Preparing files for\" $dataset\n\n    # Prepare the meta data for the mixed data\n    dataset_path=$mix_data_path/$dataset/mix_${noise_type}\n    mkdir -p \"${data}\"/$noise_type/${dataset}\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print $NF}' |\n      awk -v path=\"${dataset_path}\" '{print $1 , path \"/\" $1 , path \"/../s1/\" $1 , path \"/../s2/\" $1}' |\n      sed 's#.wav##' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/wav.scp |\n      awk -F[_-] '{print $0, $1,$4}' >\"${data}\"/$noise_type/${dataset}/utt2spk\n\n    # Prepare the meta data for single speakers\n    dataset_path=$mix_data_path/$dataset/s1\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s1/\" $NF, $0}' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/single.wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's1' |\n      awk -F[-_/] '{print $0, $2}' >\"${data}\"/$noise_type/${dataset}/single.utt2spk\n\n    dataset_path=$mix_data_path/$dataset/s2\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s2/\" $NF, $0}' | sort -k1,1 >>\"${data}\"/$noise_type/${dataset}/single.wav.scp\n\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's2' |\n      awk -F[-_/] '{print $0, $5}' >>\"${data}\"/$noise_type/${dataset}/single.utt2spk\n  done\nfi\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"stage 2: Prepare LibriMix target-speaker enroll signal\"\n\n  for dset in dev test train-100; do\n  # for dset in train-360; do\n    python local/prepare_spk2enroll_librispeech.py \\\n      \"${mix_data_path}/${dset}\" \\\n      --is_librimix True \\\n      --outfile \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --audio_format wav\n  done\n\n  for dset in dev test; do\n    if [ $num_spk -eq 2 ]; then\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    else\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    fi\n\n    output_file=\"${data}/${noise_type}/${dset}/mixture2enrollment\"\n    wget -O \"$output_file\" \"$url\"\n  done\n\n  for dset in dev test; do\n    python local/prepare_librimix_enroll.py \\\n      \"${data}\"/$noise_type/${dset}/wav.scp \\\n      \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --mix2enroll \"${data}/${noise_type}/${dset}/mixture2enrollment\" \\\n      --num_spk ${num_spk} \\\n      --train False \\\n      --output_dir \"${data}\"/${noise_type}/${dset} \\\n      --outfile_prefix \"spk\"\n  done\nfi\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  echo \"Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker...\"\n  mkdir wespeaker_models\n  wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip\n  unzip voxceleb_resnet34.zip -d wespeaker_models\n  wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip\n  unzip voxceleb_ECAPA512.zip -d wespeaker_models\nfi\n\n# if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then\n#   echo \"Prepare the speaker embeddings using wespeaker pretrained models\"\n#   for dataset in dev test train-100; do\n#     mkdir -p \"${data}\"/$noise_type/${dataset}\n#     echo \"Preparing files for\" $dataset\n#     wespeaker --task embedding_kaldi \\\n#               --wav_scp \"${data}\"/$noise_type/${dataset}/single.wav.scp \\\n#               --output_file \"${data}\"/$noise_type/${dataset}/embed \\\n#               -p wespeaker_models/voxceleb_resnet34 \\\n#               -g 0 # GPU idx\n#   done\n# fi\n\nif [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then\n  if [ ! -d \"${data}/raw_data/musan\" ]; then\n    mkdir -p ${data}/raw_data/musan\n    #\n    echo \"Downloading musan.tar.gz ...\"\n    echo \"This may take a long time. Thus we recommand you to download all archives above in your own way first.\"\n    wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data\n    md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}')\n    [ $md5 != \"0c472d4fc0c5141eca47ad1ffeb2a7df\" ] && echo \"Wrong md5sum of musan.tar.gz\" && exit 1\n\n    echo \"Decompress all archives ...\"\n    tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data\n\n    rm -rf ${data}/raw_data/musan.tar.gz\n  fi\n\n  echo \"Prepare wav.scp for musan ...\"\n  mkdir -p ${data}/musan\n  find ${data}/raw_data/musan -name \"*.wav\" | awk -F\"/\" '{print $(NF-2)\"/\"$(NF-1)\"/\"$NF,$0}' >${data}/musan/wav.scp\n\n  # Convert all musan data to LMDB\n  echo \"conver musan data to LMDB ...\"\n  python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb\nfi"
  },
  {
    "path": "examples/librimix/tse/v2/local/prepare_librimix_enroll.py",
    "content": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.utils.utils import str2bool\n\n\ndef prepare_librimix_enroll(wav_scp,\n                            spk2utts,\n                            output_dir,\n                            num_spk=2,\n                            train=True,\n                            prefix=\"enroll_spk\"):\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    with Path(spk2utts).open(\"r\", encoding=\"utf-8\") as f:\n        # {spkID: [(uid1, path1), (uid2, path2), ...]}\n        spk2utt = json.load(f)\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            uttIDs = mixtureID.split(\"_\")\n            for spk in range(num_spk):\n                uttID = uttIDs[spk]\n                spkID = uttID.split(\"-\")[0]\n                if train:\n                    # For training, we choose the auxiliary signal on the fly.\n                    # Here we use the pattern f\"*{uttID} {spkID}\".\n                    writer[f\"{prefix}{spk + 1}.enroll\"][\n                        mixtureID] = f\"*{uttID} {spkID}\"\n                else:\n                    enrollID = random.choice(spk2utt[spkID])[1]\n                    while enrollID == uttID and len(spk2utt[spkID]) > 1:\n                        enrollID = random.choice(spk2utt[spkID])[1]\n                    writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = enrollID\n\n\ndef prepare_librimix_enroll_v2(wav_scp,\n                               map_mix2enroll,\n                               output_dir,\n                               num_spk=2,\n                               prefix=\"spk\"):\n    # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    mix2enroll = {}\n    with open(map_mix2enroll) as f:\n        for line in f:\n            mix_id, utt_id, enroll_id = line.strip().split()\n            sid = mix_id.split(\"_\").index(utt_id) + 1\n            mix2enroll[mix_id, f\"s{sid}\"] = enroll_id\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            for spk in range(num_spk):\n                enroll_id = mix2enroll[mixtureID, f\"s{spk + 1}\"]\n                writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = (enroll_id +\n                                                                  \".wav\")\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"wav_scp\",\n        type=str,\n        help=\"Path to the wav.scp file\",\n    )\n    parser.add_argument(\n        \"spk2utts\",\n        type=str,\n        help=\"Path to the json, mapping from speaker ID to utterances\",\n    )\n    parser.add_argument(\n        \"--num_spk\",\n        type=int,\n        default=2,\n        choices=(2, 3),\n        help=\"Number of speakers in each mixture sample\",\n    )\n    parser.add_argument(\n        \"--train\",\n        type=str2bool,\n        default=True,\n        help=\"Whether is the training set or not\",\n    )\n    parser.add_argument(\n        \"--mix2enroll\",\n        type=str,\n        default=None,\n        help=\"Path to the downloaded map_mixture2enrollment file. \"\n        \"If `train` is False, this value is required.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"Random seed\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        required=True,\n        help=\"Path to the directory for storing output files\",\n    )\n    parser.add_argument(\n        \"--outfile_prefix\",\n        type=str,\n        default=\"spk\",\n        help=\"Prefix of the output files\",\n    )\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n\n    if args.train:\n        prepare_librimix_enroll(\n            args.wav_scp,\n            args.spk2utts,\n            args.output_dir,\n            num_spk=args.num_spk,\n            train=args.train,\n            prefix=args.outfile_prefix,\n        )\n    else:\n        prepare_librimix_enroll_v2(\n            args.wav_scp,\n            args.mix2enroll,\n            args.output_dir,\n            num_spk=args.num_spk,\n            prefix=args.outfile_prefix,\n        )\n"
  },
  {
    "path": "examples/librimix/tse/v2/local/prepare_spk2enroll_librispeech.py",
    "content": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.utils import str2bool\n\n\ndef get_spk2utt(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in Path(path).rglob(\"*.{}\".format(audio_format)):\n            readerID = audio.parent.parent.stem\n            uid = audio.stem\n            assert uid.split(\"-\")[0] == readerID, audio\n            spk2utt[readerID].append((uid, str(audio)))\n\n    return spk2utt\n\n\ndef get_spk2utt_librimix(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in chain(\n                Path(path).rglob(\"s1/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s2/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s3/*.{}\".format(audio_format)),\n        ):\n            spk_idx = int(audio.parent.stem[1:]) - 1\n            mix_uid = audio.stem\n            uid = mix_uid.split(\"_\")[spk_idx]\n            sid = uid.split(\"-\")[0]\n            spk2utt[sid].append((uid, str(audio)))\n\n    return spk2utt\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"audio_paths\",\n        type=str,\n        nargs=\"+\",\n        help=\"Paths to Librispeech subsets\",\n    )\n    parser.add_argument(\n        \"--is_librimix\",\n        type=str2bool,\n        default=False,\n        help=\"Whether the provided audio_paths points to LibriMix data\",\n    )\n    parser.add_argument(\n        \"--outfile\",\n        type=str,\n        default=\"spk2utt_tse.json\",\n        help=\"Path to the output spk2utt json file\",\n    )\n    parser.add_argument(\"--audio_format\", type=str, default=\"flac\")\n    args = parser.parse_args()\n\n    if args.is_librimix:\n        # use clean sources from LibriMix as enrollment\n        spk2utt = get_spk2utt_librimix(args.audio_paths,\n                                       audio_format=args.audio_format)\n    else:\n        # use Librispeech as enrollment\n        spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)\n    outfile = Path(args.outfile)\n    outfile.parent.mkdir(parents=True, exist_ok=True)\n    with outfile.open(\"w\", encoding=\"utf-8\") as f:\n        json.dump(spk2utt, f, indent=4)\n"
  },
  {
    "path": "examples/librimix/tse/v2/path.sh",
    "content": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHONIOENCODING=UTF-8\nexport PYTHONPATH=../../../../:$PYTHONPATH\n"
  },
  {
    "path": "examples/librimix/tse/v2/run.sh",
    "content": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\n# General configuration\nstage=-1\nstop_stage=-1\n\n# Data preparation related\ndata=data\nfs=16k\nmin_max=min\nnoise_type=\"clean\"\ndata_type=\"shard\" # shard/raw\nLibri2Mix_dir=/YourPATH/librimix/Libri2Mix\nmix_data_path=\"${Libri2Mix_dir}/wav${fs}/${min_max}\"\n\n# Training related\ngpus=\"[0]\"\nuse_gan_loss=false\nconfig=confs/bsrnn.yaml\nexp_dir=exp/BSRNN/no_spk_transform-multiply_fuse\nif [ -z \"${config}\" ] && [ -f \"${exp_dir}/config.yaml\" ]; then\n  config=\"${exp_dir}/config.yaml\"\nfi\n\n# TSE model initialization related\ncheckpoint=\n\n# Inferencing and scoring related\nsave_results=true\nuse_pesq=true\nuse_dnsmos=true\ndnsmos_use_gpu=true\n\n# Model average related\nnum_avg=10\n\n. tools/parse_options.sh || exit 1\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare datasets ...\"\n  ./local/prepare_data.sh --mix_data_path ${mix_data_path} \\\n    --data ${data} \\\n    --noise_type ${noise_type} \\\n    --stage 1 \\\n    --stop-stage 3\nfi\n\ndata=${data}/${noise_type}\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"Covert train and test data to ${data_type}...\"\n  for dset in train-100 dev test; do\n    #  for dset in train-360; do\n    python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \\\n      --num_threads 16 \\\n      --prefix shards \\\n      --shuffle \\\n      ${data}/$dset/wav.scp ${data}/$dset/utt2spk \\\n      ${data}/$dset/shards ${data}/$dset/shard.list\n  done\nfi\n\n\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  echo \"Start training ...\"\n  num_gpus=$(echo $gpus | awk -F ',' '{print NF}')\n  if [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/latest_checkpoint.pt\" ]; then\n    checkpoint=\"${exp_dir}/models/latest_checkpoint.pt\"\n  fi\n  if ${use_gan_loss}; then\n    train_script=wesep/bin/train_gan.py\n  else\n    train_script=wesep/bin/train.py\n  fi\n  export OMP_NUM_THREADS=8\n  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \\\n    ${train_script} --config $config \\\n    --exp_dir ${exp_dir} \\\n    --gpus $gpus \\\n    --num_avg ${num_avg} \\\n    --data_type \"${data_type}\" \\\n    --train_data ${data}/train-100/${data_type}.list \\\n    --train_utt2spk ${data}/train-100/single.utt2spk \\\n    --train_spk2utt ${data}/train-100/spk2enroll.json \\\n    --val_data ${data}/dev/${data_type}.list \\\n    --val_spk1_enroll ${data}/dev/spk1.enroll \\\n    --val_spk2_enroll ${data}/dev/spk2.enroll \\\n    --val_spk2utt ${data}/dev/single.wav.scp \\\n    ${checkpoint:+--checkpoint $checkpoint}\nfi\n\nif [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then\n  echo \"Do model average ...\"\n  avg_model=$exp_dir/models/avg_best_model.pt\n  python wesep/bin/average_model.py \\\n    --dst_model $avg_model \\\n    --src_path $exp_dir/models \\\n    --num ${num_avg} \\\n    --mode best \\\n    --epochs \"138,141\"\nfi\nif [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/avg_best_model.pt\" ]; then\n  checkpoint=\"${exp_dir}/models/avg_best_model.pt\"\nfi\n\n\n# shellcheck disable=SC2215\nif [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then\n  echo \"Start inferencing ...\"\n  python wesep/bin/infer.py --config $config \\\n    --fs ${fs} \\\n    --gpus 0 \\\n    --exp_dir ${exp_dir} \\\n    --data_type \"${data_type}\" \\\n    --test_data ${data}/test/${data_type}.list \\\n    --test_spk1_enroll ${data}/test/spk1.enroll \\\n    --test_spk2_enroll ${data}/test/spk2.enroll \\\n    --test_spk2utt ${data}/test/single.wav.scp \\\n    --save_wav ${save_results} \\\n    ${checkpoint:+--checkpoint $checkpoint}\nfi\n\nif [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then\n  echo \"Start scoring ...\"\n  ./tools/score.sh --dset \"${data}/test\" \\\n    --exp_dir \"${exp_dir}\" \\\n    --fs ${fs} \\\n    --use_pesq \"${use_pesq}\" \\\n    --use_dnsmos \"${use_dnsmos}\" \\\n    --dnsmos_use_gpu \"${dnsmos_use_gpu}\" \\\n    --n_gpu \"${num_gpus}\"\nfi\n"
  },
  {
    "path": "examples/voxceleb1/v2/confs/bsrnn_online.yaml",
    "content": "dataloader_args:\n  batch_size: 8\n  drop_last: true\n  num_workers: 6\n  pin_memory: false\n  prefetch_factor: 6\n\ndataset_args:\n  resample_rate: 16000\n  sample_num_per_epoch: 0\n  shuffle: true\n  shuffle_args:\n    shuffle_size: 2500\n  filter_len: true\n  filter_len_args:\n      min_num_seconds: 1.0\n      max_num_seconds: 100.0\n  chunk_len: 48000\n  online_mix: true\n  num_speakers: 2\n  use_random_snr: true\n  speaker_feat: &speaker_feat True\n  fbank_args:\n    num_mel_bins: 80\n    frame_shift: 10\n    frame_length: 25\n    dither: 1.0\n\n  noise_lmdb_file:  './data/musan/lmdb'\n  noise_prob: 0\n  reverb_prob: 0\n\nenable_amp: false\nexp_dir: exp/BSRNN\ngpus: '0,1'\nlog_batch_interval: 100\n\nloss: SISDR\nloss_args: { }\n\nmodel:\n  tse_model: BSRNN\nmodel_args:\n  tse_model:\n    sr: 16000\n    win: 512\n    stride: 128\n    feature_dim: 128\n    num_repeat: 6\n    spk_fuse_type: 'multiply'\n    use_spk_transform: False\n    multi_fuse: False        # Fuse the speaker embedding multiple times.\n    joint_training: True\n    ####### ResNet    The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md\n    spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152\n    spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt\n    spk_args:\n      feat_dim: 80\n      embed_dim: &embed_dim 256\n      pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n      two_emb_layer: False\n    ####### Ecapa_TDNN\n    # spk_model: ECAPA_TDNN_GLOB_c512\n    # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt\n    # spk_args:\n    #   embed_dim: &embed_dim 192\n    #   feat_dim: 80\n    #   pooling_func: ASTP\n    ####### CAMPPlus\n    # spk_model: CAMPPlus\n    # spk_model_init: False\n    # spk_args:\n    #   feat_dim: 80\n    #   embed_dim: &embed_dim 192\n    #   pooling_func: \"TSTP\" # TSTP, ASTP, MQMHASTP\n    #################################################################\n    spk_emb_dim: *embed_dim\n    spk_model_freeze: False    # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder\n    spk_feat: *speaker_feat   #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used\n    feat_type: \"consistent\"\n    multi_task: False\n    spksInTrain: 251    #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921\n\n# find_unused_parameters: True\n\nmodel_init:\n  tse_model: null\n  discriminator: null\nnum_avg: 2\nnum_epochs: 150\n\noptimizer:\n  tse_model: Adam\noptimizer_args:\n  tse_model:\n    lr: 0.001\n    weight_decay: 0.0001\n\nclip_grad: 5.0\nsave_epoch_interval: 1\n\nscheduler:\n  tse_model: ExponentialDecrease\nscheduler_args:\n  tse_model:\n    final_lr: 2.5e-05\n    initial_lr: 0.001\n    warm_from_zero: false\n    warm_up_epoch: 0\n\nseed: 42\n"
  },
  {
    "path": "examples/voxceleb1/v2/local/prepare_data.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n\nstage=-1\nstop_stage=-1\n\nsingle_data_path='./voxceleb/VoxCeleb1/wav/'\nmix_data_path='./Libri2Mix/wav16k/min/'\n\ndata=data\nnoise_type=clean\nnum_spk=2\n\n. tools/parse_options.sh || exit 1\n\ndata=$(realpath ${data})\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare the meta files for the vox1 single speaker datasets\"\n\n  for dataset in train-vox1; do\n    echo \"Preparing files for\" $dataset\n\n    # Prepare the meta data for the online mix data\n    mkdir -p \"${data}\"/$noise_type/${dataset}\n    find ${single_data_path} -name \"*.wav\" | awk -F\"/\" '{print $(NF-2)\"/\"$(NF-1)\"/\"$NF,$0}' | sort >\"${data}\"/$noise_type/${dataset}/wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/wav.scp | awk -F \"/\" '{print $0,$1}' >\"${data}\"/$noise_type/${dataset}/utt2spk\n\n    python local/prepare_spk2enroll_vox.py \\\n      \"${data}/$noise_type/${dataset}/wav.scp\" \\\n      --outfile \"${data}\"/$noise_type/${dataset}/spk2enroll.json\n  done\nfi\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"Prepare the meta files for the val and test datasets\"\n\n  for dataset in dev test; do\n    echo \"Preparing files for\" $dataset\n\n    # Prepare the meta data for the mixed data\n    dataset_path=$mix_data_path/$dataset/mix_${noise_type}\n    mkdir -p \"${data}\"/$noise_type/${dataset}\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print $NF}' |\n      awk -v path=\"${dataset_path}\" '{print $1 , path \"/\" $1 , path \"/../s1/\" $1 , path \"/../s2/\" $1}' |\n      sed 's#.wav##' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/wav.scp |\n      awk -F[_-] '{print $0, $1,$4}' >\"${data}\"/$noise_type/${dataset}/utt2spk\n\n    # Prepare the meta data for single speakers\n    dataset_path=$mix_data_path/$dataset/s1\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s1/\" $NF, $0}' | sort -k1,1 >\"${data}\"/$noise_type/${dataset}/single.wav.scp\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's1' |\n      awk -F[-_/] '{print $0, $2}' >\"${data}\"/$noise_type/${dataset}/single.utt2spk\n\n    dataset_path=$mix_data_path/$dataset/s2\n    find ${dataset_path}/ -type f -name \"*.wav\" | awk -F/ '{print \"s2/\" $NF, $0}' | sort -k1,1 >>\"${data}\"/$noise_type/${dataset}/single.wav.scp\n\n    awk '{print $1}' \"${data}\"/$noise_type/${dataset}/single.wav.scp | grep 's2' |\n      awk -F[-_/] '{print $0, $5}' >>\"${data}\"/$noise_type/${dataset}/single.utt2spk\n  done\nfi\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  echo \"stage 3: Prepare LibriMix target-speaker enroll signal\"\n\n  for dset in dev test; do\n    python local/prepare_spk2enroll_librispeech.py \\\n      \"${mix_data_path}/${dset}\" \\\n      --is_librimix True \\\n      --outfile \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --audio_format wav\n  done\n\n  for dset in dev test; do\n    if [ $num_spk -eq 2 ]; then\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri2mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    else\n      url=\"https://raw.githubusercontent.com/BUTSpeechFIT/speakerbeam/main/egs/libri3mix/data/wav8k/min/${dset}/map_mixture2enrollment\"\n    fi\n\n    output_file=\"${data}/${noise_type}/${dset}/mixture2enrollment\"\n    wget -O \"$output_file\" \"$url\"\n  done\n\n  for dset in dev test; do\n    python local/prepare_librimix_enroll.py \\\n      \"${data}\"/$noise_type/${dset}/wav.scp \\\n      \"${data}\"/$noise_type/${dset}/spk2enroll.json \\\n      --mix2enroll \"${data}/${noise_type}/${dset}/mixture2enrollment\" \\\n      --num_spk ${num_spk} \\\n      --train False \\\n      --output_dir \"${data}\"/${noise_type}/${dset} \\\n      --outfile_prefix \"spk\"\n  done\nfi\n\nif [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then\n  echo \"Download the pre-trained speaker encoders (Resnet34 & Ecapa-TDNN512) from wespeaker...\"\n  mkdir wespeaker_models\n  wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip\n  unzip voxceleb_resnet34.zip -d wespeaker_models\n  wget https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip\n  unzip voxceleb_ECAPA512.zip -d wespeaker_models\nfi\n\n# if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then\n#   echo \"Prepare the speaker embeddings using wespeaker pretrained models\"\n#   for dataset in dev test train-100; do\n#     mkdir -p \"${data}\"/$noise_type/${dataset}\n#     echo \"Preparing files for\" $dataset\n#     wespeaker --task embedding_kaldi \\\n#               --wav_scp \"${data}\"/$noise_type/${dataset}/single.wav.scp \\\n#               --output_file \"${data}\"/$noise_type/${dataset}/embed \\\n#               -p wespeaker_models/voxceleb_resnet34 \\\n#               -g 0 # GPU idx\n#   done\n# fi\n\n\nif [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then\n  if [ ! -d \"${data}/raw_data/musan\" ]; then\n    mkdir -p ${data}/raw_data/musan\n    #\n    echo \"Downloading musan.tar.gz ...\"\n    echo \"This may take a long time. Thus we recommand you to download all archives above in your own way first.\"\n    wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${data}/raw_data\n    md5=$(md5sum ${data}/raw_data/musan.tar.gz | awk '{print $1}')\n    [ $md5 != \"0c472d4fc0c5141eca47ad1ffeb2a7df\" ] && echo \"Wrong md5sum of musan.tar.gz\" && exit 1\n\n    echo \"Decompress all archives ...\"\n    tar -xzvf ${data}/raw_data/musan.tar.gz -C ${data}/raw_data\n\n    rm -rf ${data}/raw_data/musan.tar.gz\n  fi\n\n  echo \"Prepare wav.scp for musan ...\"\n  mkdir -p ${data}/musan\n  find ${data}/raw_data/musan -name \"*.wav\" | awk -F\"/\" '{print $(NF-2)\"/\"$(NF-1)\"/\"$NF,$0}' >${data}/musan/wav.scp\n\n  # Convert all musan data to LMDB\n  echo \"conver musan data to LMDB ...\"\n  python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb\nfi"
  },
  {
    "path": "examples/voxceleb1/v2/local/prepare_librimix_enroll.py",
    "content": "import json\nimport random\nfrom pathlib import Path\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.utils.utils import str2bool\n\n\ndef prepare_librimix_enroll(wav_scp,\n                            spk2utts,\n                            output_dir,\n                            num_spk=2,\n                            train=True,\n                            prefix=\"enroll_spk\"):\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    with Path(spk2utts).open(\"r\", encoding=\"utf-8\") as f:\n        # {spkID: [(uid1, path1), (uid2, path2), ...]}\n        spk2utt = json.load(f)\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            uttIDs = mixtureID.split(\"_\")\n            for spk in range(num_spk):\n                uttID = uttIDs[spk]\n                spkID = uttID.split(\"-\")[0]\n                if train:\n                    # For training, we choose the auxiliary signal on the fly.\n                    # Thus, here we use the pattern f\"*{uttID} {spkID}\" to indicate it.  # noqa\n                    writer[f\"{prefix}{spk + 1}.enroll\"][\n                        mixtureID] = f\"*{uttID} {spkID}\"\n                else:\n                    enrollID = random.choice(spk2utt[spkID])[1]\n                    while enrollID == uttID and len(spk2utt[spkID]) > 1:\n                        enrollID = random.choice(spk2utt[spkID])[1]\n                    writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = enrollID\n\n\ndef prepare_librimix_enroll_v2(wav_scp,\n                               map_mix2enroll,\n                               output_dir,\n                               num_spk=2,\n                               prefix=\"spk\"):\n    # noqa E501: ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/egs/libri2mix/local/create_enrollment_csv_fixed.py\n    mixtures = []\n    with Path(wav_scp).open(\"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            if not line.strip():\n                continue\n            mixtureID = line.strip().split(maxsplit=1)[0]\n            mixtures.append(mixtureID)\n\n    mix2enroll = {}\n    with open(map_mix2enroll) as f:\n        for line in f:\n            mix_id, utt_id, enroll_id = line.strip().split()\n            sid = mix_id.split(\"_\").index(utt_id) + 1\n            mix2enroll[mix_id, f\"s{sid}\"] = enroll_id\n\n    with DatadirWriter(Path(output_dir)) as writer:\n        for mixtureID in mixtures:\n            # 100-121669-0004_3180-138043-0053\n            for spk in range(num_spk):\n                enroll_id = mix2enroll[mixtureID, f\"s{spk + 1}\"]\n                writer[f\"{prefix}{spk + 1}.enroll\"][mixtureID] = (enroll_id +\n                                                                  \".wav\")\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"wav_scp\",\n        type=str,\n        help=\"Path to the wav.scp file\",\n    )\n    parser.add_argument(\"spk2utts\",\n                        type=str,\n                        help=\"Path to the json file containing mapping \"\n                        \"from speaker ID to utterances\")\n    parser.add_argument(\n        \"--num_spk\",\n        type=int,\n        default=2,\n        choices=(2, 3),\n        help=\"Number of speakers in each mixture sample\",\n    )\n    parser.add_argument(\n        \"--train\",\n        type=str2bool,\n        default=True,\n        help=\"Whether is the training set or not\",\n    )\n    parser.add_argument(\n        \"--mix2enroll\",\n        type=str,\n        default=None,\n        help=\"Path to the downloaded map_mixture2enrollment file. \"\n        \"If `train` is False, this value is required.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"Random seed\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        required=True,\n        help=\"Path to the directory for storing output files\",\n    )\n    parser.add_argument(\n        \"--outfile_prefix\",\n        type=str,\n        default=\"spk\",\n        help=\"Prefix of the output files\",\n    )\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n\n    if args.train:\n        prepare_librimix_enroll(\n            args.wav_scp,\n            args.spk2utts,\n            args.output_dir,\n            num_spk=args.num_spk,\n            train=args.train,\n            prefix=args.outfile_prefix,\n        )\n    else:\n        prepare_librimix_enroll_v2(\n            args.wav_scp,\n            args.mix2enroll,\n            args.output_dir,\n            num_spk=args.num_spk,\n            prefix=args.outfile_prefix,\n        )\n"
  },
  {
    "path": "examples/voxceleb1/v2/local/prepare_spk2enroll_librispeech.py",
    "content": "import json\nfrom collections import defaultdict\nfrom itertools import chain\nfrom pathlib import Path\n\nfrom wesep.utils.utils import str2bool\n\n\ndef get_spk2utt_vox1(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        path = Path(path)\n        for subdir in path.iterdir():\n            if subdir.is_dir():\n                readerID = subdir.name\n                for audio in subdir.rglob(\"*.{}\".format(audio_format)):\n                    spk2utt[readerID].append(str(audio))\n\n    return spk2utt\n\n\ndef get_spk2utt(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in Path(path).rglob(\"*.{}\".format(audio_format)):\n            readerID = audio.parent.parent.stem\n            uid = audio.stem\n            assert uid.split(\"-\")[0] == readerID, audio\n            spk2utt[readerID].append((uid, str(audio)))\n\n    return spk2utt\n\n\ndef get_spk2utt_librimix(paths, audio_format=\"flac\"):\n    spk2utt = defaultdict(list)\n    for path in paths:\n        for audio in chain(\n                Path(path).rglob(\"s1/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s2/*.{}\".format(audio_format)),\n                Path(path).rglob(\"s3/*.{}\".format(audio_format)),\n        ):\n            spk_idx = int(audio.parent.stem[1:]) - 1\n            mix_uid = audio.stem\n            uid = mix_uid.split(\"_\")[spk_idx]\n            sid = uid.split(\"-\")[0]\n            spk2utt[sid].append((uid, str(audio)))\n\n    return spk2utt\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"audio_paths\",\n        type=str,\n        nargs=\"+\",\n        help=\"Paths to Librispeech subsets\",\n    )\n    parser.add_argument(\n        \"--is_librimix\",\n        type=str2bool,\n        default=False,\n        help=\"Whether the provided audio_paths points to LibriMix data\",\n    )\n    parser.add_argument(\n        \"--is_vox1\",\n        type=str2bool,\n        default=False,\n        help=\"Whether the provided audio_paths points to vox1 data\",\n    )\n    parser.add_argument(\n        \"--outfile\",\n        type=str,\n        default=\"spk2utt_tse.json\",\n        help=\"Path to the output spk2utt json file\",\n    )\n    parser.add_argument(\"--audio_format\", type=str, default=\"flac\")\n    args = parser.parse_args()\n\n    if args.is_librimix:\n        # use clean sources from LibriMix as enrollment\n        spk2utt = get_spk2utt_librimix(args.audio_paths,\n                                       audio_format=args.audio_format)\n    elif args.is_vox1:\n        spk2utt = get_spk2utt_vox1(args.audio_paths,\n                                   audio_format=args.audio_format)\n    else:\n        # use Librispeech as enrollment\n        spk2utt = get_spk2utt(args.audio_paths, audio_format=args.audio_format)\n    outfile = Path(args.outfile)\n    outfile.parent.mkdir(parents=True, exist_ok=True)\n    with outfile.open(\"w\", encoding=\"utf-8\") as f:\n        json.dump(spk2utt, f, indent=4)\n"
  },
  {
    "path": "examples/voxceleb1/v2/local/prepare_spk2enroll_vox.py",
    "content": "import json\nfrom collections import defaultdict\nfrom pathlib import Path\n\n\ndef get_spk2utt_from_wavscp(wav_scp_path):\n    spk2utt = defaultdict(list)\n    with open(wav_scp_path, \"r\") as readin:\n        for line in readin:\n            speaker_id = line.split(\"/\")[0]\n            uid, audio_path = line.strip().split()\n            spk2utt[speaker_id].append((uid, str(audio_path)))\n\n    return spk2utt\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"wav_scp_path\",\n        type=str,\n        help=\"Paths to Librispeech subsets\",\n    )\n    parser.add_argument(\n        \"--outfile\",\n        type=str,\n        default=\"spk2utt_tse.json\",\n        help=\"Path to the output spk2utt json file\",\n    )\n    args = parser.parse_args()\n\n    spk2utt = get_spk2utt_from_wavscp(args.wav_scp_path)\n\n    outfile = Path(args.outfile)\n    outfile.parent.mkdir(parents=True, exist_ok=True)\n    with outfile.open(\"w\", encoding=\"utf-8\") as f:\n        json.dump(spk2utt, f)\n"
  },
  {
    "path": "examples/voxceleb1/v2/path.sh",
    "content": "export PATH=$PWD:$PATH\n\n# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C\nexport PYTHONIOENCODING=UTF-8\nexport PYTHONPATH=../../../:$PYTHONPATH\n"
  },
  {
    "path": "examples/voxceleb1/v2/run_online.sh",
    "content": "#!/bin/bash\n\n# Copyright 2023 Shuai Wang (wangshuai@cuhk.edu.cn)\n\n. ./path.sh || exit 1\n\nstage=-1\nstop_stage=-1\n\ndata=data\nfs=16k\nmin_max=min\nnoise_type=\"clean\"\ndata_type=\"shard\"  # shard/raw\nVox1_dir=/YourPATH/voxceleb/VoxCeleb1/wav\nLibri2Mix_dir=/YourPATH/librimix/Libri2Mix          #For validate and test the TSE model.\nmix_data_path=\"${Libri2Mix_dir}/wav${fs}/${min_max}\"\n\ngpus=\"[0,1]\"\nnum_avg=10\ncheckpoint=\nconfig=confs/bsrnn_online.yaml\nexp_dir=exp/BSRNN_Online/no_spk_transform_multiply\nsave_results=true\n\n\n. tools/parse_options.sh || exit 1\n\n\nif [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then\n  echo \"Prepare datasets ...\"\n  ./local/prepare_data.sh --single_data_path ${Vox1_dir} \\\n    --mix_data_path ${mix_data_path} \\\n    --data ${data} \\\n    --noise_type ${noise_type} \\\n    --stage 1 \\\n    --stop-stage 4\nfi\n\ndata=${data}/${noise_type}\n\nif [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then\n  echo \"Covert train and test data to ${data_type}...\"\n  for dset in train-vox1; do\n    python tools/make_shard_online.py --num_utts_per_shard 1000 \\\n        --num_threads 16 \\\n        --prefix shards \\\n        --shuffle \\\n        ${data}/$dset/wav.scp ${data}/$dset/utt2spk \\\n        ${data}/$dset/shards_online ${data}/$dset/shard_online.list\n  done\n  for dset in dev test; do\n    python tools/make_shard_list_premix.py --num_utts_per_shard 1000 \\\n      --num_threads 16 \\\n      --prefix shards \\\n      --shuffle \\\n      ${data}/$dset/wav.scp ${data}/$dset/utt2spk \\\n      ${data}/$dset/shards ${data}/$dset/shard.list\n  done\nfi\n\nif [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/latest_checkpoint.pt\" ]; then\n  checkpoint=\"${exp_dir}/models/latest_checkpoint.pt\"\nfi\n\nif [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then\n  #  rm -r $exp_dir\n  echo \"Start training ...\"\n  export OMP_NUM_THREADS=8\n  num_gpus=$(echo $gpus | awk -F ',' '{print NF}')\n  if [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/latest_checkpoint.pt\" ]; then\n    checkpoint=\"${exp_dir}/models/latest_checkpoint.pt\"\n  fi\n  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \\\n    wesep/bin/train.py --config $config \\\n    --exp_dir ${exp_dir} \\\n    --gpus $gpus \\\n    --num_avg ${num_avg} \\\n    --data_type \"${data_type}\" \\\n    --train_data ${data}/train-vox1/${data_type}_online.list \\\n    --train_utt2spk ${data}/train-vox1/utt2spk \\\n    --train_spk2utt ${data}/train-vox1/spk2enroll.json \\\n    --val_data ${data}/dev/${data_type}.list \\\n    --val_spk2utt ${data}/dev/single.wav.scp \\\n    --val_spk1_enroll ${data}/dev/spk1.enroll \\\n    --val_spk2_enroll ${data}/dev/spk2.enroll \\\n    ${checkpoint:+--checkpoint $checkpoint}\nfi\n\n\nif [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then\n  echo \"Do model average ...\"\n  avg_model=$exp_dir/models/avg_best_model.pt\n  python wesep/bin/average_model.py \\\n    --dst_model $avg_model \\\n    --src_path $exp_dir/models \\\n    --num ${num_avg} \\\n    --mode best \\\n    --epochs \"138,141\"\nfi\nif [ -z \"${checkpoint}\" ] && [ -f \"${exp_dir}/models/avg_best_model.pt\" ]; then\n  checkpoint=\"${exp_dir}/models/avg_best_model.pt\"\nfi\n\nif [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then\n  python wesep/bin/infer.py --config $config \\\n      --gpus 0 \\\n      --exp_dir ${exp_dir} \\\n      --data_type \"${data_type}\" \\\n      --test_data ${data}/test/${data_type}.list \\\n      --test_spk1_enroll ${data}/test/spk1.enroll \\\n      --test_spk2_enroll ${data}/test/spk2.enroll \\\n      --test_spk2utt ${data}/test/single.wav.scp \\\n      --save_wav ${save_results} \\\n      ${checkpoint:+--checkpoint $checkpoint}\nfi\n\n#./run.sh --stage 4 --stop-stage 4 --config exp/BSRNN/train_clean_460/multiply_no_spk_transform/config.yaml  --exp_dir exp/BSRNN/train_clean_460/multiply_no_spk_transform/ --checkpoint exp/BSRNN/train_clean_460/multiply_no_spk_transform/models/avg_best_model.pt\n"
  },
  {
    "path": "requirements.txt",
    "content": "fast_bss_eval==0.1.4\nfire==0.4.0\njoblib==1.1.0\nkaldiio==2.18.0\nlibrosa==0.10.1\nlmdb==1.3.0\nmatplotlib==3.5.1\nmir_eval==0.7\nsilero-vad==5.1.2\nnumpy==1.22.4\npesq==0.0.4\npystoi==0.3.3\nPyYAML==6.0\nRequests==2.31.0\nscipy==1.7.3\nsoundfile==0.12.1\ntableprint==0.9.1\nthop==0.1.1.post2209072238\ntorchnet==0.0.4\ntqdm==4.64.0\nflake8==3.8.2\nflake8-bugbear\nflake8-comprehensions\nflake8-executable\nflake8-pyi==20.5.0\nauraloss\ntorchmetrics==1.2.0\nh5py\npre-commit==3.5.0\n"
  },
  {
    "path": "runtime/.gitignore",
    "content": "fc_base\nbuild*\n"
  },
  {
    "path": "runtime/CMakeLists.txt",
    "content": "\ncmake_minimum_required(VERSION 3.14)\nproject(wesep VERSION 0.1)\n\noption(CXX11_ABI \"whether to use CXX11_ABI libtorch\" OFF)\n\nset(CMAKE_VERBOSE_MAKEFILE OFF)\n\ninclude(FetchContent)\nset(FETCHCONTENT_QUIET OFF)\nget_filename_component(fc_base \"fc_base\" REALPATH BASE_DIR \"${CMAKE_CURRENT_SOURCE_DIR}\")\nset(FETCHCONTENT_BASE_DIR ${fc_base})\n\nlist(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)\n\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC\")\n\ninclude(libtorch)\ninclude(glog)\ninclude(gflags)\ninclude_directories(${CMAKE_CURRENT_SOURCE_DIR})\n\n# build all libraries\nadd_subdirectory(utils)\nadd_subdirectory(frontend)\nadd_subdirectory(separate)\nadd_subdirectory(bin)\n"
  },
  {
    "path": "runtime/README.md",
    "content": "# Libtorch backend on wesep\n\n\n* Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above.\n\n``` sh\nmkdir build && cd build\ncmake ..\ncmake --build .\n```\n\n* Testing.\n\n1. the RTF(real time factor) is shown in the console, and outputs will be written to the wav file.\n``` sh\nexport GLOG_logtostderr=1\nexport GLOG_v=2\n\n./build/bin/separate_main \\\n  --wav_scp $wav_scp \\\n  --model /path/to/model.zip \\\n  --output_dir /output/dir/\n```\n"
  },
  {
    "path": "runtime/bin/CMakeLists.txt",
    "content": "add_executable(separate_main separate_main.cc)\ntarget_link_libraries(separate_main PUBLIC frontend separate)\n"
  },
  {
    "path": "runtime/bin/separate_main.cc",
    "content": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <fstream>\n#include <iostream>\n#include <string>\n\n#include \"gflags/gflags.h\"\n#include \"glog/logging.h\"\n\n#include \"frontend/wav.h\"\n#include \"separate/separate_engine.h\"\n#include \"utils/timer.h\"\n#include \"utils/utils.h\"\n\nDEFINE_string(wav_path, \"\", \"the path of mixing audio.\");\nDEFINE_string(spk1_emb, \"\", \"the emb of spk1.\");\nDEFINE_string(spk2_emb, \"\", \"the emb of spk2.\");\nDEFINE_string(wav_scp, \"\", \"input wav scp.\");\nDEFINE_string(model, \"\", \"the path of wesep model.\");\nDEFINE_string(output_dir, \"\", \"output path.\");\nDEFINE_int32(sample_rate, 16000, \"sample rate\");\nDEFINE_int32(feat_dim, 80, \"fbank feature dimension.\");\n\nint main(int argc, char* argv[]) {\n  gflags::ParseCommandLineFlags(&argc, &argv, false);\n  google::InitGoogleLogging(argv[0]);\n\n  std::vector<std::vector<std::string>> waves;\n  if (!FLAGS_wav_path.empty() && !FLAGS_spk1_emb.empty() &&\n      !FLAGS_spk2_emb.empty()) {\n    waves.push_back(std::vector<std::string>(\n        {\"test\", FLAGS_wav_path, FLAGS_spk1_emb, FLAGS_spk2_emb}));\n  } else {\n    std::ifstream wav_scp(FLAGS_wav_scp);\n    std::string line;\n    while (getline(wav_scp, line)) {\n      std::vector<std::string> strs;\n      wesep::SplitString(line, &strs);\n      CHECK_EQ(strs.size(), 4);\n      waves.push_back(\n          std::vector<std::string>({strs[0], strs[1], strs[2], strs[3]}));\n    }\n    if (waves.empty()) {\n      LOG(FATAL) << \"Please provide non-empty wav scp.\";\n    }\n  }\n\n  if (FLAGS_output_dir.empty()) {\n    LOG(FATAL) << \"Invalid output path.\";\n  }\n\n  int g_total_waves_dur = 0;\n  int g_total_process_time = 0;\n\n  auto model = std::make_shared<wesep::SeparateEngine>(\n      FLAGS_model, FLAGS_feat_dim, FLAGS_sample_rate);\n\n  for (auto wav : waves) {\n    // mix wav\n    wenet::WavReader wav_reader(wav[1]);\n    CHECK_EQ(wav_reader.sample_rate(), 16000);\n    int16_t* mix_wav_data = const_cast<int16_t*>(wav_reader.data());\n\n    int wave_dur =\n        static_cast<int>(static_cast<float>(wav_reader.num_sample()) /\n                         wav_reader.sample_rate() * 1000);\n\n    // spk1\n    wenet::WavReader spk1_reader(wav[2]);\n    CHECK_EQ(spk1_reader.sample_rate(), 16000);\n    int16_t* spk1_data = const_cast<int16_t*>(spk1_reader.data());\n\n    // spk2\n    wenet::WavReader spk2_reader(wav[3]);\n    CHECK_EQ(spk2_reader.sample_rate(), 16000);\n    int16_t* spk2_data = const_cast<int16_t*>(spk2_reader.data());\n\n    // forward\n    std::vector<std::vector<float>> outputs;\n    int process_time = 0;\n    wenet::Timer timer;\n    model->ForwardFunc(\n        std::vector<int16_t>(mix_wav_data,\n                             mix_wav_data + wav_reader.num_sample()),\n        spk1_data, spk2_data,\n        std::min(spk1_reader.num_sample(), spk2_reader.num_sample()), &outputs);\n    process_time = timer.Elapsed();\n    LOG(INFO) << \"process: \" << wav[0]\n              << \" RTF: \" << static_cast<float>(process_time) / wave_dur;\n    // 保存音频\n    wenet::WriteWavFile(outputs[0].data(), outputs[0].size(), 16000,\n                        FLAGS_output_dir + \"/\" + wav[0] + \"-spk1.wav\");\n    wenet::WriteWavFile(outputs[1].data(), outputs[1].size(), 16000,\n                        FLAGS_output_dir + \"/\" + wav[0] + \"-spk2.wav\");\n    g_total_process_time += process_time;\n    g_total_waves_dur += wave_dur;\n  }\n  LOG(INFO) << \"Total: process \" << g_total_waves_dur << \"ms audio taken \"\n            << g_total_process_time << \"ms.\";\n  LOG(INFO) << \"RTF: \" << std::setprecision(4)\n            << static_cast<float>(g_total_process_time) / g_total_waves_dur;\n  return 0;\n}\n"
  },
  {
    "path": "runtime/cmake/gflags.cmake",
    "content": "FetchContent_Declare(gflags\n  URL      https://github.com/gflags/gflags/archive/v2.2.2.zip\n  URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5\n)\nFetchContent_MakeAvailable(gflags)\ninclude_directories(${gflags_BINARY_DIR}/include)\n"
  },
  {
    "path": "runtime/cmake/glog.cmake",
    "content": "FetchContent_Declare(glog\n  URL      https://github.com/google/glog/archive/v0.4.0.zip\n  URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc\n)\nFetchContent_MakeAvailable(glog)\ninclude_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR})\n"
  },
  {
    "path": "runtime/cmake/libtorch.cmake",
    "content": "if(${CMAKE_SYSTEM_NAME} STREQUAL \"Linux\")\n  if(CXX11_ABI)\n    set(LIBTORCH_URL \"https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip\")\n    set(URL_HASH \"SHA256=d52f63577a07adb0bfd6d77c90f7da21896e94f71eb7dcd55ed7835ccb3b2b59\")\n  else()\n    set(LIBTORCH_URL \"https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip\")\n    set(URL_HASH \"SHA256=bee1b7be308792aa60fc95a4f5274d9658cb7248002d0e333d49eb81ec88430c\")\n  endif()\nelse()\n  message(FATAL_ERROR \"Unsported System '${CMAKE_SYSTEM_NAME}' (expected 'Linux')\")\nendif()\n\nFetchContent_Declare(libtorch\n  URL         ${LIBTORCH_URL}\n  URL_HASH    ${URL_HASH}\n)\nFetchContent_MakeAvailable(libtorch)\nfind_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)\ninclude_directories(${TORCH_INCLUDE_DIRS})\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG\")\n"
  },
  {
    "path": "runtime/frontend/CMakeLists.txt",
    "content": "add_library(frontend STATIC\n  feature_pipeline.cc\n  fft.cc\n)\ntarget_link_libraries(frontend PUBLIC utils)\n"
  },
  {
    "path": "runtime/frontend/fbank.h",
    "content": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef FRONTEND_FBANK_H_\n#define FRONTEND_FBANK_H_\n\n#include <cstring>\n#include <limits>\n#include <random>\n#include <utility>\n#include <vector>\n\n#include \"frontend/fft.h\"\n#include \"glog/logging.h\"\n\nnamespace wenet {\n\n// This code is based on kaldi Fbank implentation, please see\n// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc\nclass Fbank {\n public:\n  Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)\n      : num_bins_(num_bins),\n        sample_rate_(sample_rate),\n        frame_length_(frame_length),\n        frame_shift_(frame_shift),\n        use_log_(true),\n        remove_dc_offset_(true),\n        generator_(0),\n        distribution_(0, 1.0),\n        dither_(0.0) {\n    fft_points_ = UpperPowerOfTwo(frame_length_);\n    // generate bit reversal table and trigonometric function table\n    const int fft_points_4 = fft_points_ / 4;\n    bitrev_.resize(fft_points_);\n    sintbl_.resize(fft_points_ + fft_points_4);\n    make_sintbl(fft_points_, sintbl_.data());\n    make_bitrev(fft_points_, bitrev_.data());\n\n    int num_fft_bins = fft_points_ / 2;\n    float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;\n    int low_freq = 20, high_freq = sample_rate_ / 2;\n    float mel_low_freq = MelScale(low_freq);\n    float mel_high_freq = MelScale(high_freq);\n    float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);\n    bins_.resize(num_bins_);\n    center_freqs_.resize(num_bins_);\n    for (int bin = 0; bin < num_bins; ++bin) {\n      float left_mel = mel_low_freq + bin * mel_freq_delta,\n            center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,\n            right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;\n      center_freqs_[bin] = InverseMelScale(center_mel);\n      std::vector<float> this_bin(num_fft_bins);\n      int first_index = -1, last_index = -1;\n      for (int i = 0; i < num_fft_bins; ++i) {\n        float freq = (fft_bin_width * i);  // Center frequency of this fft\n        // bin.\n        float mel = MelScale(freq);\n        if (mel > left_mel && mel < right_mel) {\n          float weight;\n          if (mel <= center_mel)\n            weight = (mel - left_mel) / (center_mel - left_mel);\n          else\n            weight = (right_mel - mel) / (right_mel - center_mel);\n          this_bin[i] = weight;\n          if (first_index == -1) first_index = i;\n          last_index = i;\n        }\n      }\n      CHECK(first_index != -1 && last_index >= first_index);\n      bins_[bin].first = first_index;\n      int size = last_index + 1 - first_index;\n      bins_[bin].second.resize(size);\n      for (int i = 0; i < size; ++i) {\n        bins_[bin].second[i] = this_bin[first_index + i];\n      }\n    }\n\n    // NOTE(cdliang): add hamming window\n    hamming_window_.resize(frame_length_);\n    double a = M_2PI / (frame_length - 1);\n    for (int i = 0; i < frame_length; i++) {\n      double i_fl = static_cast<double>(i);\n      hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl);\n    }\n  }\n\n  void set_use_log(bool use_log) { use_log_ = use_log; }\n\n  void set_remove_dc_offset(bool remove_dc_offset) {\n    remove_dc_offset_ = remove_dc_offset;\n  }\n\n  void set_dither(float dither) { dither_ = dither; }\n\n  int num_bins() const { return num_bins_; }\n\n  static inline float InverseMelScale(float mel_freq) {\n    return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);\n  }\n\n  static inline float MelScale(float freq) {\n    return 1127.0f * logf(1.0f + freq / 700.0f);\n  }\n\n  static int UpperPowerOfTwo(int n) {\n    return static_cast<int>(pow(2, ceil(log(n) / log(2))));\n  }\n\n  // preemphasis\n  void PreEmphasis(float coeff, std::vector<float>* data) const {\n    if (coeff == 0.0) return;\n    for (int i = data->size() - 1; i > 0; i--)\n      (*data)[i] -= coeff * (*data)[i - 1];\n    (*data)[0] -= coeff * (*data)[0];\n  }\n\n  // add hamming window\n  void Hamming(std::vector<float>* data) const {\n    CHECK_GE(data->size(), hamming_window_.size());\n    for (size_t i = 0; i < hamming_window_.size(); ++i) {\n      (*data)[i] *= hamming_window_[i];\n    }\n  }\n\n  // Compute fbank feat, return num frames\n  int Compute(const std::vector<float>& wave,\n              std::vector<std::vector<float>>* feat) {\n    int num_samples = wave.size();\n    if (num_samples < frame_length_) return 0;\n    int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);\n    feat->resize(num_frames);\n    std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);\n    std::vector<float> power(fft_points_ / 2);\n    for (int i = 0; i < num_frames; ++i) {\n      std::vector<float> data(wave.data() + i * frame_shift_,\n                              wave.data() + i * frame_shift_ + frame_length_);\n      // optional add noise\n      if (dither_ != 0.0) {\n        for (size_t j = 0; j < data.size(); ++j)\n          data[j] += dither_ * distribution_(generator_);\n      }\n      // optinal remove dc offset\n      if (remove_dc_offset_) {\n        float mean = 0.0;\n        for (size_t j = 0; j < data.size(); ++j) mean += data[j];\n        mean /= data.size();\n        for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;\n      }\n\n      PreEmphasis(0.97, &data);\n      // Povey(&data);\n      Hamming(&data);\n      // copy data to fft_real\n      memset(fft_img.data(), 0, sizeof(float) * fft_points_);\n      memset(fft_real.data() + frame_length_, 0,\n             sizeof(float) * (fft_points_ - frame_length_));\n      memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);\n      fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),\n          fft_points_);\n      // power\n      for (int j = 0; j < fft_points_ / 2; ++j) {\n        power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];\n      }\n\n      (*feat)[i].resize(num_bins_);\n      // cepstral coefficients, triangle filter array\n      for (int j = 0; j < num_bins_; ++j) {\n        float mel_energy = 0.0;\n        int s = bins_[j].first;\n        for (size_t k = 0; k < bins_[j].second.size(); ++k) {\n          mel_energy += bins_[j].second[k] * power[s + k];\n        }\n        // optional use log\n        if (use_log_) {\n          if (mel_energy < std::numeric_limits<float>::epsilon())\n            mel_energy = std::numeric_limits<float>::epsilon();\n          mel_energy = logf(mel_energy);\n        }\n\n        (*feat)[i][j] = mel_energy;\n        // printf(\"%f \", mel_energy);\n      }\n      // printf(\"\\n\");\n    }\n    return num_frames;\n  }\n\n private:\n  int num_bins_;\n  int sample_rate_;\n  int frame_length_, frame_shift_;\n  int fft_points_;\n  bool use_log_;\n  bool remove_dc_offset_;\n  std::vector<float> center_freqs_;\n  std::vector<std::pair<int, std::vector<float>>> bins_;\n  std::vector<float> hamming_window_;\n  std::default_random_engine generator_;\n  std::normal_distribution<float> distribution_;\n  float dither_;\n\n  // bit reversal table\n  std::vector<int> bitrev_;\n  // trigonometric function table\n  std::vector<float> sintbl_;\n};\n\n}  // namespace wenet\n\n#endif  // FRONTEND_FBANK_H_\n"
  },
  {
    "path": "runtime/frontend/feature_pipeline.cc",
    "content": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"frontend/feature_pipeline.h\"\n\n#include <algorithm>\n#include <utility>\n\nnamespace wenet {\n\nFeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)\n    : config_(config),\n      feature_dim_(config.num_bins),\n      fbank_(config.num_bins, config.sample_rate, config.frame_length,\n             config.frame_shift),\n      num_frames_(0),\n      input_finished_(false) {}\n\nvoid FeaturePipeline::AcceptWaveform(const std::vector<float>& wav) {\n  std::vector<std::vector<float>> feats;\n  std::vector<float> waves;\n  waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());\n  waves.insert(waves.end(), wav.begin(), wav.end());\n  int num_frames = fbank_.Compute(waves, &feats);\n  for (size_t i = 0; i < feats.size(); ++i) {\n    feature_queue_.Push(std::move(feats[i]));\n  }\n  num_frames_ += num_frames;\n\n  int left_samples = waves.size() - config_.frame_shift * num_frames;\n  remained_wav_.resize(left_samples);\n  std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),\n            remained_wav_.begin());\n  // We are still adding wave, notify input is not finished\n  finish_condition_.notify_one();\n}\n\nvoid FeaturePipeline::AcceptWaveform(const std::vector<int16_t>& wav) {\n  std::vector<float> float_wav(wav.size());\n  for (size_t i = 0; i < wav.size(); i++) {\n    float_wav[i] = static_cast<float>(wav[i]);\n  }\n  this->AcceptWaveform(float_wav);\n}\n\nvoid FeaturePipeline::set_input_finished() {\n  CHECK(!input_finished_);\n  {\n    std::lock_guard<std::mutex> lock(mutex_);\n    input_finished_ = true;\n  }\n  finish_condition_.notify_one();\n}\n\nbool FeaturePipeline::ReadOne(std::vector<float>* feat) {\n  if (!feature_queue_.Empty()) {\n    *feat = std::move(feature_queue_.Pop());\n    return true;\n  } else {\n    std::unique_lock<std::mutex> lock(mutex_);\n    while (!input_finished_) {\n      // This will release the lock and wait for notify_one()\n      // from AcceptWaveform() or set_input_finished()\n      finish_condition_.wait(lock);\n      if (!feature_queue_.Empty()) {\n        *feat = std::move(feature_queue_.Pop());\n        return true;\n      }\n    }\n    CHECK(input_finished_);\n    // Double check queue.empty, see issue#893 for detailed discussions.\n    if (!feature_queue_.Empty()) {\n      *feat = std::move(feature_queue_.Pop());\n      return true;\n    } else {\n      return false;\n    }\n  }\n}\n\nbool FeaturePipeline::Read(int num_frames,\n                           std::vector<std::vector<float>>* feats) {\n  feats->clear();\n  std::vector<float> feat;\n  while (feats->size() < num_frames) {\n    if (ReadOne(&feat)) {\n      feats->push_back(std::move(feat));\n    } else {\n      return false;\n    }\n  }\n  return true;\n}\n\nvoid FeaturePipeline::Reset() {\n  input_finished_ = false;\n  num_frames_ = 0;\n  remained_wav_.clear();\n  feature_queue_.Clear();\n}\n\n}  // namespace wenet\n"
  },
  {
    "path": "runtime/frontend/feature_pipeline.h",
    "content": "// Copyright (c) 2017 Personal (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef FRONTEND_FEATURE_PIPELINE_H_\n#define FRONTEND_FEATURE_PIPELINE_H_\n\n#include <mutex>\n#include <queue>\n#include <string>\n#include <vector>\n\n#include \"frontend/fbank.h\"\n#include \"glog/logging.h\"\n#include \"utils/blocking_queue.h\"\n\nnamespace wenet {\n\nstruct FeaturePipelineConfig {\n  int num_bins;\n  int sample_rate;\n  int frame_length;\n  int frame_shift;\n  FeaturePipelineConfig(int num_bins, int sample_rate)\n      : num_bins(num_bins),                  // 80 dim fbank\n        sample_rate(sample_rate) {           // 16k sample rate\n    frame_length = sample_rate / 1000 * 25;  // frame length 25ms\n    frame_shift = sample_rate / 1000 * 10;   // frame shift 10ms\n  }\n\n  void Info() const {\n    LOG(INFO) << \"feature pipeline config\"\n              << \" num_bins \" << num_bins << \" frame_length \" << frame_length\n              << \" frame_shift \" << frame_shift;\n  }\n};\n\n// Typically, FeaturePipeline is used in two threads: one thread A calls\n// AcceptWaveform() to add raw wav data and set_input_finished() to notice\n// the end of input wav, another thread B (decoder thread) calls Read() to\n// consume features.So a BlockingQueue is used to make this class thread safe.\n\n// The Read() is designed as a blocking method when there is no feature\n// in feature_queue_ and the input is not finished.\n\nclass FeaturePipeline {\n public:\n  explicit FeaturePipeline(const FeaturePipelineConfig& config);\n\n  // The feature extraction is done in AcceptWaveform().\n  void AcceptWaveform(const std::vector<float>& wav);\n  void AcceptWaveform(const std::vector<int16_t>& wav);\n\n  // Current extracted frames number.\n  int num_frames() const { return num_frames_; }\n  int feature_dim() const { return feature_dim_; }\n  const FeaturePipelineConfig& config() const { return config_; }\n\n  // The caller should call this method when speech input is end.\n  // Never call AcceptWaveform() after calling set_input_finished() !\n  void set_input_finished();\n  bool input_finished() const { return input_finished_; }\n\n  // Return False if input is finished and no feature could be read.\n  // Return True if a feature is read.\n  // This function is a blocking method. It will block the thread when\n  // there is no feature in feature_queue_ and the input is not finished.\n  bool ReadOne(std::vector<float>* feat);\n\n  // Read #num_frames frame features.\n  // Return False if less then #num_frames features are read and the\n  // input is finished.\n  // Return True if #num_frames features are read.\n  // This function is a blocking method when there is no feature\n  // in feature_queue_ and the input is not finished.\n  bool Read(int num_frames, std::vector<std::vector<float>>* feats);\n\n  void Reset();\n  bool IsLastFrame(int frame) const {\n    return input_finished_ && (frame == num_frames_ - 1);\n  }\n\n  int NumQueuedFrames() const { return feature_queue_.Size(); }\n\n private:\n  const FeaturePipelineConfig& config_;\n  int feature_dim_;\n  Fbank fbank_;\n\n  BlockingQueue<std::vector<float>> feature_queue_;\n  int num_frames_;\n  bool input_finished_;\n\n  // The feature extraction is done in AcceptWaveform().\n  // This wavefrom sample points are consumed by frame size.\n  // The residual wavefrom sample points after framing are\n  // kept to be used in next AcceptWaveform() calling.\n  std::vector<float> remained_wav_;\n\n  // Used to block the Read when there is no feature in feature_queue_\n  // and the input is not finished.\n  mutable std::mutex mutex_;\n  std::condition_variable finish_condition_;\n};\n\n}  // namespace wenet\n\n#endif  // FRONTEND_FEATURE_PIPELINE_H_\n"
  },
  {
    "path": "runtime/frontend/fft.cc",
    "content": "// Copyright (c) 2016 HR\n\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"frontend/fft.h\"\n\nnamespace wenet {\n\nvoid make_sintbl(int n, float* sintbl) {\n  int i, n2, n4, n8;\n  float c, s, dc, ds, t;\n\n  n2 = n / 2;\n  n4 = n / 4;\n  n8 = n / 8;\n  t = sin(M_PI / n);\n  dc = 2 * t * t;\n  ds = sqrt(dc * (2 - dc));\n  t = 2 * dc;\n  c = sintbl[n4] = 1;\n  s = sintbl[0] = 0;\n  for (i = 1; i < n8; ++i) {\n    c -= dc;\n    dc += t * c;\n    s += ds;\n    ds -= t * s;\n    sintbl[i] = s;\n    sintbl[n4 - i] = c;\n  }\n  if (n8 != 0) sintbl[n8] = sqrt(0.5);\n  for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i];\n  for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i];\n}\n\nvoid make_bitrev(int n, int* bitrev) {\n  int i, j, k, n2;\n\n  n2 = n / 2;\n  i = j = 0;\n  for (;;) {\n    bitrev[i] = j;\n    if (++i >= n) break;\n    k = n2;\n    while (k <= j) {\n      j -= k;\n      k /= 2;\n    }\n    j += k;\n  }\n}\n\n// bitrev: bit reversal table\n// sintbl: trigonometric function table\n// x:real part\n// y:image part\n// n: fft length\nint fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) {\n  int i, j, k, ik, h, d, k2, n4, inverse;\n  float t, s, c, dx, dy;\n\n  /* preparation */\n  if (n < 0) {\n    n = -n;\n    inverse = 1; /* inverse transform */\n  } else {\n    inverse = 0;\n  }\n  n4 = n / 4;\n  if (n == 0) {\n    return 0;\n  }\n\n  /* bit reversal */\n  for (i = 0; i < n; ++i) {\n    j = bitrev[i];\n    if (i < j) {\n      t = x[i];\n      x[i] = x[j];\n      x[j] = t;\n      t = y[i];\n      y[i] = y[j];\n      y[j] = t;\n    }\n  }\n\n  /* transformation */\n  for (k = 1; k < n; k = k2) {\n    h = 0;\n    k2 = k + k;\n    d = n / k2;\n    for (j = 0; j < k; ++j) {\n      c = sintbl[h + n4];\n      if (inverse)\n        s = -sintbl[h];\n      else\n        s = sintbl[h];\n      for (i = j; i < n; i += k2) {\n        ik = i + k;\n        dx = s * y[ik] + c * x[ik];\n        dy = c * y[ik] - s * x[ik];\n        x[ik] = x[i] - dx;\n        x[i] += dx;\n        y[ik] = y[i] - dy;\n        y[i] += dy;\n      }\n      h += d;\n    }\n  }\n  if (inverse) {\n    /* divide by n in case of the inverse transformation */\n    for (i = 0; i < n; ++i) {\n      x[i] /= n;\n      y[i] /= n;\n    }\n  }\n  return 0; /* finished successfully */\n}\n\n}  // namespace wenet\n"
  },
  {
    "path": "runtime/frontend/fft.h",
    "content": "// Copyright (c) 2016 HR\n\n#ifndef FRONTEND_FFT_H_\n#define FRONTEND_FFT_H_\n\n#ifndef M_PI\n#define M_PI 3.1415926535897932384626433832795\n#endif\n#ifndef M_2PI\n#define M_2PI 6.283185307179586476925286766559005\n#endif\n\nnamespace wenet {\n\n// Fast Fourier Transform\n\nvoid make_sintbl(int n, float* sintbl);\n\nvoid make_bitrev(int n, int* bitrev);\n\nint fft(const int* bitrev, const float* sintbl, float* x, float* y, int n);\n\n}  // namespace wenet\n\n#endif  // FRONTEND_FFT_H_\n"
  },
  {
    "path": "runtime/frontend/wav.h",
    "content": "// Copyright (c) 2016 Personal (Binbin Zhang)\n// Created on 2016-08-15\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef FRONTEND_WAV_H_\n#define FRONTEND_WAV_H_\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <memory>\n#include <string>\n#include <vector>\n\n#include \"gflags/gflags.h\"\n#include \"glog/logging.h\"\n\nDEFINE_int32(pcm_sample_rate, 16000, \"pcm data sample rate\");\n\nnamespace wenet {\n\nclass AudioReader {\n public:\n  AudioReader() {}\n  explicit AudioReader(const std::string& filename) {}\n  virtual ~AudioReader() {}\n\n  virtual int num_channel() const = 0;\n  virtual int sample_rate() const = 0;\n  virtual int bits_per_sample() const = 0;\n  virtual int num_sample() const = 0;\n  virtual const int16_t* data() const = 0;\n};\n\nstruct WavHeader {\n  char riff[4];  // \"riff\"\n  unsigned int size;\n  char wav[4];  // \"WAVE\"\n  char fmt[4];  // \"fmt \"\n  unsigned int fmt_size;\n  uint16_t format;\n  uint16_t channels;\n  unsigned int sample_rate;\n  unsigned int bytes_per_second;\n  uint16_t block_size;\n  uint16_t bit;\n  char data[4];  // \"data\"\n  unsigned int data_size;\n};\n\nclass WavReader : public AudioReader {\n public:\n  WavReader() {}\n  explicit WavReader(const std::string& filename) { Open(filename); }\n\n  bool Open(const std::string& filename) {\n    FILE* fp = fopen(filename.c_str(), \"rb\");\n    if (NULL == fp) {\n      LOG(WARNING) << \"Error in read \" << filename;\n      return false;\n    }\n\n    WavHeader header;\n    fread(&header, 1, sizeof(header), fp);\n    if (header.fmt_size < 16) {\n      fprintf(stderr,\n              \"WaveData: expect PCM format data \"\n              \"to have fmt chunk of at least size 16.\\n\");\n      return false;\n    } else if (header.fmt_size > 16) {\n      int offset = 44 - 8 + header.fmt_size - 16;\n      fseek(fp, offset, SEEK_SET);\n      fread(header.data, 8, sizeof(char), fp);\n    }\n    // check \"riff\" \"WAVE\" \"fmt \" \"data\"\n\n    // Skip any subchunks between \"fmt\" and \"data\".  Usually there will\n    // be a single \"fact\" subchunk, but on Windows there can also be a\n    // \"list\" subchunk.\n    while (0 != strncmp(header.data, \"data\", 4)) {\n      // We will just ignore the data in these chunks.\n      fseek(fp, header.data_size, SEEK_CUR);\n      // read next subchunk\n      fread(header.data, 8, sizeof(char), fp);\n    }\n\n    num_channel_ = header.channels;\n    sample_rate_ = header.sample_rate;\n    bits_per_sample_ = header.bit;\n    int num_data = header.data_size / (bits_per_sample_ / 8);\n    data_.resize(num_data);\n    int num_read = fread(&data_[0], 1, header.data_size, fp);\n    if (num_read < header.data_size) {\n      // If the header size is wrong, adjust\n      header.data_size = num_read;\n      num_data = header.data_size / (bits_per_sample_ / 8);\n      data_.resize(num_data);\n    }\n    num_sample_ = num_data / num_channel_;\n    fclose(fp);\n    return true;\n  }\n\n  int num_channel() const { return num_channel_; }\n  int sample_rate() const { return sample_rate_; }\n  int bits_per_sample() const { return bits_per_sample_; }\n  int num_sample() const { return num_sample_; }\n  const int16_t* data() const { return data_.data(); }\n\n private:\n  int num_channel_;\n  int sample_rate_;\n  int bits_per_sample_;\n  int num_sample_;  // sample points per channel\n  std::vector<int16_t> data_;\n};\n\nclass WavWriter {\n public:\n  WavWriter(const float* data, int num_sample, int num_channel, int sample_rate,\n            int bits_per_sample)\n      : data_(data),\n        num_sample_(num_sample),\n        num_channel_(num_channel),\n        sample_rate_(sample_rate),\n        bits_per_sample_(bits_per_sample) {}\n\n  void Write(const std::string& filename) {\n    FILE* fp = fopen(filename.c_str(), \"w\");\n    // init char 'riff' 'WAVE' 'fmt ' 'data'\n    WavHeader header;\n    char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,\n                           0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,\n                           0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,\n                           0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,\n                           0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};\n    memcpy(&header, wav_header, sizeof(header));\n    header.channels = num_channel_;\n    header.bit = bits_per_sample_;\n    header.sample_rate = sample_rate_;\n    header.data_size = num_sample_ * num_channel_ * (bits_per_sample_ / 8);\n    header.size = sizeof(header) - 8 + header.data_size;\n    header.bytes_per_second =\n        sample_rate_ * num_channel_ * (bits_per_sample_ / 8);\n    header.block_size = num_channel_ * (bits_per_sample_ / 8);\n\n    fwrite(&header, 1, sizeof(header), fp);\n\n    for (int i = 0; i < num_sample_; ++i) {\n      for (int j = 0; j < num_channel_; ++j) {\n        switch (bits_per_sample_) {\n          case 8: {\n            char sample = static_cast<char>(data_[i * num_channel_ + j]);\n            fwrite(&sample, 1, sizeof(sample), fp);\n            break;\n          }\n          case 16: {\n            int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);\n            fwrite(&sample, 1, sizeof(sample), fp);\n            break;\n          }\n          case 32: {\n            int sample = static_cast<int>(data_[i * num_channel_ + j]);\n            fwrite(&sample, 1, sizeof(sample), fp);\n            break;\n          }\n        }\n      }\n    }\n    fclose(fp);\n  }\n\n private:\n  const float* data_;\n  int num_sample_;  // total float points in data_\n  int num_channel_;\n  int sample_rate_;\n  int bits_per_sample_;\n};\n\nclass PcmReader : public AudioReader {\n public:\n  PcmReader() {}\n  explicit PcmReader(const std::string& filename) { Open(filename); }\n\n  bool Open(const std::string& filename) {\n    FILE* fp = fopen(filename.c_str(), \"rb\");\n    if (NULL == fp) {\n      LOG(WARNING) << \"Error in read \" << filename;\n      return false;\n    }\n\n    num_channel_ = 1;\n    sample_rate_ = FLAGS_pcm_sample_rate;\n    bits_per_sample_ = 16;\n    fseek(fp, 0, SEEK_END);\n    int data_size = ftell(fp);\n    fseek(fp, 0, SEEK_SET);\n    num_sample_ = data_size / sizeof(int16_t);\n    data_.resize(num_sample_);\n    fread(&data_[0], data_size, 1, fp);\n    fclose(fp);\n    return true;\n  }\n\n  int num_channel() const { return num_channel_; }\n  int sample_rate() const { return sample_rate_; }\n  int bits_per_sample() const { return bits_per_sample_; }\n  int num_sample() const { return num_sample_; }\n\n  const int16_t* data() const { return data_.data(); }\n\n private:\n  int num_channel_;\n  int sample_rate_;\n  int bits_per_sample_;\n  int num_sample_;  // sample points per channel\n  std::vector<int16_t> data_;\n};\n\nstd::shared_ptr<AudioReader> ReadAudioFile(const std::string& filename) {\n  size_t pos = filename.rfind('.');\n  std::string suffix = filename.substr(pos);\n  if (suffix == \".wav\" || suffix == \".WAV\") {\n    return std::make_shared<WavReader>(filename);\n  } else {\n    return std::make_shared<PcmReader>(filename);\n  }\n}\n\nvoid WriteWavFile(const float* data, int data_size, int sample_rate,\n                  const std::string& wav_path) {\n  std::vector<float> tmp_wav(data, data + data_size);\n  for (int i = 0; i < tmp_wav.size(); i++) {\n    tmp_wav[i] *= (1 << 15);\n  }\n  WavWriter wav_write(tmp_wav.data(), tmp_wav.size(), 1, sample_rate, 16);\n  wav_write.Write(wav_path);\n}\n\n}  // namespace wenet\n\n#endif  // FRONTEND_WAV_H_\n"
  },
  {
    "path": "runtime/separate/CMakeLists.txt",
    "content": "add_library(separate STATIC separate_engine.cc)\ntarget_link_libraries(separate PUBLIC frontend ${TORCH_LIBRARIES})\n"
  },
  {
    "path": "runtime/separate/separate_engine.cc",
    "content": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"separate/separate_engine.h\"\n\n#include <algorithm>\n#include <functional>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"gflags/gflags.h\"\n#include \"glog/logging.h\"\n#include \"torch/script.h\"\n#include \"torch/torch.h\"\n\nnamespace wesep {\n\nvoid SeparateEngine::InitEngineThreads(int num_threads) {\n  // for multi-thread performance\n  at::set_num_threads(num_threads);\n  VLOG(1) << \"Num intra-op threads: \" << at::get_num_threads();\n}\n\nSeparateEngine::SeparateEngine(const std::string& model_path,\n                               const int feat_dim, const int sample_rate) {\n  sample_rate_ = sample_rate;\n  feat_dim_ = feat_dim;\n  feature_config_ =\n      std::make_shared<wenet::FeaturePipelineConfig>(feat_dim, sample_rate);\n  feature_pipeline_ =\n      std::make_shared<wenet::FeaturePipeline>(*feature_config_);\n  feature_pipeline_->Reset();\n\n  InitEngineThreads(1);\n  torch::jit::script::Module model = torch::jit::load(model_path);\n  model_ = std::make_shared<torch::jit::script::Module>(std::move(model));\n  model_->eval();\n}\n\nvoid SeparateEngine::ExtractFeature(const int16_t* data, int data_size,\n                                    std::vector<std::vector<float>>* feat) {\n  feature_pipeline_->AcceptWaveform(\n      std::vector<int16_t>(data, data + data_size));\n  feature_pipeline_->set_input_finished();\n  feature_pipeline_->Read(feature_pipeline_->num_frames(), feat);\n  feature_pipeline_->Reset();\n  this->ApplyMean(feat);\n}\n\nvoid SeparateEngine::ApplyMean(std::vector<std::vector<float>>* feat) {\n  std::vector<float> mean(feat_dim_, 0);\n  for (auto& i : *feat) {\n    std::transform(i.begin(), i.end(), mean.begin(), mean.begin(),\n                   std::plus<>{});\n  }\n  std::transform(mean.begin(), mean.end(), mean.begin(),\n                 [&](const float d) { return d / feat->size(); });\n  for (auto& i : *feat) {\n    std::transform(i.begin(), i.end(), mean.begin(), i.begin(), std::minus<>{});\n  }\n}\n\nvoid SeparateEngine::ForwardFunc(const std::vector<int16_t>& mix_wav,\n                                 const int16_t* spk1_emb,\n                                 const int16_t* spk2_emb, int data_size,\n                                 std::vector<std::vector<float>>* output) {\n  // pre-process\n  std::vector<float> input_wav(mix_wav.size());\n  for (int i = 0; i < mix_wav.size(); i++) {\n    input_wav[i] = static_cast<float>(mix_wav[i]) / (1 << 15);\n  }\n  std::vector<std::vector<float>> spk1_emb_feat;\n  this->ExtractFeature(spk1_emb, data_size, &spk1_emb_feat);\n  std::vector<std::vector<float>> spk2_emb_feat;\n  this->ExtractFeature(spk2_emb, data_size, &spk2_emb_feat);\n\n  // torch mix_wav\n  torch::Tensor torch_wav = torch::zeros({2, mix_wav.size()}, torch::kFloat32);\n  for (size_t i = 0; i < 2; i++) {\n    torch::Tensor row =\n        torch::from_blob(input_wav.data(), {input_wav.size()}, torch::kFloat32)\n            .clone();\n    torch_wav[i] = std::move(row);\n  }\n\n  // torch spk_emb_feat\n  torch::Tensor torch_spk_emb_feat =\n      torch::zeros({2, spk1_emb_feat.size(), feat_dim_}, torch::kFloat32);\n  for (size_t i = 0; i < spk1_emb_feat.size(); i++) {\n    torch::Tensor row1 =\n        torch::from_blob(spk1_emb_feat[i].data(), {feat_dim_}, torch::kFloat32);\n    torch_spk_emb_feat[0][i] = std::move(row1);\n    torch::Tensor row2 =\n        torch::from_blob(spk2_emb_feat[i].data(), {feat_dim_}, torch::kFloat32);\n    torch_spk_emb_feat[1][i] = std::move(row2);\n  }\n\n  // forward\n  torch::NoGradGuard no_grad;\n  auto outputs =\n      model_->forward({torch_wav, torch_spk_emb_feat}).toTuple()->elements();\n  torch::Tensor wav_out = outputs[0].toTensor();\n  auto accessor = wav_out.accessor<float, 2>();\n\n  output->resize(2, std::vector<float>(wav_out.size(1), 0.0));\n  for (int i = 0; i < wav_out.size(1); i++) {\n    (*output)[0][i] = accessor[0][i];\n    (*output)[1][i] = accessor[1][i];\n  }\n}\n\n}  // namespace wesep\n"
  },
  {
    "path": "runtime/separate/separate_engine.h",
    "content": "// Copyright (c) 2024 wesep team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef SEPARATE_SEPARATE_ENGINE_H_\n#define SEPARATE_SEPARATE_ENGINE_H_\n\n#include <memory>\n#include <string>\n#include <vector>\n\n#include \"torch/script.h\"\n#include \"torch/torch.h\"\n\n#include \"frontend/feature_pipeline.h\"\n\nnamespace wesep {\n\nclass SeparateEngine {\n public:\n  explicit SeparateEngine(const std::string& model_path, const int feat_dim,\n                          const int sample_rate);\n\n  void InitEngineThreads(int num_threads = 1);\n\n  void ForwardFunc(const std::vector<int16_t>& mix_wav, const int16_t* spk1_emb,\n                   const int16_t* spk2_emb, int data_size,\n                   std::vector<std::vector<float>>* output);\n\n  void ExtractFeature(const int16_t* data, int data_size,\n                      std::vector<std::vector<float>>* feat);\n\n  void ApplyMean(std::vector<std::vector<float>>* feat);\n\n private:\n  std::shared_ptr<torch::jit::script::Module> model_ = nullptr;\n  std::shared_ptr<wenet::FeaturePipelineConfig> feature_config_ = nullptr;\n  std::shared_ptr<wenet::FeaturePipeline> feature_pipeline_ = nullptr;\n  int sample_rate_ = 16000;\n  int feat_dim_ = 80;\n};\n\n}  // namespace wesep\n\n#endif  // SEPARATE_SEPARATE_ENGINE_H_\n"
  },
  {
    "path": "runtime/utils/CMakeLists.txt",
    "content": "add_library(utils STATIC\n  utils.cc\n)\ntarget_link_libraries(utils PUBLIC glog gflags frontend)\n"
  },
  {
    "path": "runtime/utils/blocking_queue.h",
    "content": "// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef UTILS_BLOCKING_QUEUE_H_\n#define UTILS_BLOCKING_QUEUE_H_\n\n#include <condition_variable>\n#include <limits>\n#include <mutex>\n#include <queue>\n#include <utility>\n\nnamespace wenet {\n\n#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \\\n  Type(const Type&) = delete;                \\\n  Type& operator=(const Type&) = delete;\n\ntemplate <typename T>\nclass BlockingQueue {\n public:\n  explicit BlockingQueue(size_t capacity = std::numeric_limits<int>::max())\n      : capacity_(capacity) {}\n\n  void Push(const T& value) {\n    {\n      std::unique_lock<std::mutex> lock(mutex_);\n      while (queue_.size() >= capacity_) {\n        not_full_condition_.wait(lock);\n      }\n      queue_.push(value);\n    }\n    not_empty_condition_.notify_one();\n  }\n\n  void Push(T&& value) {\n    {\n      std::unique_lock<std::mutex> lock(mutex_);\n      while (queue_.size() >= capacity_) {\n        not_full_condition_.wait(lock);\n      }\n      queue_.push(std::move(value));\n    }\n    not_empty_condition_.notify_one();\n  }\n\n  T Pop() {\n    std::unique_lock<std::mutex> lock(mutex_);\n    while (queue_.empty()) {\n      not_empty_condition_.wait(lock);\n    }\n    T t(std::move(queue_.front()));\n    queue_.pop();\n    not_full_condition_.notify_one();\n    return t;\n  }\n\n  bool Empty() const {\n    std::lock_guard<std::mutex> lock(mutex_);\n    return queue_.empty();\n  }\n\n  size_t Size() const {\n    std::lock_guard<std::mutex> lock(mutex_);\n    return queue_.size();\n  }\n\n  void Clear() {\n    while (!Empty()) {\n      Pop();\n    }\n  }\n\n private:\n  size_t capacity_;\n  mutable std::mutex mutex_;\n  std::condition_variable not_full_condition_;\n  std::condition_variable not_empty_condition_;\n  std::queue<T> queue_;\n\n public:\n  WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue);\n};\n\n}  // namespace wenet\n\n#endif  // UTILS_BLOCKING_QUEUE_H_\n"
  },
  {
    "path": "runtime/utils/timer.h",
    "content": "// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef UTILS_TIMER_H_\n#define UTILS_TIMER_H_\n\n#include <chrono>\n\nnamespace wenet {\n\nclass Timer {\n public:\n  Timer() : time_start_(std::chrono::steady_clock::now()) {}\n  void Reset() { time_start_ = std::chrono::steady_clock::now(); }\n  // return int in milliseconds\n  int Elapsed() const {\n    auto time_now = std::chrono::steady_clock::now();\n    return std::chrono::duration_cast<std::chrono::milliseconds>(time_now -\n                                                                 time_start_)\n        .count();\n  }\n\n private:\n  std::chrono::time_point<std::chrono::steady_clock> time_start_;\n};\n}  // namespace wenet\n\n#endif  // UTILS_TIMER_H_\n"
  },
  {
    "path": "runtime/utils/utils.cc",
    "content": "// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <algorithm>\n#include <cmath>\n#include <fstream>\n#include <limits>\n#include <numeric>\n#include <sstream>\n#include <vector>\n\n#include \"glog/logging.h\"\n#include \"utils/utils.h\"\n\nnamespace wesep {\n\nstd::string Ltrim(const std::string& str) {\n  size_t start = str.find_first_not_of(WHITESPACE);\n  return (start == std::string::npos) ? \"\" : str.substr(start);\n}\n\nstd::string Rtrim(const std::string& str) {\n  size_t end = str.find_last_not_of(WHITESPACE);\n  return (end == std::string::npos) ? \"\" : str.substr(0, end + 1);\n}\n\nstd::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); }\n\nvoid SplitString(const std::string& str, std::vector<std::string>* strs) {\n  SplitStringToVector(Trim(str), \" \\t\", true, strs);\n}\n\nvoid SplitStringToVector(const std::string& full, const char* delim,\n                         bool omit_empty_strings,\n                         std::vector<std::string>* out) {\n  size_t start = 0, found = 0, end = full.size();\n  out->clear();\n  while (found != std::string::npos) {\n    found = full.find_first_of(delim, start);\n    // start != end condition is for when the delimiter is at the end\n    if (!omit_empty_strings || (found != start && start != end))\n      out->push_back(full.substr(start, found - start));\n    start = found + 1;\n  }\n}\n\n#ifdef _MSC_VER\nstd::wstring ToWString(const std::string& str) {\n  unsigned len = str.size() * 2;\n  setlocale(LC_CTYPE, \"\");\n  wchar_t* p = new wchar_t[len];\n  mbstowcs(p, str.c_str(), len);\n  std::wstring wstr(p);\n  delete[] p;\n  return wstr;\n}\n#endif\n\n}  // namespace wesep\n"
  },
  {
    "path": "runtime/utils/utils.h",
    "content": "// Copyright (c) 2023 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#ifndef UTILS_UTILS_H_\n#define UTILS_UTILS_H_\n\n#include <string>\n#include <unordered_map>\n#include <vector>\n\nnamespace wesep {\n\nconst char WHITESPACE[] = \" \\n\\r\\t\\f\\v\";\n\n// Split the string with space or tab.\nvoid SplitString(const std::string& str, std::vector<std::string>* strs);\n\nvoid SplitStringToVector(const std::string& full, const char* delim,\n                         bool omit_empty_strings,\n                         std::vector<std::string>* out);\n\n#ifdef _MSC_VER\nstd::wstring ToWString(const std::string& str);\n#endif\n\n}  // namespace wesep\n\n#endif  // UTILS_UTILS_H_\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nrequirements = [\n    \"tqdm\",\n    \"kaldiio\",\n    \"torch>=1.12.0\",\n    \"torchaudio>=0.12.0\",\n    \"silero-vad\",\n]\n\nsetup(\n    name=\"wesep\",\n    install_requires=requirements,\n    packages=find_packages(),\n    entry_points={\n        \"console_scripts\": [\n            \"wesep = wesep.cli.extractor:main\",\n        ],\n    },\n)\n"
  },
  {
    "path": "tools/extract_embed_depreciated.py",
    "content": "# Copyright (c) 2022, Shuai Wang (wsstriving@gmail.com)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport argparse\r\nimport os\r\n\r\nimport kaldiio\r\nimport onnxruntime as ort\r\nimport torch\r\nimport torchaudio\r\nimport torchaudio.compliance.kaldi as kaldi\r\nfrom tqdm import tqdm\r\n\r\n\r\ndef get_args():\r\n    parser = argparse.ArgumentParser(description=\"infer example using onnx\")\r\n    parser.add_argument(\"--onnx_path\", required=True, help=\"onnx path\")\r\n    parser.add_argument(\"--wav_scp\", required=True, help=\"wav path\")\r\n    parser.add_argument(\"--out_path\",\r\n                        required=True,\r\n                        help=\"output path of the embeddings\")\r\n    args = parser.parse_args()\r\n    return args\r\n\r\n\r\ndef compute_fbank(wav_path,\r\n                  num_mel_bins=80,\r\n                  frame_length=25,\r\n                  frame_shift=10,\r\n                  dither=0.0):\r\n    \"\"\"Extract fbank, simlilar to the one in wespeaker.dataset.processor,\r\n    While integrating the wave reading and CMN.\r\n    \"\"\"\r\n    waveform, sample_rate = torchaudio.load(wav_path)\r\n    waveform = waveform * (1 << 15)\r\n    mat = kaldi.fbank(\r\n        waveform,\r\n        num_mel_bins=num_mel_bins,\r\n        frame_length=frame_length,\r\n        frame_shift=frame_shift,\r\n        dither=dither,\r\n        sample_frequency=sample_rate,\r\n        window_type=\"hamming\",\r\n        use_energy=False,\r\n    )\r\n    # CMN, without CVN\r\n    mat = mat - torch.mean(mat, dim=0)\r\n    return mat\r\n\r\n\r\ndef main():\r\n    args = get_args()\r\n\r\n    so = ort.SessionOptions()\r\n    so.inter_op_num_threads = 1\r\n    so.intra_op_num_threads = 1\r\n    session = ort.InferenceSession(args.onnx_path, sess_options=so)\r\n\r\n    embed_ark = os.path.join(args.out_path, \"embed.ark\")\r\n    embed_scp = os.path.join(args.out_path, \"embed.scp\")\r\n\r\n    with kaldiio.WriteHelper(\"ark,scp:\" + embed_ark + \",\" +\r\n                             embed_scp) as writer:\r\n        with open(args.wav_scp, \"r\") as read_scp:\r\n            for line in tqdm(read_scp):\r\n                tokens = line.strip().split(\" \")\r\n                name, wav_path = tokens[0], tokens[1]\r\n\r\n                feats = compute_fbank(wav_path)\r\n                feats = feats.unsqueeze(0).numpy()  # add batch dimension\r\n                embed = session.run(output_names=[\"embs\"],\r\n                                    input_feed={\"feats\": feats})\r\n                writer(name, embed[0])\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tools/make_lmdb.py",
    "content": "# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport math\nimport pickle\n\nimport lmdb\nfrom tqdm import tqdm\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description=\"\")\n    parser.add_argument(\"in_scp_file\", help=\"input scp file\")\n    parser.add_argument(\"out_lmdb\", help=\"output lmdb\")\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = get_args()\n    db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4)))  # 1TB\n    # txn is for Transaciton\n    txn = db.begin(write=True)\n    keys = []\n    with open(args.in_scp_file, \"r\", encoding=\"utf8\") as fin:\n        lines = fin.readlines()\n        for i, line in enumerate(tqdm(lines)):\n            arr = line.strip().split()\n            assert len(arr) == 2\n            key, wav = arr[0], arr[1]\n            keys.append(key)\n            with open(wav, \"rb\") as fin:\n                data = fin.read()\n            txn.put(key.encode(), data)\n            # Write flush to disk\n            if i % 100 == 0:\n                txn.commit()\n                txn = db.begin(write=True)\n    txn.commit()\n    with db.begin(write=True) as txn:\n        txn.put(b\"__keys__\", pickle.dumps(keys))\n    db.sync()\n    db.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tools/make_shard_list_premix.py",
    "content": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang\r\n#               2023    SRIBD              Shuai Wang )\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport argparse\r\nimport io\r\nimport logging\r\nimport multiprocessing\r\nimport os\r\nimport random\r\nimport tarfile\r\nimport time\r\nimport sys\r\n\r\nAUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}\r\n\r\n\r\ndef write_tar_file(data_list, tar_file, index=0, total=1):\r\n    logging.info('Processing {} {}/{}'.format(tar_file, index, total))\r\n    read_time = 0.0\r\n    write_time = 0.0\r\n    with tarfile.open(tar_file, \"w\") as tar:\r\n        for item in data_list:\r\n            assert len(\r\n                item) == 3, 'item should have 3 elements: Key, Speaker, Wav'\r\n            key, spks, wavs = item\r\n            spk_idx = 1\r\n            for spk in spks:\r\n                assert isinstance(spk, str)\r\n                spk_file = key + '.spk' + str(spk_idx)\r\n                spk = spk.encode('utf8')\r\n                spk_data = io.BytesIO(spk)\r\n                spk_info = tarfile.TarInfo(spk_file)\r\n                spk_info.size = len(spk)\r\n                tar.addfile(spk_info, spk_data)\r\n                spk_idx = spk_idx + 1\r\n\r\n            spk_idx = 0\r\n            for wav in wavs:\r\n                suffix = wav.split('.')[-1]\r\n                assert suffix in AUDIO_FORMAT_SETS\r\n                ts = time.time()\r\n                try:\r\n                    with open(wav, 'rb') as fin:\r\n                        data = fin.read()\r\n                except FileNotFoundError as e:\r\n                    print(e)\r\n                    sys.exit()\r\n                read_time += (time.time() - ts)\r\n                ts = time.time()\r\n                if spk_idx > 0:\r\n                    wav_file = key + '_spk' + str(spk_idx) + '.' + suffix\r\n                else:\r\n                    wav_file = key + '.' + suffix\r\n                wav_data = io.BytesIO(data)\r\n                wav_info = tarfile.TarInfo(wav_file)\r\n                wav_info.size = len(data)\r\n                tar.addfile(wav_info, wav_data)\r\n                write_time += (time.time() - ts)\r\n                spk_idx = spk_idx + 1\r\n\r\n        logging.info('read {} write {}'.format(read_time, write_time))\r\n\r\n\r\ndef get_args():\r\n    parser = argparse.ArgumentParser(description='')\r\n    parser.add_argument('--num_utts_per_shard',\r\n                        type=int,\r\n                        default=1000,\r\n                        help='num utts per shard')\r\n    parser.add_argument('--num_threads',\r\n                        type=int,\r\n                        default=1,\r\n                        help='num threads for make shards')\r\n    parser.add_argument('--prefix',\r\n                        default='shards',\r\n                        help='prefix of shards tar file')\r\n    parser.add_argument('--seed', type=int, default=42, help='random seed')\r\n    parser.add_argument('--shuffle',\r\n                        action='store_true',\r\n                        help='whether to shuffle data')\r\n    parser.add_argument('wav_file', help='wav file')\r\n    parser.add_argument('utt2spk_file', help='utt2spk file')\r\n    parser.add_argument('shards_dir', help='output shards dir')\r\n    parser.add_argument('shards_list', help='output shards list file')\r\n    args = parser.parse_args()\r\n    return args\r\n\r\n\r\ndef main():\r\n    args = get_args()\r\n    random.seed(args.seed)\r\n    logging.basicConfig(level=logging.INFO,\r\n                        format='%(asctime)s %(levelname)s %(message)s')\r\n\r\n    wav_table = {}\r\n    with open(args.wav_file, 'r', encoding='utf8') as fin:\r\n        for line in fin:\r\n            arr = line.strip().split()\r\n            key = arr[0]  # key = os.path.splitext(arr[0])[0]\r\n            wav_table[key] = [arr[i + 1] for i in range(len(arr) - 1)]\r\n\r\n    data = []\r\n    with open(args.utt2spk_file, 'r', encoding='utf8') as fin:\r\n        for line in fin:\r\n            arr = line.strip().split()\r\n            key = arr[0]  # key = os.path.splitext(arr[0])[0]\r\n            spks = [arr[i + 1] for i in range(len(arr) - 1)]\r\n            assert key in wav_table\r\n            wavs = wav_table[key]\r\n            data.append((key, spks, wavs))\r\n\r\n    if args.shuffle:\r\n        random.shuffle(data)\r\n\r\n    num = args.num_utts_per_shard\r\n    chunks = [data[i:i + num] for i in range(0, len(data), num)]\r\n    os.makedirs(args.shards_dir, exist_ok=True)\r\n\r\n    # Using thread pool to speedup\r\n    pool = multiprocessing.Pool(processes=args.num_threads)\r\n    shards_list = []\r\n    num_chunks = len(chunks)\r\n    for i, chunk in enumerate(chunks):\r\n        tar_file = os.path.join(args.shards_dir,\r\n                                '{}_{:09d}.tar'.format(args.prefix, i))\r\n        shards_list.append(tar_file)\r\n        pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks))\r\n\r\n    pool.close()\r\n    pool.join()\r\n\r\n    with open(args.shards_list, 'w', encoding='utf8') as fout:\r\n        for name in shards_list:\r\n            fout.write(name + '\\n')\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "tools/make_shard_online.py",
    "content": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang\n#                                      2023 Shuai Wang )\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport io\nimport logging\nimport multiprocessing\nimport os\nimport random\nimport tarfile\nimport time\n\nAUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}\n\n\ndef write_tar_file(data_list, tar_file, index=0, total=1):\n    logging.info('Processing {} {}/{}'.format(tar_file, index, total))\n    read_time = 0.0\n    write_time = 0.0\n    with tarfile.open(tar_file, \"w\") as tar:\n        for item in data_list:\n            assert len(\n                item) == 3, 'item should have 3 elements: Key, Speaker, Wav'\n            key, spk, wav = item\n\n            suffix = wav.split('.')[-1]\n            assert suffix in AUDIO_FORMAT_SETS\n\n            ts = time.time()\n\n            with open(wav, 'rb') as fin:\n                data = fin.read()\n\n            read_time += (time.time() - ts)\n            assert isinstance(spk, str)\n            ts = time.time()\n            spk_file = key + '.spk'\n            spk = spk.encode('utf8')\n            spk_data = io.BytesIO(spk)\n            spk_info = tarfile.TarInfo(spk_file)\n            spk_info.size = len(spk)\n            tar.addfile(spk_info, spk_data)\n\n            wav_file = key + '.' + suffix\n            wav_data = io.BytesIO(data)\n            wav_info = tarfile.TarInfo(wav_file)\n            wav_info.size = len(data)\n            tar.addfile(wav_info, wav_data)\n            write_time += (time.time() - ts)\n        logging.info('read {} write {}'.format(read_time, write_time))\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('--num_utts_per_shard',\n                        type=int,\n                        default=1000,\n                        help='num utts per shard')\n    parser.add_argument('--num_threads',\n                        type=int,\n                        default=1,\n                        help='num threads for make shards')\n    parser.add_argument('--prefix',\n                        default='shards',\n                        help='prefix of shards tar file')\n    parser.add_argument('--seed', type=int, default=42, help='random seed')\n    parser.add_argument('--shuffle',\n                        action='store_true',\n                        help='whether to shuffle data')\n    parser.add_argument('wav_file', help='wav file')\n    parser.add_argument('utt2spk_file', help='utt2spk file')\n    parser.add_argument('shards_dir', help='output shards dir')\n    parser.add_argument('shards_list', help='output shards list file')\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = get_args()\n    random.seed(args.seed)\n    logging.basicConfig(level=logging.INFO,\n                        format='%(asctime)s %(levelname)s %(message)s')\n\n    wav_table = {}\n    with open(args.wav_file, 'r', encoding='utf8') as fin:\n        for line in fin:\n            arr = line.strip().split()\n            key = arr[0]  # key = os.path.splitext(arr[0])[0]\n            wav_table[key] = ' '.join(arr[1:])\n\n    data = []\n    with open(args.utt2spk_file, 'r', encoding='utf8') as fin:\n        for line in fin:\n            arr = line.strip().split(maxsplit=1)\n            key = arr[0]  # key = os.path.splitext(arr[0])[0]\n            spk = arr[1]\n            assert key in wav_table\n            wav = wav_table[key]\n            data.append((key, spk, wav))\n\n    if args.shuffle:\n        random.shuffle(data)\n\n    num = args.num_utts_per_shard\n    chunks = [data[i:i + num] for i in range(0, len(data), num)]\n    os.makedirs(args.shards_dir, exist_ok=True)\n\n    # Using thread pool to speedup\n    pool = multiprocessing.Pool(processes=args.num_threads)\n    shards_list = []\n    num_chunks = len(chunks)\n    for i, chunk in enumerate(chunks):\n        tar_file = os.path.join(args.shards_dir,\n                                '{}_{:09d}.tar'.format(args.prefix, i))\n        shards_list.append(tar_file)\n        pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks))\n\n    pool.close()\n    pool.join()\n\n    with open(args.shards_list, 'w', encoding='utf8') as fout:\n        for name in shards_list:\n            fout.write(name + '\\n')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/parse_options.sh",
    "content": "#!/bin/bash\n\n# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);\n#                 Arnab Ghoshal, Karel Vesely\n#           2022  Hongji Wang (jijijiang77@gmail.com)\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#  http://www.apache.org/licenses/LICENSE-2.0\n#\n# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED\n# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,\n# MERCHANTABLITY OR NON-INFRINGEMENT.\n# See the Apache 2 License for the specific language governing permissions and\n# limitations under the License.\n\n# Parse command-line options.\n# To be sourced by another script (as in \". parse_options.sh\").\n# Option format is: --option-name arg\n# and shell variable \"option_name\" gets set to value \"arg.\"\n# The exception is --help, which takes no arguments, but prints the\n# $help_message variable (if defined).\n\n###\n### The --conf file options have lower priority to command line\n### options, so we need to import them first...\n###\n\n# Now import all the confs specified by command-line, in left-to-right order\nfor ((argpos = 1; argpos < $#; argpos++)); do\n  if [ \"${!argpos}\" == \"--conf\" ]; then\n    argpos_plus1=$((argpos + 1))\n    conf=${!argpos_plus1}\n    [ ! -r $conf ] && echo \"$0: missing conf '$conf'\" && exit 1\n    . $conf # source the conf file.\n  fi\ndone\n\n###\n### No we process the command line options\n###\nwhile true; do\n  [ -z \"${1:-}\" ] && break # break if there are no arguments\n  case \"$1\" in\n  # If the enclosing script is called with --help option, print the help\n  # message and exit.  Scripts should put help messages in $help_message\n  --help | -h)\n    if [ -z \"$help_message\" ]; then\n      echo \"No help found.\" 1>&2\n    else printf \"$help_message\\n\" 1>&2; fi\n    exit 0\n    ;;\n  --*=*)\n    echo \"$0: options to scripts must be of the form --name value, got '$1'\"\n    exit 1\n    ;;\n  # If the first command-line argument begins with \"--\" (e.g. --foo-bar),\n  # then work out the variable name as $name, which will equal \"foo_bar\".\n  --*)\n    name=$(echo \"$1\" | sed s/^--// | sed s/-/_/g)\n    # Next we test whether the variable in question is undefned-- if so it's\n    # an invalid option and we die.  Note: $0 evaluates to the name of the\n    # enclosing script.\n    # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar\n    # is undefined.  We then have to wrap this test inside \"eval\" because\n    # foo_bar is itself inside a variable ($name).\n    eval '[ -z \"${'$name'+xxx}\" ]' && echo \"$0: invalid option $1\" 1>&2 && exit 1\n\n    oldval=\"$(eval echo \\$$name)\"\n    # Work out whether we seem to be expecting a Boolean argument.\n    if [ \"$oldval\" == \"true\" ] || [ \"$oldval\" == \"false\" ]; then\n      was_bool=true\n    else\n      was_bool=false\n    fi\n\n    # Set the variable to the right value-- the escaped quotes make it work if\n    # the option had spaces, like --cmd \"queue.pl -sync y\"\n    eval $name=\\\"$2\\\"\n\n    # Check that Boolean-valued arguments are really Boolean.\n    if $was_bool && [[ \"$2\" != \"true\" && \"$2\" != \"false\" ]]; then\n      echo \"$0: expected \\\"true\\\" or \\\"false\\\": $1 $2\" 1>&2\n      exit 1\n    fi\n    shift 2\n    ;;\n  *) break ;;\n  esac\ndone\n\n# Check for an empty argument to the --cmd option, which can easily occur as a\n# result of scripting errors.\n[ ! -z \"${cmd+xxx}\" ] && [ -z \"$cmd\" ] && echo \"$0: empty argument to --cmd option\" 1>&2 && exit 1\n\ntrue # so this script returns exit code 0."
  },
  {
    "path": "tools/print_train_val_curve.py",
    "content": "import re\r\n\r\nimport matplotlib.pyplot as plt\r\n\r\n# Initialize lists to store epochs, train losses and validation losses\r\nepochs = []\r\ntrain_loss = []\r\nval_loss = []\r\n\r\n# Open the log file\r\nprev_epoch = 0\r\n\r\nwith open(\"train.log\", \"r\") as f:\r\n    for line in f:\r\n        # Find lines with epoch info\r\n        if \"info\" in line:\r\n            # Extract epoch number\r\n            epoch = int(re.search(r\"Epoch (\\d+)\", line).group(1))\r\n            if epoch != prev_epoch:\r\n                print(prev_epoch, epoch)\r\n                # Extract loss values\r\n                # pattern = r'loss (.*?)\\n'\r\n                pattern = r\"[-+]?\\d*\\.\\d+\"\r\n                loss = float(re.search(pattern, line).group())\r\n                if \"Train\" in line:\r\n                    epochs.append(epoch)\r\n                    train_loss.append(loss)\r\n                elif \"Val\" in line:\r\n                    val_loss.append(loss)\r\n                    prev_epoch = epoch\r\n\r\n# Create the plot\r\nplt.figure(figsize=(10, 5))\r\n\r\n# Plot training and validation loss\r\nplt.plot(epochs, train_loss, label=\"Training Loss\", color=\"blue\")\r\nplt.plot(epochs, val_loss, label=\"Validation Loss\", color=\"red\")\r\n\r\n# Add horizontal lines at the minimum values\r\nplt.axhline(min(train_loss),\r\n            color=\"blue\",\r\n            linestyle=\"--\",\r\n            label=\"Min Training Loss\")\r\nplt.axhline(min(val_loss),\r\n            color=\"red\",\r\n            linestyle=\"--\",\r\n            label=\"Min Validation Loss\")\r\n\r\n# Annotate the minimum values on the y-axis\r\nplt.text(\r\n    0,\r\n    min(train_loss),\r\n    \"{:.2f}\".format(min(train_loss)),\r\n    va=\"center\",\r\n    ha=\"left\",\r\n    backgroundcolor=\"w\",\r\n)\r\nplt.text(\r\n    0,\r\n    min(val_loss),\r\n    \"{:.2f}\".format(min(val_loss)),\r\n    va=\"center\",\r\n    ha=\"left\",\r\n    backgroundcolor=\"w\",\r\n)\r\n\r\n# Add legend, title, and x, y labels\r\nplt.legend(loc=\"upper right\")\r\nplt.title(\"Training and Validation Loss Over Epochs\")\r\nplt.ylabel(\"Loss Value\")\r\nplt.xlabel(\"Epochs\")\r\n\r\n# Save the plot as a .png file\r\nplt.savefig(\"train_val_loss.png\")\r\n\r\n# Show the plot\r\n# plt.show()\r\n"
  },
  {
    "path": "tools/run.pl",
    "content": "#!/usr/bin/env perl\nuse warnings; #sed replacement for -w perl parameter\n# In general, doing\n#  run.pl some.log a b c is like running the command a b c in\n# the bash shell, and putting the standard error and output into some.log.\n# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)\n#  run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB\n# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].\n# If any of the jobs fails, this script will fail.\n\n# A typical example is:\n#  run.pl some.log my-prog \"--opt=foo bar\" foo \\|  other-prog baz\n# and run.pl will run something like:\n# ( my-prog '--opt=foo bar' foo |  other-prog baz ) >& some.log\n#\n# Basically it takes the command-line arguments, quotes them\n# as necessary to preserve spaces, and evaluates them with bash.\n# In addition it puts the command line at the top of the log, and\n# the start and end times of the command at the beginning and end.\n# The reason why this is useful is so that we can create a different\n# version of this program that uses a queueing system instead.\n\n#use Data::Dumper;\n\n@ARGV < 2 && die \"usage: run.pl log-file command-line arguments...\";\n\n#print STDERR \"COMMAND-LINE: \" .  Dumper(\\@ARGV) . \"\\n\";\n$job_pick = 'all';\n$max_jobs_run = -1;\n$jobstart = 1;\n$jobend = 1;\n$ignored_opts = \"\"; # These will be ignored.\n\n# First parse an option like JOB=1:4, and any\n# options that would normally be given to\n# queue.pl, which we will just discard.\n\nfor (my $x = 1; $x <= 2; $x++) { # This for-loop is to\n  # allow the JOB=1:n option to be interleaved with the\n  # options to qsub.\n  while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {\n    # parse any options that would normally go to qsub, but which will be ignored here.\n    my $switch = shift @ARGV;\n    if ($switch eq \"-V\") {\n      $ignored_opts .= \"-V \";\n    } elsif ($switch eq \"--max-jobs-run\" || $switch eq \"-tc\") {\n      # we do support the option --max-jobs-run n, and its GridEngine form -tc n.\n      # if the command appears multiple times uses the smallest option.\n      if ( $max_jobs_run <= 0 ) {\n          $max_jobs_run =  shift @ARGV;\n      } else {\n        my $new_constraint = shift @ARGV;\n        if ( ($new_constraint < $max_jobs_run) ) {\n          $max_jobs_run = $new_constraint;\n        }\n      }\n\n      if (! ($max_jobs_run > 0)) {\n        die \"run.pl: invalid option --max-jobs-run $max_jobs_run\";\n      }\n    } else {\n      my $argument = shift @ARGV;\n      if ($argument =~ m/^--/) {\n        print STDERR \"run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\\n\";\n      }\n      if ($switch eq \"-sync\" && $argument =~ m/^[yY]/) {\n        $ignored_opts .= \"-sync \"; # Note: in the\n        # corresponding code in queue.pl it says instead, just \"$sync = 1;\".\n      } elsif ($switch eq \"-pe\") { # e.g. -pe smp 5\n        my $argument2 = shift @ARGV;\n        $ignored_opts .= \"$switch $argument $argument2 \";\n      } elsif ($switch eq \"--gpu\") {\n        $using_gpu = $argument;\n      } elsif ($switch eq \"--pick\") {\n        if($argument =~ m/^(all|failed|incomplete)$/) {\n          $job_pick = $argument;\n        } else {\n          print STDERR \"run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'\"\n        }\n      } else {\n        # Ignore option.\n        $ignored_opts .= \"$switch $argument \";\n      }\n    }\n  }\n  if ($ARGV[0] =~ m/^([\\w_][\\w\\d_]*)+=(\\d+):(\\d+)$/) { # e.g. JOB=1:20\n    $jobname = $1;\n    $jobstart = $2;\n    $jobend = $3;\n    if ($jobstart > $jobend) {\n      die \"run.pl: invalid job range $ARGV[0]\";\n    }\n    if ($jobstart <= 0) {\n      die \"run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).\";\n    }\n    shift;\n  } elsif ($ARGV[0] =~ m/^([\\w_][\\w\\d_]*)+=(\\d+)$/) { # e.g. JOB=1.\n    $jobname = $1;\n    $jobstart = $2;\n    $jobend = $2;\n    shift;\n  } elsif ($ARGV[0] =~ m/.+\\=.*\\:.*$/) {\n    print STDERR \"run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\\n\";\n  }\n}\n\n# Users found this message confusing so we are removing it.\n# if ($ignored_opts ne \"\") {\n#   print STDERR \"run.pl: Warning: ignoring options \\\"$ignored_opts\\\"\\n\";\n# }\n\nif ($max_jobs_run == -1) { # If --max-jobs-run option not set,\n                           # then work out the number of processors if possible,\n                           # and set it based on that.\n  $max_jobs_run = 0;\n  if ($using_gpu) {\n    if (open(P, \"nvidia-smi -L |\")) {\n      $max_jobs_run++ while (<P>);\n      close(P);\n    }\n    if ($max_jobs_run == 0) {\n      $max_jobs_run = 1;\n      print STDERR \"run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\\n\";\n    }\n  } elsif (open(P, \"</proc/cpuinfo\")) {  # Linux\n    while (<P>) { if (m/^processor/) { $max_jobs_run++; } }\n    if ($max_jobs_run == 0) {\n      print STDERR \"run.pl: Warning: failed to detect any processors from /proc/cpuinfo\\n\";\n      $max_jobs_run = 10;  # reasonable default.\n    }\n    close(P);\n  } elsif (open(P, \"sysctl -a |\")) {  # BSD/Darwin\n    while (<P>) {\n      if (m/hw\\.ncpu\\s*[:=]\\s*(\\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4\n        $max_jobs_run = $1;\n        last;\n      }\n    }\n    close(P);\n    if ($max_jobs_run == 0) {\n      print STDERR \"run.pl: Warning: failed to detect any processors from sysctl -a\\n\";\n      $max_jobs_run = 10;  # reasonable default.\n    }\n  } else {\n    # allow at most 32 jobs at once, on non-UNIX systems; change this code\n    # if you need to change this default.\n    $max_jobs_run = 32;\n  }\n  # The just-computed value of $max_jobs_run is just the number of processors\n  # (or our best guess); and if it happens that the number of jobs we need to\n  # run is just slightly above $max_jobs_run, it will make sense to increase\n  # $max_jobs_run to equal the number of jobs, so we don't have a small number\n  # of leftover jobs.\n  $num_jobs = $jobend - $jobstart + 1;\n  if (!$using_gpu &&\n      $num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {\n    $max_jobs_run = $num_jobs;\n  }\n}\n\nsub pick_or_exit {\n  # pick_or_exit ( $logfile )\n  # Invoked before each job is started helps to run jobs selectively.\n  #\n  # Given the name of the output logfile decides whether the job must be\n  # executed (by returning from the subroutine) or not (by terminating the\n  # process calling exit)\n  #\n  # PRE: $job_pick is a global variable set by command line switch --pick\n  #      and indicates which class of jobs must be executed.\n  #\n  # 1) If a failed job is not executed the process exit code will indicate\n  #    failure, just as if the task was just executed  and failed.\n  #\n  # 2) If a task is incomplete it will be executed. Incomplete may be either\n  #    a job whose log file does not contain the accounting notes in the end,\n  #    or a job whose log file does not exist.\n  #\n  # 3) If the $job_pick is set to 'all' (default behavior) a task will be\n  #    executed regardless of the result of previous attempts.\n  #\n  # This logic could have been implemented in the main execution loop\n  # but a subroutine to preserve the current level of readability of\n  # that part of the code.\n  #\n  # Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020\n  #\n  if($job_pick eq 'all'){\n    return; # no need to bother with the previous log\n  }\n  open my $fh, \"<\", $_[0] or return; # job not executed yet\n  my $log_line;\n  my $cur_line;\n  while ($cur_line = <$fh>) {\n    if( $cur_line =~ m/# Ended \\(code .*/ ) {\n      $log_line = $cur_line;\n    }\n  }\n  close $fh;\n  if (! defined($log_line)){\n    return; # incomplete\n  }\n  if ( $log_line =~ m/# Ended \\(code 0\\).*/ ) {\n    exit(0); # complete\n  } elsif ( $log_line =~ m/# Ended \\(code \\d+(; signal \\d+)?\\).*/ ){\n    if ($job_pick !~ m/^(failed|all)$/) {\n      exit(1); # failed but not going to run\n    } else {\n      return; # failed\n    }\n  } elsif ( $log_line =~ m/.*\\S.*/ ) {\n    return; # incomplete jobs are always run\n  }\n}\n\n\n$logfile = shift @ARGV;\n\nif (defined $jobname && $logfile !~ m/$jobname/ &&\n    $jobend > $jobstart) {\n  print STDERR \"run.pl: you are trying to run a parallel job but \"\n    . \"you are putting the output into just one log file ($logfile)\\n\";\n  exit(1);\n}\n\n$cmd = \"\";\n\nforeach $x (@ARGV) {\n    if ($x =~ m/^\\S+$/) { $cmd .=  $x . \" \"; }\n    elsif ($x =~ m:\\\":) { $cmd .= \"'$x' \"; }\n    else { $cmd .= \"\\\"$x\\\" \"; }\n}\n\n#$Data::Dumper::Indent=0;\n$ret = 0;\n$numfail = 0;\n%active_pids=();\n\nuse POSIX \":sys_wait_h\";\nfor ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {\n  if (scalar(keys %active_pids) >= $max_jobs_run) {\n\n    # Lets wait for a change in any child's status\n    # Then we have to work out which child finished\n    $r = waitpid(-1, 0);\n    $code = $?;\n    if ($r < 0 ) { die \"run.pl: Error waiting for child process\"; } # should never happen.\n    if ( defined $active_pids{$r} ) {\n        $jid=$active_pids{$r};\n        $fail[$jid]=$code;\n        if ($code !=0) { $numfail++;}\n        delete $active_pids{$r};\n        # print STDERR \"Finished: $r/$jid \" .  Dumper(\\%active_pids) . \"\\n\";\n    } else {\n        die \"run.pl: Cannot find the PID of the child process that just finished.\";\n    }\n\n    # In theory we could do a non-blocking waitpid over all jobs running just\n    # to find out if only one or more jobs finished during the previous waitpid()\n    # However, we just omit this and will reap the next one in the next pass\n    # through the for(;;) cycle\n  }\n  $childpid = fork();\n  if (!defined $childpid) { die \"run.pl: Error forking in run.pl (writing to $logfile)\"; }\n  if ($childpid == 0) { # We're in the child... this branch\n    # executes the job and returns (possibly with an error status).\n    if (defined $jobname) {\n      $cmd =~ s/$jobname/$jobid/g;\n      $logfile =~ s/$jobname/$jobid/g;\n    }\n    # exit if the job does not need to be executed\n    pick_or_exit( $logfile );\n\n    system(\"mkdir -p `dirname $logfile` 2>/dev/null\");\n    open(F, \">$logfile\") || die \"run.pl: Error opening log file $logfile\";\n    print F \"# \" . $cmd . \"\\n\";\n    print F \"# Started at \" . `date`;\n    $starttime = `date +'%s'`;\n    print F \"#\\n\";\n    close(F);\n\n    # Pipe into bash.. make sure we're not using any other shell.\n    open(B, \"|bash\") || die \"run.pl: Error opening shell command\";\n    print B \"( \" . $cmd . \") 2>>$logfile >> $logfile\";\n    close(B);                   # If there was an error, exit status is in $?\n    $ret = $?;\n\n    $lowbits = $ret & 127;\n    $highbits = $ret >> 8;\n    if ($lowbits != 0) { $return_str = \"code $highbits; signal $lowbits\" }\n    else { $return_str = \"code $highbits\"; }\n\n    $endtime = `date +'%s'`;\n    open(F, \">>$logfile\") || die \"run.pl: Error opening log file $logfile (again)\";\n    $enddate = `date`;\n    chop $enddate;\n    print F \"# Accounting: time=\" . ($endtime - $starttime) . \" threads=1\\n\";\n    print F \"# Ended ($return_str) at \" . $enddate . \", elapsed time \" . ($endtime-$starttime) . \" seconds\\n\";\n    close(F);\n    exit($ret == 0 ? 0 : 1);\n  } else {\n    $pid[$jobid] = $childpid;\n    $active_pids{$childpid} = $jobid;\n    # print STDERR \"Queued: \" .  Dumper(\\%active_pids) . \"\\n\";\n  }\n}\n\n# Now we have submitted all the jobs, lets wait until all the jobs finish\nforeach $child (keys %active_pids) {\n    $jobid=$active_pids{$child};\n    $r = waitpid($pid[$jobid], 0);\n    $code = $?;\n    if ($r == -1) { die \"run.pl: Error waiting for child process\"; } # should never happen.\n    if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully\n}\n\n# Some sanity checks:\n# The $fail array should not contain undefined codes\n# The number of non-zeros in that array  should be equal to $numfail\n# We cannot do foreach() here, as the JOB ids do not start at zero\n$failed_jids=0;\nfor ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {\n  $job_return = $fail[$jobid];\n  if (not defined $job_return ) {\n    # print Dumper(\\@fail);\n\n    die \"run.pl: Sanity check failed: we have indication that some jobs are running \" .\n      \"even after we waited for all jobs to finish\" ;\n  }\n  if ($job_return != 0 ){ $failed_jids++;}\n}\nif ($failed_jids != $numfail) {\n  die \"run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail).\"\n}\nif ($numfail > 0) { $ret = 1; }\n\nif ($ret != 0) {\n  $njobs = $jobend - $jobstart + 1;\n  if ($njobs == 1) {\n    if (defined $jobname) {\n      $logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with\n                                         # that job.\n    }\n    print STDERR \"run.pl: job failed, log is in $logfile\\n\";\n    if ($logfile =~ m/JOB/) {\n      print STDERR \"run.pl: probably you forgot to put JOB=1:\\$nj in your script.\";\n    }\n  }\n  else {\n    $logfile =~ s/$jobname/*/g;\n    print STDERR \"run.pl: $numfail / $njobs failed, log is in $logfile\\n\";\n  }\n}\n\n\nexit ($ret);\n"
  },
  {
    "path": "tools/score.sh",
    "content": "#!/bin/bash\n\nmin() {\n    local a b\n    a=$1\n    for b in \"$@\"; do\n        if [ \"${b}\" -le \"${a}\" ]; then\n            a=\"${b}\"\n        fi\n    done\n    echo \"${a}\"\n}\n\n# Set default values\ndset=\nexp_dir=\nscoring_opts=\n\nn_gpu=1\nscore_nj=16\nref_channel=0\nuse_pesq=false\nuse_dnsmos=false\ndnsmos_use_gpu=true\nfs=16k\nscoring_protocol=\"STOI SDR SAR SIR SI_SNR\"\n\n# Parse command line options\n. tools/parse_options.sh || exit 1\n\nif [ ! ${fs} = 16k ] && ${use_dnsmos}; then\n    echo \"Warning: DNSMOS only supports 16k sampling rate.\"\n    echo \"--use_dnsmos will be set to false automatically.\"\n    use_dnsmos=false\nfi\n\n# Set scoring options\nscoring_opts=\"\"\nif ${use_dnsmos}; then\n    # Set model path\n    primary_model_path=DNSMOS/sig_bak_ovr.onnx\n    p808_model_path=DNSMOS/model_v8.onnx\n\n    if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then\n        echo \"==========================================\"\n        echo \"Warning: DNSMOS model files are not found.\"\n        echo \"Trying to download them from the official repository.\"\n        echo \"If this takes too long,\"\n        echo \"please manually download the model files\"\n        echo \"and put them in the DNSMOS directory.\"\n        echo \"==========================================\"\n\n        # creat directory for DNSMOS model files\n        mkdir -p DNSMOS\n        # download DNSMOS model files and save them to the directory\n        wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx\n        wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/model_v8.onnx\n        # check if the model files are downloaded successfully\n        if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then\n            echo \"Error: DNSMOS model files are not downloaded successfully.\"\n            exit 1\n        fi\n    fi\n    scoring_opts+=\"--dnsmos_mode local \"\n    scoring_opts+=\"--dnsmos_primary_model ${primary_model_path} \"\n    scoring_opts+=\"--dnsmos_p808_model ${p808_model_path} \"\n    if ${dnsmos_use_gpu}; then\n        score_nj=$(min \"${score_nj}\" \"${n_gpu}\")\n        scoring_opts+=\"--dnsmos_use_gpu ${dnsmos_use_gpu} \"\n    fi\nfi\n\n# Set directories and log directory\n_dir=\"${exp_dir}/scoring\"\n_logdir=\"${_dir}/logdir\"\nmkdir -p \"${_logdir}\"\n\n# 0. Check the inference file\ninf_scp=${exp_dir}/audio/spk1.scp\nif [ ! -s \"${inf_scp}\" ] || [ -z \"$(cat \"${inf_scp}\")\" ]; then\n    echo \"Error: ${inf_scp} does not exist or is empty!\"\n    exit 1\nfi\n# 1. Split the key file\nkey_file=${dset}/single.wav.scp\nsplit_scps=\"\"\n_nj=$(min \"${score_nj}\" \"$(wc <${key_file} -l)\")\nfor n in $(seq \"${_nj}\"); do\n    split_scps+=\" ${_logdir}/keys.${n}.scp\"\ndone\n# shellcheck disable=SC2086\n./tools/split_scp.pl \"${key_file}\" ${split_scps}\n\n_ref_scp=\"--ref_scp ${dset}/single.wav.scp \"\n_inf_scp=\"--inf_scp ${exp_dir}/audio/spk1.scp \"\n\n# 2. Submit scoring jobs\necho \"log: '${_logdir}/tse_scoring.*.log'\"\nif ${use_dnsmos} && ${dnsmos_use_gpu}; then\n    cmd=\"./tools/run.pl --gpu ${n_gpu}\"\nelse\n    cmd=\"./tools/run.pl\"\nfi\n# shellcheck disable=SC2086\n${cmd} JOB=1:\"${_nj}\" \"${_logdir}\"/tse_scoring.JOB.log \\\n    python -m wesep.bin.score \\\n    --key_file \"${_logdir}\"/keys.JOB.scp \\\n    --output_dir \"${_logdir}\"/output.JOB \\\n    ${_ref_scp} \\\n    ${_inf_scp} \\\n    --ref_channel ${ref_channel} \\\n    --use_pesq ${use_pesq} \\\n    --use_dnsmos ${use_dnsmos} \\\n    --dnsmos_gpu_device JOB \\\n    ${scoring_opts}\n\n# Check if PESQ is used\nif \"${use_pesq}\"; then\n    if [ ${fs} = 16k ]; then\n        scoring_protocol+=\" PESQ_WB\"\n    else\n        scoring_protocol+=\" PESQ_NB\"\n    fi\nfi\n\n# Check if dnsmos is used\nif \"${use_dnsmos}\"; then\n    scoring_protocol+=\" BAK SIG OVRL P808_MOS\"\nfi\n\n# Merge and sort result files\nfor protocol in ${scoring_protocol} wav; do\n    for i in $(seq \"${_nj}\"); do\n        cat \"${_logdir}/output.${i}/${protocol}_spk1\"\n    done | LC_ALL=C sort -k1 >\"${_dir}/${protocol}_spk1\"\ndone\n\n# Calculate and save results\nfor protocol in ${scoring_protocol}; do\n    # shellcheck disable=SC2046\n    paste $(printf \"%s/%s_spk1 \" \"${_dir}\" \"${protocol}\") |\n        awk 'BEGIN{sum=0}\n            {n=0;score=0;for (i=2; i<=NF; i+=2){n+=1;score+=$i}; sum+=score/n}\n            END{printf (\"%.2f\\n\",sum/NR)}' >\"${_dir}/result_${protocol,,}.txt\"\ndone\n\n# show the result\n./tools/show_enh_score.sh \"${_dir}/../..\" > \\\n    \"${_dir}/../../RESULTS.md\"\n"
  },
  {
    "path": "tools/show_enh_score.sh",
    "content": "#!/usr/bin/env bash\nmindepth=0\nmaxdepth=1\n\n. tools/parse_options.sh\n\nif [ $# -gt 1 ]; then\n    echo \"Usage: $0 --mindepth 0 --maxdepth 1 [exp]\" 1>&2\n    echo \"\"\n    echo \"Show the system environments and the evaluation results in Markdown format.\"\n    echo 'The default of <exp> is \"exp/\".'\n    exit 1\nfi\n\n[ -f ./path.sh ] && . ./path.sh\nset -euo pipefail\nif [ $# -eq 1 ]; then\n    exp=$(realpath \"$1\")\nelse\n    exp=exp\nfi\n\ncat <<EOF\n<!-- Generated by $0 -->\n# RESULTS\n## Environments\n- date: \\`$(LC_ALL=C date)\\`\nEOF\n\ncat <<EOF\n- Git hash: \\`$(git rev-parse HEAD)\\`\n  - Commit date: \\`$(git log -1 --format='%cd')\\`\n\nEOF\n\nwhile IFS= read -r expdir; do\n    if ls \"${expdir}\"/*/scoring/result_stoi.txt &>/dev/null; then\n        echo -e \"\\n## $(basename ${expdir})\\n\"\n        [ -e \"${expdir}\"/config.yaml ] && grep ^config \"${expdir}\"/config.yaml\n        metrics=()\n        heading=\"\\n|dataset|\"\n        sep=\"|---|\"\n        for type in pesq pesq_wb pesq_nb estoi stoi sar sdr sir si_snr ovrl sig bak p808_mos; do\n            if ls \"${expdir}\"/*/scoring/result_${type}.txt &>/dev/null; then\n                metrics+=(\"$type\")\n                heading+=\"${type^^}|\"\n                sep+=\"---|\"\n            fi\n        done\n        echo -e \"${heading}\\n${sep}\"\n\n        setnames=()\n        for dirname in \"${expdir}\"/*/scoring/result_stoi.txt; do\n            dset=$(echo $dirname | sed -e \"s#${expdir}/\\([^/]*\\)/scoring/result_stoi.txt#\\1#g\")\n            setnames+=(\"$dset\")\n        done\n        for dset in \"${setnames[@]}\"; do\n            line=\"|${dset}|\"\n            for ((i = 0; i < ${#metrics[@]}; i++)); do\n                type=${metrics[$i]}\n                if [ -f \"${expdir}\"/${dset}/scoring/result_${type}.txt ]; then\n                    score=$(head -n1 \"${expdir}\"/${dset}/scoring/result_${type}.txt)\n                else\n                    score=\"\"\n                fi\n                line+=\"${score}|\"\n            done\n            echo $line\n        done\n        echo \"\"\n    fi\n\ndone < <(find ${exp} -mindepth ${mindepth} -maxdepth ${maxdepth} -type d)\n"
  },
  {
    "path": "tools/split_scp.pl",
    "content": "#!/usr/bin/env perl\n\n# Copyright 2010-2011 Microsoft Corporation\n\n# See ../../COPYING for clarification regarding multiple authors\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#  http://www.apache.org/licenses/LICENSE-2.0\n#\n# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED\n# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,\n# MERCHANTABLITY OR NON-INFRINGEMENT.\n# See the Apache 2 License for the specific language governing permissions and\n# limitations under the License.\n\n\n# This program splits up any kind of .scp or archive-type file.\n# If there is no utt2spk option it will work on any text  file and\n# will split it up with an approximately equal number of lines in\n# each but.\n# With the --utt2spk option it will work on anything that has the\n# utterance-id as the first entry on each line; the utt2spk file is\n# of the form \"utterance speaker\" (on each line).\n# It splits it into equal size chunks as far as it can.  If you use the utt2spk\n# option it will make sure these chunks coincide with speaker boundaries.  In\n# this case, if there are more chunks than speakers (and in some other\n# circumstances), some of the resulting chunks will be empty and it will print\n# an error message and exit with nonzero status.\n# You will normally call this like:\n# split_scp.pl scp scp.1 scp.2 scp.3 ...\n# or\n# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...\n# Note that you can use this script to split the utt2spk file itself,\n# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...\n\n# You can also call the scripts like:\n# split_scp.pl -j 3 0 scp scp.0\n# [note: with this option, it assumes zero-based indexing of the split parts,\n# i.e. the second number must be 0 <= n < num-jobs.]\n\nuse warnings;\n\n$num_jobs = 0;\n$job_id = 0;\n$utt2spk_file = \"\";\n$one_based = 0;\n\nfor ($x = 1; $x <= 3 && @ARGV > 0; $x++) {\n    if ($ARGV[0] eq \"-j\") {\n        shift @ARGV;\n        $num_jobs = shift @ARGV;\n        $job_id = shift @ARGV;\n    }\n    if ($ARGV[0] =~ /--utt2spk=(.+)/) {\n        $utt2spk_file=$1;\n        shift;\n    }\n    if ($ARGV[0] eq '--one-based') {\n        $one_based = 1;\n        shift @ARGV;\n    }\n}\n\nif ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||\n                       $job_id - $one_based >= $num_jobs)) {\n  die \"$0: Invalid job number/index values for '-j $num_jobs $job_id\" .\n      ($one_based ? \" --one-based\" : \"\") . \"'\\n\"\n}\n\n$one_based\n    and $job_id--;\n\nif(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {\n    die\n\"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...\n   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]\n ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\\n\";\n}\n\n$error = 0;\n$inscp = shift @ARGV;\nif ($num_jobs == 0) { # without -j option\n    @OUTPUTS = @ARGV;\n} else {\n    for ($j = 0; $j < $num_jobs; $j++) {\n        if ($j == $job_id) {\n            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }\n            else { push @OUTPUTS, \"-\"; }\n        } else {\n            push @OUTPUTS, \"/dev/null\";\n        }\n    }\n}\n\nif ($utt2spk_file ne \"\") {  # We have the --utt2spk option...\n    open($u_fh, '<', $utt2spk_file) || die \"$0: Error opening utt2spk file $utt2spk_file: $!\\n\";\n    while(<$u_fh>) {\n        @A = split;\n        @A == 2 || die \"$0: Bad line $_ in utt2spk file $utt2spk_file\\n\";\n        ($u,$s) = @A;\n        $utt2spk{$u} = $s;\n    }\n    close $u_fh;\n    open($i_fh, '<', $inscp) || die \"$0: Error opening input scp file $inscp: $!\\n\";\n    @spkrs = ();\n    while(<$i_fh>) {\n        @A = split;\n        if(@A == 0) { die \"$0: Empty or space-only line in scp file $inscp\\n\"; }\n        $u = $A[0];\n        $s = $utt2spk{$u};\n        defined $s || die \"$0: No utterance $u in utt2spk file $utt2spk_file\\n\";\n        if(!defined $spk_count{$s}) {\n            push @spkrs, $s;\n            $spk_count{$s} = 0;\n            $spk_data{$s} = [];  # ref to new empty array.\n        }\n        $spk_count{$s}++;\n        push @{$spk_data{$s}}, $_;\n    }\n    # Now split as equally as possible ..\n    # First allocate spks to files by allocating an approximately\n    # equal number of speakers.\n    $numspks = @spkrs;  # number of speakers.\n    $numscps = @OUTPUTS; # number of output files.\n    if ($numspks < $numscps) {\n      die \"$0: Refusing to split data because number of speakers $numspks \" .\n          \"is less than the number of output .scp files $numscps\\n\";\n    }\n    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {\n        $scparray[$scpidx] = []; # [] is array reference.\n    }\n    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {\n        $scpidx = int(($spkidx*$numscps) / $numspks);\n        $spk = $spkrs[$spkidx];\n        push @{$scparray[$scpidx]}, $spk;\n        $scpcount[$scpidx] += $spk_count{$spk};\n    }\n\n    # Now will try to reassign beginning + ending speakers\n    # to different scp's and see if it gets more balanced.\n    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.\n    # We can show that if considering changing just 2 scp's, we minimize\n    # this by minimizing the squared difference in sizes.  This is\n    # equivalent to minimizing the absolute difference in sizes.  This\n    # shows this method is bound to converge.\n\n    $changed = 1;\n    while($changed) {\n        $changed = 0;\n        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {\n            # First try to reassign ending spk of this scp.\n            if($scpidx < $numscps-1) {\n                $sz = @{$scparray[$scpidx]};\n                if($sz > 0) {\n                    $spk = $scparray[$scpidx]->[$sz-1];\n                    $count = $spk_count{$spk};\n                    $nutt1 = $scpcount[$scpidx];\n                    $nutt2 = $scpcount[$scpidx+1];\n                    if( abs( ($nutt2+$count) - ($nutt1-$count))\n                        < abs($nutt2 - $nutt1))  { # Would decrease\n                        # size-diff by reassigning spk...\n                        $scpcount[$scpidx+1] += $count;\n                        $scpcount[$scpidx] -= $count;\n                        pop @{$scparray[$scpidx]};\n                        unshift @{$scparray[$scpidx+1]}, $spk;\n                        $changed = 1;\n                    }\n                }\n            }\n            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {\n                $spk = $scparray[$scpidx]->[0];\n                $count = $spk_count{$spk};\n                $nutt1 = $scpcount[$scpidx-1];\n                $nutt2 = $scpcount[$scpidx];\n                if( abs( ($nutt2-$count) - ($nutt1+$count))\n                    < abs($nutt2 - $nutt1))  { # Would decrease\n                    # size-diff by reassigning spk...\n                    $scpcount[$scpidx-1] += $count;\n                    $scpcount[$scpidx] -= $count;\n                    shift @{$scparray[$scpidx]};\n                    push @{$scparray[$scpidx-1]}, $spk;\n                    $changed = 1;\n                }\n            }\n        }\n    }\n    # Now print out the files...\n    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {\n        $scpfile = $OUTPUTS[$scpidx];\n        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)\n                         : open($f_fh, '>&', \\*STDOUT)) ||\n            die \"$0: Could not open scp file $scpfile for writing: $!\\n\";\n        $count = 0;\n        if(@{$scparray[$scpidx]} == 0) {\n            print STDERR \"$0: eError: split_scp.pl producing empty .scp file \" .\n                         \"$scpfile (too many splits and too few speakers?)\\n\";\n            $error = 1;\n        } else {\n            foreach $spk ( @{$scparray[$scpidx]} ) {\n                print $f_fh @{$spk_data{$spk}};\n                $count += $spk_count{$spk};\n            }\n            $count == $scpcount[$scpidx] || die \"Count mismatch [code error]\";\n        }\n        close($f_fh);\n    }\n} else {\n   # This block is the \"normal\" case where there is no --utt2spk\n   # option and we just break into equal size chunks.\n\n    open($i_fh, '<', $inscp) || die \"$0: Error opening input scp file $inscp: $!\\n\";\n\n    $numscps = @OUTPUTS;  # size of array.\n    @F = ();\n    while(<$i_fh>) {\n        push @F, $_;\n    }\n    $numlines = @F;\n    if($numlines == 0) {\n        print STDERR \"$0: error: empty input scp file $inscp\\n\";\n        $error = 1;\n    }\n    $linesperscp = int( $numlines / $numscps); # the \"whole part\"..\n    $linesperscp >= 1 || die \"$0: You are splitting into too many pieces! [reduce \\$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\\n\";\n    $remainder = $numlines - ($linesperscp * $numscps);\n    ($remainder >= 0 && $remainder < $numlines) || die \"bad remainder $remainder\";\n    # [just doing int() rounds down].\n    $n = 0;\n    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {\n        $scpfile = $OUTPUTS[$scpidx];\n        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)\n                         : open($o_fh, '>&', \\*STDOUT)) ||\n            die \"$0: Could not open scp file $scpfile for writing: $!\\n\";\n        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {\n            print $o_fh $F[$n++];\n        }\n        close($o_fh) || die \"$0: Eror closing scp file $scpfile: $!\\n\";\n    }\n    $n == $numlines || die \"$n != $numlines [code error]\";\n}\n\nexit ($error);\n"
  },
  {
    "path": "tools/test_dataset.py",
    "content": "from torch.utils.data import DataLoader\r\n\r\nfrom wesep.dataset.dataset import Dataset\r\nfrom wesep.dataset.dataset import tse_collate_fn\r\nfrom wesep.utils.file_utils import load_speaker_embeddings\r\n\r\n\r\ndef test_premixed_dataset():\r\n    configs = {\r\n        \"shuffle\": False,\r\n        \"shuffle_args\": {\r\n            \"shuffle_size\": 2500\r\n        },\r\n        \"resample_rate\": 16000,\r\n        \"chunk_len\": 32000,\r\n    }\r\n\r\n    spk2embed_dict = load_speaker_embeddings(\"data/clean/test/embed.scp\",\r\n                                             \"data/clean/test/single.utt2spk\")\r\n\r\n    dataset = Dataset(\r\n        \"shard\",\r\n        \"data/clean/test/shard.list\",\r\n        configs=configs,\r\n        spk2embed_dict=spk2embed_dict,\r\n        whole_utt=False,\r\n    )\r\n    return dataset\r\n\r\n\r\ndef test_online_dataset():\r\n    # Implementation to test the online speaker mixing dataloader\r\n    configs = {\r\n        \"shuffle\": True,\r\n        \"resample_rate\": 16000,\r\n        \"chunk_len\": 64000,\r\n        \"num_speakers\": 2,\r\n        \"online_mix\": True,\r\n        \"reverb\": False,\r\n    }\r\n\r\n    spk2embed_dict = load_speaker_embeddings(\"mydata/clean/test/embed.scp\",\r\n                                             \"mydata/clean/test/utt2spk\")\r\n    dataset = Dataset(\r\n        \"shard\",\r\n        \"mydata/clean/test/shard.list\",\r\n        configs=configs,\r\n        spk2embed_dict=spk2embed_dict,\r\n        whole_utt=False,\r\n    )\r\n\r\n    return dataset\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    dataset = test_online_dataset()\r\n\r\n    dataloader = DataLoader(dataset,\r\n                            batch_size=4,\r\n                            num_workers=1,\r\n                            collate_fn=tse_collate_fn)\r\n\r\n    for i, batch in enumerate(dataloader):\r\n        print(\r\n            batch[\"wav_mix\"].size(),\r\n            batch[\"wav_targets\"].size(),\r\n            batch[\"spk_embeds\"].size(),\r\n        )\r\n        if i == 0:\r\n            break\r\n"
  },
  {
    "path": "wesep/__init__.py",
    "content": "from wesep.cli.extractor import load_model  # noqa\nfrom wesep.cli.extractor import load_model_local  # noqa"
  },
  {
    "path": "wesep/bin/average_model.py",
    "content": "# Copyright (c) 2020 Mobvoi Inc (Di Wu)\r\n#               2021 Hongji Wang (jijijiang77@gmail.com)\r\n#               2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#   http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport argparse\r\nimport glob\r\nimport os.path\r\nimport re\r\n\r\nimport torch\r\n\r\n\r\ndef get_args():\r\n    parser = argparse.ArgumentParser(description=\"average model\")\r\n    parser.add_argument(\"--dst_model\", required=True, help=\"averaged model\")\r\n    parser.add_argument(\"--src_path\",\r\n                        required=True,\r\n                        help=\"src model path for average\")\r\n    parser.add_argument(\"--num\",\r\n                        default=5,\r\n                        type=int,\r\n                        help=\"nums for averaged model\")\r\n    parser.add_argument(\r\n        \"--min_epoch\",\r\n        default=0,\r\n        type=int,\r\n        help=\"min epoch used for averaging model\",\r\n    )\r\n    parser.add_argument(\r\n        \"--max_epoch\",\r\n        default=65536,  # Big enough\r\n        type=int,\r\n        help=\"max epoch used for averaging model\",\r\n    )\r\n    parser.add_argument(\r\n        \"--mode\",\r\n        default=\"final\",\r\n        type=str,\r\n        help=\"use last epochs for average or best epochs\",\r\n    )\r\n    parser.add_argument(\r\n        \"--epochs\",\r\n        default=\"1,2,3,4,5\",\r\n        type=str,\r\n        help=\"use last epochs for average or best epochs\",\r\n    )\r\n    args = parser.parse_args()\r\n    print(args)\r\n    return args\r\n\r\n\r\ndef main():\r\n    args = get_args()\r\n    if args.mode == \"final\":\r\n        path_list = glob.glob(\"{}/*[!avg][!final][!latest].pt\".format(\r\n            args.src_path))\r\n        path_list = sorted(\r\n            path_list,\r\n            key=lambda p: int(re.findall(r\"(?<=checkpoint_)\\d*(?=.pt)\", p)[0]),\r\n        )\r\n        path_list = path_list[-args.num:]\r\n    else:\r\n        epoch_indexes = list(args.epochs.split(\",\"))\r\n        path_list = [\r\n            os.path.join(args.src_path, \"checkpoint_\" + x + \".pt\")\r\n            for x in epoch_indexes\r\n        ]\r\n    print(path_list)\r\n    avg = None\r\n    num = args.num\r\n    assert num == len(path_list)\r\n    for path in path_list:\r\n        print(\"Processing {}\".format(path))\r\n        states = torch.load(path, map_location=torch.device(\"cpu\"))\r\n        states = states[\"models\"][0] if \"models\" in states else states\r\n        if avg is None:\r\n            avg = states\r\n        else:\r\n            for k in avg.keys():\r\n                avg[k] += states[k]\r\n    # average\r\n    for k in avg.keys():\r\n        if avg[k] is not None:\r\n            # pytorch 1.6 use true_divide instead of /=\r\n            avg[k] = torch.true_divide(avg[k], num)\r\n    avg = {\"models\": [avg]}\r\n    print(\"Saving to {}\".format(args.dst_model))\r\n    torch.save(avg, args.dst_model)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "wesep/bin/export_jit.py",
    "content": "from __future__ import print_function\n\nimport argparse\nimport os\n\nimport torch\nimport yaml\n\nfrom wesep.models import get_model\nfrom wesep.utils.checkpoint import load_pretrained_model\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description=\"export your script model\")\n    parser.add_argument(\"--config\", required=True, help=\"config file\")\n    parser.add_argument(\"--checkpoint\", required=True, help=\"checkpoint model\")\n    parser.add_argument(\"--output_model\", required=True, help=\"output file\")\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = get_args()\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n\n    with open(args.config, \"r\") as fin:\n        configs = yaml.load(fin, Loader=yaml.FullLoader)\n    print(configs)\n\n    model = get_model(\n        configs[\"model\"][\"tse_model\"])(**configs[\"model_args\"][\"tse_model\"])\n    print(model)\n\n    load_pretrained_model(model, args.checkpoint)\n    model.eval()\n\n    speaker_feat_dim = configs[\"dataset_args\"][\"fbank_args\"].get(\n        \"num_mel_bins\", 80)\n\n    speaker_dummy_input = torch.ones(2, 300, speaker_feat_dim)\n    mix_dummy_input = torch.ones(2, 81280)\n    script_model = torch.jit.script(model,\n                                    (mix_dummy_input, speaker_dummy_input))\n    script_model.save(args.output_model)\n    print(\"Export model successfully, see {}\".format(args.output_model))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "wesep/bin/infer.py",
    "content": "from __future__ import print_function\n\nimport os\nimport time\n\nimport fire\nimport soundfile\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom wesep.dataset.dataset import Dataset, tse_collate_fn_2spk\nfrom wesep.models import get_model\nfrom wesep.utils.checkpoint import load_pretrained_model\nfrom wesep.utils.file_utils import read_label_file, read_vec_scp_file\nfrom wesep.utils.score import cal_SISNRi\nfrom wesep.utils.utils import (\n    generate_enahnced_scp,\n    get_logger,\n    parse_config_or_kwargs,\n    set_seed,\n)\n\nos.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\nos.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\"\n\n\ndef infer(config=\"confs/conf.yaml\", **kwargs):\n    start = time.time()\n    total_SISNR = 0\n    total_SISNRi = 0\n    total_cnt = 0\n    accept_cnt = 0\n\n    configs = parse_config_or_kwargs(config, **kwargs)\n    sign_save_wav = configs.get(\n        \"save_wav\", True)  # Control if save the extracted speech as .wav\n\n    rank = 0\n    set_seed(configs[\"seed\"] + rank)\n    gpu = configs[\"gpus\"]\n    device = (torch.device(\"cuda:{}\".format(gpu))\n              if gpu >= 0 else torch.device(\"cpu\"))\n\n    sample_rate = configs.get(\"fs\", None)\n    if sample_rate is None or sample_rate == \"16k\":\n        sample_rate = 16000\n    else:\n        sample_rate = 8000\n\n    if 'spk_model_init' in configs['model_args']['tse_model']:\n        configs['model_args']['tse_model']['spk_model_init'] = False\n    model = get_model(\n        configs[\"model\"][\"tse_model\"])(**configs[\"model_args\"][\"tse_model\"])\n    model_path = os.path.join(configs[\"checkpoint\"])\n    load_pretrained_model(model, model_path)\n\n    logger = get_logger(configs[\"exp_dir\"], \"infer.log\")\n    logger.info(\"Load checkpoint from {}\".format(model_path))\n    save_audio_dir = os.path.join(configs[\"exp_dir\"], \"audio\")\n    if sign_save_wav:\n        if not os.path.exists(save_audio_dir):\n            try:\n                os.makedirs(save_audio_dir)\n                print(f\"Directory {save_audio_dir} created successfully.\")\n            except OSError as e:\n                print(f\"Error creating directory {save_audio_dir}: {e}\")\n        else:\n            print(f\"Directory {save_audio_dir} already exists.\")\n    else:\n        print(\"Do NOT save the results in wav.\")\n\n    model = model.to(device)\n    model.eval()\n\n    test_spk_embeds = configs.get(\"test_spk_embeds\", None)\n    test_spk1_embed_scp = configs[\"test_spk1_enroll\"]\n    test_spk2_embed_scp = configs[\"test_spk2_enroll\"]\n    joint_training = configs[\"model_args\"][\"tse_model\"].get(\n        \"joint_training\", None)\n    if not joint_training and test_spk_embeds:\n        test_spk2embed_dict = read_vec_scp_file(test_spk_embeds)\n    else:\n        test_spk2embed_dict = read_label_file(configs[\"test_spk2utt\"])\n\n    test_spk1_embed = read_label_file(test_spk1_embed_scp)\n    test_spk2_embed = read_label_file(test_spk2_embed_scp)\n\n    lines = len(test_spk2embed_dict)\n\n    test_dataset = Dataset(\n        configs[\"data_type\"],\n        configs[\"test_data\"],\n        configs[\"dataset_args\"],\n        test_spk2embed_dict,\n        test_spk1_embed,\n        test_spk2_embed,\n        state=\"test\",\n        joint_training=joint_training,\n        whole_utt=configs.get(\"whole_utt\", True),\n        repeat_dataset=configs.get(\"repeat_dataset\", False),\n    )\n    test_dataloader = DataLoader(test_dataset,\n                                 batch_size=1,\n                                 collate_fn=tse_collate_fn_2spk)\n    test_iter = lines // 2\n    logger.info(\"test number: {}\".format(test_iter))\n\n    with torch.no_grad():\n        for i, batch in enumerate(test_dataloader):\n            features = batch[\"wav_mix\"]\n            targets = batch[\"wav_targets\"]\n            enroll = batch[\"spk_embeds\"]\n            spk = batch[\"spk\"]\n            key = batch[\"key\"]\n\n            features = features.float().to(device)  # (B,T,F)\n            targets = targets.float().to(device)\n            enroll = enroll.float().to(device)\n\n            outputs = model(features, enroll)\n            if isinstance(outputs, (list, tuple)):\n                outputs = outputs[0]\n\n            if torch.min(outputs.max(dim=1).values) > 0:\n                outputs = ((outputs /\n                            abs(outputs).max(dim=1, keepdim=True)[0] *\n                            0.9).cpu().numpy())\n            else:\n                outputs = outputs.cpu().numpy()\n\n            if sign_save_wav:\n                file1 = os.path.join(\n                    save_audio_dir,\n                    f\"Utt{total_cnt + 1}-{key[0]}-T{spk[0]}.wav\",\n                )\n                soundfile.write(file1, outputs[0], sample_rate)\n                file2 = os.path.join(\n                    save_audio_dir,\n                    f\"Utt{total_cnt + 1}-{key[1]}-T{spk[1]}.wav\",\n                )\n                soundfile.write(file2, outputs[1], sample_rate)\n\n            ref = targets.cpu().numpy()\n            ests = outputs\n            mix = features.cpu().numpy()\n\n            if ests[0].size != ref[0].size:\n                end = min(ests[0].size, ref[0].size, mix[0].size)\n                ests_1 = ests[0][:end]\n                ref_1 = ref[0][:end]\n                mix_1 = mix[0][:end]\n                SISNR1, delta1 = cal_SISNRi(ests_1, ref_1, mix_1)\n            else:\n                SISNR1, delta1 = cal_SISNRi(ests[0], ref[0], mix[0])\n\n            logger.info(\n                \"Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}\"\n                .format(total_cnt + 1, key[0], spk[0], SISNR1, delta1))\n            total_SISNR += SISNR1\n            total_SISNRi += delta1\n            total_cnt += 1\n            if delta1 > 1:\n                accept_cnt += 1\n\n            if ests[1].size != ref[1].size:\n                end = min(ests[1].size, ref[1].size, mix[1].size)\n                ests_2 = ests[1][:end]\n                ref_2 = ref[1][:end]\n                mix_2 = mix[1][:end]\n                SISNR2, delta2 = cal_SISNRi(ests_2, ref_2, mix_2)\n            else:\n                SISNR2, delta2 = cal_SISNRi(ests[1], ref[1], mix[1])\n            logger.info(\n                \"Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}\"\n                .format(total_cnt + 1, key[1], spk[1], SISNR2, delta2))\n            total_SISNR += SISNR2\n            total_SISNRi += delta2\n            total_cnt += 1\n            if delta2 > 1:\n                accept_cnt += 1\n\n            # if (i + 1) == test_iter:\n            #     break\n        end = time.time()\n    # generate the scp file of the enhanced speech for scoring\n    if sign_save_wav:\n        generate_enahnced_scp(os.path.abspath(save_audio_dir), extension=\"wav\")\n\n    logger.info(\"Time Elapsed: {:.1f}s\".format(end - start))\n    logger.info(\"Average SI-SNR: {:.2f}\".format(total_SISNR / total_cnt))\n    logger.info(\"Average SI-SNRi: {:.2f}\".format(total_SISNRi / total_cnt))\n    logger.info(\n        \"Acceptance rate of Utterances with SI-SDRi > 1 dB: {:.2f}\".format(\n            accept_cnt / total_cnt * 100))\n\n\nif __name__ == \"__main__\":\n    fire.Fire(infer)\n"
  },
  {
    "path": "wesep/bin/score.py",
    "content": "# ported from\n# https://github.com/espnet/espnet/blob/master/espnet2/bin/enh_scoring.py\nimport argparse\nimport logging\nimport sys\nfrom pathlib import Path\nfrom typing import Dict, List, Union\n\nimport numpy as np\nfrom mir_eval.separation import bss_eval_sources\nfrom pystoi import stoi\n\nfrom wesep.utils.datadir_writer import DatadirWriter\nfrom wesep.utils.file_utils import SoundScpReader\nfrom wesep.utils.score import cal_SISNR\nfrom wesep.utils.utils import ArgumentParser, get_commandline_args, str2bool\n\n\ndef get_readers(scps: List[str], dtype: str):\n    readers = [SoundScpReader(f, dtype=dtype) for f in scps]\n    audio_format = \"sound\"\n    return readers, audio_format\n\n\ndef read_audio(reader, key, audio_format=\"sound\"):\n    if audio_format == \"sound\":\n        return reader[key][1]\n    else:\n        raise ValueError(f\"Unknown audio format: {audio_format}\")\n\n\ndef scoring(\n    output_dir: str,\n    dtype: str,\n    log_level: Union[int, str],\n    key_file: str,\n    ref_scp: List[str],\n    inf_scp: List[str],\n    ref_channel: int,\n    use_dnsmos: bool,\n    dnsmos_args: Dict,\n    use_pesq: bool,\n):\n    logging.basicConfig(\n        level=log_level,\n        format=\"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s\",\n    )\n\n    if use_dnsmos:\n        if dnsmos_args[\"mode\"] == \"local\":\n            from wesep.utils.dnsmos import DNSMOS_local\n\n            if not Path(dnsmos_args[\"primary_model\"]).exists():\n                raise ValueError(\n                    f\"The primary model {dnsmos_args['primary_model']} doesn't exist.\"\n                    \" You can download the model from https://github.com/microsoft/\"\n                    \"DNS-Challenge/tree/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx\")\n            if not Path(dnsmos_args[\"p808_model\"]).exists():\n                raise ValueError(\n                    f\"The P808 model {dnsmos_args['p808_model']} doesn't exist.\"\n                    \" You can download the model from https://github.com/microsoft/\"\n                    \"DNS-Challenge/tree/master/DNSMOS/DNSMOS/model_v8.onnx\")\n            dnsmos = DNSMOS_local(\n                dnsmos_args[\"primary_model\"],\n                dnsmos_args[\"p808_model\"],\n                use_gpu=dnsmos_args[\"use_gpu\"],\n                convert_to_torch=dnsmos_args[\"convert_to_torch\"],\n                gpu_device=dnsmos_args[\"gpu_device\"] - 1,\n            )\n            logging.warning(\"Using local DNSMOS models for evaluation\")\n\n        elif dnsmos_args[\"mode\"] == \"web\":\n            from wesep.utils.dnsmos import DNSMOS_web\n\n            if not dnsmos_args[\"auth_key\"]:\n                raise ValueError(\n                    \"Please specify the authentication key for access to the Web-API. \"\n                    \"You can apply for the AUTH_KEY at https://github.com/microsoft/\"\n                    \"DNS-Challenge/blob/master/DNSMOS/README.md#to-use-the-web-api\"\n                )\n            dnsmos = DNSMOS_web(dnsmos_args[\"auth_key\"])\n            logging.warning(\"Using the DNSMOS Web-API for evaluation\")\n    else:\n        dnsmos = None\n\n    if use_pesq:\n        try:\n            from pesq import PesqError, pesq\n\n            logging.warning(\"Using the PESQ package for evaluation\")\n        except ImportError:\n            raise ImportError(\n                \"Please install pesq and retry: pip install pesq\") from None\n    else:\n        pesq = None\n\n    assert len(ref_scp) == len(inf_scp), \"len(ref_scp) != len(inf_scp)\"\n    num_spk = len(ref_scp)\n\n    keys = [\n        line.rstrip().split(maxsplit=1)[0]\n        for line in open(key_file, encoding=\"utf-8\")\n    ]\n\n    ref_readers, ref_audio_format = get_readers(ref_scp, dtype)\n    inf_readers, inf_audio_format = get_readers(inf_scp, dtype)\n\n    # get sample rate\n    retval = ref_readers[0][keys[0]]\n    if ref_audio_format == \"kaldi_ark\":\n        sample_rate = ref_readers[0].rate\n    elif ref_audio_format == \"sound\":\n        sample_rate = retval[0]\n    else:\n        raise NotImplementedError(ref_audio_format)\n    assert sample_rate is not None, (sample_rate, ref_audio_format)\n\n    # check keys\n    for inf_reader, ref_reader in zip(inf_readers, ref_readers):\n        assert inf_reader.keys() == ref_reader.keys()\n\n    with DatadirWriter(output_dir) as writer:\n        for n, key in enumerate(keys):\n            logging.info(f\"[{n}] Scoring {key}\")\n            ref_audios = [\n                read_audio(ref_reader, key, audio_format=ref_audio_format)\n                for ref_reader in ref_readers\n            ]\n            inf_audios = [\n                read_audio(inf_reader, key, audio_format=inf_audio_format)\n                for inf_reader in inf_readers\n            ]\n            ref = np.array(ref_audios)\n            inf = np.array(inf_audios)\n            if ref.ndim > inf.ndim:\n                # multi-channel reference and single-channel output\n                ref = ref[..., ref_channel]\n            elif ref.ndim < inf.ndim:\n                # single-channel reference and multi-channel output\n                inf = inf[..., ref_channel]\n            elif ref.ndim == inf.ndim == 3:\n                # multi-channel reference and output\n                ref = ref[..., ref_channel]\n                inf = inf[..., ref_channel]\n\n            assert ref.shape == inf.shape, (ref.shape, inf.shape)\n\n            sdr, sir, sar, perm = bss_eval_sources(ref,\n                                                   inf,\n                                                   compute_permutation=True)\n\n            for i in range(num_spk):\n                stoi_score = stoi(ref[i],\n                                  inf[int(perm[i])],\n                                  fs_sig=sample_rate)\n                estoi_score = stoi(\n                    ref[i],\n                    inf[int(perm[i])],\n                    fs_sig=sample_rate,\n                    extended=True,\n                )\n                si_snr_score = cal_SISNR(\n                    ref[i],\n                    inf[int(perm[i])],\n                )\n\n                if dnsmos:\n                    dnsmos_score = dnsmos(inf[int(perm[i])], sample_rate)\n                    writer[f\"OVRL_spk{i + 1}\"][key] = str(dnsmos_score[\"OVRL\"])\n                    writer[f\"SIG_spk{i + 1}\"][key] = str(dnsmos_score[\"SIG\"])\n                    writer[f\"BAK_spk{i + 1}\"][key] = str(dnsmos_score[\"BAK\"])\n                    writer[f\"P808_MOS_spk{i + 1}\"][key] = str(\n                        dnsmos_score[\"P808_MOS\"])\n                if pesq:\n                    if sample_rate == 8000:\n                        mode = \"nb\"\n                    elif sample_rate == 16000:\n                        mode = \"wb\"\n                    else:\n                        raise ValueError(\n                            \"sample rate must be 8000 or 16000 for PESQ evaluation, \"\n                            f\"but got {sample_rate}\")\n                    pesq_score = pesq(\n                        sample_rate,\n                        ref[i],\n                        inf[int(perm[i])],\n                        mode=mode,\n                        on_error=PesqError.RETURN_VALUES,\n                    )\n                    if pesq_score == PesqError.NO_UTTERANCES_DETECTED:\n                        logging.warning(\n                            f\"[PESQ] Error: No utterances detected for {key}. \"\n                            \"Skipping this utterance.\")\n                    else:\n                        writer[f\"PESQ_{mode.upper()}_spk{i + 1}\"][key] = str(\n                            pesq_score)\n                writer[f\"STOI_spk{i + 1}\"][key] = str(stoi_score *\n                                                      100)  # in percentage\n                writer[f\"ESTOI_spk{i + 1}\"][key] = str(estoi_score * 100)\n                writer[f\"SI_SNR_spk{i + 1}\"][key] = str(si_snr_score)\n                writer[f\"SDR_spk{i + 1}\"][key] = str(sdr[i])\n                writer[f\"SAR_spk{i + 1}\"][key] = str(sar[i])\n                writer[f\"SIR_spk{i + 1}\"][key] = str(sir[i])\n                # save permutation assigned script file\n                if i < len(ref_scp):\n                    if inf_audio_format == \"sound\":\n                        writer[f\"wav_spk{i + 1}\"][key] = inf_readers[\n                            perm[i]].data[key]\n                    elif inf_audio_format == \"kaldi_ark\":\n                        # NOTE: SegmentsExtractor is not supported\n                        writer[f\"wav_spk{i + 1}\"][key] = inf_readers[\n                            perm[i]].loader._dict[key]\n                    else:\n                        raise ValueError(\n                            f\"Unknown audio format: {inf_audio_format}\")\n\n\ndef get_parser():\n    parser = ArgumentParser(\n        description=\"Frontend inference\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    # Note(kamo): Use '_' instead of '-' as separator.\n    # '-' is confusing if written in yaml.\n\n    parser.add_argument(\n        \"--log_level\",\n        type=lambda x: x.upper(),\n        default=\"INFO\",\n        choices=(\"CRITICAL\", \"ERROR\", \"WARNING\", \"INFO\", \"DEBUG\", \"NOTSET\"),\n        help=\"The verbose level of logging\",\n    )\n\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n\n    parser.add_argument(\n        \"--dtype\",\n        default=\"float32\",\n        choices=[\"float16\", \"float32\", \"float64\"],\n        help=\"Data type\",\n    )\n\n    group = parser.add_argument_group(\"Input data related\")\n    group.add_argument(\n        \"--ref_scp\",\n        type=str,\n        required=True,\n        action=\"append\",\n    )\n    group.add_argument(\n        \"--inf_scp\",\n        type=str,\n        required=True,\n        action=\"append\",\n    )\n    group.add_argument(\"--key_file\", type=str)\n    group.add_argument(\"--ref_channel\", type=int, default=0)\n\n    group = parser.add_argument_group(\"DNSMOS related\")\n    group.add_argument(\"--use_dnsmos\", type=str2bool, default=False)\n    group.add_argument(\n        \"--dnsmos_mode\",\n        type=str,\n        choices=(\"local\", \"web\"),\n        default=\"local\",\n        help=\"Use local DNSMOS model or web API for DNSMOS calculation\",\n    )\n    group.add_argument(\n        \"--dnsmos_auth_key\",\n        type=str,\n        default=\"\",\n        help=\"Required if dnsmsos_mode='web'\",\n    )\n    group.add_argument(\n        \"--dnsmos_use_gpu\",\n        type=str2bool,\n        default=False,\n        help=\"used when dnsmsos_mode='local'\",\n    )\n    group.add_argument(\n        \"--dnsmos_convert_to_torch\",\n        type=str2bool,\n        default=False,\n        help=\"used when dnsmsos_mode='local'\",\n    )\n    group.add_argument(\"--dnsmos_primary_model\",\n                       type=str,\n                       default=\"./DNSMOS/sig_bak_ovr.onnx\",\n                       help=\"Path to the primary DNSMOS model. \"\n                       \"Required if dnsmsos_mode='local'\")\n    group.add_argument(\n        \"--dnsmos_p808_model\",\n        type=str,\n        default=\"./DNSMOS/model_v8.onnx\",\n        help=\"Path to the p808 model. Required if dnsmsos_mode='local'\",\n    )\n    group.add_argument(\"--dnsmos_gpu_device\",\n                       type=int,\n                       default=None,\n                       help=\"gpu device to use for DNSMOS evaluation. \"\n                       \"Used when dnsmsos_mode='local'\")\n\n    group = parser.add_argument_group(\"PESQ related\")\n    group.add_argument(\n        \"--use_pesq\",\n        type=str2bool,\n        default=False,\n        help=\"Bebore setting this to True, please make sure that you or \"\n        \"your institution have the license \"\n        \"(check https://www.itu.int/rec/T-REC-P.862-200511-I!Amd2/en) to report PESQ\",\n    )\n    return parser\n\n\ndef main(cmd=None):\n    print(get_commandline_args(), file=sys.stderr)\n    parser = get_parser()\n    args = parser.parse_args(cmd)\n    kwargs = vars(args)\n    kwargs.pop(\"config\", None)\n\n    dnsmos_args = {\n        \"mode\": kwargs.pop(\"dnsmos_mode\"),\n        \"auth_key\": kwargs.pop(\"dnsmos_auth_key\"),\n        \"primary_model\": kwargs.pop(\"dnsmos_primary_model\"),\n        \"p808_model\": kwargs.pop(\"dnsmos_p808_model\"),\n        \"use_gpu\": kwargs.pop(\"dnsmos_use_gpu\"),\n        \"convert_to_torch\": kwargs.pop(\"dnsmos_convert_to_torch\"),\n        \"gpu_device\": kwargs.pop(\"dnsmos_gpu_device\"),\n    }\n    kwargs[\"dnsmos_args\"] = dnsmos_args\n    scoring(**kwargs)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "wesep/bin/train.py",
    "content": "# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport re\nfrom pprint import pformat\n\nimport fire\nimport matplotlib.pyplot as plt\nimport tableprint as tp\nimport torch\nimport torch.distributed as dist\nimport yaml\nfrom torch.utils.data import DataLoader\n\nimport wesep.utils.schedulers as schedulers\nfrom wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk\nfrom wesep.models import get_model\nfrom wesep.utils.checkpoint import (\n    load_checkpoint,\n    load_pretrained_model,\n    save_checkpoint,\n)\nfrom wesep.utils.executor import Executor\nfrom wesep.utils.file_utils import (\n    load_speaker_embeddings,\n    read_label_file,\n    read_vec_scp_file,\n)\nfrom wesep.utils.losses import parse_loss\nfrom wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger\n\nMAX_NUM_log_files = 100  # The maximum number of log-files to be kept\nlogging.getLogger(\"matplotlib.font_manager\").setLevel(logging.ERROR)\n\n\ndef train(config=\"conf/config.yaml\", **kwargs):\n    \"\"\"Trains a model on the given features and spk labels.\n\n    :config: A training configuration. Note that all parameters in the\n             config can also be manually adjusted with --ARG VALUE\n    :returns: None\n    \"\"\"\n    # print(kwargs)\n    configs = parse_config_or_kwargs(config, **kwargs)\n    checkpoint = configs.get(\"checkpoint\", None)\n    if checkpoint is not None:\n        checkpoint = os.path.realpath(checkpoint)\n    find_unused_parameters = configs.get(\"find_unused_parameters\", False)\n\n    # dist configs\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n    gpu = int(configs[\"gpus\"][rank])\n    torch.cuda.set_device(gpu)\n    dist.init_process_group(backend=\"nccl\")\n\n    # Log rotation\n    model_dir = os.path.join(configs[\"exp_dir\"], \"models\")\n    logger = setup_logger(rank, configs[\"exp_dir\"], gpu, MAX_NUM_log_files)\n\n    print(\"-------------------\", dist.get_rank(), world_size)\n    if world_size > 1:\n        logger.info(\"training on multiple gpus, this gpu {}\".format(gpu))\n\n    if rank == 0:\n        logger.info(\"exp_dir is: {}\".format(configs[\"exp_dir\"]))\n        logger.info(\"<== Passed Arguments ==>\")\n        # Print arguments into logs\n        for line in pformat(configs).split(\"\\n\"):\n            logger.info(line)\n\n    # seed\n    set_seed(configs[\"seed\"] + rank)\n\n    # loss\n    criterion = configs.get(\"loss\", None)\n    if criterion:\n        criterion = parse_loss(criterion)\n    else:\n        criterion = [\n            parse_loss(\"SISDR\"),\n        ]\n    loss_posi = configs[\"loss_args\"].get(\n        \"loss_posi\",\n        [[\n            0,\n        ]],\n    )\n    loss_weight = configs[\"loss_args\"].get(\n        \"loss_weight\",\n        [[\n            1.0,\n        ]],\n    )\n    loss_args = (loss_posi, loss_weight)\n\n    # embeds\n    tr_spk_embeds = configs.get(\"train_spk_embeds\", None)\n    tr_single_utt2spk = configs[\"train_utt2spk\"]\n    joint_training = configs[\"model_args\"][\"tse_model\"].get(\n        \"joint_training\", False)\n    multi_task = configs[\"model_args\"][\"tse_model\"].get(\"multi_task\", False)\n\n    dict_spk = {}\n    if not joint_training and tr_spk_embeds:\n        tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds,\n                                                    tr_single_utt2spk)\n        multi_task = None\n    else:\n        with open(configs[\"train_spk2utt\"], \"r\") as f:\n            tr_spk2embed_dict = json.load(f)\n            if multi_task:\n                for i, j in enumerate(tr_spk2embed_dict.keys(\n                )):  # Generate the dictionary for speakers in training set\n                    dict_spk[j] = i\n\n    with open(tr_single_utt2spk, \"r\") as f:\n        tr_lines = f.readlines()\n\n    val_spk_embeds = configs.get(\"val_spk_embeds\", None)\n    val_spk1_enroll = configs[\"val_spk1_enroll\"]\n    val_spk2_enroll = configs[\"val_spk2_enroll\"]\n\n    if not joint_training and val_spk_embeds:\n        val_spk2embed_dict = read_vec_scp_file(val_spk_embeds)\n    else:\n        val_spk2embed_dict = read_label_file(configs[\"val_spk2utt\"])\n\n    val_lines = len(val_spk2embed_dict)\n\n    val_spk1_embed = read_label_file(val_spk1_enroll)\n    val_spk2_embed = read_label_file(val_spk2_enroll)\n\n    # dataset and dataloader\n    train_dataset = Dataset(\n        configs[\"data_type\"],\n        configs[\"train_data\"],\n        configs[\"dataset_args\"],\n        tr_spk2embed_dict,\n        None,\n        None,\n        state=\"train\",\n        joint_training=joint_training,\n        dict_spk=dict_spk,\n        whole_utt=configs.get(\"whole_utt\", False),\n        repeat_dataset=configs.get(\"repeat_dataset\", True),\n        noise_prob=configs[\"dataset_args\"].get(\"noise_prob\", 0),\n        reverb_prob=configs[\"dataset_args\"].get(\"reverb_prob\", 0),\n        noise_enroll_prob=configs[\"dataset_args\"].get(\"noise_enroll_prob\", 0),\n        reverb_enroll_prob=configs[\"dataset_args\"].get(\"reverb_enroll_prob\",\n                                                       0),\n        specaug_enroll_prob=configs[\"dataset_args\"].get(\n            \"specaug_enroll_prob\", 0),\n        online_mix=configs[\"dataset_args\"].get(\"online_mix\", False),\n        noise_lmdb_file=configs[\"dataset_args\"].get(\"noise_lmdb_file\", None),\n    )\n    val_dataset = Dataset(configs[\"data_type\"],\n                          configs[\"val_data\"],\n                          configs[\"dataset_args\"],\n                          val_spk2embed_dict,\n                          val_spk1_embed,\n                          val_spk2_embed,\n                          state=\"val\",\n                          joint_training=joint_training,\n                          whole_utt=configs.get(\"whole_utt\", False),\n                          repeat_dataset=True,\n                          online_mix=False,\n                          noise_prob=0,\n                          reverb_prob=0,\n                          noise_enroll_prob=0,\n                          reverb_enroll_prob=0,\n                          specaug_enroll_prob=0)\n    train_dataloader = DataLoader(train_dataset,\n                                  **configs[\"dataloader_args\"],\n                                  collate_fn=tse_collate_fn)\n    val_dataloader = DataLoader(\n        val_dataset,\n        **configs[\"dataloader_args\"],\n        collate_fn=tse_collate_fn_2spk,\n    )\n    batch_size = configs[\"dataloader_args\"][\"batch_size\"]\n    if configs[\"dataset_args\"].get(\"sample_num_per_epoch\", 0) > 0:\n        sample_num_per_epoch = configs[\"dataset_args\"][\"sample_num_per_epoch\"]\n    else:\n        sample_num_per_epoch = len(tr_lines) // 2\n    epoch_iter = sample_num_per_epoch // world_size // batch_size\n    val_iter = val_lines // 2 // world_size // batch_size\n    if rank == 0:\n        logger.info(\"<== Dataloaders ==>\")\n        logger.info(\"train dataloaders created\")\n        logger.info(\"epoch iteration number: {}\".format(epoch_iter))\n        logger.info(\"val iteration number: {}\".format(val_iter))\n\n    # model\n    model_list = []\n    scheduler_list = []\n    optimizer_list = []\n\n    logger.info(\"<== Model ==>\")\n    model = get_model(\n        configs[\"model\"][\"tse_model\"])(**configs[\"model_args\"][\"tse_model\"])\n    num_params = sum(param.numel() for param in model.parameters())\n\n    if rank == 0:\n        logger.info(\"tse_model size: {:.2f} M\".format(num_params / 1e6))\n        # print model\n        for line in pformat(model).split(\"\\n\"):\n            logger.info(line)\n\n    # ddp_model\n    model.cuda()\n    ddp_model = torch.nn.parallel.DistributedDataParallel(\n        model, find_unused_parameters=find_unused_parameters)\n    device = torch.device(\"cuda\")\n\n    if rank == 0:\n        logger.info(\"<== TSE Model Loss ==>\")\n        logger.info(\"loss criterion is: \" + str(configs[\"loss\"]))\n\n    configs[\"optimizer_args\"][\"tse_model\"][\"lr\"] = configs[\"scheduler_args\"][\n        \"tse_model\"][\"initial_lr\"]\n    optimizer = getattr(torch.optim, configs[\"optimizer\"][\"tse_model\"])(\n        ddp_model.parameters(), **configs[\"optimizer_args\"][\"tse_model\"])\n    if rank == 0:\n        logger.info(\"<== TSE Model Optimizer ==>\")\n        logger.info(\"optimizer is: \" + configs[\"optimizer\"][\"tse_model\"])\n\n    # scheduler\n    configs[\"scheduler_args\"][\"tse_model\"][\"num_epochs\"] = configs[\n        \"num_epochs\"]\n    configs[\"scheduler_args\"][\"tse_model\"][\"epoch_iter\"] = epoch_iter\n    configs[\"scheduler_args\"][\"scale_ratio\"] = 1.0\n\n    scheduler = getattr(schedulers, configs[\"scheduler\"][\"tse_model\"])(\n        optimizer, **configs[\"scheduler_args\"][\"tse_model\"])\n    if rank == 0:\n        logger.info(\"<== TSE Model Scheduler ==>\")\n        logger.info(\"scheduler is: \" + configs[\"scheduler\"][\"tse_model\"])\n\n    if configs[\"model_init\"][\"tse_model\"] is not None:\n        logger.info(\"Load initial model from {}\".format(\n            configs[\"model_init\"][\"tse_model\"]))\n        load_pretrained_model(ddp_model, configs[\"model_init\"][\"tse_model\"])\n    elif checkpoint is None:\n        logger.info(\"Train model from scratch ...\")\n\n    for c in criterion:\n        c = c.to(device)\n\n    # append to list\n    model_list.append(ddp_model)\n    optimizer_list.append(optimizer)\n    scheduler_list.append(scheduler)\n    scaler = torch.cuda.amp.GradScaler(enabled=configs[\"enable_amp\"])\n\n    # If specify checkpoint, load some info from checkpoint.\n    if checkpoint is not None:\n        load_checkpoint(model_list, optimizer_list, scheduler_list, scaler,\n                        checkpoint)\n        start_epoch = (\n            int(re.findall(r\"(?<=checkpoint_)\\d*(?=.pt)\", checkpoint)[0]) + 1)\n        logger.info(\"Load checkpoint: {}\".format(checkpoint))\n    else:\n        start_epoch = 1\n    logger.info(\"start_epoch: {}\".format(start_epoch))\n\n    # save config.yaml\n    if rank == 0:\n        saved_config_path = os.path.join(configs[\"exp_dir\"], \"config.yaml\")\n        with open(saved_config_path, \"w\") as fout:\n            data = yaml.dump(configs)\n            fout.write(data)\n\n    # training\n    dist.barrier(device_ids=[gpu])  # synchronize here\n    if rank == 0:\n        logger.info(\"<========== Training process ==========>\")\n        header = [\"Train/Val\", \"Epoch\", \"iter\", \"Loss\", \"LR\"]\n        for line in tp.header(header, width=10, style=\"grid\").split(\"\\n\"):\n            logger.info(line)\n    dist.barrier(device_ids=[gpu])  # synchronize here\n\n    executor = Executor()\n    executor.step = 0\n\n    train_losses = []\n    val_losses = []\n    for epoch in range(start_epoch, configs[\"num_epochs\"] + 1):\n        train_dataset.set_epoch(epoch)\n\n        # train_loss_com\n        train_loss, _ = executor.train(\n            train_dataloader,\n            model_list,\n            epoch_iter,\n            optimizer_list,\n            criterion,\n            scheduler_list,\n            scaler=scaler,\n            epoch=epoch,\n            logger=logger,\n            enable_amp=configs[\"enable_amp\"],\n            clip_grad=configs[\"clip_grad\"],\n            log_batch_interval=configs[\"log_batch_interval\"],\n            device=device,\n            se_loss_weight=loss_args,\n            multi_task=multi_task,\n            SSA_enroll_prob=configs[\"dataset_args\"].get(\"SSA_enroll_prob\", 0),\n            fbank_args=configs[\"dataset_args\"].get('fbank_args', None),\n            sample_rate=configs[\"dataset_args\"]['resample_rate'],\n            speaker_feat=configs[\"dataset_args\"].get('speaker_feat', True)\n        )\n\n        val_loss, _ = executor.cv(\n            val_dataloader,\n            model_list,\n            val_iter,\n            criterion,\n            epoch=epoch,\n            logger=logger,\n            enable_amp=configs[\"enable_amp\"],\n            log_batch_interval=configs[\"log_batch_interval\"],\n            device=device,\n        )\n\n        if rank == 0:\n            logger.info(\"Epoch {} Train info train_loss {}\".format(\n                epoch, train_loss))\n            logger.info(\"Epoch {} Val info val_loss {}\".format(\n                epoch, val_loss))\n            train_losses.append(train_loss)\n            val_losses.append(val_loss)\n\n            best_loss = val_loss\n            scheduler.best = best_loss\n            # plot\n            plt.figure()\n            plt.title(\"Loss of Train and Validation\")\n            x = list(range(start_epoch, epoch + 1))\n            plt.plot(x, train_losses, \"b-\", label=\"Train Loss\", linewidth=0.8)\n            plt.plot(x,\n                     val_losses,\n                     \"c-\",\n                     label=\"Validation Loss\",\n                     linewidth=0.8)\n            plt.legend()\n            plt.xlabel(\"Epoch\")\n            plt.ylabel(\"Loss\")\n            plt.xticks(range(start_epoch, epoch + 1, 1))\n            plt.savefig(\n                f\"{configs['exp_dir']}/{configs['model']['tse_model']}.png\")\n            plt.close()\n\n        if rank == 0:\n            if (epoch % configs[\"save_epoch_interval\"] == 0\n                    or epoch >= configs[\"num_epochs\"] - configs[\"num_avg\"]):\n                save_checkpoint(\n                    model_list,\n                    optimizer_list,\n                    scheduler_list,\n                    scaler,\n                    os.path.join(model_dir, \"checkpoint_{}.pt\".format(epoch)),\n                )\n                try:\n                    os.symlink(\n                        \"checkpoint_{}.pt\".format(epoch),\n                        os.path.join(model_dir, \"latest_checkpoint.pt\"),\n                    )\n                except FileExistsError:\n                    os.remove(os.path.join(model_dir, \"latest_checkpoint.pt\"))\n                    os.symlink(\n                        \"checkpoint_{}.pt\".format(epoch),\n                        os.path.join(model_dir, \"latest_checkpoint.pt\"),\n                    )\n\n    if rank == 0:\n        os.symlink(\n            \"checkpoint_{}.pt\".format(configs[\"num_epochs\"]),\n            os.path.join(model_dir, \"final_checkpoint.pt\"),\n        )\n        logger.info(tp.bottom(len(header), width=10, style=\"grid\"))\n\n\nif __name__ == \"__main__\":\n    fire.Fire(train)\n"
  },
  {
    "path": "wesep/bin/train_gan.py",
    "content": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\n#               2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport re\nfrom pprint import pformat\n\nimport fire\nimport matplotlib.pyplot as plt\nimport tableprint as tp\nimport torch\nimport torch.distributed as dist\nimport yaml\nfrom torch.utils.data import DataLoader\n\nimport wesep.utils.schedulers as schedulers\nfrom wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk\nfrom wesep.models import get_model\nfrom wesep.utils.checkpoint import (\n    load_checkpoint,\n    load_pretrained_model,\n    save_checkpoint,\n)\nfrom wesep.utils.executor_gan import ExecutorGAN\nfrom wesep.utils.file_utils import (\n    load_speaker_embeddings,\n    read_label_file,\n    read_vec_scp_file,\n)\nfrom wesep.utils.losses import parse_loss\nfrom wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger\n\nMAX_NUM_log_files = 100  # The maximum number of log-files to be kept\nlogging.getLogger(\"matplotlib.font_manager\").setLevel(logging.ERROR)\n\n\ndef train(config=\"conf/config.yaml\", **kwargs):\n    \"\"\"Trains a model on the given features and spk labels.\n\n    :config: A training configuration. Note that all parameters in the\n             config can also be manually adjusted with --ARG VALUE\n    :returns: None\n    \"\"\"\n    configs = parse_config_or_kwargs(config, **kwargs)\n    checkpoint = configs.get(\"checkpoint\", None)\n    if checkpoint is not None:\n        checkpoint = os.path.realpath(checkpoint)\n    find_unused_parameters = configs.get(\"find_unused_parameters\", False)\n    gan_loss_weight = configs.get(\"gan_loss_weight\", 0.05)\n\n    # dist configs\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n    gpu = int(configs[\"gpus\"][rank])\n    torch.cuda.set_device(gpu)\n    dist.init_process_group(backend=\"nccl\")\n\n    # Log rotation\n    model_dir = os.path.join(configs[\"exp_dir\"], \"models\")\n    logger = setup_logger(rank, configs[\"exp_dir\"], gpu, MAX_NUM_log_files)\n\n    print(\"-------------------\", dist.get_rank(), world_size)\n    if world_size > 1:\n        logger.info(\"training on multiple gpus, this gpu {}\".format(gpu))\n\n    if rank == 0:\n        logger.info(\"exp_dir is: {}\".format(configs[\"exp_dir\"]))\n        logger.info(\"<== Passed Arguments ==>\")\n        # Print arguments into logs\n        for line in pformat(configs).split(\"\\n\"):\n            logger.info(line)\n\n    # seed\n    set_seed(configs[\"seed\"] + rank)\n\n    # support multiple losses, e.g., criterion = [SISNR, CE]\n    criterion = configs.get(\"loss\", None)\n    if criterion:\n        criterion = parse_loss(criterion)\n    else:\n        criterion = [\n            parse_loss(\"SISNR\"),\n        ]\n    # loss_posi is used to store the indices when the model has multiple outputs\n    # loss_posi[i][j] stores the index of the output used for i-th criterion,\n    # that is, output[loss_posi[i][j]] is used for the i-th criterion.\n    loss_posi = configs[\"loss_args\"].get(\n        \"loss_posi\",\n        [[\n            0,\n        ]],\n    )\n    # loss_weight[i][j] stores the loss weight of output[loss_posi[i][j]] for the i-th criterion.  # noqa\n    loss_weight = configs[\"loss_args\"].get(\n        \"loss_weight\",\n        [[\n            1.0,\n        ]],\n    )\n    loss_args = (loss_posi, loss_weight)\n\n    # embeds\n    tr_spk_embeds = configs[\"train_spk_embeds\"]\n    tr_single_utt2spk = configs[\"train_utt2spk\"]\n    joint_training = configs[\"model_args\"][\"tse_model\"].get(\n        \"joint_training\", False)\n    multi_task = configs[\"model_args\"][\"tse_model\"].get(\"multi_task\", False)\n\n    # dict_spk: {spk_id: int_label}\n    dict_spk = {}\n    if not joint_training:\n        tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds,\n                                                    tr_single_utt2spk)\n        multi_task = False\n    else:\n        with open(configs[\"train_spk2utt\"], \"r\") as f:\n            tr_spk2embed_dict = json.load(f)\n            # tr_spk2embed_dict: {spk_id: [[spk_id, wav_path], ...]}\n            if multi_task:\n                for i, j in enumerate(tr_spk2embed_dict.keys(\n                )):  # Generate the dictionary for speakers in training set\n                    dict_spk[j] = i\n\n    with open(tr_single_utt2spk, \"r\") as f:\n        tr_lines = f.readlines()\n\n    val_spk_embeds = configs[\"val_spk_embeds\"]\n    val_spk1_enroll = configs[\"val_spk1_enroll\"]\n    val_spk2_enroll = configs[\"val_spk2_enroll\"]\n\n    if not joint_training:\n        val_spk2embed_dict = read_vec_scp_file(val_spk_embeds)\n    else:\n        val_spk2embed_dict = read_label_file(configs[\"val_spk2utt\"])\n\n    val_spk1_embed = read_label_file(val_spk1_enroll)\n    val_spk2_embed = read_label_file(val_spk2_enroll)\n\n    with open(val_spk_embeds, \"r\") as f:\n        val_lines = f.readlines()\n\n    # dataset and dataloader\n    train_dataset = Dataset(\n        configs[\"data_type\"],\n        configs[\"train_data\"],\n        configs[\"dataset_args\"],\n        tr_spk2embed_dict,\n        None,\n        None,\n        state=\"train\",\n        joint_training=joint_training,\n        dict_spk=dict_spk,\n        whole_utt=configs.get(\"whole_utt\", False),\n        repeat_dataset=configs.get(\"repeat_dataset\", True),\n        reverb=configs[\"dataset_args\"].get(\"reverb\", False),\n        noise=configs[\"dataset_args\"].get(\"noise\", False),\n        noise_lmdb_file=configs[\"dataset_args\"].get(\"noise_lmdb_file\", None),\n        online_mix=configs[\"dataset_args\"].get(\"online_mix\", False),\n    )\n    val_dataset = Dataset(\n        configs[\"data_type\"],\n        configs[\"val_data\"],\n        configs[\"dataset_args\"],\n        val_spk2embed_dict,\n        val_spk1_embed,\n        val_spk2_embed,\n        state=\"val\",\n        joint_training=joint_training,\n        whole_utt=configs.get(\"whole_utt\", False),\n        repeat_dataset=True,\n        reverb=False,\n        online_mix=False,\n    )\n    train_dataloader = DataLoader(train_dataset,\n                                  **configs[\"dataloader_args\"],\n                                  collate_fn=tse_collate_fn)\n    val_dataloader = DataLoader(\n        val_dataset,\n        **configs[\"dataloader_args\"],\n        collate_fn=tse_collate_fn_2spk,\n    )\n    batch_size = configs[\"dataloader_args\"][\"batch_size\"]\n    if configs[\"dataset_args\"].get(\"sample_num_per_epoch\", 0) > 0:\n        sample_num_per_epoch = configs[\"dataset_args\"][\"sample_num_per_epoch\"]\n    else:\n        sample_num_per_epoch = len(tr_lines) // 2\n    epoch_iter = sample_num_per_epoch // world_size // batch_size\n    val_iter = len(val_lines) // 2 // world_size // batch_size\n    if rank == 0:\n        logger.info(\"<== Dataloaders ==>\")\n        logger.info(\"train dataloaders created\")\n        logger.info(\"epoch iteration number: {}\".format(epoch_iter))\n        logger.info(\"val iteration number: {}\".format(val_iter))\n\n    # model\n    model_list = []\n    scheduler_list = []\n    optimizer_list = []\n\n    logger.info(\"<== Model ==>\")\n    model = get_model(\n        configs[\"model\"][\"tse_model\"])(**configs[\"model_args\"][\"tse_model\"])\n    num_params = sum(param.numel() for param in model.parameters())\n\n    if rank == 0:\n        logger.info(\"tse_model size: {}\".format(num_params))\n        # print model\n        for line in pformat(model).split(\"\\n\"):\n            logger.info(line)\n\n    # ddp_model\n    model.cuda()\n    ddp_model = torch.nn.parallel.DistributedDataParallel(\n        model, find_unused_parameters=find_unused_parameters)\n    device = torch.device(\"cuda\")\n\n    if rank == 0:\n        logger.info(\"<== TSE Model Loss ==>\")\n        logger.info(\"loss criterion is: \" + str(configs[\"loss\"]))\n\n    configs[\"optimizer_args\"][\"tse_model\"][\"lr\"] = configs[\"scheduler_args\"][\n        \"tse_model\"][\"initial_lr\"]\n    optimizer = getattr(torch.optim, configs[\"optimizer\"][\"tse_model\"])(\n        ddp_model.parameters(), **configs[\"optimizer_args\"][\"tse_model\"])\n    if rank == 0:\n        logger.info(\"<== TSE Model Optimizer ==>\")\n        logger.info(\"optimizer is: \" + configs[\"optimizer\"][\"tse_model\"])\n\n    # scheduler\n    configs[\"scheduler_args\"][\"tse_model\"][\"num_epochs\"] = configs[\n        \"num_epochs\"]\n    configs[\"scheduler_args\"][\"tse_model\"][\"epoch_iter\"] = epoch_iter\n    configs[\"scheduler_args\"][\"scale_ratio\"] = 1.0\n\n    scheduler = getattr(schedulers, configs[\"scheduler\"][\"tse_model\"])(\n        optimizer, **configs[\"scheduler_args\"][\"tse_model\"])\n    if rank == 0:\n        logger.info(\"<== TSE Model Scheduler ==>\")\n        logger.info(\"scheduler is: \" + configs[\"scheduler\"][\"tse_model\"])\n\n    if configs[\"model_init\"][\"tse_model\"] is not None:\n        logger.info(\"Load initial model from {}\".format(\n            configs[\"model_init\"][\"tse_model\"]))\n        load_pretrained_model(ddp_model, configs[\"model_init\"][\"tse_model\"])\n    elif checkpoint is None:\n        logger.info(\"Train model from scratch ...\")\n\n    for c in criterion:\n        c = c.to(device)\n\n    # append to list\n    model_list.append(ddp_model)\n    optimizer_list.append(optimizer)\n    scheduler_list.append(scheduler)\n    scaler = torch.cuda.amp.GradScaler(enabled=configs[\"enable_amp\"])\n\n    # discriminator\n    discriminator = get_model(configs[\"model\"][\"discriminator\"])(\n        **configs[\"model_args\"][\"discriminator\"])\n    num_params = sum(param.numel() for param in discriminator.parameters())\n    # optimizer\n    configs[\"optimizer_args\"][\"discriminator\"][\"lr\"] = configs[\n        \"scheduler_args\"][\"discriminator\"][\"initial_lr\"]\n    # scheduler\n    configs[\"scheduler_args\"][\"discriminator\"][\"num_epochs\"] = configs[\n        \"num_epochs\"]\n    configs[\"scheduler_args\"][\"discriminator\"][\"epoch_iter\"] = epoch_iter\n    configs[\"scheduler_args\"][\"discriminator\"][\"scale_ratio\"] = 1.0\n    # ddp model\n    discriminator.cuda()\n    ddp_discriminator = torch.nn.parallel.DistributedDataParallel(\n        discriminator, find_unused_parameters=find_unused_parameters)\n    optimizer_d = getattr(torch.optim, configs[\"optimizer\"][\"discriminator\"])(\n        ddp_discriminator.parameters(),\n        **configs[\"optimizer_args\"][\"discriminator\"],\n    )\n    scheduler_d = getattr(schedulers, configs[\"scheduler\"][\"discriminator\"])(\n        optimizer_d, **configs[\"scheduler_args\"][\"discriminator\"])\n\n    # initialize discriminator\n    if configs[\"model_init\"][\"discriminator\"] is not None:\n        logger.info(\"Load initial discriminator from {}\".format(\n            configs[\"model_init\"][\"discriminator\"]))\n        load_pretrained_model(\n            ddp_discriminator,\n            configs[\"model_init\"][\"discriminator\"],\n            type=\"discriminator\",\n        )\n    elif checkpoint is None:\n        logger.info(\"Train discriminator from scratch ...\")\n\n    # If specify checkpoint, load some info from checkpoint.\n    if checkpoint is not None:\n        load_checkpoint(model_list, optimizer_list, scheduler_list, scaler,\n                        checkpoint)\n        start_epoch = (\n            int(re.findall(r\"(?<=checkpoint_)\\d*(?=.pt)\", checkpoint)[0]) + 1)\n        logger.info(\"Load checkpoint: {}\".format(checkpoint))\n    else:\n        start_epoch = 1\n\n    model_list.append(ddp_discriminator)\n    optimizer_list.append(optimizer_d)\n    scheduler_list.append(scheduler_d)\n\n    if rank == 0:\n        logger.info(\"<== Discriminator Model ==>\")\n        logger.info(\"discriminator size: {}\".format(num_params))\n        for line in pformat(discriminator).split(\"\\n\"):\n            logger.info(line)\n        logger.info(\"<== Discriminator Optimizer ==>\")\n        logger.info(\"optimizer is: \" + configs[\"optimizer\"][\"discriminator\"])\n        logger.info(\"<== Discriminator Scheduler ==>\")\n        logger.info(\"scheduler is: \" + configs[\"scheduler\"][\"discriminator\"])\n\n        # save config.yaml\n        saved_config_path = os.path.join(configs[\"exp_dir\"], \"config.yaml\")\n        with open(saved_config_path, \"w\") as fout:\n            data = yaml.dump(configs)\n            fout.write(data)\n\n    logger.info(\"start_epoch: {}\".format(start_epoch))\n\n    # training\n    dist.barrier(device_ids=[gpu])  # synchronize here\n    if rank == 0:\n        logger.info(\"<========== Training process ==========>\")\n        header = [\n            \"Train/Val\",\n            \"Epoch\",\n            \"iter\",\n            \"SE_Loss\",\n            \"G_Loss\",\n            \"D_Loss\",\n            \"LR\",\n        ]\n        for line in tp.header(header, width=10, style=\"grid\").split(\"\\n\"):\n            logger.info(line)\n    dist.barrier(device_ids=[gpu])  # synchronize here\n\n    executor = ExecutorGAN()\n    executor.step = 0\n\n    train_losses = []\n    val_losses = []\n    train_d_losses = []\n    val_d_losses = []\n    for epoch in range(start_epoch, configs[\"num_epochs\"] + 1):\n        train_dataset.set_epoch(epoch)\n\n        train_loss, train_d_loss = executor.train(\n            train_dataloader,\n            model_list,\n            epoch_iter,\n            optimizer_list,\n            criterion,\n            scheduler_list,\n            scaler=scaler,\n            epoch=epoch,\n            logger=logger,\n            enable_amp=configs[\"enable_amp\"],\n            clip_grad=configs[\"clip_grad\"],\n            log_batch_interval=configs[\"log_batch_interval\"],\n            device=device,\n            se_loss_weight=loss_args,\n            gan_loss_weight=gan_loss_weight,\n            multi_task=multi_task,\n        )\n\n        val_loss, val_d_loss = executor.cv(\n            val_dataloader,\n            model_list,\n            val_iter,\n            criterion,\n            epoch=epoch,\n            logger=logger,\n            enable_amp=configs[\"enable_amp\"],\n            log_batch_interval=configs[\"log_batch_interval\"],\n            device=device,\n        )\n\n        if rank == 0:\n            logger.info(\n                \"Epoch {} Train info train_loss {}, train_d_loss {}\".format(\n                    epoch, train_loss, train_d_loss))\n            logger.info(\"Epoch {} Val info val_loss {}, val_d_loss {}\".format(\n                epoch, val_loss, val_d_loss))\n            train_losses.append(train_loss)\n            train_d_losses.append(train_d_loss)\n            val_losses.append(val_loss)\n            val_d_losses.append(val_d_loss)\n\n            best_loss = val_loss\n            scheduler.best = best_loss\n            # plot\n            plt.figure()\n            plt.title(\"Loss of Train and Validation\")\n            x = list(range(start_epoch, epoch + 1))\n            plt.plot(x,\n                     train_losses,\n                     \"b-\",\n                     label=\"train_G_loss\",\n                     linewidth=0.8)\n            plt.plot(x,\n                     train_d_losses,\n                     \"r-\",\n                     label=\"train_D_loss\",\n                     linewidth=0.8)\n            plt.plot(x, val_losses, \"c-\", label=\"val_G_loss\", linewidth=0.8)\n            plt.plot(x, val_d_losses, \"m-\", label=\"val_D_loss\", linewidth=0.8)\n            plt.legend()\n            plt.xlabel(\"Epoch\")\n            plt.ylabel(\"Loss\")\n            plt.xticks(range(start_epoch, epoch + 1, 1))\n            plt.savefig(\n                f\"{configs['exp_dir']}/{configs['model']['tse_model']}.png\")\n            plt.close()\n\n        if rank == 0:\n            if (epoch % configs[\"save_epoch_interval\"] == 0\n                    or epoch >= configs[\"num_epochs\"] - configs[\"num_avg\"]):\n                save_checkpoint(\n                    model_list,\n                    optimizer_list,\n                    scheduler_list,\n                    scaler,\n                    os.path.join(model_dir, \"checkpoint_{}.pt\".format(epoch)),\n                )\n            try:\n                os.symlink(\n                    \"checkpoint_{}.pt\".format(epoch),\n                    os.path.join(model_dir, \"latest_checkpoint.pt\"),\n                )\n            except FileExistsError:\n                os.remove(os.path.join(model_dir, \"latest_checkpoint.pt\"))\n                os.symlink(\n                    \"checkpoint_{}.pt\".format(epoch),\n                    os.path.join(model_dir, \"latest_checkpoint.pt\"),\n                )\n\n    if rank == 0:\n        os.symlink(\n            \"checkpoint_{}.pt\".format(configs[\"num_epochs\"]),\n            os.path.join(model_dir, \"final_checkpoint.pt\"),\n        )\n        logger.info(tp.bottom(len(header), width=10, style=\"grid\"))\n\n\nif __name__ == \"__main__\":\n    fire.Fire(train)\n"
  },
  {
    "path": "wesep/cli/__init__.py",
    "content": ""
  },
  {
    "path": "wesep/cli/extractor.py",
    "content": "import os\nimport sys\n\nfrom silero_vad import load_silero_vad, get_speech_timestamps\nimport torch\nimport torchaudio\nimport torchaudio.compliance.kaldi as kaldi\nimport yaml\nimport soundfile\n\nfrom wesep.cli.hub import Hub\nfrom wesep.cli.utils import get_args\nfrom wesep.models import get_model\nfrom wesep.utils.checkpoint import load_pretrained_model\nfrom wesep.utils.utils import set_seed\n\n\nclass Extractor:\n\n    def __init__(self, model_dir: str):\n        set_seed()\n\n        config_path = os.path.join(model_dir, \"config.yaml\")\n        model_path = os.path.join(model_dir, \"avg_model.pt\")\n        with open(config_path, \"r\") as fin:\n            configs = yaml.load(fin, Loader=yaml.FullLoader)\n            if 'spk_model_init' in configs['model_args']['tse_model']:\n                configs['model_args']['tse_model']['spk_model_init'] = False\n        self.model = get_model(configs[\"model\"][\"tse_model\"])(\n            **configs[\"model_args\"][\"tse_model\"]\n        )\n        load_pretrained_model(self.model, model_path)\n        self.model.eval()\n        self.vad = load_silero_vad()\n        self.table = {}\n        self.resample_rate = configs[\"dataset_args\"].get(\"resample_rate\", 16000)\n        self.apply_vad = False\n        self.device = torch.device(\"cpu\")\n        self.wavform_norm = True\n        self.output_norm = True\n\n        self.speaker_feat = configs[\"model_args\"][\"tse_model\"].get(\"spk_feat\", False)\n        self.joint_training = configs[\"model_args\"][\"tse_model\"].get(\n            \"joint_training\", False\n        )\n\n    def set_wavform_norm(self, wavform_norm: bool):\n        self.wavform_norm = wavform_norm\n\n    def set_resample_rate(self, resample_rate: int):\n        self.resample_rate = resample_rate\n\n    def set_vad(self, apply_vad: bool):\n        self.apply_vad = apply_vad\n\n    def set_device(self, device: str):\n        self.device = torch.device(device)\n        self.model = self.model.to(self.device)\n\n    def set_output_norm(self, output_norm: bool):\n        self.output_norm = output_norm\n\n    def compute_fbank(\n        self,\n        wavform,\n        sample_rate=16000,\n        num_mel_bins=80,\n        frame_length=25,\n        frame_shift=10,\n        cmn=True,\n    ):\n        feat = kaldi.fbank(\n            wavform,\n            num_mel_bins=num_mel_bins,\n            frame_length=frame_length,\n            frame_shift=frame_shift,\n            sample_frequency=sample_rate,\n        )\n        if cmn:\n            feat = feat - torch.mean(feat, 0)\n        return feat\n\n    def extract_speech(self, audio_path: str, audio_path_2: str):\n        pcm_mix, sample_rate_mix = torchaudio.load(\n            audio_path, normalize=self.wavform_norm\n        )\n        pcm_enroll, sample_rate_enroll = torchaudio.load(\n            audio_path_2, normalize=self.wavform_norm\n        )\n        return self.extract_speech_from_pcm(pcm_mix,\n                                            sample_rate_mix,\n                                            pcm_enroll,\n                                            sample_rate_enroll)\n\n    def extract_speech_from_pcm(self,\n                                pcm_mix: torch.Tensor,\n                                sample_rate_mix: int,\n                                pcm_enroll: torch.Tensor,\n                                sample_rate_enroll: int):\n        if self.apply_vad:\n            # TODO(Binbin Zhang): Refine the segments logic, here we just\n            # suppose there is only silence at the start/end of the speech\n            # Only do vad on the enrollment\n            vad_sample_rate = 16000\n            wav = pcm_enroll\n            if wav.size(0) > 1:\n                wav = wav.mean(dim=0, keepdim=True)\n            if sample_rate_enroll != vad_sample_rate:\n                transform = torchaudio.transforms.Resample(\n                    orig_freq=sample_rate_enroll, new_freq=vad_sample_rate\n                )\n                wav = transform(wav)\n\n            segments = get_speech_timestamps(wav, self.vad, return_seconds=True)\n            pcmTotal = torch.Tensor()\n            if len(segments) > 0:  # remove all the silence\n                for segment in segments:\n                    start = int(segment[\"start\"] * sample_rate_enroll)\n                    end = int(segment[\"end\"] * sample_rate_enroll)\n                    pcmTemp = pcm_enroll[0, start:end]\n                    pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)\n                pcm_enroll = pcmTotal.unsqueeze(0)\n            else:  # all silence, nospeech\n                return None\n\n        pcm_mix = pcm_mix.to(torch.float)\n        if sample_rate_mix != self.resample_rate:\n            pcm_mix = torchaudio.transforms.Resample(\n                orig_freq=sample_rate_mix, new_freq=self.resample_rate\n            )(pcm_mix)\n        pcm_enroll = pcm_enroll.to(torch.float)\n        if sample_rate_enroll != self.resample_rate:\n            pcm_enroll = torchaudio.transforms.Resample(\n                orig_freq=sample_rate_enroll, new_freq=self.resample_rate\n            )(pcm_enroll)\n\n        if self.joint_training:\n            if self.speaker_feat:\n                feats = self.compute_fbank(\n                    pcm_enroll, sample_rate=self.resample_rate, cmn=True\n                )\n                feats = feats.unsqueeze(0)\n            else:\n                feats = pcm_enroll\n\n            feats = feats.to(self.device)\n            pcm_mix = pcm_mix.to(self.device)\n            with torch.no_grad():\n                outputs = self.model(pcm_mix, feats)\n                outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs\n            target_speech = outputs.to(torch.device(\"cpu\"))\n            if self.output_norm:\n                target_speech = (\n                    target_speech\n                    / abs(target_speech).max(dim=1, keepdim=True).values * 0.9\n                )\n            return target_speech\n        else:\n            return None\n\n\ndef load_model(language: str) -> Extractor:\n    model_path = Hub.get_model(language)\n    return Extractor(model_path)\n\n\ndef load_model_local(model_dir: str) -> Extractor:\n    return Extractor(model_dir)\n\n\ndef main():\n    args = get_args()\n    if args.pretrain == \"\":\n        if args.bsrnn:\n            model = load_model(\"bsrnn\")\n        else:\n            model = load_model(args.language)\n    else:\n        model = load_model_local(args.pretrain)\n    model.set_resample_rate(args.resample_rate)\n    model.set_vad(args.vad)\n    model.set_device(args.device)\n    model.set_output_norm(args.output_norm)\n    if args.task == \"extraction\":\n        speech = model.extract_speech(args.audio_file, args.audio_file2)\n        if speech is not None:\n            if args.normalize_output:\n                speech = speech / abs(speech).max(dim=1, keepdim=True).values * 0.9\n            soundfile.write(args.output_file, speech[0], args.resample_rate)\n            print(\"Succeed, see {}\".format(args.output_file))\n        else:\n            print(\"Fails to extract the target speech\")\n    else:\n        print(\"Unsupported task {}\".format(args.task))\n        sys.exit(-1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "wesep/cli/hub.py",
    "content": "# Copyright (c) 2022  Mddct(hamddct@gmail.com)\n#               2023  Binbin Zhang(binbzha@qq.com)\n#               2024  Shuai Wang(wsstriving@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nfrom pathlib import Path\nimport tarfile\nimport zipfile\nfrom urllib.request import urlretrieve\n\nimport tqdm\n\n\ndef download(url: str, dest: str, only_child=True):\n    \"\"\"download from url to dest\"\"\"\n    assert os.path.exists(dest)\n    print(\"Downloading {} to {}\".format(url, dest))\n\n    def progress_hook(t):\n        last_b = [0]\n\n        def update_to(b=1, bsize=1, tsize=None):\n            if tsize not in (None, -1):\n                t.total = tsize\n            displayed = t.update((b - last_b[0]) * bsize)\n            last_b[0] = b\n            return displayed\n\n        return update_to\n\n    # *.tar.gz\n    name = url.split(\"?\")[0].split(\"/\")[-1]\n    file_path = os.path.join(dest, name)\n    with tqdm.tqdm(\n        unit=\"B\", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name)\n    ) as t:\n        urlretrieve(\n            url, filename=file_path, reporthook=progress_hook(t), data=None\n        )\n        t.total = t.n\n\n    if name.endswith((\".tar.gz\", \".tar\")):\n        with tarfile.open(file_path) as f:\n            if not only_child:\n                f.extractall(dest)\n            else:\n                for tarinfo in f:\n                    if \"/\" not in tarinfo.name:\n                        continue\n                    name = os.path.basename(tarinfo.name)\n                    fileobj = f.extractfile(tarinfo)\n                    with open(os.path.join(dest, name), \"wb\") as writer:\n                        writer.write(fileobj.read())\n\n    elif name.endswith(\".zip\"):\n        with zipfile.ZipFile(file_path, \"r\") as zip_ref:\n            if not only_child:\n                zip_ref.extractall(dest)\n            else:\n                for member in zip_ref.namelist():\n                    member_path = os.path.relpath(\n                        member, start=os.path.commonpath(zip_ref.namelist())\n                    )\n                    print(member_path)\n                    if \"/\" not in member_path:\n                        continue\n                    name = os.path.basename(member_path)\n                    with zip_ref.open(member_path) as source, open(\n                        os.path.join(dest, name), \"wb\"\n                    ) as target:\n                        target.write(source.read())\n\n\nclass Hub(object):\n    Assets = {\n        \"english\": \"bsrnn_ecapa_vox1.tar.gz\",\n    }\n    #   Hard coding of the URL\n    ModelURLs = {\n        \"bsrnn_ecapa_vox1.tar.gz\": (\n            \"https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/\"\n            \"resolve/master/bsrnn_ecapa_vox1.tar.gz\"\n        ),\n    }\n\n    def __init__(self) -> None:\n        pass\n\n    @staticmethod\n    def get_model(lang: str) -> str:\n        if lang not in Hub.Assets.keys():\n            print(\"ERROR: Unsupported lang {} !!!\".format(lang))\n            sys.exit(1)\n        # model = Hub.Assets[lang]\n        model_name = Hub.Assets[lang]\n        model_dir = os.path.join(Path.home(), \".wesep\", lang)\n        if not os.path.exists(model_dir):\n            os.makedirs(model_dir)\n        if set([\"avg_model.pt\", \"config.yaml\"]).issubset(\n            set(os.listdir(model_dir))\n        ):\n            return model_dir\n        else:\n            if model_name in Hub.ModelURLs:\n                model_url = Hub.ModelURLs[model_name]\n                download(model_url, model_dir)\n                return model_dir\n            else:\n                print(f\"ERROR: No URL found for model {model_name}\")\n                return None\n"
  },
  {
    "path": "wesep/cli/utils.py",
    "content": "import argparse\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description=\"\")\n    parser.add_argument(\n        \"-t\",\n        \"--task\",\n        choices=[\n            \"extraction\",\n        ],\n        default=\"extraction\",\n        help=\"task type\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--language\",\n        choices=[\n            # \"chinese\",\n            \"english\",\n        ],\n        default=\"english\",\n        help=\"language type\",\n    )\n    parser.add_argument(\n        \"--bsrnn\",\n        action=\"store_true\",\n        help=\"whether to use the bsrnn model\",\n    )\n    parser.add_argument(\n        \"-p\", \"--pretrain\", type=str, default=\"\", help=\"model directory\"\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"cpu\",\n        help=\"device type (most commonly cpu or cuda,\"\n        \"but also potentially mps, xpu, xla or meta)\"\n        \"and optional device ordinal for the device type.\",\n    )\n    parser.add_argument(\"--audio_file\", help=\"mixture's audio file\")\n    parser.add_argument(\"--audio_file2\", help=\"enroll's audio file\")\n    parser.add_argument(\n        \"--resample_rate\", type=int, default=16000, help=\"resampling rate\"\n    )\n    parser.add_argument(\n        \"--vad\", action=\"store_true\", help=\"whether to do VAD or not\"\n    )\n    parser.add_argument(\n        \"--output_file\",\n        default='./extracted_speech.wav',\n        help=\"extracted speech saved in .wav\"\n    )\n    parser.add_argument(\n        \"--output_norm\",\n        default=True,\n        help=\"Control if normalize the output audio in .wav\"\n    )\n    args = parser.parse_args()\n    return args\n"
  },
  {
    "path": "wesep/dataset/FRAM_RIR.py",
    "content": "# Author: Rongzhi Gu, Yi Luo\n# Copyright: Tencent AI Lab\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nfrom torchaudio.functional import highpass_biquad\nfrom torchaudio.transforms import Resample\n\n# set random seed\nseed = 20231\nnp.random.seed(seed)\ntorch.manual_seed(seed)\n\n\ndef calc_cos(orientation_rad):\n    \"\"\"\n    cos_theta: tensor, [azimuth, elevation] with shape [..., 2]\n    return: [..., 3]\n    \"\"\"\n    return torch.stack(\n        [\n            torch.cos(\n                orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])),\n            torch.sin(\n                orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])),\n            torch.cos(orientation_rad[..., 1]),\n        ],\n        -1,\n    )\n\n\ndef freq_invariant_decay_func(cos_theta, pattern=\"cardioid\"):\n    \"\"\"\n    cos_theta: tensor\n    Return:\n    amplitude: tensor with same shape as cos_theta\n    \"\"\"\n\n    if pattern == \"cardioid\":\n        return 0.5 + 0.5 * cos_theta\n\n    elif pattern == \"omni\":\n        return torch.ones_like(cos_theta)\n\n    elif pattern == \"bidirectional\":\n        return cos_theta\n\n    elif pattern == \"hyper_cardioid\":\n        return 0.25 + 0.75 * cos_theta\n\n    elif pattern == \"sub_cardioid\":\n        return 0.75 + 0.25 * cos_theta\n\n    elif pattern == \"half_omni\":\n        c = torch.clamp(cos_theta, 0)\n        c[c > 0] = 1.0\n        return c\n    else:\n        raise NotImplementedError\n\n\ndef freq_invariant_src_decay_func(mic_pos,\n                                  src_pos,\n                                  src_orientation_rad,\n                                  pattern=\"cardioid\"):\n    \"\"\"\n    mic_pos: [n_mic, 3] (tensor)\n    src_pos: [n_src, 3] (tensor)\n    src_orientation_rad: [n_src, 2] (tensor), elevation, azimuth\n\n    Return:\n    amplitude: [n_mic, n_src, n_image]\n    \"\"\"\n    # Steering vector of source(s)\n    orV_src = calc_cos(src_orientation_rad).unsqueeze(0)  # [nsrc, 3]\n\n    # receiver to src vector\n    rcv_to_src_vec = mic_pos.unsqueeze(1) - src_pos.unsqueeze(\n        0)  # [n_mic, n_src, 3]\n\n    cos_theta = (rcv_to_src_vec * orV_src).sum(-1)  # [n_mic, n_src]\n    cos_theta /= torch.sqrt(rcv_to_src_vec.pow(2).sum(-1))\n    cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1))\n\n    return freq_invariant_decay_func(cos_theta, pattern)\n\n\ndef freq_invariant_mic_decay_func(mic_pos,\n                                  img_pos,\n                                  mic_orientation_rad,\n                                  pattern=\"cardioid\"):\n    \"\"\"\n    mic_pos: [n_mic, 3] (tensor)\n    img_pos: [n_src, n_image, 3] (tensor)\n    mic_orientation_rad: [n_mic, 2] (tensor), azimuth, elevation\n\n    Return:\n    amplitude: [n_mic, n_src, n_image]\n    \"\"\"\n    # Steering vector of source(s)\n    orV_src = calc_cos(mic_orientation_rad)  # [nmic, 3]\n    orV_src = orV_src.view(-1, 1, 1, 3)  # [n_mic, 1, 1, 3]\n\n    # image to receiver vector\n    # [1, n_src, n_image, 3] - [n_mic, 1, 1, 3] => [n_mic, n_src, n_image, 3]\n    img_to_rcv_vec = img_pos.unsqueeze(0) - mic_pos.unsqueeze(1).unsqueeze(1)\n\n    cos_theta = (img_to_rcv_vec * orV_src).sum(-1)  # [n_mic, n_src, n_image]\n    cos_theta /= torch.sqrt(img_to_rcv_vec.pow(2).sum(-1))\n    cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1))\n\n    return freq_invariant_decay_func(cos_theta, pattern)\n\n\ndef FRAM_RIR(\n    mic_pos,\n    sr,\n    T60,\n    room_dim,\n    src_pos,\n    num_src=1,\n    direct_range=(-6, 50),\n    n_image=(1024, 4097),\n    a=-2.0,\n    b=2.0,\n    tau=0.25,\n    src_pattern=\"omni\",\n    src_orientation_rad=None,\n    mic_pattern=\"omni\",\n    mic_orientation_rad=None,\n):\n    \"\"\"Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR)  # noqa\n\n    Args:\n        mic_pos: The microphone(s) position with respect to the room coordinates,  # noqa\n                 with shape [num_mic, 3] (in meters). Room coordinate system must be defined in advance,  # noqa\n                 with the constraint that the origin of the coordinate is on the floor(so positive z axis points up).  # noqa\n        sr: RIR sampling rate (Hz).\n        T60: RT60 (second).\n        room_dim: Room size with shape [3] (meters).\n        src_pos: The source(s) position with respect to the room coordinate system, with shape [num_src, 3] (meters).  # noqa\n        num_src: Number of sources. Defaults to 1.\n        direct_range: 2-element tuple, range of early reflection time (milliseconds,  # noqa\n                                        defined as the context around the direct path signal) of RIRs.  # noqa\n                                        Defaults to (-6, 50).\n        n_image: 2-element tuple, minimum and maximum number of images to sample from.  # noqa\n                                   Defaults to (1024, 4097).\n        a: controlling the random perturbation added to each virtual sound source.  Defaults to -2.0.  # noqa\n        b: controlling the random perturbation added to each virtual sound source. Defaults to 2.0.  # noqa\n        tau: controlling the relationship between the distance and the number of reflections of each  # noqa\n                               virtual sound source. Defaults to 0.25.\n        src_pattern: Polar pattern for all of the sources. Defaults to \"omni\".\n        src_orientation_rad: Array-like with shape [num_src, 2]. Orientation (rad) of all  # noqa\n                                                the sources, where the first column indicate azimuth and the  # noqa\n                                                second column indicate elevation. Defaults to None.  # noqa\n        mic_pattern: Polar pattern for all of the receivers. Defaults to \"omni\".\n        mic_orientation_rad: Array-like with shape [num_mic, 2]. Orientation (rad) of all  # noqa\n                                                the microphones, where the first column indicate azimuth and  # noqa\n                                                the second column indicate elevation. Defaults to None.  # noqa\n\n    Returns:\n        rir: RIR filters for all mic-source pairs, with shape [num_mic, num_src, rir_length].  # noqa\n        early_rir: Early reflection (direct path) RIR filters for all mic-source pairs,  # noqa\n                   with shape [num_mic, num_src, rir_length].\n    \"\"\"\n\n    # sample image\n    image = np.random.choice(range(n_image[0], n_image[1]))\n\n    R = torch.tensor(\n        1.0 / (2 *\n               (1.0 / room_dim[0] + 1.0 / room_dim[1] + 1.0 / room_dim[2])))\n\n    eps = np.finfo(np.float16).eps\n    mic_position = torch.from_numpy(mic_pos)\n    src_position = torch.from_numpy(src_pos)  # [nsource, 3]\n    n_mic = mic_position.shape[0]\n    num_src = src_position.shape[0]\n\n    # [nmic, nsource]\n    direct_dist = ((mic_position.unsqueeze(1) -\n                    src_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt()\n    # [nsource]\n    nearest_dist, nearest_mic_idx = direct_dist.min(0)\n    # [nsource, 3]\n    nearest_mic_position = mic_position[nearest_mic_idx]\n\n    ns = n_mic * num_src\n    ratio = 64\n    sample_sr = sr * ratio\n    velocity = 340.0\n    T60 = torch.tensor(T60)\n\n    direct_idx = (torch.ceil(direct_dist * sample_sr / velocity).long().view(\n        ns, ))\n    rir_length = int(np.ceil(sample_sr * T60))\n\n    resample1 = Resample(sample_sr, sample_sr // int(np.sqrt(ratio)))\n    resample2 = Resample(sample_sr // int(np.sqrt(ratio)), sr)\n\n    reflect_coef = (1 - (1 - torch.exp(-0.16 * R / T60)).pow(2)).sqrt()\n    dist_range = [\n        torch.linspace(1.0, velocity * T60 / nearest_dist[i] - 1, rir_length)\n        for i in range(num_src)\n    ]\n\n    dist_prob = torch.linspace(0.0, 1.0, rir_length)\n    dist_prob /= dist_prob.sum()\n    dist_select_idx = dist_prob.multinomial(num_samples=int(image * num_src),\n                                            replacement=True).view(\n                                                num_src, image)\n\n    dist_nearest_ratio = torch.stack(\n        [dist_range[i][dist_select_idx[i]] for i in range(num_src)], 0)\n\n    # apply different dist ratios to mirophones\n    azm = torch.FloatTensor(num_src, image).uniform_(-np.pi, np.pi)\n    ele = torch.FloatTensor(num_src, image).uniform_(-np.pi / 2, np.pi / 2)\n    # [nsource, nimage, 3]\n    unit_3d = torch.stack(\n        [\n            torch.sin(ele) * torch.cos(azm),\n            torch.sin(ele) * torch.sin(azm),\n            torch.cos(ele),\n        ],\n        -1,\n    )\n    # [nsource] x [nsource, T] x [nsource, nimage, 3] => [nsource, nimage, 3]\n    image2nearest_dist = nearest_dist.view(\n        -1, 1, 1) * dist_nearest_ratio.unsqueeze(-1)\n    image_position = (nearest_mic_position.unsqueeze(1) +\n                      image2nearest_dist * unit_3d)\n    # [nmic, nsource, nimage]\n    dist = ((mic_position.view(-1, 1, 1, 3) -\n             image_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt()\n\n    # reflection perturbation\n    reflect_max = (torch.log10(velocity * T60) - 3) / torch.log10(reflect_coef)\n    reflect_ratio = (dist /\n                     (velocity * T60)) * (reflect_max.view(1, -1, 1) - 1) + 1\n    reflect_pertub = torch.FloatTensor(num_src, image).uniform_(\n        a, b) * dist_nearest_ratio.pow(tau)\n    reflect_ratio = torch.maximum(reflect_ratio + reflect_pertub.unsqueeze(0),\n                                  torch.ones(1))\n\n    # [nmic, nsource, 1 + nimage]\n    dist = torch.cat([direct_dist.unsqueeze(2), dist], 2)\n    reflect_ratio = torch.cat([torch.zeros(n_mic, num_src, 1), reflect_ratio],\n                              2)\n\n    delta_idx = (torch.minimum(\n        torch.ceil(dist * sample_sr / velocity),\n        torch.ones(1) * rir_length - 1,\n    ).long().view(ns, -1))\n    delta_decay = reflect_coef.pow(reflect_ratio) / dist\n\n    #################################\n    # source orientation simulation #\n    #################################\n    if src_pattern != \"omni\":\n        # randomly sample each image's relative orientation with respect to the original source  # noqa\n        # equivalent to a random decay corresponds to the source's orientation pattern decay  # noqa\n        img_orientation_rad = torch.FloatTensor(num_src, image,\n                                                2).uniform_(-np.pi, np.pi)\n        img_cos_theta = torch.cos(img_orientation_rad[..., 0]) * torch.cos(\n            img_orientation_rad[..., 1])  # [nsource, nimage]\n        img_orientation_decay = freq_invariant_decay_func(\n            img_cos_theta, pattern=src_pattern)  # [nsource, nimage]\n\n        # direct path orientation should use the provided parameter\n        if src_orientation_rad is None:\n            # assume random orientation if not given\n            src_orientation_azi = torch.FloatTensor(num_src).uniform_(\n                -np.pi, np.pi)\n            src_orientation_ele = torch.FloatTensor(num_src).uniform_(\n                -np.pi, np.pi)\n            src_orientation_rad = torch.stack(\n                [src_orientation_azi, src_orientation_ele], -1)\n        else:\n            src_orientation_rad = torch.from_numpy(\n                src_orientation_rad)  # [nsource, 2]\n\n        src_orientation_decay = freq_invariant_src_decay_func(\n            mic_position,\n            src_position,\n            src_orientation_rad,\n            pattern=src_pattern,\n        )  # [nmic, nsource]\n        # apply decay\n        delta_decay[:, :, 0] *= src_orientation_decay\n        delta_decay[:, :, 1:] *= img_orientation_decay.unsqueeze(0)\n\n    if mic_pattern != \"omni\":\n        # mic orientation simulation #\n        # when not given, assume that all mics facing up (positive z axis)\n        if mic_orientation_rad is None:\n            mic_orientation_rad = torch.stack(\n                [torch.zeros(n_mic), torch.zeros(n_mic)], -1)  # [nmic, 2]\n        else:\n            mic_orientation_rad = torch.from_numpy(mic_orientation_rad)\n        all_src_img_pos = torch.cat(\n            (src_position.unsqueeze(1), image_position),\n            1)  # [nsource, nimage+1, 3]\n        mic_orientation_decay = freq_invariant_mic_decay_func(\n            mic_position,\n            all_src_img_pos,\n            mic_orientation_rad,\n            pattern=mic_pattern,\n        )  # [nmic, nsource, nimage+1]\n        # apply decay\n        delta_decay *= mic_orientation_decay\n\n    rir = torch.zeros(ns, rir_length)\n    delta_decay = delta_decay.view(ns, -1)\n    for i in range(ns):\n        remainder_idx = delta_idx[i]\n        valid_mask = np.ones(len(remainder_idx))\n        while np.sum(valid_mask) > 0:\n            valid_remainder_idx, unique_remainder_idx = np.unique(\n                remainder_idx, return_index=True)\n            rir[i][valid_remainder_idx] += (\n                delta_decay[i][unique_remainder_idx] *\n                valid_mask[unique_remainder_idx])\n            valid_mask[unique_remainder_idx] = 0\n            remainder_idx[unique_remainder_idx] = 0\n\n    direct_mask = torch.zeros(ns, rir_length).float()\n\n    for i in range(ns):\n        direct_mask[\n            i,\n            max(direct_idx[i] + sample_sr * direct_range[0] // 1000, 0\n                ):min(direct_idx[i] +\n                      sample_sr * direct_range[1] // 1000, rir_length), ] = 1.0\n\n    rir_direct = rir * direct_mask\n\n    all_rir = torch.stack([rir, rir_direct], 1).view(ns * 2, -1)\n    rir_downsample = resample1(all_rir)\n    rir_hp = highpass_biquad(rir_downsample, sample_sr // int(np.sqrt(ratio)),\n                             80.0)\n    rir = resample2(rir_hp).float().view(n_mic, num_src, 2, -1)\n\n    return rir[:, :, 0].data.numpy(), rir[:, :, 1].data.numpy()\n\n\ndef sample_mic_arch(n_mic, mic_spacing=None, bounding_box=None):\n    if mic_spacing is None:\n        mic_spacing = [0.02, 0.10]\n    if bounding_box is None:\n        bounding_box = [0.08, 0.12, 0]\n\n    sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1)\n    if sample_n_mic == 1:\n        mic_arch = np.array([[0, 0, 0]])\n    else:\n        mic_arch = []\n        while len(mic_arch) < sample_n_mic:\n            this_mic_pos = np.random.uniform(np.array([0, 0, 0]),\n                                             np.array(bounding_box))\n\n            if len(mic_arch) != 0:\n                ok = True\n                for other_mic_pos in mic_arch:\n                    this_mic_spacing = np.linalg.norm(this_mic_pos -\n                                                      other_mic_pos)\n                    if (this_mic_spacing < mic_spacing[0]\n                            or this_mic_spacing > mic_spacing[1]):\n                        ok = False\n                        break\n                if ok:\n                    mic_arch.append(this_mic_pos)\n            else:\n                mic_arch.append(this_mic_pos)\n        mic_arch = np.stack(mic_arch, 0)  # [nmic, 3]\n    return mic_arch\n\n\ndef sample_src_pos(\n    room_dim,\n    num_src,\n    array_pos,\n    min_mic_dis=0.5,\n    max_mic_dis=5,\n    min_dis_wall=None,\n):\n    if min_dis_wall is None:\n        min_dis_wall = [0.5, 0.5, 0.5]\n\n    # random sample the source positon\n    src_pos = []\n    while len(src_pos) < num_src:\n        pos = np.random.uniform(np.array(min_dis_wall),\n                                np.array(room_dim) - np.array(min_dis_wall))\n        dis = np.linalg.norm(pos - np.array(array_pos))\n\n        if dis >= min_mic_dis and dis <= max_mic_dis:\n            src_pos.append(pos)\n\n    return np.stack(src_pos, 0)\n\n\ndef sample_mic_array_pos(mic_arch, room_dim, min_dis_wall=None):\n    \"\"\"\n    Generate the microphone array position according to the given microphone architecture (geometry)  # noqa\n    :param mic_arch: np.array with shape [n_mic, 3]\n                    the relative 3D coordinate to the array_pos in (m)\n                    e.g., 2-mic linear array [[-0.1, 0, 0], [0.1, 0, 0]];\n                    e.g., 4-mic circular array [[0, 0.035, 0], [0.035, 0, 0], [0, -0.035, 0], [-0.035, 0, 0]]  # noqa\n    :param min_dis_wall: minimum distance from the wall in (m)\n    :return\n        mic_pos: microphone array position in (m) with shape [n_mic, 3]\n        array_pos: array CENTER / REFERENCE position in (m) with shape [1, 3]\n    \"\"\"\n\n    def rotate(angle, valuex, valuey):\n        rotate_x = valuex * np.cos(angle) + valuey * np.sin(angle)  # [nmic]\n        rotate_y = valuey * np.cos(angle) - valuex * np.sin(angle)\n        return np.stack(\n            [rotate_x, rotate_y, np.zeros_like(rotate_x)], -1)  # [nmic, 3]\n\n    if min_dis_wall is None:\n        min_dis_wall = [0.5, 0.5, 0.5]\n\n    if isinstance(mic_arch, dict):  # ADHOC ARRAY\n        n_mic, mic_spacing, bounding_box = (\n            mic_arch[\"n_mic\"],\n            mic_arch[\"spacing\"],\n            mic_arch[\"bounding_box\"],\n        )\n        sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1)\n\n        if sample_n_mic == 1:\n            mic_arch = np.array([[0, 0, 0]])\n        else:\n            mic_arch = [\n                np.random.uniform(np.array([0, 0, 0]), np.array(bounding_box))\n            ]\n            while len(mic_arch) < sample_n_mic:\n                this_mic_pos = np.random.uniform(np.array([0, 0, 0]),\n                                                 np.array(bounding_box))\n                ok = True\n                for other_mic_pos in mic_arch:\n                    this_mic_spacing = np.linalg.norm(this_mic_pos -\n                                                      other_mic_pos)\n                    if (this_mic_spacing < mic_spacing[0]\n                            or this_mic_spacing > mic_spacing[1]):\n                        ok = False\n                        break\n                if ok:\n                    mic_arch.append(this_mic_pos)\n            mic_arch = np.stack(mic_arch, 0)  # [nmic, 3]\n    else:\n        mic_arch = np.array(mic_arch)\n\n    mic_array_center = np.mean(mic_arch, 0, keepdims=True)  # [1, 3]\n    max_radius = max(np.linalg.norm(mic_arch - mic_array_center, axis=-1))\n    array_pos = np.random.uniform(\n        np.array(min_dis_wall) + max_radius,\n        np.array(room_dim) - np.array(min_dis_wall) - max_radius,\n    ).reshape(1, 3)\n    mic_pos = array_pos + mic_arch\n    # assume the array is always horizontal\n    rotate_azm = np.random.uniform(-np.pi, np.pi)\n    mic_pos = array_pos + rotate(rotate_azm, mic_arch[:, 0],\n                                 mic_arch[:, 1])  # [n_mic, 3]\n\n    return mic_pos, array_pos\n\n\ndef sample_a_config(simu_config):\n    room_config = simu_config[\"min_max_room\"]\n    rt60_config = simu_config[\"rt60\"]\n    mic_dist_config = simu_config[\"mic_dist\"]\n    num_src = simu_config[\"num_src\"]\n    room_dim = np.random.uniform(np.array(room_config[0]),\n                                 np.array(room_config[1]))\n    rt60 = np.random.uniform(rt60_config[0], rt60_config[1])\n    sr = simu_config[\"sr\"]\n\n    if (\"array_pos\"\n            not in simu_config.keys()):  # mic_arch must be given in this case\n        mic_arch = simu_config[\"mic_arch\"]\n        mic_pos, array_pos = sample_mic_array_pos(mic_arch, room_dim)\n    else:\n        array_pos = simu_config[\"array_pos\"]\n\n    if \"src_pos\" not in simu_config.keys():\n        src_pos = sample_src_pos(\n            room_dim,\n            num_src,\n            array_pos,\n            min_mic_dis=mic_dist_config[0],\n            max_mic_dis=mic_dist_config[1],\n        )\n    else:\n        src_pos = np.array(simu_config[\"src_pos\"])\n\n    return mic_pos, sr, rt60, room_dim, src_pos, array_pos\n\n\n# === single-channel FRA-RIR ===\ndef single_channel(simu_config):\n    mic_arch = {\"n_mic\": [1, 1], \"spacing\": None, \"bounding_box\": None}\n    simu_config[\"mic_arch\"] = mic_arch\n    mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(\n        simu_config)\n\n    rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos, array_pos)\n    # with shape [1, n_src, rir_len]\n    return rir, rir_direct\n\n\n# === multi-channel (fixed) ===\ndef multi_channel_array(simu_config):\n    mic_arch = [[-0.05, 0, 0], [0.05, 0, 0]]\n\n    simu_config[\"mic_arch\"] = mic_arch\n    mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(\n        simu_config)\n\n    rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos)\n    # with shape [n_mic, n_src, rir_len]\n    return rir, rir_direct\n\n\n# === multi-channel (adhoc) ===\ndef multi_channel_adhoc(simu_config):\n    mic_arch = {\n        \"n_mic\": [1, 3],\n        \"spacing\": [0.02, 0.05],\n        \"bounding_box\": [0.5, 1.0, 0],  # x, y, z\n    }\n    simu_config[\"mic_arch\"] = mic_arch\n    mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(\n        simu_config)\n\n    rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos)\n    # with shape [sample_n_mic, n_src, rir_len]\n    return rir, rir_direct\n\n\ndef multi_channel_src_orientation():\n    \"\"\"\n    ========================= → y axis\n    |                       |\n    |    *1          *2     |\n    |                       |\n    |          ↑            |\n    |                       |\n    |    *3          *4     |\n    |                       |\n    =========================\n    ↓\n    x axis\n    \"\"\"\n    sr = 16000\n    rt60 = 0.6\n    room_dim = [8, 8, 3]\n    src_pos = np.array([[4, 4, 1.5]])  # middle of the room\n    mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5],\n                        [6, 6, 1.5]]  # mic 1, 2\n                       )  # mic 3, 4\n    src_pattern = \"sub_cardioid\"\n    src_orientation_rad = (np.array([180, 90]) / 180.0 * np.pi\n                           )  # facing *front* (negative x axis)\n\n    rir, rir_direct = FRAM_RIR(\n        mic_pos,\n        sr,\n        rt60,\n        room_dim=room_dim,\n        src_pos=src_pos,\n        src_pattern=src_pattern,\n        src_orientation_rad=src_orientation_rad,\n    )\n\n    return rir, rir_direct\n\n\ndef multi_channel_mic_orientation():\n    \"\"\"\n    ========================= → y axis\n    |                       |\n    |    ↑1          ↓2     |\n    |                       |\n    |          o            |\n    |                       |\n    |    ↑3          ↓4     |\n    |                       |\n    =========================\n    ↓\n    x axis\n    \"\"\"\n\n    sr = 16000\n    rt60 = 0.6\n    room_dim = [8, 8, 3]\n    src_pos = np.array([[4, 4, 1.5]])  # middle of the room\n    mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5],\n                        [6, 6, 1.5]]  # mic 1, 2\n                       )  # mic 3, 4\n    mic_pattern = \"sub_cardioid\"\n    mic_orientation_rad = (\n        np.array([\n            [180, 90],\n            [0, 90],  # mic 1 (negative x axis), 2 (positive x axis)\n            [180, 90],\n            [0, 90],\n        ]) / 180.0 * np.pi)  # mic 3 (negative x axis), 4 (positive x axis)\n\n    rir, rir_direct = FRAM_RIR(\n        mic_pos,\n        sr,\n        rt60,\n        room_dim=room_dim,\n        src_pos=src_pos,\n        mic_pattern=mic_pattern,\n        mic_orientation_rad=mic_orientation_rad,\n    )\n    return rir, rir_direct\n"
  },
  {
    "path": "wesep/dataset/dataset.py",
    "content": "# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)\n#               2023 Shuai Wang (wsstriving@gmail.com)\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as tf\nfrom torch.utils.data import IterableDataset\n\nimport wesep.dataset.processor as processor\nfrom wesep.utils.file_utils import read_lists\n\n\nclass Processor(IterableDataset):\n\n    def __init__(self, source, f, *args, **kw):\n        assert callable(f)\n        self.source = source\n        self.f = f\n        self.args = args\n        self.kw = kw\n\n    def set_epoch(self, epoch):\n        self.source.set_epoch(epoch)\n\n    def __iter__(self):\n        \"\"\"Return an iterator over the source dataset processed by the\n        given processor.\n        \"\"\"\n        assert self.source is not None\n        assert callable(self.f)\n        return self.f(iter(self.source), *self.args, **self.kw)\n\n    def apply(self, f):\n        assert callable(f)\n        return Processor(self, f, *self.args, **self.kw)\n\n\nclass DistributedSampler:\n\n    def __init__(self, shuffle=True, partition=True):\n        self.epoch = -1\n        self.update()\n        self.shuffle = shuffle\n        self.partition = partition\n\n    def update(self):\n        assert dist.is_available()\n        if dist.is_initialized():\n            self.rank = dist.get_rank()\n            self.world_size = dist.get_world_size()\n        else:\n            self.rank = 0\n            self.world_size = 1\n        worker_info = torch.utils.data.get_worker_info()\n        if worker_info is None:\n            self.worker_id = 0\n            self.num_workers = 1\n        else:\n            self.worker_id = worker_info.id\n            self.num_workers = worker_info.num_workers\n        return dict(\n            rank=self.rank,\n            world_size=self.world_size,\n            worker_id=self.worker_id,\n            num_workers=self.num_workers,\n        )\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n    def sample(self, data):\n        \"\"\"Sample data according to rank/world_size/num_workers\n\n        Args:\n            data(List): input data list\n\n        Returns:\n            List: data list after sample\n        \"\"\"\n        data = list(range(len(data)))\n        if len(data) <= self.num_workers:\n            if self.shuffle:\n                random.Random(self.epoch).shuffle(data)\n        else:\n            if self.partition:\n                if self.shuffle:\n                    random.Random(self.epoch).shuffle(data)\n                data = data[self.rank::self.world_size]\n            data = data[self.worker_id::self.num_workers]\n        return data\n\n\nclass DataList(IterableDataset):\n\n    def __init__(self,\n                 lists,\n                 shuffle=True,\n                 partition=True,\n                 repeat_dataset=False):\n        self.lists = lists\n        self.repeat_dataset = repeat_dataset\n        self.sampler = DistributedSampler(shuffle, partition)\n\n    def set_epoch(self, epoch):\n        self.sampler.set_epoch(epoch)\n\n    def __iter__(self):\n        sampler_info = self.sampler.update()\n        indexes = self.sampler.sample(self.lists)\n        if not self.repeat_dataset:\n            for index in indexes:\n                data = dict(src=self.lists[index])\n                data.update(sampler_info)\n                yield data\n        else:\n            indexes_len = len(indexes)\n            counter = 0\n            while True:\n                index = indexes[counter % indexes_len]\n                counter += 1\n                data = dict(src=self.lists[index])\n                data.update(sampler_info)\n                yield data\n\n\ndef tse_collate_fn_2spk(batch, mode=\"min\"):\n    # Warning: hard-coded for 2 speakers, will be deprecated in the future,\n    # use tse_collate_fn instead\n    new_batch = {}\n\n    wav_mix = []\n    wav_targets = []\n    spk_embeds = []\n    spk = []\n    key = []\n    spk_label = []\n    length_spk_embeds = []\n    for s in batch:\n        wav_mix.append(s[\"wav_mix\"])\n        wav_targets.append(s[\"wav_spk1\"])\n        spk.append(s[\"spk1\"])\n        key.append(s[\"key\"])\n        spk_embeds.append(torch.from_numpy(s[\"embed_spk1\"].copy()))\n        length_spk_embeds.append(spk_embeds[-1].shape[1])\n        if \"spk1_label\" in s.keys():\n            spk_label.append(s[\"spk1_label\"])\n\n        wav_mix.append(s[\"wav_mix\"])\n        wav_targets.append(s[\"wav_spk2\"])\n        spk.append(s[\"spk2\"])\n        key.append(s[\"key\"])\n        spk_embeds.append(torch.from_numpy(s[\"embed_spk2\"].copy()))\n        length_spk_embeds.append(spk_embeds[-1].shape[1])\n        if \"spk2_label\" in s.keys():\n            spk_label.append(s[\"spk2_label\"])\n\n    if not (len(set(length_spk_embeds)) == 1):\n        if mode == \"max\":\n            max_len = max(length_spk_embeds)\n            for i in range(len(length_spk_embeds)):\n                if len(spk_embeds[i].shape) == 2:\n                    spk_embeds[i] = tf.pad(\n                        spk_embeds[i],\n                        (0, max_len - length_spk_embeds[i]),\n                        \"constant\",\n                        0,\n                    )\n                elif len(spk_embeds[i].shape) == 3:\n                    spk_embeds[i] = tf.pad(\n                        spk_embeds[i],\n                        (0, 0, 0, max_len - length_spk_embeds[i]),\n                        \"constant\",\n                        0,\n                    )\n        if mode == \"min\":\n            min_len = min(length_spk_embeds)\n            for i in range(len(length_spk_embeds)):\n                if len(spk_embeds[i].shape) == 2:\n                    spk_embeds[i] = spk_embeds[i][:, :min_len]\n                elif len(spk_embeds[i].shape) == 3:\n                    spk_embeds[i] = spk_embeds[i][:, :min_len, :]\n\n    new_batch[\"wav_mix\"] = torch.concat(wav_mix)\n    new_batch[\"wav_targets\"] = torch.concat(wav_targets)\n    new_batch[\"spk_embeds\"] = torch.concat(spk_embeds)\n    new_batch[\"length_spk_embeds\"] = length_spk_embeds\n    new_batch[\"spk\"] = spk\n    new_batch[\"key\"] = key\n    new_batch[\"spk_label\"] = torch.as_tensor(spk_label)\n    return new_batch\n\n\ndef tse_collate_fn(batch, mode=\"min\"):\n    # This is a more generalizable implementation for target speaker extraction\n    # Support arbitrary number of speakers\n    new_batch = {}\n    wav_mix = []\n    wav_targets = []\n    spk_embeds = []\n    spk = []\n    key = []\n    spk_label = []\n    length_spk_embeds = []\n    for s in batch:\n        for i in range(s[\"num_speaker\"]):\n            wav_mix.append(s[\"wav_mix\"])\n            wav_targets.append(s[\"wav_spk{}\".format(i + 1)])\n            spk.append(s[\"spk{}\".format(i + 1)])\n            key.append(s[\"key\"])\n            spk_embeds.append(\n                torch.from_numpy(s[\"embed_spk{}\".format(i + 1)].copy()))\n            length_spk_embeds.append(spk_embeds[-1].shape[1])\n            if \"spk{}_label\".format(i + 1) in s.keys():\n                spk_label.append(s[\"spk{}_label\".format(i + 1)])\n\n    if not (len(set(length_spk_embeds)) == 1):\n        if mode == \"max\":\n            max_len = max(length_spk_embeds)\n            for i in range(len(length_spk_embeds)):\n                if len(spk_embeds[i].shape) == 2:\n                    spk_embeds[i] = tf.pad(\n                        spk_embeds[i],\n                        (0, max_len - length_spk_embeds[i]),\n                        \"constant\",\n                        0,\n                    )\n                elif len(spk_embeds[i].shape) == 3:\n                    spk_embeds[i] = tf.pad(\n                        spk_embeds[i],\n                        (0, 0, 0, max_len - length_spk_embeds[i]),\n                        \"constant\",\n                        0,\n                    )\n        if mode == \"min\":\n            min_len = min(length_spk_embeds)\n            for i in range(len(length_spk_embeds)):\n                if len(spk_embeds[i].shape) == 2:\n                    spk_embeds[i] = spk_embeds[i][:, :min_len]\n                elif len(spk_embeds[i].shape) == 3:\n                    spk_embeds[i] = spk_embeds[i][:, :min_len, :]\n\n    new_batch[\"wav_mix\"] = torch.concat(wav_mix)\n    new_batch[\"wav_targets\"] = torch.concat(wav_targets)\n    new_batch[\"spk_embeds\"] = torch.concat(spk_embeds)\n    new_batch[\"length_spk_embeds\"] = (\n        length_spk_embeds  # Not used, but maybe needed when using the enrollment utterance  # noqa\n    )\n    new_batch[\"spk\"] = spk\n    new_batch[\"key\"] = key\n    new_batch[\"spk_label\"] = torch.as_tensor(spk_label)\n    return new_batch\n\n\ndef Dataset(\n    data_type,\n    data_list_file,\n    configs,\n    spk2embed_dict=None,\n    spk1_embed=None,\n    spk2_embed=None,\n    state=\"train\",\n    joint_training=False,\n    dict_spk=None,\n    whole_utt=False,\n    repeat_dataset=False,\n    noise_prob=0,\n    reverb_prob=0,\n    noise_enroll_prob=0,\n    reverb_enroll_prob=0,\n    specaug_enroll_prob=0,\n    noise_lmdb_file=None,\n    online_mix=False,\n):\n    \"\"\"Construct dataset from arguments\n    We have two shuffle stage in the Dataset. The first is global\n    shuffle at shards tar/raw/feat file level. The second is local shuffle\n    at training samples level.\n\n    Args:\n        :param spk2_embed:\n        :param online_mix:\n        :param spk1_embed:\n        :param data_type(str): shard/raw/feat\n        :param data_list_file: data list file\n        :param configs: dataset configs\n        :param noise_prob:probility to add noise on mixture\n        :param reverb_prob:probility to add reverb on mixture\n        :param noise_enroll_prob:probility to add noise on enrollment speech\n        :param reverb_enroll_prob:probility to add reverb on enrollment speech\n        :param specaug_enroll_prob: probility to apply SpecAug on fbank of enrollment speech  # noqa\n        :param noise_lmdb_file: noise data source lmdb file\n        :param whole_utt: use whole utt or random chunk\n        :param repeat_dataset:\n    \"\"\"\n    assert data_type in [\"shard\", \"raw\"]\n    lists = read_lists(data_list_file)\n    shuffle = configs.get(\"shuffle\", False)\n    # Global shuffle\n    dataset = DataList(lists, shuffle=shuffle, repeat_dataset=repeat_dataset)\n    if data_type == \"shard\":\n        dataset = Processor(dataset, processor.url_opener)\n        if not online_mix:\n            dataset = Processor(dataset, processor.tar_file_and_group)\n        else:\n            dataset = Processor(dataset,\n                                processor.tar_file_and_group_single_spk)\n    else:\n        dataset = Processor(dataset, processor.parse_raw)\n\n    if configs.get(\"filter_len\", False) and state == \"train\":\n        # Filter the data with unwanted length\n        filter_conf = configs.get(\"filter_args\", {})\n        dataset = Processor(dataset, processor.filter_len, **filter_conf)\n    # Local shuffle\n    if shuffle and not online_mix:\n        dataset = Processor(dataset, processor.shuffle,\n                            **configs[\"shuffle_args\"])\n\n    # resample\n    resample_rate = configs.get(\"resample_rate\", 16000)\n    dataset = Processor(dataset, processor.resample, resample_rate)\n\n    if not whole_utt:\n        # random chunk\n        chunk_len = configs.get(\"chunk_len\", resample_rate * 3)\n        dataset = Processor(dataset, processor.random_chunk, chunk_len)\n\n    if online_mix:\n        dataset = Processor(\n            dataset,\n            processor.mix_speakers,\n            configs.get(\"num_speakers\", 2),\n            configs.get(\"online_buffer_size\", 1000),\n        )\n        if reverb_prob > 0:\n            dataset = Processor(dataset, processor.add_reverb, reverb_prob)\n        dataset = Processor(\n            dataset,\n            processor.snr_mixer,\n            configs.get(\"use_random_snr\", False),\n        )\n    if noise_prob > 0:\n        assert noise_lmdb_file is not None\n        dataset = Processor(dataset, processor.add_noise, noise_lmdb_file,\n                            noise_prob)\n    speaker_feat = configs.get(\"speaker_feat\", False)\n    if state == \"train\":\n        if not joint_training:\n            dataset = Processor(dataset, processor.sample_spk_embedding,\n                                spk2embed_dict)\n        else:\n            dataset = Processor(dataset, processor.sample_enrollment,\n                                spk2embed_dict, dict_spk)\n            if reverb_enroll_prob > 0:\n                dataset = Processor(dataset, processor.add_reverb_on_enroll,\n                                    reverb_enroll_prob)\n            if noise_enroll_prob > 0:\n                assert noise_lmdb_file is not None\n                dataset = Processor(\n                    dataset,\n                    processor.add_noise_on_enroll,\n                    noise_lmdb_file,\n                    noise_enroll_prob,\n                )\n            if speaker_feat:\n                dataset = Processor(dataset, processor.compute_fbank,\n                                    **configs[\"fbank_args\"])\n                dataset = Processor(dataset, processor.apply_cmvn)\n                if specaug_enroll_prob > 0:\n                    dataset = Processor(dataset,\n                                        processor.spec_aug,\n                                        prob=specaug_enroll_prob)\n    else:\n        if not joint_training:\n            dataset = Processor(\n                dataset,\n                processor.sample_fix_spk_embedding,\n                spk2embed_dict,\n                spk1_embed,\n                spk2_embed,\n            )\n        else:\n            dataset = Processor(\n                dataset,\n                processor.sample_fix_spk_enrollment,\n                spk2embed_dict,\n                spk1_embed,\n                spk2_embed,\n                dict_spk,\n            )\n            if speaker_feat:\n                dataset = Processor(dataset, processor.compute_fbank,\n                                    **configs[\"fbank_args\"])\n                dataset = Processor(dataset, processor.apply_cmvn)\n\n    return dataset\n"
  },
  {
    "path": "wesep/dataset/lmdb_data.py",
    "content": "# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport pickle\r\nimport random\r\n\r\nimport lmdb\r\n\r\n\r\nclass LmdbData:\r\n\r\n    def __init__(self, lmdb_file):\r\n        self.db = lmdb.open(lmdb_file,\r\n                            readonly=True,\r\n                            lock=False,\r\n                            readahead=False)\r\n        with self.db.begin(write=False) as txn:\r\n            obj = txn.get(b\"__keys__\")\r\n            assert obj is not None\r\n            self.keys = pickle.loads(obj)\r\n            assert isinstance(self.keys, list)\r\n\r\n    def random_one(self):\r\n        assert len(self.keys) > 0\r\n        index = random.randint(0, len(self.keys) - 1)\r\n        key = self.keys[index]\r\n        with self.db.begin(write=False) as txn:\r\n            value = txn.get(key.encode())\r\n            assert value is not None\r\n        return key, value\r\n\r\n    def __del__(self):\r\n        self.db.close()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    import sys\r\n\r\n    db = LmdbData(sys.argv[1])\r\n    key, _ = db.random_one()\r\n    print(key)\r\n    key, _ = db.random_one()\r\n    print(key)\r\n"
  },
  {
    "path": "wesep/dataset/processor.py",
    "content": "import io\nimport json\nimport logging\nimport random\nimport tarfile\nfrom subprocess import PIPE, Popen\nfrom urllib.parse import urlparse\n\nimport librosa\nimport numpy as np\nimport soundfile as sf\nimport torch\nimport torchaudio\nimport torchaudio.compliance.kaldi as kaldi\nfrom scipy import signal\n\nfrom wesep.dataset.FRAM_RIR import single_channel as RIR_sim\nfrom wesep.dataset.lmdb_data import LmdbData\n\nAUDIO_FORMAT_SETS = {\"flac\", \"mp3\", \"m4a\", \"ogg\", \"opus\", \"wav\", \"wma\"}\n\n# set the simulation configuration\nsimu_config = {\n    \"min_max_room\": [[3, 3, 2.5], [10, 6, 4]],\n    \"rt60\": [0.1, 0.7],\n    \"sr\": 16000,\n    \"mic_dist\": [0.2, 5.0],\n    \"num_src\": 1,\n}\n\n\ndef url_opener(data):\n    \"\"\"Give url or local file, return file descriptor\n    Inplace operation.\n\n    Args:\n        data(Iterable[str]): url or local file list\n\n    Returns:\n        Iterable[{src, stream}]\n    \"\"\"\n    for sample in data:\n        assert \"src\" in sample\n        # TODO(Binbin Zhang): support HTTP\n        url = sample[\"src\"]\n        try:\n            pr = urlparse(url)\n            # local file\n            if pr.scheme == \"\" or pr.scheme == \"file\":\n                stream = open(url, \"rb\")\n            # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP\n            else:\n                cmd = f\"wget -q -O - {url}\"\n                process = Popen(cmd, shell=True, stdout=PIPE)\n                sample.update(process=process)\n                stream = process.stdout\n            sample.update(stream=stream)\n            yield sample\n        except Exception as ex:\n            logging.warning(\"Failed to open {}\".format(url))\n\n\ndef tar_file_and_group(data):\n    \"\"\"Expand a stream of open tar files into a stream of tar file contents.\n    And groups the file with same prefix\n\n    Args:\n        data: Iterable[{src, stream}]\n\n    Returns:\n        Iterable[{key, mix_wav, spk1_wav, spk2_wav, ..., sample_rate}]\n    \"\"\"\n    for sample in data:\n        assert \"stream\" in sample\n        stream = tarfile.open(fileobj=sample[\"stream\"], mode=\"r:*\")\n        # TODO: The mode need to be validated\n        # In order to be compatible with the torch 2.x version,\n        # the file reading method here does not use streaming.\n        prev_prefix = None\n        example = {}\n        num_speakers = 0\n        valid = True\n        for tarinfo in stream:\n            name = tarinfo.name\n            pos = name.rfind(\".\")\n            assert pos > 0\n            prefix, postfix = name[:pos], name[pos + 1:]\n            if prev_prefix is not None and prev_prefix not in prefix:\n                example[\"key\"] = prev_prefix\n                if valid:\n                    example[\"num_speaker\"] = num_speakers\n                    num_speakers = 0\n                    yield example\n                example = {}\n                valid = True\n            with stream.extractfile(tarinfo) as file_obj:\n                try:\n                    if \"spk\" in postfix:\n                        example[postfix] = (\n                            file_obj.read().decode(\"utf8\").strip())\n                        num_speakers += 1\n                    elif postfix in AUDIO_FORMAT_SETS:\n                        waveform, sample_rate = torchaudio.load(file_obj)\n                        if prefix[-5:-1] == \"_spk\":\n                            example[\"wav\" + prefix[-5:]] = waveform\n                            prefix = prefix[:-5]\n                        else:\n                            example[\"wav_mix\"] = waveform\n                            example[\"sample_rate\"] = sample_rate\n                    else:\n                        example[postfix] = file_obj.read()\n                except Exception as ex:\n                    valid = False\n                    logging.warning(\"error to parse {}\".format(name))\n            prev_prefix = prefix\n\n        if prev_prefix is not None:\n            example[\"key\"] = prev_prefix\n            example[\"num_speaker\"] = num_speakers\n            num_speakers = 0\n            yield example\n        stream.close()\n        if \"process\" in sample:\n            sample[\"process\"].communicate()\n        sample[\"stream\"].close()\n\n\ndef tar_file_and_group_single_spk(data):\n    \"\"\"Expand a stream of open tar files into a stream of tar file contents.\n    And groups the file with same prefix\n\n    Args:\n        data: Iterable[{src, stream}]\n\n    Returns:\n        Iterable[{key, wav, spk, sample_rate}]\n    \"\"\"\n    for sample in data:\n        assert \"stream\" in sample\n        stream = tarfile.open(fileobj=sample[\"stream\"],\n                              mode=\"r|*\")  # Only support pytorch version <2.0\n        prev_prefix = None\n        example = {}\n        valid = True\n        for tarinfo in stream:\n            name = tarinfo.name\n            pos = name.rfind(\".\")\n            assert pos > 0\n            prefix, postfix = name[:pos], name[pos + 1:]\n            if prev_prefix is not None and prefix != prev_prefix:\n                example[\"key\"] = prev_prefix\n                if valid:\n                    yield example\n                example = {}\n                valid = True\n            with stream.extractfile(tarinfo) as file_obj:\n                try:\n                    if postfix in [\"spk\"]:\n                        example[postfix] = (\n                            file_obj.read().decode(\"utf8\").strip())\n                    elif postfix in AUDIO_FORMAT_SETS:\n                        waveform, sample_rate = torchaudio.load(file_obj)\n                        example[\"wav\"] = waveform\n                        example[\"sample_rate\"] = sample_rate\n                    else:\n                        example[postfix] = file_obj.read()\n                except Exception as ex:\n                    valid = False\n                    logging.warning(\"error to parse {}\".format(name))\n            prev_prefix = prefix\n        if prev_prefix is not None:\n            example[\"key\"] = prev_prefix\n            yield example\n        stream.close()\n        if \"process\" in sample:\n            sample[\"process\"].communicate()\n        sample[\"stream\"].close()\n\n\ndef parse_raw_single_spk(data):\n    \"\"\"Parse key/wav/spk from json line\n\n    Args:\n        data: Iterable[str], str is a json line has key/wav/spk\n\n    Returns:\n        Iterable[{key, wav, spk, sample_rate}]\n    \"\"\"\n    for sample in data:\n        assert \"src\" in sample\n        json_line = sample[\"src\"]\n        obj = json.loads(json_line)\n        assert \"key\" in obj\n        assert \"wav\" in obj\n        assert \"spk\" in obj\n        key = obj[\"key\"]\n        wav_file = obj[\"wav\"]\n        spk = obj[\"spk\"]\n        try:\n            waveform, sample_rate = torchaudio.load(wav_file)\n            example = dict(key=key,\n                           spk=spk,\n                           wav=waveform,\n                           sample_rate=sample_rate)\n            yield example\n        except Exception as ex:\n            logging.warning(\"Failed to read {}\".format(wav_file))\n\n\ndef mix_speakers(data, num_speaker=2, shuffle_size=1000):\n    \"\"\"Dynamic mixing speakers when loading data,\n    shuffle is not needed if this function is used\n    Args:\n        :param data: Iterable[{key, wavs, spks}]\n        :param num_speaker:\n        :param use_random_snr:\n        :param shuffle_size:\n    Returns:\n        Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n    \"\"\"\n    buf = []\n    for sample in data:\n        buf.append(sample)\n        if len(buf) >= shuffle_size:\n            random.shuffle(buf)\n            for x in buf:\n                cur_spk = x[\"spk\"]\n                example = {\n                    \"key\": x[\"key\"],\n                    \"wav_spk1\": x[\"wav\"],\n                    \"spk1\": x[\"spk\"],\n                    \"sample_rate\": x[\"sample_rate\"],\n                }\n                key = \"mix_\" + x[\"key\"]\n                interference_idx = 1\n                while interference_idx < num_speaker:\n                    interference = random.choice(buf)\n                    while interference[\"spk\"] == cur_spk:\n                        interference = random.choice(buf)\n                    key = key + \"_\" + interference[\"key\"]\n                    interference_idx += 1\n                    example[\"wav_spk\" +\n                            str(interference_idx)] = interference[\"wav\"]\n                    example[\"spk\" +\n                            str(interference_idx)] = interference[\"spk\"]\n                example[\"key\"] = key\n                example[\"num_speaker\"] = num_speaker\n                yield example\n\n            buf = []\n\n    # The samples left over\n    random.shuffle(buf)\n    for x in buf:\n        cur_spk = x[\"spk\"]\n        example = {\n            \"key\": x[\"key\"],\n            \"wav_spk1\": x[\"wav\"],\n            \"spk1\": x[\"spk\"],\n            \"sample_rate\": x[\"sample_rate\"],\n        }\n        key = \"mix_\" + x[\"key\"]\n        interference_idx = 1\n        while interference_idx < num_speaker:\n            interference = random.choice(buf)\n            while interference[\"spk\"] == cur_spk:\n                interference = random.choice(buf)\n            key = key + \"_\" + interference[\"key\"]\n            interference_idx += 1\n            example[\"wav_spk\" + str(interference_idx)] = interference[\"wav\"]\n            example[\"spk\" + str(interference_idx)] = interference[\"spk\"]\n        example[\"key\"] = key\n        example[\"num_speaker\"] = num_speaker\n        yield example\n\n\ndef snr_mixer(data, use_random_snr: bool = False):\n    \"\"\"Dynamic mixing speakers when loading data, shuffle is not needed if this function is used.  # noqa\n\n    Args:\n        data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n        use_random_snr (bool, optional): Whether use random SNR to mix speeches. Defaults to False.  # noqa\n\n    Returns:\n        Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n    \"\"\"\n    for sample in data:\n        assert \"num_speaker\" in sample.keys()\n        if \"wav_spk1_reverb\" in sample.keys():\n            suffix = \"_reverb\"\n        else:\n            suffix = \"\"\n        num_speaker = sample[\"num_speaker\"]\n        wavs_to_mix = [sample[\"wav_spk1\" + suffix]]\n        target_energy = torch.sum(wavs_to_mix[0]**2, dim=-1, keepdim=True)\n        for i in range(1, num_speaker):\n            interference = sample[f\"wav_spk{i + 1}\" + suffix]\n            if use_random_snr:\n                snr = random.uniform(-10, 10)\n            else:\n                snr = 0\n            energy = torch.sum(interference**2, dim=-1, keepdim=True)\n            interference *= torch.sqrt(target_energy / energy) * 10**(snr / 20)\n            wavs_to_mix.append(interference)\n        wavs_to_mix = torch.stack(wavs_to_mix)\n        sample[\"wav_mix\"] = torch.sum(wavs_to_mix, 0)\n        max_amp = max(\n            torch.abs(sample[\"wav_mix\"]).max().item(),\n            *[x.item() for x in torch.abs(wavs_to_mix).max(dim=-1)[0]],\n        )\n        if max_amp != 0:\n            mix_scaling = 1 / max_amp\n        else:\n            mix_scaling = 1\n\n        sample[\"wav_mix\"] = sample[\"wav_mix\"] * mix_scaling\n        for i in range(0, num_speaker):\n            sample[f\"wav_spk{i + 1}\" + suffix] *= mix_scaling\n\n        yield sample\n\n\ndef shuffle(data, shuffle_size=2500):\n    \"\"\"Local shuffle the data\n\n    Args:\n        data: Iterable[{key, wavs, spks}]\n        shuffle_size: buffer size for shuffle\n\n    Returns:\n        Iterable[{key, wavs, spks}]\n    \"\"\"\n    buf = []\n    for sample in data:\n        buf.append(sample)\n        if len(buf) >= shuffle_size:\n            random.shuffle(buf)\n            for x in buf:\n                yield x\n            buf = []\n    # The sample left over\n    random.shuffle(buf)\n    for x in buf:\n        yield x\n\n\ndef spk_to_id(data, spk2id):\n    \"\"\"Parse spk id\n\n    Args:\n        data: Iterable[{key, wav/feat, spk}]\n        spk2id: Dict[str, int]\n\n    Returns:\n        Iterable[{key, wav/feat, label}]\n    \"\"\"\n    for sample in data:\n        assert \"spk\" in sample\n        if sample[\"spk\"] in spk2id:\n            label = spk2id[sample[\"spk\"]]\n        else:\n            label = -1\n        sample[\"label\"] = label\n        yield sample\n\n\ndef resample(data, resample_rate=16000):\n    \"\"\"Resample data.\n    Inplace operation.\n    Args:\n        data: Iterable[{key, wavs, spks, sample_rate}]\n        resample_rate: target resample rate\n    Returns:\n        Iterable[{key, wavs, spks, sample_rate}]\n    \"\"\"\n    for sample in data:\n        assert \"sample_rate\" in sample\n        sample_rate = sample[\"sample_rate\"]\n        if sample_rate != resample_rate:\n            all_keys = list(sample.keys())\n            sample[\"sample_rate\"] = resample_rate\n            for key in all_keys:\n                if \"wav\" in key:\n                    waveform = sample[key]\n                    sample[key] = torchaudio.transforms.Resample(\n                        orig_freq=sample_rate,\n                        new_freq=resample_rate)(waveform)\n        yield sample\n\n\ndef sample_spk_embedding(data, spk_embeds):\n    \"\"\"sample reference speaker embeddings for the target speaker\n    Args:\n        data: Iterable[{key, wav, label, sample_rate}]\n        spk_embeds: dict which stores all potential embeddings for the speaker\n    Returns:\n        Iterable[{key, wav, label, sample_rate}]\n    \"\"\"\n    for sample in data:\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"spk\"):\n                sample[\"embed_\" + key] = random.choice(spk_embeds[sample[key]])\n        yield sample\n\n\ndef sample_fix_spk_embedding(data, spk2embed_dict, spk1_embed, spk2_embed):\n    \"\"\"sample reference speaker embeddings for the target speaker\n    Args:\n        data: Iterable[{key, wav, label, sample_rate}]\n        spk_embeds: dict which stores all potential embeddings for the speaker\n    Returns:\n        Iterable[{key, wav, label, sample_rate}]\n    \"\"\"\n    for sample in data:\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"spk\"):\n                if key == \"spk1\":\n                    sample[\"embed_\" +\n                           key] = spk2embed_dict[spk1_embed[sample[\"key\"]]]\n                else:\n                    sample[\"embed_\" +\n                           key] = spk2embed_dict[spk2_embed[sample[\"key\"]]]\n        yield sample\n\n\ndef sample_enrollment(data, spk_embeds, dict_spk):\n    \"\"\"sample reference speech for the target speaker\n    Args:\n        data: Iterable[{key, wav, label, sample_rate}]\n        spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker  # noqa\n        dict_spk: dict of speakers in the enrollment sets [Order: spkID]\n    Returns:\n        Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment),  # noqa\n                  spk_lable(when multi-task training)}]\n    \"\"\"\n    for sample in data:\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"spk\"):\n                enrollment, _ = sf.read(\n                    random.choice(spk_embeds[sample[key]])[1])\n                sample[\"embed_\" + key] = np.expand_dims(enrollment, axis=0)\n                if dict_spk:\n                    sample[key + \"_label\"] = dict_spk[sample[key]]\n        yield sample\n\n\ndef sample_fix_spk_enrollment(data,\n                              spk2embed_dict,\n                              spk1_embed,\n                              spk2_embed,\n                              dict_spk=None):\n    \"\"\"sample reference speaker embeddings for the target speaker\n    Args:\n        data: Iterable[{key, wav, label, sample_rate}]\n        spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker  # noqa\n        dict_spk: dict of speakers in the enrollment sets [Order: spkID]\n    Returns:\n        Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment),  # noqa\n                  spk_lable(when multi-task training)}]\n    \"\"\"\n    for sample in data:\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"spk\"):\n                if key == \"spk1\":\n                    enrollment, _ = sf.read(\n                        spk2embed_dict[spk1_embed[sample[\"key\"]]])\n                else:\n                    enrollment, _ = sf.read(\n                        spk2embed_dict[spk2_embed[sample[\"key\"]]])\n                sample[\"embed_\" + key] = np.expand_dims(enrollment, axis=0)\n                if dict_spk:\n                    sample[key + \"_label\"] = dict_spk[sample[key]]\n        yield sample\n\n\ndef compute_fbank(data,\n                  num_mel_bins=80,\n                  frame_length=25,\n                  frame_shift=10,\n                  dither=1.0):\n    \"\"\"Extract fbank\n\n    Args:\n        data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']  # noqa\n\n    Returns:\n        Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']  # noqa\n    \"\"\"\n    for sample in data:\n        assert \"sample_rate\" in sample\n        sample_rate = sample[\"sample_rate\"]\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"embed\"):\n                waveform = torch.from_numpy(sample[key])\n                waveform = waveform * (1 << 15)\n                mat = kaldi.fbank(\n                    waveform,\n                    num_mel_bins=num_mel_bins,\n                    frame_length=frame_length,\n                    frame_shift=frame_shift,\n                    dither=dither,\n                    sample_frequency=sample_rate,\n                    window_type=\"hamming\",\n                    use_energy=False,\n                )\n                sample[key] = mat\n        yield sample\n\n\ndef apply_cmvn(data, norm_mean=True, norm_var=False):\n    \"\"\"Apply CMVN\n\n    Args:\n        data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']  # noqa\n\n    Returns:\n        Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']  # noqa\n    \"\"\"\n    for sample in data:\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"embed\"):\n                mat = sample[key]\n                if norm_mean:\n                    mat = mat - torch.mean(mat, dim=0)\n                if norm_var:\n                    mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8)\n                mat = mat.unsqueeze(0)\n                sample[key] = mat.detach().numpy()\n        yield sample\n\n\ndef get_random_chunk(data_list, chunk_len):\n    \"\"\"Get random chunk\n\n    Args:\n        data_list: [torch.Tensor: 1XT] (random len)\n        chunk_len: chunk length\n\n    Returns:\n        [torch.Tensor] (exactly chunk_len)\n    \"\"\"\n    # Assert all entries in the list share the same length\n    assert False not in [len(i) == len(data_list[0]) for i in data_list]\n    data_list = [data[0] for data in data_list]\n\n    data_len = len(data_list[0])\n\n    # random chunk\n    if data_len >= chunk_len:\n        chunk_start = random.randint(0, data_len - chunk_len)\n        for i in range(len(data_list)):\n            temp_data = data_list[i][chunk_start:chunk_start + chunk_len]\n            while torch.equal(temp_data, torch.zeros_like(temp_data)):\n                chunk_start = random.randint(0, data_len - chunk_len)\n                temp_data = data_list[i][chunk_start:chunk_start + chunk_len]\n            data_list[i] = temp_data\n            # re-clone the data to avoid memory leakage\n            if type(data_list[i]) == torch.Tensor:\n                data_list[i] = data_list[i].clone()\n            else:  # np.array\n                data_list[i] = data_list[i].copy()\n    else:\n        # padding\n        repeat_factor = chunk_len // data_len + 1\n        for i in range(len(data_list)):\n            if type(data_list[i]) == torch.Tensor:\n                data_list[i] = data_list[i].repeat(repeat_factor)\n            else:  # np.array\n                data_list[i] = np.tile(data_list[i], repeat_factor)\n            data_list[i] = data_list[i][:chunk_len]\n    data_list = [data.unsqueeze(0) for data in data_list]\n    return data_list\n\n\ndef filter_len(\n    data,\n    min_num_seconds=1,\n    max_num_seconds=1000,\n):\n    \"\"\"Filter the utterance with very short duration and random chunk the\n    utterance with very long duration.\n\n    Args:\n        data: Iterable[{key, wav, label, sample_rate}]\n        min_num_seconds: minimum number of seconds of wav file\n        max_num_seconds: maximum number of seconds of wav file\n    Returns:\n        Iterable[{key, wav, label, sample_rate}]\n    \"\"\"\n    for sample in data:\n        assert \"key\" in sample\n        assert \"sample_rate\" in sample\n        assert \"wav\" in sample\n        sample_rate = sample[\"sample_rate\"]\n        wav = sample[\"wav\"]\n        min_len = min_num_seconds * sample_rate\n        max_len = max_num_seconds * sample_rate\n        if wav.size(1) < min_len:\n            continue\n        elif wav.size(1) > max_len:\n            wav = get_random_chunk([wav], max_len)[0]\n        sample[\"wav\"] = wav\n        yield sample\n\n\ndef random_chunk(data, chunk_len):\n    \"\"\"Random chunk the data into chunk_len\n\n    Args:\n        data: Iterable[{key, wav/feat, label}]\n        chunk_len: chunk length for each sample\n\n    Returns:\n        Iterable[{key, wav/feat, label}]\n    \"\"\"\n    for sample in data:\n        assert \"key\" in sample\n        wav_keys = [key for key in list(sample.keys()) if \"wav\" in key]\n        wav_data_list = [sample[key] for key in wav_keys]\n        wav_data_list = get_random_chunk(wav_data_list, chunk_len)\n        sample.update(zip(wav_keys, wav_data_list))\n        yield sample\n\n\ndef fix_chunk(data, chunk_len):\n    \"\"\"Random chunk the data into chunk_len\n\n    Args:\n        data: Iterable[{key, wav/feat, label}]\n        chunk_len: chunk length for each sample\n\n    Returns:\n        Iterable[{key, wav/feat, label}]\n    \"\"\"\n    for sample in data:\n        assert \"key\" in sample\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"wav\"):\n                sample[key] = sample[key][:, :chunk_len]\n        yield sample\n\n\ndef add_noise(\n    data,\n    noise_lmdb_file,\n    noise_prob: float = 0.0,\n    noise_db_low: int = -5,\n    noise_db_high: int = 25,\n    single_channel: bool = True,\n):\n    \"\"\"Add noise to mixture\n\n    Args:\n        data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n        noise_lmdb_file: noise LMDB data source.\n        noise_db_low (int, optional): SNR lower bound. Defaults to -5.\n        noise_db_high (int, optional): SNR upper bound. Defaults to 25.\n        single_channel (bool, optional): Whether to force the noise file to be single channel.  # noqa\n                                         Defaults to True.\n\n    Returns:\n        Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}]  # noqa\n    \"\"\"\n    noise_source = LmdbData(noise_lmdb_file)\n    for sample in data:\n        if noise_prob > random.random():\n            assert \"sample_rate\" in sample.keys()\n            tgt_fs = sample[\"sample_rate\"]\n            speech = sample[\"wav_mix\"].numpy()  # [1, nsamples]\n            nsamples = speech.shape[1]\n            power = (speech**2).mean()\n            noise_key, noise_data = noise_source.random_one()\n            if noise_key.startswith(\n                    \"speech\"):  # using interference speech as additive noise\n                snr_range = [10, 30]\n            else:\n                snr_range = [noise_db_low, noise_db_high]\n            noise_db = np.random.uniform(snr_range[0], snr_range[1])\n            with sf.SoundFile(io.BytesIO(noise_data)) as f:\n                fs = f.samplerate\n                if tgt_fs and fs != tgt_fs:\n                    nsamples_ = int(nsamples / tgt_fs * fs) + 1\n                else:\n                    nsamples_ = nsamples\n                if f.frames == nsamples_:\n                    noise = f.read(dtype=np.float64, always_2d=True)\n                elif f.frames < nsamples_:\n                    offset = np.random.randint(0, nsamples_ - f.frames)\n                    # noise: (Time, Nmic)\n                    noise = f.read(dtype=np.float64, always_2d=True)\n                    # Repeat noise\n                    noise = np.pad(\n                        noise,\n                        [(offset, nsamples_ - f.frames - offset), (0, 0)],\n                        mode=\"wrap\",\n                    )\n                else:\n                    offset = np.random.randint(0, f.frames - nsamples_)\n                    f.seek(offset)\n                    # noise: (Time, Nmic)\n                    noise = f.read(nsamples_, dtype=np.float64, always_2d=True)\n                    if len(noise) != nsamples_:\n                        raise RuntimeError(\n                            f\"Something wrong: {noise_lmdb_file}\")\n\n            if single_channel:\n                num_ch = noise.shape[1]\n                chs = [np.random.randint(num_ch)]\n                noise = noise[:, chs]\n            # noise: (Nmic, Time)\n            noise = noise.T\n            if tgt_fs and fs != tgt_fs:\n                logging.warning(\n                    f\"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)\"  # noqa\n                )\n                noise = librosa.resample(noise,\n                                         orig_sr=fs,\n                                         target_sr=tgt_fs,\n                                         res_type=\"kaiser_fast\")\n                if noise.shape[1] < nsamples:\n                    noise = np.pad(\n                        noise,\n                        [(0, 0), (0, nsamples - noise.shape[1])],\n                        mode=\"wrap\",\n                    )\n                else:\n                    noise = noise[:, :nsamples]\n            noise_power = (noise**2).mean()\n            scale = (10**(-noise_db / 20) * np.sqrt(power) /\n                     np.sqrt(max(noise_power, 1e-10)))\n            scaled_noise = scale * noise\n            speech = speech + scaled_noise\n            sample[\"wav_mix\"] = torch.from_numpy(speech)\n            sample[\"noise\"] = torch.from_numpy(scaled_noise)\n            sample[\"snr\"] = noise_db\n        yield sample\n\n\ndef add_reverb(data, reverb_prob=0):\n    \"\"\"\n    Args:\n        data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n\n    Returns:\n        Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n\n    Note: This function is implemented with reference to\n    Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR)\n    https://arxiv.org/pdf/2304.08052\n        This function is only used when online_mixing.\n    \"\"\"\n    for sample in data:\n        assert \"num_speaker\" in sample.keys()\n        assert \"sample_rate\" in sample.keys()\n        simu_config[\"num_src\"] = sample[\"num_speaker\"]\n        simu_config[\"sr\"] = sample[\"sample_rate\"]\n        rirs, _ = RIR_sim(simu_config)  # [n_mic, nsource, nsamples]\n        rirs = rirs[0]  # [nsource, nsamples]\n\n        for i in range(sample[\"num_speaker\"]):\n            if reverb_prob > random.random():\n                # [1, audio_len], currently only support single-channel audio\n                audio = sample[f\"wav_spk{i + 1}\"].numpy()\n                rir = rirs[i:i + 1, :]  # [1, nsamples]\n                rir_audio = signal.convolve(\n                    audio, rir,\n                    mode=\"full\")[:, :audio.shape[1]]  # [1, audio_len]\n\n                max_scale = np.max(np.abs(rir_audio))\n                out_audio = rir_audio / max_scale * 0.9\n                # Note: Here, we do not replace the dry audio with the reverberant audio,  # noqa\n                # which means we hope the model to perform dereverberation and\n                # TSE simultaneously.\n                sample[f\"wav_spk{i + 1}\"] = torch.from_numpy(out_audio)\n        yield sample\n\n\ndef add_noise_on_enroll(\n    data,\n    noise_lmdb_file,\n    noise_enroll_prob: float = 0.0,\n    noise_db_low: int = 0,\n    noise_db_high: int = 25,\n    single_channel: bool = True,\n):\n    \"\"\"Add noise to mixture\n\n    Args:\n        data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n        noise_lmdb_file: noise LMDB data source.\n        noise_db_low (int, optional): SNR lower bound. Defaults to 0.\n        noise_db_high (int, optional): SNR upper bound. Defaults to 25.\n        single_channel (bool, optional): Whether to force the noise file to be single channel.  # noqa\n                                         Defaults to True.\n\n    Returns:\n        Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}]  # noqa\n    \"\"\"\n\n    noise_source = LmdbData(noise_lmdb_file)\n    for sample in data:\n        assert \"sample_rate\" in sample.keys()\n        tgt_fs = sample[\"sample_rate\"]\n        all_keys = list(sample.keys())\n        for key in all_keys:\n            if key.startswith(\"spk\") and \"label\" not in key:\n                if noise_enroll_prob > random.random():\n                    speech = sample[\"embed_\" + key]\n                    nsamples = speech.shape[1]\n                    power = (speech**2).mean()\n                    noise_key, noise_data = noise_source.random_one()\n                    if noise_key.startswith(\n                            \"speech\"\n                    ):  # using interference speech as additive noise\n                        snr_range = [10, 30]\n                    else:\n                        snr_range = [noise_db_low, noise_db_high]\n                    noise_db = np.random.uniform(snr_range[0], snr_range[1])\n                    _, noise_data = noise_source.random_one()\n                    with sf.SoundFile(io.BytesIO(noise_data)) as f:\n                        fs = f.samplerate\n                        if tgt_fs and fs != tgt_fs:\n                            nsamples_ = int(nsamples / tgt_fs * fs) + 1\n                        else:\n                            nsamples_ = nsamples\n                        if f.frames == nsamples_:\n                            noise = f.read(dtype=np.float64, always_2d=True)\n                        elif f.frames < nsamples_:\n                            offset = np.random.randint(0, nsamples_ - f.frames)\n                            # noise: (Time, Nmic)\n                            noise = f.read(dtype=np.float64, always_2d=True)\n                            # Repeat noise\n                            noise = np.pad(\n                                noise,\n                                [\n                                    (offset, nsamples_ - f.frames - offset),\n                                    (0, 0),\n                                ],\n                                mode=\"wrap\",\n                            )\n                        else:\n                            offset = np.random.randint(0, f.frames - nsamples_)\n                            f.seek(offset)\n                            # noise: (Time, Nmic)\n                            noise = f.read(nsamples_,\n                                           dtype=np.float64,\n                                           always_2d=True)\n                            if len(noise) != nsamples_:\n                                raise RuntimeError(\n                                    f\"Something wrong: {noise_lmdb_file}\")\n\n                    if single_channel:\n                        num_ch = noise.shape[1]\n                        chs = [np.random.randint(num_ch)]\n                        noise = noise[:, chs]\n                    # noise: (Nmic, Time)\n                    noise = noise.T\n                    if tgt_fs and fs != tgt_fs:\n                        logging.warning(\n                            f\"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)\"  # noqa\n                        )\n                        noise = librosa.resample(\n                            noise,\n                            orig_sr=fs,\n                            target_sr=tgt_fs,\n                            res_type=\"kaiser_fast\",\n                        )\n                        if noise.shape[1] < nsamples:\n                            noise = np.pad(\n                                noise,\n                                [(0, 0), (0, nsamples - noise.shape[1])],\n                                mode=\"wrap\",\n                            )\n                        else:\n                            noise = noise[:, :nsamples]\n                    noise_power = (noise**2).mean()\n                    scale = (10**(-noise_db / 20) * np.sqrt(power) /\n                             np.sqrt(max(noise_power, 1e-10)))\n                    scaled_noise = scale * noise\n                    speech = speech + scaled_noise\n                    sample[\"embed_\" + key] = speech\n        yield sample\n\n\ndef add_reverb_on_enroll(data, reverb_enroll_prob=0):\n    \"\"\"\n    Args:\n        data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n\n    Returns:\n        Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]\n\n    \"\"\"\n    for sample in data:\n        assert \"num_speaker\" in sample.keys()\n        assert \"sample_rate\" in sample.keys()\n        for i in range(sample[\"num_speaker\"]):\n            simu_config[\"sr\"] = sample[\"sample_rate\"]\n            simu_config[\"num_src\"] = 1\n            rirs, _ = RIR_sim(simu_config)  # [n_mic, nsource, nsamples]\n            rirs = rirs[0]  # [nsource, nsamples]\n            if reverb_enroll_prob > random.random():\n                # [1, audio_len], currently only support single-channel audio\n                audio = sample[f\"embed_spk{i+1}\"]\n                # rir = rirs[i : i + 1, :]  # [1, nsamples]\n                rir = rirs\n                rir_audio = signal.convolve(\n                    audio, rir,\n                    mode=\"full\")[:, :audio.shape[1]]  # [1, audio_len]\n\n                max_scale = np.max(np.abs(rir_audio))\n                out_audio = rir_audio / max_scale * 0.9\n                # Note: Here, we do not replace the dry audio with the reverberant audio,  # noqa\n                # which means we hope the model to perform dereverberation and\n                # TSE simultaneously.\n                sample[f\"embed_spk{i+1}\"] = out_audio\n\n        yield sample\n\n\ndef spec_aug(data, num_t_mask=1, num_f_mask=1, max_t=10, max_f=8, prob=0):\n    \"\"\"Do spec augmentation\n    Inplace operation\n\n    Args:\n        data: Iterable[{key, feat, label}]\n        num_t_mask: number of time mask to apply\n        num_f_mask: number of freq mask to apply\n        max_t: max width of time mask\n        max_f: max width of freq mask\n        prob: prob of spec_aug\n\n    Returns\n        Iterable[{key, feat, label}]\n    \"\"\"\n    for sample in data:\n        if random.random() < prob:\n            all_keys = list(sample.keys())\n            for key in all_keys:\n                if key.startswith(\"embed\"):\n                    y = sample[key]\n                    max_frames = y.shape[1]\n                    max_freq = y.shape[2]\n                    # time mask\n                    for i in range(num_t_mask):\n                        start = random.randint(0, max_frames - 1)\n                        length = random.randint(1, max_t)\n                        end = min(max_frames, start + length)\n                        y[:, start:end, :] = 0\n                    # freq mask\n                    for i in range(num_f_mask):\n                        start = random.randint(0, max_freq - 1)\n                        length = random.randint(1, max_f)\n                        end = min(max_freq, start + length)\n                        y[:, :, start:end] = 0\n                    sample[key] = y\n        yield sample\n"
  },
  {
    "path": "wesep/dataset/vad.py",
    "content": "import numpy as np\nimport soundfile as sf\n\n\nclass VoiceActivityDetection:\n\n    def __init__(self, wave):\n        self.wave = wave\n\n    def segmentation(self, overlap, slice_len):\n        frequency = 16000\n        signal = self.wave\n        self.seg_len = len(signal) / frequency\n        self.slice_len = slice_len\n        overlap = 2\n\n        slices = np.arange(0, self.seg_len, slice_len - overlap, dtype=np.intc)\n        # print(slices)\n        audio_slices = []\n        for start, end in zip(slices[:-1], slices[1:]):\n            start_audio = start * frequency\n            end_audio = (end + overlap) * frequency\n            audio_slice = signal[int(start_audio):int(end_audio)]\n            # print(len(audio_slice))\n            audio_slices.append(audio_slice)\n\n            # wavfile.write('slices{}.wav'.format(start), 16000, audio_slice)\n        # print(len(audio_slices))\n        return audio_slices\n\n    def calc_energy(self, audio):\n        # for a in enumerate(audio):\n        #     if (a == 0.0):\n        #         a = 0.00001\n\n        # print(np.sum(np.sum(audio**2)))\n\n        energy = audio / np.sum(np.sum(audio**2) + 1e-8) * 1e2\n        # print(len(audio))\n        return energy\n\n    def select(self):\n        audio_slices = self.segmentation(overlap=1, slice_len=4)\n        energies = []\n        for audio in audio_slices:\n            chunk_len = len(audio) / 10\n            chunk_slice = np.arange(0,\n                                    len(audio) + chunk_len,\n                                    chunk_len,\n                                    dtype=np.intc)\n\n            for start, end in zip(chunk_slice[:-1], chunk_slice[1:]):\n\n                energy = self.calc_energy(audio[start:end])\n                # print(energy)\n                for i, _ in enumerate(energy):\n                    if (energy[i]) == 0:\n                        energy[i] = 0.00001\n                        # print(energy[i])\n                energies.append(sum(energy))\n\n        # print(energies)\n\n        threshold = np.quantile(energies, 0.25)\n        print(threshold)\n\n        if threshold < 0.0001:\n            threshold = 0.0001\n\n        fin_audios = []\n        i = 0\n        for audio in audio_slices:\n            chunk_len = len(audio) / 10\n            chunk_slice = np.arange(0,\n                                    len(audio) + chunk_len,\n                                    chunk_len,\n                                    dtype=np.intc)\n            count = 0\n            for start, end in zip(chunk_slice[:-1], chunk_slice[1:]):\n                energy = self.calc_energy(audio[start:end])\n                # if 50% enenrgy > threshold\n                # print(energy)\n                print(sum(i >= threshold for i in energy))\n                if sum(i >= threshold for i in energy) >= chunk_len // 2:\n                    count += 1\n                # save seg\n            # print(count)\n            if count >= 5:\n                sf.write(\"output{}.wav\".format(i), audio, 16000)\n                if len(audio) < self.slice_len * 16000:\n                    # print(self.slice_len*16000-len(audio))\n                    audio = np.concatenate(\n                        [audio,\n                         np.zeros(self.slice_len * 16000 - len(audio))])\n                fin_audios.append(audio)\n\n            i += 1\n\n        if len(fin_audios) == 0:\n            fin_audios.append(np.zeros(self.slice_len * 16000))\n        return fin_audios\n"
  },
  {
    "path": "wesep/models/__init__.py",
    "content": "import wesep.models.bsrnn as bsrnn\nimport wesep.models.convtasnet as convtasnet\nimport wesep.models.dpccn as dpccn\nimport wesep.models.tfgridnet as tfgridnet\nimport wesep.modules.metric_gan.discriminator as discriminator\nimport wesep.models.bsrnn_multi_optim as bsrnn_multi\nimport wesep.models.bsrnn_feats as bsrnn_feats\n\n\ndef get_model(model_name: str):\n    if model_name.startswith(\"ConvTasNet\"):\n        return getattr(convtasnet, model_name)\n    elif model_name.startswith(\"BSRNN_Multi\"):\n        return getattr(bsrnn_multi, model_name)\n    elif model_name.startswith(\"BSRNN_Feats\"):\n        return getattr(bsrnn_feats, model_name)\n    elif model_name.startswith(\"BSRNN\"):\n        return getattr(bsrnn, model_name)\n    elif model_name.startswith(\"DPCCN\"):\n        return getattr(dpccn, model_name)\n    elif model_name.startswith(\"TFGridNet\"):\n        return getattr(tfgridnet, model_name)\n    elif model_name.startswith(\"CMGAN\"):\n        return getattr(discriminator, model_name)\n    else:  # model_name error !!!\n        print(model_name + \" not found !!!\")\n        exit(1)\n"
  },
  {
    "path": "wesep/models/bsrnn.py",
    "content": "from __future__ import print_function\r\n\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torchaudio\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n\r\nfrom wesep.modules.common.speaker import PreEmphasis\r\nfrom wesep.modules.common.speaker import SpeakerFuseLayer\r\nfrom wesep.modules.common.speaker import SpeakerTransform\r\n\r\n\r\nclass ResRNN(nn.Module):\r\n\r\n    def __init__(self, input_size, hidden_size, bidirectional=True):\r\n        super(ResRNN, self).__init__()\r\n\r\n        self.input_size = input_size\r\n        self.hidden_size = hidden_size\r\n        self.eps = torch.finfo(torch.float32).eps\r\n\r\n        self.norm = nn.GroupNorm(1, input_size, self.eps)\r\n        self.rnn = nn.LSTM(\r\n            input_size,\r\n            hidden_size,\r\n            1,\r\n            batch_first=True,\r\n            bidirectional=bidirectional,\r\n        )\r\n\r\n        # linear projection layer\r\n        self.proj = nn.Linear(hidden_size * 2,\r\n                              input_size)  # hidden_size = feature_dim * 2\r\n\r\n    def forward(self, input):\r\n        # input shape: batch, dim, seq\r\n\r\n        rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())\r\n        rnn_output = self.proj(rnn_output.contiguous().view(\r\n            -1, rnn_output.shape[2])).view(input.shape[0], input.shape[2],\r\n                                           input.shape[1])\r\n\r\n        return input + rnn_output.transpose(1, 2).contiguous()\r\n\r\n\r\n\"\"\"\r\nTODO : attach the speaker embedding to each input\r\nInput shape:(B,feature_dim + spk_emb_dim , T)\r\n\"\"\"\r\n\r\n\r\nclass BSNet(nn.Module):\r\n\r\n    def __init__(self, in_channel, nband=7, bidirectional=True):\r\n        super(BSNet, self).__init__()\r\n\r\n        self.nband = nband\r\n        self.feature_dim = in_channel // nband\r\n        self.band_rnn = ResRNN(self.feature_dim,\r\n                               self.feature_dim * 2,\r\n                               bidirectional=bidirectional)\r\n        self.band_comm = ResRNN(self.feature_dim,\r\n                                self.feature_dim * 2,\r\n                                bidirectional=bidirectional)\r\n\r\n    def forward(self, input, dummy: Optional[torch.Tensor] = None):\r\n        # input shape: B, nband*N, T\r\n        B, N, T = input.shape\r\n\r\n        band_output = self.band_rnn(\r\n            input.view(B * self.nband, self.feature_dim,\r\n                       -1)).view(B, self.nband, -1, T)\r\n\r\n        # band comm\r\n        band_output = (band_output.permute(0, 3, 2, 1).contiguous().view(\r\n            B * T, -1, self.nband))\r\n        output = (self.band_comm(band_output).view(\r\n            B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous())\r\n\r\n        return output.view(B, N, T)\r\n\r\n\r\nclass FuseSeparation(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        nband=7,\r\n        num_repeat=6,\r\n        feature_dim=128,\r\n        spk_emb_dim=256,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=True,\r\n    ):\r\n        \"\"\"\r\n\r\n        :param nband : len(self.band_width)\r\n        \"\"\"\r\n        super(FuseSeparation, self).__init__()\r\n        self.multi_fuse = multi_fuse\r\n        self.nband = nband\r\n        self.feature_dim = feature_dim\r\n        self.separation = nn.ModuleList([])\r\n        if self.multi_fuse:\r\n            for _ in range(num_repeat):\r\n                self.separation.append(\r\n                    SpeakerFuseLayer(\r\n                        embed_dim=spk_emb_dim,\r\n                        feat_dim=feature_dim,\r\n                        fuse_type=spk_fuse_type,\r\n                    ))\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n        else:\r\n            self.separation.append(\r\n                SpeakerFuseLayer(\r\n                    embed_dim=spk_emb_dim,\r\n                    feat_dim=feature_dim,\r\n                    fuse_type=spk_fuse_type,\r\n                ))\r\n            for _ in range(num_repeat):\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n\r\n    def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):\r\n        \"\"\"\r\n        x: [B, nband, feature_dim, T]\r\n        out: [B, nband, feature_dim, T]\r\n        \"\"\"\r\n        batch_size = x.shape[0]\r\n\r\n        if self.multi_fuse:\r\n            for i, sep_func in enumerate(self.separation):\r\n                x = sep_func(x, spk_embedding)\r\n                if i % 2 == 0:\r\n                    x = x.view(batch_size * nch, self.nband * self.feature_dim,\r\n                               -1)\r\n                else:\r\n                    x = x.view(batch_size * nch, self.nband, self.feature_dim,\r\n                               -1)\r\n        else:\r\n            x = self.separation[0](x, spk_embedding)\r\n            x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)\r\n            for idx, sep in enumerate(self.separation):\r\n                if idx > 0:\r\n                    x = sep(x, spk_embedding)\r\n            x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)\r\n        return x\r\n\r\n\r\nclass BSRNN(nn.Module):\r\n    # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,\r\n    # use_bidirectional=True\r\n    def __init__(\r\n        self,\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        use_spk_transform=True,\r\n        use_bidirectional=True,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=True,\r\n        joint_training=True,\r\n        multi_task=False,\r\n        spksInTrain=251,\r\n        spk_model=None,\r\n        spk_model_init=None,\r\n        spk_model_freeze=False,\r\n        spk_args=None,\r\n        spk_feat=False,\r\n        feat_type=\"consistent\",\r\n    ):\r\n        super(BSRNN, self).__init__()\r\n\r\n        self.sr = sr\r\n        self.win = win\r\n        self.stride = stride\r\n        self.group = self.win // 2\r\n        self.enc_dim = self.win // 2 + 1\r\n        self.feature_dim = feature_dim\r\n        self.eps = torch.finfo(torch.float32).eps\r\n        self.spk_emb_dim = spk_emb_dim\r\n        self.joint_training = joint_training\r\n        self.spk_feat = spk_feat\r\n        self.feat_type = feat_type\r\n        self.spk_model_freeze = spk_model_freeze\r\n        self.multi_task = multi_task\r\n\r\n        # 0-1k (100 hop), 1k-4k (250 hop),\r\n        # 4k-8k (500 hop), 8k-16k (1k hop),\r\n        # 16k-20k (2k hop), 20k-inf\r\n\r\n        # 0-8k (1k hop), 8k-16k (2k hop), 16k\r\n        bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))\r\n\r\n        # add up to 8k\r\n        self.band_width = [bandwidth_100] * 15\r\n        self.band_width += [bandwidth_200] * 10\r\n        self.band_width += [bandwidth_500] * 5\r\n        self.band_width += [bandwidth_2k] * 1\r\n\r\n        self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))\r\n        self.nband = len(self.band_width)\r\n\r\n        if use_spk_transform:\r\n            self.spk_transform = SpeakerTransform()\r\n        else:\r\n            self.spk_transform = nn.Identity()\r\n\r\n        if joint_training:\r\n            self.spk_model = get_speaker_model(spk_model)(**spk_args)\r\n            if spk_model_init:\r\n                pretrained_model = torch.load(spk_model_init)\r\n                state = self.spk_model.state_dict()\r\n                for key in state.keys():\r\n                    if key in pretrained_model.keys():\r\n                        state[key] = pretrained_model[key]\r\n                        # print(key)\r\n                    else:\r\n                        print(\"not %s loaded\" % key)\r\n                self.spk_model.load_state_dict(state)\r\n            if spk_model_freeze:\r\n                for param in self.spk_model.parameters():\r\n                    param.requires_grad = False\r\n            if not spk_feat:\r\n                if feat_type == \"consistent\":\r\n                    self.preEmphasis = PreEmphasis()\r\n                    self.spk_encoder = torchaudio.transforms.MelSpectrogram(\r\n                        sample_rate=sr,\r\n                        n_fft=win,\r\n                        win_length=win,\r\n                        hop_length=stride,\r\n                        f_min=20,\r\n                        window_fn=torch.hamming_window,\r\n                        n_mels=spk_args[\"feat_dim\"],\r\n                    )\r\n            else:\r\n                self.preEmphasis = nn.Identity()\r\n                self.spk_encoder = nn.Identity()\r\n\r\n            if multi_task:\r\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\r\n            else:\r\n                self.pred_linear = nn.Identity()\r\n\r\n        self.BN = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.BN.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.band_width[i] * 2, self.eps),\r\n                    nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),\r\n                ))\r\n\r\n        self.separator = FuseSeparation(\r\n            nband=self.nband,\r\n            num_repeat=num_repeat,\r\n            feature_dim=feature_dim,\r\n            spk_emb_dim=spk_emb_dim,\r\n            spk_fuse_type=spk_fuse_type,\r\n            multi_fuse=multi_fuse,\r\n        )\r\n\r\n        # self.proj =  nn.Linear(hidden_size*2, input_size)\r\n\r\n        self.mask = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.mask.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.feature_dim,\r\n                                 torch.finfo(torch.float32).eps),\r\n                    nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),\r\n                ))\r\n\r\n    def pad_input(self, input, window, stride):\r\n        \"\"\"\r\n        Zero-padding input according to window/stride size.\r\n        \"\"\"\r\n        batch_size, nsample = input.shape\r\n\r\n        # pad the signals at the end for matching the window/stride size\r\n        rest = window - (stride + nsample % window) % window\r\n        if rest > 0:\r\n            pad = torch.zeros(batch_size, rest).type(input.type())\r\n            input = torch.cat([input, pad], 1)\r\n        pad_aux = torch.zeros(batch_size, stride).type(input.type())\r\n        input = torch.cat([pad_aux, input, pad_aux], 1)\r\n\r\n        return input, rest\r\n\r\n    def forward(self, input, embeddings):\r\n        # input shape: (B, C, T)\r\n\r\n        wav_input = input\r\n        spk_emb_input = embeddings\r\n        batch_size, nsample = wav_input.shape\r\n        nch = 1\r\n\r\n        # frequency-domain separation\r\n        spec = torch.stft(\r\n            wav_input,\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            return_complex=True,\r\n        )\r\n\r\n        # concat real and imag, split to subbands\r\n        spec_RI = torch.stack([spec.real, spec.imag], 1)  # B*nch, 2, F, T\r\n        subband_spec = []\r\n        subband_mix_spec = []\r\n        band_idx = 0\r\n        for i in range(len(self.band_width)):\r\n            subband_spec.append(spec_RI[:, :, band_idx:band_idx +\r\n                                        self.band_width[i]].contiguous())\r\n            subband_mix_spec.append(spec[:, band_idx:band_idx +\r\n                                         self.band_width[i]])  # B*nch, BW, T\r\n            band_idx += self.band_width[i]\r\n\r\n        # normalization and bottleneck\r\n        subband_feature = []\r\n        for i, bn_func in enumerate(self.BN):\r\n            subband_feature.append(\r\n                bn_func(subband_spec[i].view(batch_size * nch,\r\n                                             self.band_width[i] * 2, -1)))\r\n        subband_feature = torch.stack(subband_feature, 1)  # B, nband, N, T\r\n        # print(subband_feature.size(), spk_emb_input.size())\r\n\r\n        predict_speaker_lable = torch.tensor(0.0).to(\r\n            spk_emb_input.device)  # dummy\r\n        if self.joint_training:\r\n            if not self.spk_feat:\r\n                if self.feat_type == \"consistent\":\r\n                    with torch.no_grad():\r\n                        spk_emb_input = self.preEmphasis(spk_emb_input)\r\n                        spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8\r\n                        spk_emb_input = spk_emb_input.log()\r\n                        spk_emb_input = spk_emb_input - torch.mean(\r\n                            spk_emb_input, dim=-1, keepdim=True)\r\n                        spk_emb_input = spk_emb_input.permute(0, 2, 1)\r\n\r\n            tmp_spk_emb_input = self.spk_model(spk_emb_input)\r\n            if isinstance(tmp_spk_emb_input, tuple):\r\n                spk_emb_input = tmp_spk_emb_input[-1]\r\n            else:\r\n                spk_emb_input = tmp_spk_emb_input\r\n            predict_speaker_lable = self.pred_linear(spk_emb_input)\r\n\r\n        spk_embedding = self.spk_transform(spk_emb_input)\r\n        spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)\r\n\r\n        sep_output = self.separator(subband_feature, spk_embedding,\r\n                                    torch.tensor(nch))\r\n\r\n        sep_subband_spec = []\r\n        for i, mask_func in enumerate(self.mask):\r\n            this_output = mask_func(sep_output[:, i]).view(\r\n                batch_size * nch, 2, 2, self.band_width[i], -1)\r\n            this_mask = this_output[:, 0] * torch.sigmoid(\r\n                this_output[:, 1])  # B*nch, 2, K, BW, T\r\n            this_mask_real = this_mask[:, 0]  # B*nch, K, BW, T\r\n            this_mask_imag = this_mask[:, 1]  # B*nch, K, BW, T\r\n            est_spec_real = (subband_mix_spec[i].real * this_mask_real -\r\n                             subband_mix_spec[i].imag * this_mask_imag\r\n                             )  # B*nch, BW, T\r\n            est_spec_imag = (subband_mix_spec[i].real * this_mask_imag +\r\n                             subband_mix_spec[i].imag * this_mask_real\r\n                             )  # B*nch, BW, T\r\n            sep_subband_spec.append(torch.complex(est_spec_real,\r\n                                                  est_spec_imag))\r\n        est_spec = torch.cat(sep_subband_spec, 1)  # B*nch, F, T\r\n        output = torch.istft(\r\n            est_spec.view(batch_size * nch, self.enc_dim, -1),\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            length=nsample,\r\n        )\r\n\r\n        output = output.view(batch_size, nch, -1)\r\n        s = torch.squeeze(output, dim=1)\r\n\r\n        return s, predict_speaker_lable\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    from thop import profile, clever_format\r\n\r\n    model = BSRNN(\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        spk_fuse_type=\"additive\",\r\n    )\r\n\r\n    s = 0\r\n    for param in model.parameters():\r\n        s += np.product(param.size())\r\n    print(\"# of parameters: \" + str(s / 1024.0 / 1024.0))\r\n    x = torch.randn(4, 32000)\r\n    spk_embeddings = torch.randn(4, 256)\r\n    output = model(x, spk_embeddings)\r\n    print(output.shape)\r\n\r\n    macs, params = profile(model, inputs=(x, spk_embeddings))\r\n    macs, params = clever_format([macs, params], \"%.3f\")\r\n    print(macs, params)\r\n"
  },
  {
    "path": "wesep/models/bsrnn_feats.py",
    "content": "from __future__ import print_function\r\n\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torchaudio\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n\r\nfrom wesep.modules.common.speaker import PreEmphasis\r\nfrom wesep.modules.common.speaker import SpeakerFuseLayer\r\nfrom wesep.modules.common.speaker import SpeakerTransform\r\nfrom wesep.utils.funcs import compute_fbank, apply_cmvn\r\n\r\n\r\nclass ResRNN(nn.Module):\r\n\r\n    def __init__(self, input_size, hidden_size, bidirectional=True):\r\n        super(ResRNN, self).__init__()\r\n\r\n        self.input_size = input_size\r\n        self.hidden_size = hidden_size\r\n        self.eps = torch.finfo(torch.float32).eps\r\n\r\n        self.norm = nn.GroupNorm(1, input_size, self.eps)\r\n        self.rnn = nn.LSTM(\r\n            input_size,\r\n            hidden_size,\r\n            1,\r\n            batch_first=True,\r\n            bidirectional=bidirectional,\r\n        )\r\n\r\n        # linear projection layer\r\n        self.proj = nn.Linear(hidden_size * 2,\r\n                              input_size)  # hidden_size = feature_dim * 2\r\n\r\n    def forward(self, input):\r\n        # input shape: batch, dim, seq\r\n\r\n        rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())\r\n        rnn_output = self.proj(rnn_output.contiguous().view(\r\n            -1, rnn_output.shape[2])).view(input.shape[0], input.shape[2],\r\n                                           input.shape[1])\r\n\r\n        return input + rnn_output.transpose(1, 2).contiguous()\r\n\r\n\r\n\"\"\"\r\nTODO : attach the speaker embedding to each input\r\nInput shape:(B,feature_dim + spk_emb_dim , T)\r\n\"\"\"\r\n\r\n\r\nclass BSNet(nn.Module):\r\n\r\n    def __init__(self, in_channel, nband=7, bidirectional=True):\r\n        super(BSNet, self).__init__()\r\n\r\n        self.nband = nband\r\n        self.feature_dim = in_channel // nband\r\n        self.band_rnn = ResRNN(self.feature_dim,\r\n                               self.feature_dim * 2,\r\n                               bidirectional=bidirectional)\r\n        self.band_comm = ResRNN(self.feature_dim,\r\n                                self.feature_dim * 2,\r\n                                bidirectional=bidirectional)\r\n\r\n    def forward(self, input, dummy: Optional[torch.Tensor] = None):\r\n        # input shape: B, nband*N, T\r\n        B, N, T = input.shape\r\n\r\n        band_output = self.band_rnn(\r\n            input.view(B * self.nband, self.feature_dim,\r\n                       -1)).view(B, self.nband, -1, T)\r\n\r\n        # band comm\r\n        band_output = (band_output.permute(0, 3, 2, 1).contiguous().view(\r\n            B * T, -1, self.nband))\r\n        output = (self.band_comm(band_output).view(\r\n            B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous())\r\n\r\n        return output.view(B, N, T)\r\n\r\nclass CrossAtt(nn.Module):\r\n    def __init__(self, embed_dim, num_heads, *args, **kwargs):\r\n        super(CrossAtt, self).__init__()\r\n        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads,\r\n                                                    *args, **kwargs)\r\n\r\n    def forward(self, query, key, value):\r\n        if query.dim() == 4:\r\n            spk_embeddings = []\r\n            for i in range(query.shape[1]):\r\n                x = query[:, i, :, :].squeeze(dim=1)  # (batch, feature, time)\r\n                x, _ = self.multihead_attn(x.transpose(1, 2),\r\n                                           key.transpose(1, 2),\r\n                                           value.transpose(1, 2))\r\n                spk_embeddings.append(x.transpose(1, 2))\r\n            spk_embeddings = torch.stack(spk_embeddings, dim=1)\r\n        elif query.dim() == 3:\r\n            x, _ = self.multihead_attn(query.transpose(1, 2),\r\n                                       key.transpose(1, 2),\r\n                                       value.transpose(1, 2))\r\n            spk_embeddings = x.transpose(1, 2)\r\n        return spk_embeddings\r\n\r\nclass FuseSeparation(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        nband=7,\r\n        num_repeat=6,\r\n        feature_dim=128,\r\n        spk_emb_dim=256,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=True,\r\n    ):\r\n        \"\"\"\r\n\r\n        :param nband : len(self.band_width)\r\n        \"\"\"\r\n        super(FuseSeparation, self).__init__()\r\n        self.spk_fuse_type = spk_fuse_type\r\n        self.multi_fuse = multi_fuse\r\n        self.nband = nband\r\n        self.feature_dim = feature_dim\r\n\r\n        self.attenFuse = nn.ModuleList([])\r\n        if spk_fuse_type and spk_fuse_type.startswith(\"cross_\"):\r\n            spk_emb_frame_dim = 512     # Ecapa_TDNN\r\n            spk_emb_dim = feature_dim\r\n            self.attenFuse.append(nn.Linear(spk_emb_frame_dim, feature_dim))\r\n            self.attenFuse.append(CrossAtt(embed_dim=feature_dim, num_heads=2,\r\n                                           batch_first=True))\r\n\r\n        self.separation = nn.ModuleList([])\r\n        if self.multi_fuse and self.spk_fuse_type:\r\n            for _ in range(num_repeat):\r\n                self.separation.append(\r\n                    SpeakerFuseLayer(\r\n                        embed_dim=spk_emb_dim,\r\n                        feat_dim=feature_dim,\r\n                        fuse_type=spk_fuse_type.removeprefix(\"cross_\"),\r\n                    ))\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n        else:\r\n            if self.spk_fuse_type:\r\n                self.separation.append(\r\n                    SpeakerFuseLayer(\r\n                        embed_dim=spk_emb_dim,\r\n                        feat_dim=feature_dim,\r\n                        fuse_type=spk_fuse_type.removeprefix(\"cross_\"),\r\n                    ))\r\n            for _ in range(num_repeat):\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n\r\n    def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):\r\n        \"\"\"\r\n        x: [B, nband, feature_dim, T]\r\n        out: [B, nband, feature_dim, T]\r\n        \"\"\"\r\n        batch_size = x.shape[0]\r\n\r\n        if self.spk_fuse_type and self.spk_fuse_type.startswith('cross_'):\r\n            spk_embedding = spk_embedding.transpose(1, 2)\r\n            spk_embedding = self.attenFuse[0](spk_embedding)\r\n            spk_embedding = spk_embedding.transpose(1, 2)\r\n            spk_embedding = self.attenFuse[1](x, spk_embedding, spk_embedding)\r\n\r\n        if self.multi_fuse and self.spk_fuse_type:\r\n            for i, sep_func in enumerate(self.separation):\r\n                x = sep_func(x, spk_embedding)\r\n                if i % 2 == 0:\r\n                    x = x.view(batch_size * nch, self.nband * self.feature_dim,\r\n                               -1)\r\n                else:\r\n                    x = x.view(batch_size * nch, self.nband, self.feature_dim,\r\n                               -1)\r\n                    if self.spk_fuse_type.startswith('cross_'):\r\n                        spk_embedding = spk_embedding.transpose(1, 2)\r\n                        spk_embedding = self.attenFuse[0](spk_embedding)\r\n                        spk_embedding = spk_embedding.transpose(1, 2)\r\n                        spk_embedding = self.attenFuse[1](x, spk_embedding,\r\n                                                          spk_embedding)\r\n        else:\r\n            idx_start = -1\r\n            if self.spk_fuse_type:\r\n                x = self.separation[0](x, spk_embedding)\r\n                idx_start += 1\r\n            x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)\r\n            for idx, sep in enumerate(self.separation):\r\n                if idx > idx_start:\r\n                    x = sep(x, spk_embedding)\r\n            x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)\r\n        return x\r\n\r\n\r\nclass BSRNN_Feats(nn.Module):\r\n    # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,\r\n    # use_bidirectional=True\r\n    def __init__(\r\n        self,\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        use_spk_transform=False,\r\n        use_bidirectional=True,\r\n        spectral_feat=False,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=False,\r\n        joint_training=True,\r\n        multi_task=False,\r\n        spksInTrain=251,\r\n        spk_model=None,\r\n        spk_model_init=None,\r\n        spk_model_freeze=False,\r\n        spk_args=None,\r\n        spk_feat=False,\r\n        feat_type=\"consistent\",\r\n    ):\r\n        super(BSRNN_Feats, self).__init__()\r\n\r\n        self.sr = sr\r\n        self.win = win\r\n        self.stride = stride\r\n        self.group = self.win // 2\r\n        self.enc_dim = self.win // 2 + 1\r\n        self.feature_dim = feature_dim\r\n        self.eps = torch.finfo(torch.float32).eps\r\n        self.spk_emb_dim = spk_emb_dim\r\n        self.spk_fuse_type = spk_fuse_type\r\n        self.joint_training = joint_training\r\n        self.spk_feat = spk_feat\r\n        self.feat_type = feat_type\r\n        self.spk_model_freeze = spk_model_freeze\r\n        self.multi_task = multi_task\r\n\r\n        # 0-1k (100 hop), 1k-4k (250 hop),\r\n        # 4k-8k (500 hop), 8k-16k (1k hop),\r\n        # 16k-20k (2k hop), 20k-inf\r\n\r\n        # 0-8k (1k hop), 8k-16k (2k hop), 16k\r\n        bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))\r\n\r\n        # add up to 8k\r\n        self.band_width = [bandwidth_100] * 15\r\n        self.band_width += [bandwidth_200] * 10\r\n        self.band_width += [bandwidth_500] * 5\r\n        self.band_width += [bandwidth_2k] * 1\r\n\r\n        self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))\r\n        self.nband = len(self.band_width)\r\n\r\n        if use_spk_transform:\r\n            self.spk_transform = SpeakerTransform()\r\n        else:\r\n            self.spk_transform = nn.Identity()\r\n\r\n        if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb'):\r\n            self.spk_model = get_speaker_model(spk_model)(**spk_args)\r\n            if spk_model_init:\r\n                pretrained_model = torch.load(spk_model_init)\r\n                state = self.spk_model.state_dict()\r\n                for key in state.keys():\r\n                    if key in pretrained_model.keys():\r\n                        state[key] = pretrained_model[key]\r\n                        # print(key)\r\n                    else:\r\n                        print(\"not %s loaded\" % key)\r\n                self.spk_model.load_state_dict(state)\r\n            if spk_model_freeze:\r\n                for param in self.spk_model.parameters():\r\n                    param.requires_grad = False\r\n            if not spk_feat:\r\n                if feat_type == \"consistent\":\r\n                    self.preEmphasis = PreEmphasis()\r\n                    self.spk_encoder = torchaudio.transforms.MelSpectrogram(\r\n                        sample_rate=sr,\r\n                        n_fft=win,\r\n                        win_length=win,\r\n                        hop_length=stride,\r\n                        f_min=20,\r\n                        window_fn=torch.hamming_window,\r\n                        n_mels=spk_args[\"feat_dim\"],\r\n                    )\r\n            else:\r\n                self.preEmphasis = nn.Identity()\r\n                self.spk_encoder = nn.Identity()\r\n\r\n            if multi_task:\r\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\r\n            else:\r\n                self.pred_linear = nn.Identity()\r\n\r\n        spec_map = 2\r\n        if spectral_feat:\r\n            spec_map += 1\r\n        self.spectral_feat = spectral_feat\r\n        self.spec_map = spec_map\r\n\r\n        self.BN = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.BN.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.band_width[i] * spec_map, self.eps),\r\n                    nn.Conv1d(self.band_width[i] * spec_map, self.feature_dim, 1),\r\n                ))\r\n\r\n        self.separator = FuseSeparation(\r\n            nband=self.nband,\r\n            num_repeat=num_repeat,\r\n            feature_dim=feature_dim,\r\n            spk_emb_dim=spk_emb_dim,\r\n            spk_fuse_type=spk_fuse_type,\r\n            multi_fuse=multi_fuse,\r\n        )\r\n\r\n        self.mask = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.mask.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.feature_dim,\r\n                                 torch.finfo(torch.float32).eps),\r\n                    nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),\r\n                ))\r\n\r\n    def pad_input(self, input, window, stride):\r\n        \"\"\"\r\n        Zero-padding input according to window/stride size.\r\n        \"\"\"\r\n        batch_size, nsample = input.shape\r\n\r\n        # pad the signals at the end for matching the window/stride size\r\n        rest = window - (stride + nsample % window) % window\r\n        if rest > 0:\r\n            pad = torch.zeros(batch_size, rest).type(input.type())\r\n            input = torch.cat([input, pad], 1)\r\n        pad_aux = torch.zeros(batch_size, stride).type(input.type())\r\n        input = torch.cat([pad_aux, input, pad_aux], 1)\r\n\r\n        return input, rest\r\n\r\n    def forward(self, input, embeddings):\r\n        # input shape: (B, C, T)\r\n\r\n        wav_input = input\r\n        spk_emb_input = embeddings\r\n        batch_size, nsample = wav_input.shape\r\n        nch = 1\r\n\r\n        # frequency-domain separation\r\n        spec = torch.stft(\r\n            wav_input,\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            return_complex=True,\r\n        )\r\n\r\n        spec_RI = torch.stack([spec.real, spec.imag], 1)  # B*nch, 2, F, T\r\n\r\n        # Calculate the spectral level feature\r\n        if self.spectral_feat:\r\n            aux_c = torch.stft(\r\n                spk_emb_input,\r\n                n_fft=self.win,\r\n                hop_length=self.stride,\r\n                window=torch.hann_window(self.win).to(spk_emb_input.device).type(\r\n                    spk_emb_input.type()),\r\n                return_complex=True,\r\n            )  \r\n            if self.spectral_feat == 'tfmap_spec':\r\n                mix_mag_ori = torch.abs(spec)\r\n                enroll_mag = torch.abs(aux_c)\r\n\r\n                mix_mag = F.normalize(mix_mag_ori, p=2, dim=1)\r\n                enroll_mag = F.normalize(enroll_mag, p=2, dim=1)\r\n\r\n                mix_mag = mix_mag.permute(0, 2, 1).contiguous()\r\n                att_scores = torch.matmul(mix_mag, enroll_mag)\r\n                att_weights = F.softmax(att_scores, dim=-1)\r\n                enroll_mag = enroll_mag.permute(0, 2, 1).contiguous()\r\n                tf_map = torch.matmul(att_weights, enroll_mag)\r\n                tf_map = tf_map.permute(0, 2, 1).contiguous()\r\n\r\n                tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)\r\n                # Recover the energy of estimated tfmap feature\r\n                tf_map = (\r\n                    torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True) \r\n                    * tf_map\r\n                )\r\n                # Another kind of nomalization for tf_map feature\r\n                # tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)\r\n\r\n                spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1)\r\n\r\n            if self.spectral_feat == 'tfmap_emb':  # Only Ecapa-TDNN.\r\n                with torch.no_grad():\r\n                    signal_dim = wav_input.dim()\r\n                    extended_shape = (\r\n                        [1] * (3 - signal_dim) \r\n                        + list(wav_input.size())\r\n                    )\r\n                    pad = int(self.win // 2)\r\n                    mix_emb = F.pad(\r\n                        wav_input.view(extended_shape),\r\n                        [pad, pad],\r\n                        mode=\"reflect\"\r\n                    )\r\n                    mix_emb = mix_emb.view(mix_emb.shape[-signal_dim:])\r\n\r\n                    signal_dim = spk_emb_input.dim()\r\n                    extended_shape = (\r\n                        [1] * (3 - signal_dim) \r\n                        + list(spk_emb_input.size())\r\n                    )\r\n                    pad = int(self.win // 2)\r\n                    spk_emb = F.pad(\r\n                        spk_emb_input.view(extended_shape),\r\n                        [pad, pad],\r\n                        mode=\"reflect\"\r\n                    )\r\n                    spk_emb = spk_emb.view(spk_emb.shape[-signal_dim:])\r\n\r\n                    spk_emb = compute_fbank(\r\n                        spk_emb, \r\n                        frame_length=self.win * 1e3 / self.sr,\r\n                        frame_shift=self.stride * 1e3 / self.sr,\r\n                        dither=0.0, \r\n                        sample_rate=self.sr\r\n                    )\r\n                    mix_emb = compute_fbank(\r\n                        mix_emb, \r\n                        frame_length=self.win * 1e3 / self.sr,\r\n                        frame_shift=self.stride * 1e3 / self.sr,\r\n                        dither=0.0, \r\n                        sample_rate=self.sr\r\n                    )\r\n                    mix_emb = apply_cmvn(mix_emb)\r\n                    spk_emb = apply_cmvn(spk_emb)\r\n\r\n                spk_emb = self.spk_model(spk_emb)\r\n                if isinstance(spk_emb, tuple):\r\n                    spk_emb_frame = spk_emb[0]\r\n                else:\r\n                    spk_emb_frame = spk_emb\r\n                mix_emb = self.spk_model(mix_emb)\r\n                if isinstance(mix_emb, tuple):\r\n                    mix_emb_frame = mix_emb[0]\r\n                else:\r\n                    mix_emb_frame = mix_emb\r\n\r\n                mix_emb_frame_ = F.normalize(mix_emb_frame, p=2, dim=1)\r\n                spk_emb_frame_ = F.normalize(spk_emb_frame, p=2, dim=1)\r\n\r\n                mix_emb_frame_ = mix_emb_frame_.transpose(1, 2)\r\n                att_scores = torch.matmul(mix_emb_frame_, spk_emb_frame_)\r\n                att_weights = F.softmax(att_scores, dim=-1)\r\n\r\n                mix_mag_ori = torch.abs(spec)\r\n                enroll_mag = torch.abs(aux_c)\r\n\r\n                enroll_mag = enroll_mag.transpose(1, 2)\r\n                # enroll_mag = F.normalize(enroll_mag, p=2, dim=1)\r\n                tf_map = torch.matmul(att_weights, enroll_mag)\r\n                tf_map = tf_map.transpose(1, 2)\r\n\r\n                tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)\r\n                # Recover the energy of estimated tfmap feature\r\n                tf_map = (\r\n                    torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True) \r\n                    * tf_map\r\n                )\r\n                # Another kind of nomalization for tf_map feature\r\n                # tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)\r\n\r\n                spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1)\r\n\r\n        # concat real and imag, split to subbands\r\n        subband_spec = []\r\n        subband_mix_spec = []\r\n        band_idx = 0\r\n        for i in range(len(self.band_width)):\r\n            subband_spec.append(spec_RI[:, :, band_idx:band_idx +\r\n                                        self.band_width[i]].contiguous())\r\n            subband_mix_spec.append(spec[:, band_idx:band_idx +\r\n                                         self.band_width[i]])  # B*nch, BW, T\r\n            band_idx += self.band_width[i]\r\n\r\n        # normalization and bottleneck\r\n        subband_feature = []\r\n        for i, bn_func in enumerate(self.BN):\r\n            subband_feature.append(\r\n                bn_func(subband_spec[i].view(batch_size * nch,\r\n                                             self.band_width[i] * self.spec_map,\r\n                                             -1)))\r\n        subband_feature = torch.stack(subband_feature, 1)  # B, nband, N, T\r\n        # print(subband_feature.size(), spk_emb_input.size())\r\n\r\n        predict_speaker_lable = torch.tensor(0.0).to(\r\n            spk_emb_input.device)  # dummy\r\n        if (\r\n            (self.spectral_feat and self.spectral_feat == \"tfmap_emb\")\r\n            and (self.spk_fuse_type and self.spk_fuse_type.startswith(\"cross_\"))\r\n        ):\r\n            spk_emb_input = spk_emb_frame\r\n        elif self.joint_training and self.spk_fuse_type:\r\n            if not self.spk_feat:\r\n                if self.feat_type == \"consistent\":\r\n                    with torch.no_grad():\r\n                        spk_emb_input = self.preEmphasis(spk_emb_input)\r\n                        spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8\r\n                        spk_emb_input = spk_emb_input.log()\r\n                        spk_emb_input = spk_emb_input - torch.mean(\r\n                            spk_emb_input, dim=-1, keepdim=True)\r\n                        spk_emb_input = spk_emb_input.permute(0, 2, 1)\r\n\r\n            if self.spk_fuse_type and self.spk_fuse_type.startswith(\"cross_\"):\r\n                tmp_spk_emb_input = self.spk_model._get_frame_level_feat(\r\n                    spk_emb_input)\r\n            else:\r\n                tmp_spk_emb_input = self.spk_model(spk_emb_input)\r\n            if isinstance(tmp_spk_emb_input, tuple):\r\n                spk_emb_input = tmp_spk_emb_input[-1]\r\n            else:\r\n                spk_emb_input = tmp_spk_emb_input\r\n            predict_speaker_lable = self.pred_linear(spk_emb_input)\r\n\r\n        spk_embedding = self.spk_transform(spk_emb_input)\r\n        if self.spk_fuse_type and not self.spk_fuse_type.startswith(\"cross_\"):\r\n            spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)\r\n\r\n        sep_output = self.separator(subband_feature, spk_embedding,\r\n                                    torch.tensor(nch))\r\n\r\n        sep_subband_spec = []\r\n        for i, mask_func in enumerate(self.mask):\r\n            this_output = mask_func(sep_output[:, i]).view(\r\n                batch_size * nch, 2, 2, self.band_width[i], -1)\r\n            this_mask = this_output[:, 0] * torch.sigmoid(\r\n                this_output[:, 1])  # B*nch, 2, K, BW, T\r\n            this_mask_real = this_mask[:, 0]  # B*nch, K, BW, T\r\n            this_mask_imag = this_mask[:, 1]  # B*nch, K, BW, T\r\n            est_spec_real = (subband_mix_spec[i].real * this_mask_real -\r\n                             subband_mix_spec[i].imag * this_mask_imag\r\n                             )  # B*nch, BW, T\r\n            est_spec_imag = (subband_mix_spec[i].real * this_mask_imag +\r\n                             subband_mix_spec[i].imag * this_mask_real\r\n                             )  # B*nch, BW, T\r\n            sep_subband_spec.append(torch.complex(est_spec_real,\r\n                                                  est_spec_imag))\r\n        est_spec = torch.cat(sep_subband_spec, 1)  # B*nch, F, T\r\n        output = torch.istft(\r\n            est_spec.view(batch_size * nch, self.enc_dim, -1),\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            length=nsample,\r\n        )\r\n\r\n        output = output.view(batch_size, nch, -1)\r\n        s = torch.squeeze(output, dim=1)\r\n        return s, predict_speaker_lable\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    from thop import profile, clever_format\r\n\r\n    model = BSRNN_Feats(\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        spectral_feat='tfmap_emb',\r\n        spk_fuse_type='cross_multiply',\r\n        spk_model=\"ECAPA_TDNN_GLOB_c512\",\r\n        spk_args={\r\n            \"embed_dim\": 192,\r\n            \"feat_dim\": 80,\r\n            \"pooling_func\": \"ASTP\",\r\n        }\r\n    )\r\n\r\n    s = 0\r\n    for param in model.parameters():\r\n        s += np.product(param.size())\r\n    print(\"# of parameters: \" + str(s / 1024.0 / 1024.0))\r\n    x = torch.randn(4, 32000)\r\n    spk_embeddings = torch.randn(4, 16000)\r\n    output = model(x, spk_embeddings)\r\n    print(output[0].shape)\r\n\r\n    macs, params = profile(model, inputs=(x, spk_embeddings))\r\n    macs, params = clever_format([macs, params], \"%.3f\")\r\n    print(macs, params)\r\n"
  },
  {
    "path": "wesep/models/bsrnn_multi_optim.py",
    "content": "from __future__ import print_function\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torchaudio\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n\r\nfrom wesep.modules.common.speaker import PreEmphasis\r\nfrom wesep.modules.common.speaker import SpeakerFuseLayer\r\nfrom wesep.modules.common.speaker import SpeakerTransform\r\n\r\n\r\nclass ResRNN(nn.Module):\r\n\r\n    def __init__(self, input_size, hidden_size, bidirectional=True):\r\n        super(ResRNN, self).__init__()\r\n\r\n        self.input_size = input_size\r\n        self.hidden_size = hidden_size\r\n        self.eps = torch.finfo(torch.float32).eps\r\n\r\n        self.norm = nn.GroupNorm(1, input_size, self.eps)\r\n        self.rnn = nn.LSTM(\r\n            input_size,\r\n            hidden_size,\r\n            1,\r\n            batch_first=True,\r\n            bidirectional=bidirectional,\r\n        )\r\n\r\n        # linear projection layer\r\n        self.proj = nn.Linear(\r\n            hidden_size * 2, input_size\r\n        )  # hidden_size = feature_dim * 2\r\n\r\n    def forward(self, input):\r\n        # input shape: batch, dim, seq\r\n\r\n        rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())\r\n        rnn_output = self.proj(\r\n            rnn_output.contiguous().view(-1, rnn_output.shape[2])\r\n        ).view(input.shape[0], input.shape[2], input.shape[1])\r\n\r\n        return input + rnn_output.transpose(1, 2).contiguous()\r\n\r\n\r\n\"\"\"\r\nTODO : attach the speaker embedding to each input\r\nInput shape:(B,feature_dim + spk_emb_dim , T)\r\n\"\"\"\r\n\r\n\r\nclass BSNet(nn.Module):\r\n\r\n    def __init__(self, in_channel, nband=7, bidirectional=True):\r\n        super(BSNet, self).__init__()\r\n\r\n        self.nband = nband\r\n        self.feature_dim = in_channel // nband\r\n        self.band_rnn = ResRNN(\r\n            self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional\r\n        )\r\n        self.band_comm = ResRNN(\r\n            self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional\r\n        )\r\n\r\n    def forward(self, input, dummy: Optional[torch.Tensor] = None):\r\n        # input shape: B, nband*N, T\r\n        B, N, T = input.shape\r\n\r\n        band_output = self.band_rnn(\r\n            input.view(B * self.nband, self.feature_dim, -1)\r\n        ).view(B, self.nband, -1, T)\r\n\r\n        # band comm\r\n        band_output = (\r\n            band_output.permute(0, 3, 2, 1).contiguous().view(B * T, -1, self.nband)\r\n        )\r\n        output = (\r\n            self.band_comm(band_output)\r\n            .view(B, T, -1, self.nband)\r\n            .permute(0, 3, 2, 1)\r\n            .contiguous()\r\n        )\r\n\r\n        return output.view(B, N, T)\r\n\r\n\r\nclass FuseSeparation(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        nband=7,\r\n        num_repeat=6,\r\n        feature_dim=128,\r\n        spk_emb_dim=256,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=True,\r\n    ):\r\n        \"\"\"\r\n\r\n        :param nband : len(self.band_width)\r\n        \"\"\"\r\n        super(FuseSeparation, self).__init__()\r\n        self.multi_fuse = multi_fuse\r\n        self.nband = nband\r\n        self.feature_dim = feature_dim\r\n        self.separation = nn.ModuleList([])\r\n        if self.multi_fuse:\r\n            for _ in range(num_repeat):\r\n                self.separation.append(\r\n                    SpeakerFuseLayer(\r\n                        embed_dim=spk_emb_dim,\r\n                        feat_dim=feature_dim,\r\n                        fuse_type=spk_fuse_type,\r\n                    )\r\n                )\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n        else:\r\n            self.separation.append(\r\n                SpeakerFuseLayer(\r\n                    embed_dim=spk_emb_dim,\r\n                    feat_dim=feature_dim,\r\n                    fuse_type=spk_fuse_type,\r\n                )\r\n            )\r\n            for _ in range(num_repeat):\r\n                self.separation.append(BSNet(nband * feature_dim, nband))\r\n\r\n    def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):\r\n        \"\"\"\r\n        x: [B, nband, feature_dim, T]\r\n        out: [B, nband, feature_dim, T]\r\n        \"\"\"\r\n        batch_size = x.shape[0]\r\n\r\n        if self.multi_fuse:\r\n            for i, sep_func in enumerate(self.separation):\r\n                x = sep_func(x, spk_embedding)\r\n                if i % 2 == 0:\r\n                    x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)\r\n                else:\r\n                    x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)\r\n        else:\r\n            x = self.separation[0](x, spk_embedding)\r\n            x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)\r\n            for idx, sep in enumerate(self.separation):\r\n                if idx > 0:\r\n                    x = sep(x, spk_embedding)\r\n            x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)\r\n        return x\r\n\r\n\r\nclass BSRNN_Multi(nn.Module):\r\n    # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,\r\n    # use_bidirectional=True\r\n    def __init__(\r\n        self,\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        use_spk_transform=True,\r\n        use_bidirectional=True,\r\n        spk_fuse_type=\"concat\",\r\n        multi_fuse=True,\r\n        joint_training=True,\r\n        multi_task=False,\r\n        spksInTrain=251,\r\n        spk_model=None,\r\n        spk_model_init=None,\r\n        spk_model_freeze=False,\r\n        spk_args=None,\r\n        spk_feat=False,\r\n        feat_type=\"consistent\",\r\n    ):\r\n        super(BSRNN_Multi, self).__init__()\r\n\r\n        self.sr = sr\r\n        self.win = win\r\n        self.stride = stride\r\n        self.group = self.win // 2\r\n        self.enc_dim = self.win // 2 + 1\r\n        self.feature_dim = feature_dim\r\n        self.eps = torch.finfo(torch.float32).eps\r\n        self.spk_emb_dim = spk_emb_dim\r\n        self.joint_training = joint_training\r\n        self.spk_feat = spk_feat\r\n        self.feat_type = feat_type\r\n        self.spk_model_freeze = spk_model_freeze\r\n        self.multi_task = multi_task\r\n\r\n        # 0-1k (100 hop), 1k-4k (250 hop),\r\n        # 4k-8k (500 hop), 8k-16k (1k hop),\r\n        # 16k-20k (2k hop), 20k-inf\r\n\r\n        # 0-8k (1k hop), 8k-16k (2k hop), 16k\r\n        bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))\r\n        bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))\r\n\r\n        # add up to 8k\r\n        self.band_width = [bandwidth_100] * 15\r\n        self.band_width += [bandwidth_200] * 10\r\n        self.band_width += [bandwidth_500] * 5\r\n        self.band_width += [bandwidth_2k] * 1\r\n\r\n        self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))\r\n        self.nband = len(self.band_width)\r\n\r\n        if use_spk_transform:\r\n            self.spk_transform = SpeakerTransform()\r\n        else:\r\n            self.spk_transform = nn.Identity()\r\n\r\n        if joint_training:\r\n            self.spk_model = get_speaker_model(spk_model)(**spk_args)\r\n            if spk_model_init:\r\n                pretrained_model = torch.load(spk_model_init)\r\n                state = self.spk_model.state_dict()\r\n                for key in state.keys():\r\n                    if key in pretrained_model.keys():\r\n                        state[key] = pretrained_model[key]\r\n                        # print(key)\r\n                    else:\r\n                        print(\"not %s loaded\" % key)\r\n                self.spk_model.load_state_dict(state)\r\n            if spk_model_freeze:\r\n                for param in self.spk_model.parameters():\r\n                    param.requires_grad = False\r\n            if not spk_feat:\r\n                if feat_type == \"consistent\":\r\n                    self.preEmphasis = PreEmphasis()\r\n                    self.spk_encoder = torchaudio.transforms.MelSpectrogram(\r\n                        sample_rate=sr,\r\n                        n_fft=win,\r\n                        win_length=win,\r\n                        hop_length=stride,\r\n                        f_min=20,\r\n                        window_fn=torch.hamming_window,\r\n                        n_mels=spk_args[\"feat_dim\"],\r\n                    )\r\n            else:\r\n                self.preEmphasis = nn.Identity()\r\n                self.spk_encoder = nn.Identity()\r\n\r\n            if multi_task:\r\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\r\n            else:\r\n                self.pred_linear = nn.Identity()\r\n\r\n        self.BN = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.BN.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.band_width[i] * 2, self.eps),\r\n                    nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),\r\n                )\r\n            )\r\n\r\n        self.separator = FuseSeparation(\r\n            nband=self.nband,\r\n            num_repeat=num_repeat,\r\n            feature_dim=feature_dim,\r\n            spk_emb_dim=spk_emb_dim,\r\n            spk_fuse_type=spk_fuse_type,\r\n            multi_fuse=multi_fuse,\r\n        )\r\n\r\n        # self.proj =  nn.Linear(hidden_size*2, input_size)\r\n\r\n        self.mask = nn.ModuleList([])\r\n        for i in range(self.nband):\r\n            self.mask.append(\r\n                nn.Sequential(\r\n                    nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),\r\n                    nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),\r\n                    nn.Tanh(),\r\n                    nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),\r\n                )\r\n            )\r\n\r\n    def pad_input(self, input, window, stride):\r\n        \"\"\"\r\n        Zero-padding input according to window/stride size.\r\n        \"\"\"\r\n        batch_size, nsample = input.shape\r\n\r\n        # pad the signals at the end for matching the window/stride size\r\n        rest = window - (stride + nsample % window) % window\r\n        if rest > 0:\r\n            pad = torch.zeros(batch_size, rest).type(input.type())\r\n            input = torch.cat([input, pad], 1)\r\n        pad_aux = torch.zeros(batch_size, stride).type(input.type())\r\n        input = torch.cat([pad_aux, input, pad_aux], 1)\r\n\r\n        return input, rest\r\n\r\n    def forward(self, input, embeddings):\r\n        # input shape: (B, C, T)\r\n\r\n        wav_input = input\r\n        spk_emb_input = embeddings\r\n        batch_size, nsample = wav_input.shape\r\n        nch = 1\r\n\r\n        # frequency-domain separation\r\n        spec = torch.stft(\r\n            wav_input,\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win)\r\n            .to(wav_input.device)\r\n            .type(wav_input.type()),\r\n            return_complex=True,\r\n        )\r\n\r\n        # concat real and imag, split to subbands\r\n        spec_RI = torch.stack([spec.real, spec.imag], 1)  # B*nch, 2, F, T\r\n        subband_spec = []\r\n        subband_mix_spec = []\r\n        band_idx = 0\r\n        for i in range(len(self.band_width)):\r\n            subband_spec.append(\r\n                spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous()\r\n            )\r\n            subband_mix_spec.append(\r\n                spec[:, band_idx : band_idx + self.band_width[i]]\r\n            )  # B*nch, BW, T\r\n            band_idx += self.band_width[i]\r\n\r\n        # normalization and bottleneck\r\n        subband_feature = []\r\n        for i, bn_func in enumerate(self.BN):\r\n            subband_feature.append(\r\n                bn_func(\r\n                    subband_spec[i].view(batch_size * nch, self.band_width[i] * 2, -1)\r\n                )\r\n            )\r\n        subband_feature = torch.stack(subband_feature, 1)  # B, nband, N, T\r\n        # print(subband_feature.size(), spk_emb_input.size())\r\n\r\n        predict_speaker_lable = torch.tensor(0.0).to(spk_emb_input.device)  # dummy\r\n        if self.joint_training:\r\n            if not self.spk_feat:\r\n                if self.feat_type == \"consistent\":\r\n                    with torch.no_grad():\r\n                        spk_emb_input = self.preEmphasis(spk_emb_input)\r\n                        spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8\r\n                        spk_emb_input = spk_emb_input.log()\r\n                        spk_emb_input = spk_emb_input - torch.mean(\r\n                            spk_emb_input, dim=-1, keepdim=True\r\n                        )\r\n                        spk_emb_input = spk_emb_input.permute(0, 2, 1)\r\n\r\n            tmp_spk_emb_input = self.spk_model(spk_emb_input)\r\n            if isinstance(tmp_spk_emb_input, tuple):\r\n                spk_emb_input = tmp_spk_emb_input[-1]\r\n            else:\r\n                spk_emb_input = tmp_spk_emb_input\r\n            predict_speaker_lable = self.pred_linear(spk_emb_input)\r\n\r\n        spk_embedding = self.spk_transform(spk_emb_input)\r\n        spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)\r\n\r\n        sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch))\r\n\r\n        sep_subband_spec = []\r\n        for i, mask_func in enumerate(self.mask):\r\n            this_output = mask_func(sep_output[:, i]).view(\r\n                batch_size * nch, 2, 2, self.band_width[i], -1\r\n            )\r\n            this_mask = this_output[:, 0] * torch.sigmoid(\r\n                this_output[:, 1]\r\n            )  # B*nch, 2, K, BW, T\r\n            this_mask_real = this_mask[:, 0]  # B*nch, K, BW, T\r\n            this_mask_imag = this_mask[:, 1]  # B*nch, K, BW, T\r\n            est_spec_real = (\r\n                subband_mix_spec[i].real * this_mask_real\r\n                - subband_mix_spec[i].imag * this_mask_imag\r\n            )  # B*nch, BW, T\r\n            est_spec_imag = (\r\n                subband_mix_spec[i].real * this_mask_imag\r\n                + subband_mix_spec[i].imag * this_mask_real\r\n            )  # B*nch, BW, T\r\n            sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))\r\n        est_spec = torch.cat(sep_subband_spec, 1)  # B*nch, F, T\r\n        output = torch.istft(\r\n            est_spec.view(batch_size * nch, self.enc_dim, -1),\r\n            n_fft=self.win,\r\n            hop_length=self.stride,\r\n            window=torch.hann_window(self.win)\r\n            .to(wav_input.device)\r\n            .type(wav_input.type()),\r\n            length=nsample,\r\n        )\r\n\r\n        output = output.view(batch_size, nch, -1)\r\n        s = torch.squeeze(output, dim=1)\r\n        if torch.is_grad_enabled():\r\n            self_embedding = s.detach()\r\n            self_predict_speaker_lable = torch.tensor(0.0).to(\r\n                self_embedding.device\r\n            )  # dummy\r\n            if self.joint_training:\r\n                if self.feat_type == \"consistent\":\r\n                    with torch.no_grad():\r\n                        self_embedding = self.preEmphasis(self_embedding)\r\n                        self_embedding = self.spk_encoder(self_embedding) + 1e-8\r\n                        self_embedding = self_embedding.log()\r\n                        self_embedding = self_embedding - torch.mean(\r\n                            self_embedding, dim=-1, keepdim=True\r\n                        )\r\n                        self_embedding = self_embedding.permute(0, 2, 1)\r\n\r\n                self_tmp_spk_emb_input = self.spk_model(self_embedding)\r\n                if isinstance(self_tmp_spk_emb_input, tuple):\r\n                    self_spk_emb_input = self_tmp_spk_emb_input[-1]\r\n                else:\r\n                    self_spk_emb_input = self_tmp_spk_emb_input\r\n                self_predict_speaker_lable = self.pred_linear(self_spk_emb_input)\r\n\r\n            self_spk_embedding = self.spk_transform(self_spk_emb_input)\r\n            self_spk_embedding = self_spk_embedding.unsqueeze(1).unsqueeze(3)\r\n\r\n            self_sep_output = self.separator(\r\n                subband_feature, self_spk_embedding, torch.tensor(nch)\r\n            )\r\n\r\n            self_sep_subband_spec = []\r\n            for i, mask_func in enumerate(self.mask):\r\n                this_output = mask_func(self_sep_output[:, i]).view(\r\n                    batch_size * nch, 2, 2, self.band_width[i], -1\r\n                )\r\n                this_mask = this_output[:, 0] * torch.sigmoid(\r\n                    this_output[:, 1]\r\n                )  # B*nch, 2, K, BW, T\r\n                this_mask_real = this_mask[:, 0]  # B*nch, K, BW, T\r\n                this_mask_imag = this_mask[:, 1]  # B*nch, K, BW, T\r\n                est_spec_real = (\r\n                    subband_mix_spec[i].real * this_mask_real\r\n                    - subband_mix_spec[i].imag * this_mask_imag\r\n                )  # B*nch, BW, T\r\n                est_spec_imag = (\r\n                    subband_mix_spec[i].real * this_mask_imag\r\n                    + subband_mix_spec[i].imag * this_mask_real\r\n                )  # B*nch, BW, T\r\n                self_sep_subband_spec.append(\r\n                    torch.complex(est_spec_real, est_spec_imag)\r\n                )\r\n            self_est_spec = torch.cat(self_sep_subband_spec, 1)  # B*nch, F, T\r\n            self_output = torch.istft(\r\n                self_est_spec.view(batch_size * nch, self.enc_dim, -1),\r\n                n_fft=self.win,\r\n                hop_length=self.stride,\r\n                window=torch.hann_window(self.win)\r\n                .to(wav_input.device)\r\n                .type(wav_input.type()),\r\n                length=nsample,\r\n            )\r\n\r\n            self_output = self_output.view(batch_size, nch, -1)\r\n            self_s = torch.squeeze(self_output, dim=1)\r\n\r\n            return s, self_s, predict_speaker_lable, self_predict_speaker_lable\r\n\r\n        return s, predict_speaker_lable\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    from thop import profile, clever_format\r\n\r\n    model = BSRNN_Multi(\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        win=512,\r\n        stride=128,\r\n        feature_dim=128,\r\n        num_repeat=6,\r\n        spk_fuse_type=\"additive\",\r\n    )\r\n\r\n    s = 0\r\n    for param in model.parameters():\r\n        s += np.product(param.size())\r\n    print(\"# of parameters: \" + str(s / 1024.0 / 1024.0))\r\n    x = torch.randn(4, 32000)\r\n    spk_embeddings = torch.randn(4, 256)\r\n    output = model(x, spk_embeddings)\r\n    print(output.shape)\r\n\r\n    macs, params = profile(model, inputs=(x, spk_embeddings))\r\n    macs, params = clever_format([macs, params], \"%.3f\")\r\n    print(macs, params)\r\n"
  },
  {
    "path": "wesep/models/convtasnet.py",
    "content": "import torch\r\nimport torch.nn as nn\r\n\r\nfrom wesep.modules.common import select_norm\r\nfrom wesep.modules.common.speaker import SpeakerTransform\r\nfrom wesep.modules.tasnet import DeepEncoder, DeepDecoder\r\nfrom wesep.modules.tasnet import MultiEncoder, MultiDecoder\r\nfrom wesep.modules.tasnet import FuseSeparation\r\nfrom wesep.modules.tasnet.convs import Conv1D, ConvTrans1D\r\nfrom wesep.modules.tasnet.speaker import ResNet4SpExplus\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n\r\n\r\nclass ConvTasNet(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        N=512,\r\n        L=16,\r\n        B=128,\r\n        H=512,\r\n        P=3,\r\n        X=8,\r\n        R=3,\r\n        spk_emb_dim=256,\r\n        norm=\"gLN\",\r\n        activate=\"relu\",\r\n        causal=False,\r\n        skip_con=False,\r\n        spk_fuse_type=\"concatConv\",\r\n        # \"concat\", \"additive\", \"multiply\", \"FiLM\", \"None\",\r\n        # (\"concatConv\" only for convtasnet)\r\n        multi_fuse=True,\r\n        use_spk_transform=True,\r\n        encoder_type=\"Multi\",  # 'Multi', 'Deep', None\r\n        decoder_type=\"Multi\",\r\n        joint_training=True,\r\n        multi_task=False,\r\n        spksInTrain=251,\r\n        spk_model=None,\r\n        spk_model_init=None,\r\n        spk_model_freeze=False,\r\n        spk_args=None,\r\n        spk_feat=False,\r\n        feat_type=\"consistent\",\r\n    ):\r\n        \"\"\"\r\n        :param N: Number of filters in autoencoder\r\n        :param L: Length of the filters (in samples)\r\n        :param B: Number of channels in bottleneck and the residual paths\r\n        :param H: Number of channels in convolutional blocks\r\n        :param P: Kernel size in convolutional blocks\r\n        :param X: Number of convolutional blocks in each repeat\r\n        :param R: Number of repeats\r\n        :param norm:\r\n        :param activate:\r\n        :param causal:\r\n        :param skip_con:\r\n        :param spk_fuse_type: concat/addition/FiLM\r\n        :param use_spk_transform:\r\n        :param use_deep_enc:\r\n        :param use_deep_dec:\r\n        \"\"\"\r\n        super(ConvTasNet, self).__init__()\r\n\r\n        self.encoder_type = encoder_type\r\n        self.decoder_type = decoder_type\r\n        # n x 1 x T => n x N x T\r\n        if encoder_type == \"Multi\":\r\n            self.encoder = MultiEncoder(\r\n                in_channels=1,\r\n                middle_channels=N,\r\n                out_channels=B,\r\n                kernel_size=L,\r\n                stride=L // 2,\r\n            )\r\n        elif encoder_type == \"Deep\":\r\n            self.encoder = DeepEncoder(1, N, L, stride=L // 2)\r\n            self.LayerN_S = select_norm(norm, N)\r\n            self.BottleN_S = Conv1D(N, B, 1)\r\n        else:\r\n            self.encoder = nn.Sequential(\r\n                Conv1D(1, N, L, stride=L // 2, padding=0), nn.ReLU())\r\n            self.LayerN_S = select_norm(norm, N)\r\n            self.BottleN_S = Conv1D(N, B, 1)\r\n\r\n        self.joint_training = joint_training\r\n        self.spk_feat = spk_feat\r\n        self.feat_type = feat_type\r\n        self.spk_model_freeze = spk_model_freeze\r\n        self.multi_task = multi_task\r\n\r\n        if joint_training:\r\n            if not self.spk_feat:\r\n                if self.feat_type == \"consistent\":\r\n                    self.spk_model = ResNet4SpExplus(\r\n                        in_channel=N, C_embedding=spk_emb_dim\r\n                    )  # The speaker model is fixed for SpEx+ currently\r\n            else:\r\n                self.spk_model = get_speaker_model(spk_model)(**spk_args)\r\n                if spk_model_init:\r\n                    pretrained_model = torch.load(spk_model_init)\r\n                    state = self.spk_model.state_dict()\r\n                    for key in state.keys():\r\n                        if key in pretrained_model.keys():\r\n                            state[key] = pretrained_model[key]\r\n                            # print(key)\r\n                        else:\r\n                            print(\"not %s loaded\" % key)\r\n                    self.spk_model.load_state_dict(state)\r\n                    if self.spk_model_freeze:\r\n                        for param in self.spk_model.parameters():\r\n                            param.requires_grad = False\r\n            if multi_task:\r\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\r\n\r\n        if not use_spk_transform:\r\n            self.spk_transform = nn.Identity()\r\n        else:\r\n            self.spk_transform = SpeakerTransform()\r\n\r\n        # Separation block\r\n        # n x B x T => n x B x T\r\n        self.separation = FuseSeparation(\r\n            R,\r\n            X,\r\n            B,\r\n            H,\r\n            P,\r\n            norm=norm,\r\n            causal=causal,\r\n            skip_con=skip_con,\r\n            C_embedding=spk_emb_dim,\r\n            spk_fuse_type=spk_fuse_type,\r\n            multi_fuse=multi_fuse,\r\n        )\r\n\r\n        # n x N x T => n x 1 x L\r\n        if decoder_type == \"Multi\":\r\n            self.decoder = MultiDecoder(\r\n                in_channels=B,\r\n                middle_channels=N,\r\n                out_channels=1,\r\n                kernel_size=L,\r\n                stride=L // 2,\r\n            )\r\n        elif decoder_type == \"Deep\":\r\n            self.decoder = DeepDecoder(N, L, stride=L // 2)\r\n            self.gen_masks = Conv1D(B, N, 1)\r\n        else:\r\n            self.decoder = ConvTrans1D(N, 1, L, stride=L // 2)\r\n            self.gen_masks = Conv1D(B, N, 1)\r\n        # activation function\r\n        active_f = {\r\n            \"relu\": nn.ReLU(),\r\n            \"sigmoid\": nn.Sigmoid(),\r\n            \"softmax\": nn.Softmax(dim=0),\r\n        }\r\n        # self.activation_type = activate\r\n        self.activation = active_f[activate]\r\n\r\n    def forward(self, x, embeddings):\r\n        if x.dim() >= 3:\r\n            raise RuntimeError(\r\n                \"{} accept 1/2D tensor as input, but got {:d}\".format(\r\n                    self.__name__, x.dim()))\r\n        if x.dim() == 1:\r\n            x = torch.unsqueeze(x, 0)\r\n        # x: n x 1 x L => n x N x T\r\n        if self.encoder_type == \"Multi\":\r\n            e, w1, w2, w3 = self.encoder(x)\r\n            x = e  # replace x with e, for asymmetric encoder-decoder\r\n        else:\r\n            x = self.encoder(x)\r\n            e = self.LayerN_S(x)\r\n            e = self.BottleN_S(\r\n                e)  # Embedding fuse after dimension changed fro N to B\r\n\r\n        if (self.joint_training):\r\n            # Only support sharing Encoder and ResNet in SpEx+ currently\r\n            # Speaker Encoder\r\n            if not self.spk_feat and self.feat_type == \"consistent\":\r\n                if self.encoder_type == \"Multi\":\r\n                    _, aux_w1, aux_w2, aux_w3 = self.encoder(embeddings)\r\n                    embeddings = torch.cat([aux_w1, aux_w2, aux_w3], 1)\r\n                else:\r\n                    aux_x = self.encoder(embeddings)\r\n                    aux_e = self.LayerN_S(aux_x)\r\n                    embeddings = self.BottleN_S(aux_e)\r\n            embeddings = self.spk_model(embeddings)\r\n            if isinstance(embeddings, tuple):\r\n                embeddings = embeddings[-1]\r\n            if self.multi_task:\r\n                predict_speaker_lable = self.pred_linear(embeddings)\r\n\r\n        spk_embeds = self.spk_transform(embeddings.unsqueeze(-1))\r\n        e = self.separation(e, spk_embeds)\r\n\r\n        # decoder part  n x L\r\n        if self.decoder_type == \"Multi\":\r\n            s = self.decoder(\r\n                e, w1, w2, w3,\r\n                actLayer=self.activation)  # s is a tuple by using multiDecoder\r\n        else:\r\n            # n x B x L => n x N x L\r\n            m = self.gen_masks(e)\r\n            # n x N x L\r\n            m = self.activation(m)\r\n            x = x * m\r\n            s = self.decoder(x)\r\n\r\n        if self.joint_training and self.multi_task:\r\n            if not isinstance(s, list):\r\n                s = [\r\n                    s,\r\n                ]\r\n            s.append(predict_speaker_lable)\r\n\r\n        return s  # s: N x Len Or List(N  x Len,x3/x4)\r\n\r\n\r\ndef check_parameters(net):\r\n    \"\"\"\r\n    Returns module parameters. Mb\r\n    \"\"\"\r\n    parameters = sum(param.numel() for param in net.parameters())\r\n    return parameters / 10**6\r\n\r\n\r\ndef test_convtasnet():\r\n    x = torch.randn(4, 32000)\r\n    spk_embeddings = torch.randn(4, 256)\r\n    net = ConvTasNet(use_spk_transform=False, spk_fuse_type=\"FiLM\")\r\n    s = net(x, spk_embeddings)\r\n    print(str(check_parameters(net)) + \" Mb\")\r\n    print(s[1].shape)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    test_convtasnet()\r\n"
  },
  {
    "path": "wesep/models/dpccn.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torchaudio\r\n\r\nfrom wespeaker.models.speaker_model import get_speaker_model\r\n\r\nfrom wesep.modules.common.speaker import PreEmphasis\r\nfrom wesep.modules.common.speaker import SpeakerFuseLayer\r\nfrom wesep.modules.common.speaker import SpeakerTransform\r\nfrom wesep.modules.dpccn.convs import Conv2dBlock\r\nfrom wesep.modules.dpccn.convs import ConvTrans2dBlock\r\nfrom wesep.modules.dpccn.convs import DenseBlock\r\nfrom wesep.modules.dpccn.convs import TCNBlock\r\n\r\n\r\nclass DPCCN(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        win=512,\r\n        stride=128,\r\n        spk_emb_dim=256,\r\n        sr=16000,\r\n        use_spk_transform=False,\r\n        spk_fuse_type=\"multiply\",\r\n        feature_dim=257,\r\n        kernel_size=(3, 3),\r\n        stride1=(1, 1),\r\n        stride2=(1, 2),\r\n        paddings=(1, 1),\r\n        output_padding=(0, 0),\r\n        tcn_dims=384,\r\n        tcn_blocks=10,\r\n        tcn_layers=2,\r\n        causal=False,\r\n        pool_size=(4, 8, 16, 32),\r\n        multi_fuse=False,\r\n        joint_training=True,\r\n        multi_task=False,\r\n        spksInTrain=251,\r\n        spk_model=None,\r\n        spk_model_init=None,\r\n        spk_model_freeze=False,\r\n        spk_args=None,\r\n        spk_feat=False,\r\n        feat_type=\"consistent\",\r\n    ) -> None:\r\n        super(DPCCN, self).__init__()\r\n\r\n        self.win_len = win\r\n        self.hop_size = stride\r\n        self.spk_emb_dim = spk_emb_dim\r\n        self.joint_training = joint_training\r\n        self.spk_feat = spk_feat\r\n        self.feat_type = feat_type\r\n        self.spk_model_freeze = spk_model_freeze\r\n        self.multi_task = multi_task\r\n\r\n        self.conv2d = nn.Conv2d(2, 16, kernel_size, stride1, paddings)\r\n\r\n        self.encoder = self._build_encoder(kernel_size=kernel_size,\r\n                                           stride=stride2,\r\n                                           padding=paddings)\r\n\r\n        if use_spk_transform:\r\n            self.spk_transform = SpeakerTransform()\r\n        else:\r\n            self.spk_transform = nn.Identity()\r\n\r\n        if joint_training:\r\n            self.spk_model = get_speaker_model(spk_model)(**spk_args)\r\n            if spk_model_init:\r\n                pretrained_model = torch.load(spk_model_init)\r\n                state = self.spk_model.state_dict()\r\n                for key in state.keys():\r\n                    if key in pretrained_model.keys():\r\n                        state[key] = pretrained_model[key]\r\n                        # print(key)\r\n                    else:\r\n                        print(\"not %s loaded\" % key)\r\n                self.spk_model.load_state_dict(state)\r\n            if spk_model_freeze:\r\n                for param in self.spk_model.parameters():\r\n                    param.requires_grad = False\r\n            if not spk_feat:\r\n                if feat_type == \"consistent\":\r\n                    self.preEmphasis = PreEmphasis()\r\n                    self.spk_encoder = torchaudio.transforms.MelSpectrogram(\r\n                        sample_rate=sr,\r\n                        n_fft=win,\r\n                        win_length=win,\r\n                        hop_length=stride,\r\n                        f_min=20,\r\n                        window_fn=torch.hamming_window,\r\n                        n_mels=spk_args[\"feat_dim\"],\r\n                    )\r\n            else:\r\n                self.preEmphasis = nn.Identity()\r\n                self.spk_encoder = nn.Identity()\r\n\r\n            if multi_task:\r\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\r\n            else:\r\n                self.pred_linear = nn.Identity()\r\n\r\n        self.spk_fuse = SpeakerFuseLayer(\r\n            embed_dim=self.spk_emb_dim,\r\n            feat_dim=feature_dim,\r\n            fuse_type=spk_fuse_type,\r\n        )\r\n\r\n        self.tcn_layers = self._build_tcn_layers(\r\n            tcn_layers,\r\n            tcn_blocks,\r\n            in_dims=tcn_dims,\r\n            out_dims=tcn_dims,\r\n            causal=causal,\r\n        )\r\n        self.decoder = self._build_decoder(\r\n            kernel_size=kernel_size,\r\n            stride=stride2,\r\n            padding=paddings,\r\n            output_padding=output_padding,\r\n        )\r\n        self.avg_pool = self._build_avg_pool(pool_size)\r\n        self.avg_proj = nn.Conv2d(64, 32, 1, 1)\r\n\r\n        self.deconv2d = nn.ConvTranspose2d(32, 2, kernel_size, stride1,\r\n                                           paddings)\r\n\r\n    def _build_encoder(self, **enc_kargs):\r\n        \"\"\"\r\n        Build encoder layers\r\n        \"\"\"\r\n        encoder = nn.ModuleList()\r\n        encoder.append(DenseBlock(16, 16, \"enc\"))\r\n        for i in range(4):\r\n            encoder.append(\r\n                nn.Sequential(\r\n                    Conv2dBlock(in_dims=16 if i == 0 else 32,\r\n                                out_dims=32,\r\n                                **enc_kargs),\r\n                    DenseBlock(32, 32, \"enc\"),\r\n                ))\r\n        encoder.append(Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs))\r\n        encoder.append(Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs))\r\n        encoder.append(Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs))\r\n\r\n        return encoder\r\n\r\n    def _build_decoder(self, **dec_kargs):\r\n        \"\"\"\r\n        Build decoder layers\r\n        \"\"\"\r\n        decoder = nn.ModuleList()\r\n        decoder.append(\r\n            ConvTrans2dBlock(in_dims=384 * 2, out_dims=128, **dec_kargs))\r\n        decoder.append(\r\n            ConvTrans2dBlock(in_dims=128 * 2, out_dims=64, **dec_kargs))\r\n        decoder.append(\r\n            ConvTrans2dBlock(in_dims=64 * 2, out_dims=32, **dec_kargs))\r\n        for i in range(4):\r\n            decoder.append(\r\n                nn.Sequential(\r\n                    DenseBlock(32, 64, \"dec\"),\r\n                    ConvTrans2dBlock(in_dims=64,\r\n                                     out_dims=32 if i != 3 else 16,\r\n                                     **dec_kargs),\r\n                ))\r\n        decoder.append(DenseBlock(16, 32, \"dec\"))\r\n\r\n        return decoder\r\n\r\n    def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs):\r\n        \"\"\"\r\n        Build TCN blocks in each repeat (layer)\r\n        \"\"\"\r\n        blocks = [\r\n            TCNBlock(**tcn_kargs, dilation=(2**b)) for b in range(tcn_blocks)\r\n        ]\r\n\r\n        return nn.Sequential(*blocks)\r\n\r\n    def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs):\r\n        \"\"\"\r\n        Build TCN layers\r\n        \"\"\"\r\n        layers = [\r\n            self._build_tcn_blocks(tcn_blocks, **tcn_kargs)\r\n            for _ in range(tcn_layers)\r\n        ]\r\n\r\n        return nn.Sequential(*layers)\r\n\r\n    def _build_avg_pool(self, pool_size):\r\n        \"\"\"\r\n        Build avg pooling layers\r\n        \"\"\"\r\n        avg_pool = nn.ModuleList()\r\n        for sz in pool_size:\r\n            avg_pool.append(\r\n                nn.Sequential(nn.AvgPool2d(sz), nn.Conv2d(32, 8, 1, 1)))\r\n\r\n        return avg_pool\r\n\r\n    def forward(self, input, aux):\r\n        wav_input = input\r\n        spk_emb_input = aux\r\n        batch_size, nsample = wav_input.shape\r\n\r\n        # frequency-domain separation\r\n        spec = torch.stft(\r\n            wav_input,\r\n            n_fft=self.win_len,\r\n            hop_length=self.hop_size,\r\n            window=torch.hann_window(self.win_len).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            return_complex=True,\r\n        )\r\n        # concat real and imag, split to subbands\r\n        spec_RI = torch.stack([spec.real, spec.imag], 1)\r\n\r\n        # spec = torch.einsum(\"hijk->hikj\", spec_RI)  # batchsize, 2, T, F\r\n        spec = torch.transpose(spec_RI, 2, 3)  # batchsize, 2, T, F\r\n        out = self.conv2d(spec)\r\n        out_list = []\r\n        out = self.encoder[0](out)\r\n\r\n        predict_speaker_lable = torch.tensor(0.0).to(\r\n            spk_emb_input.device)  # dummy\r\n        if self.joint_training:\r\n            if not self.spk_feat:\r\n                if self.feat_type == \"consistent\":\r\n                    with torch.no_grad():\r\n                        spk_emb_input = self.preEmphasis(spk_emb_input)\r\n                        spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8\r\n                        spk_emb_input = spk_emb_input.log()\r\n                        spk_emb_input = spk_emb_input - torch.mean(\r\n                            spk_emb_input, dim=-1, keepdim=True)\r\n                        spk_emb_input = spk_emb_input.permute(0, 2, 1)\r\n\r\n            tmp_spk_emb_input = self.spk_model(spk_emb_input)\r\n            if isinstance(tmp_spk_emb_input, tuple):\r\n                spk_emb_input = tmp_spk_emb_input[-1]\r\n            else:\r\n                spk_emb_input = tmp_spk_emb_input\r\n            predict_speaker_lable = self.pred_linear(spk_emb_input)\r\n\r\n        spk_embedding = self.spk_transform(spk_emb_input)\r\n        spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)\r\n\r\n        out = self.spk_fuse(out.transpose(2, 3), spk_embedding).transpose(2, 3)\r\n        out_list.append(out)\r\n        for _, enc in enumerate(self.encoder[1:]):\r\n            out = enc(out)\r\n            out_list.append(out)\r\n\r\n        B, N, T, F = out.shape\r\n        out = out.reshape(B, N, T * F)\r\n        out = self.tcn_layers(out)\r\n        out = out.reshape(B, N, T, F)\r\n        out_list = out_list[::-1]\r\n        for idx, dec in enumerate(self.decoder):\r\n            out = dec(torch.cat([out_list[idx], out], 1))\r\n            # Pyramidal pooling\r\n        B, N, T, F = out.shape\r\n        upsample = nn.Upsample(size=(T, F), mode=\"bilinear\")\r\n        pool_list = []\r\n        for avg in self.avg_pool:\r\n            pool_list.append(upsample(avg(out)))\r\n        out = torch.cat([out, *pool_list], 1)\r\n        out = self.avg_proj(out)\r\n\r\n        out = self.deconv2d(out)\r\n\r\n        est_spec = torch.transpose(out, 2, 3)  # (batchsize, 2, F, T)\r\n        B, N, F, T = est_spec.shape\r\n        est_spec = torch.chunk(est_spec, 2, 1)  # [(B, 1, F, T), (B, 1, F, T)])\r\n        est_spec = torch.complex(est_spec[0], est_spec[1])\r\n\r\n        output = torch.istft(\r\n            est_spec.reshape(B, -1, T),\r\n            n_fft=self.win_len,\r\n            hop_length=self.hop_size,\r\n            window=torch.hann_window(self.win_len).to(wav_input.device).type(\r\n                wav_input.type()),\r\n            length=nsample,\r\n        )\r\n\r\n        return output, predict_speaker_lable\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    import numpy as np\r\n\r\n    model = DPCCN()\r\n    s = 0\r\n    for param in model.parameters():\r\n        s += np.product(param.size())\r\n    print(\"# of parameters: \" + str(s / 1024.0 / 1024.0))\r\n    mix = torch.randn(4, 32000)\r\n    aux = torch.randn(4, 256)\r\n    est = model(mix, aux)\r\n    print(est.size())\r\n"
  },
  {
    "path": "wesep/models/sep_model.py",
    "content": "import wesep.models.bsrnn as bsrnn\nimport wesep.models.convtasnet as convtasnet\nimport wesep.models.dpccn as dpccn\nimport wesep.models.tfgridnet as tfgridnet\n\n\ndef get_model(model_name: str):\n    if model_name.startswith(\"ConvTasNet\"):\n        return getattr(convtasnet, model_name)\n    elif model_name.startswith(\"BSRNN\"):\n        return getattr(bsrnn, model_name)\n    elif model_name.startswith(\"DPCNN\"):\n        return getattr(dpccn, model_name)\n    elif model_name.startswith(\"TFGridNet\"):\n        return getattr(tfgridnet, model_name)\n    else:  # model_name error !!!\n        print(model_name + \" not found !!!\")\n        exit(1)\n\n\nif __name__ == \"__main__\":\n    print(get_model(\"ConvTasNet\"))\n"
  },
  {
    "path": "wesep/models/tfgridnet.py",
    "content": "# The implementation is based on:\n# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn as nn\nimport torchaudio\nfrom packaging.version import parse as V\n\nfrom wespeaker.models.speaker_model import get_speaker_model\n\nfrom wesep.modules.common.speaker import PreEmphasis\nfrom wesep.modules.common.speaker import SpeakerFuseLayer, SpeakerTransform\nfrom wesep.modules.tfgridnet.gridnet_block import GridNetBlock\n\nis_torch_1_9_plus = V(torch.__version__) >= V(\"1.9.0\")\n\n\nclass TFGridNet(nn.Module):\n    \"\"\"Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up\n        the code by vectorizing multiple heads in self-attention,\n        and better dealing with Deconv1D in each intra- and inter-block\n        when emb_ks == emb_hs.\n\n    Reference:\n    [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,\n    \"TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation\",\n    in TASLP, 2023.\n    [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,\n    \"TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural\n    Speaker Separation\", in ICASSP, 2023.\n\n    NOTES:\n    As outlined in the Reference, this model works best when trained with\n    variance normalized mixture input and target, e.g., with mixture of\n    shape [batch, samples, microphones], you normalize it by dividing\n    with torch.std(mixture, (1, 2)). You must do the same for the target\n    signals. It is encouraged to do so when not using\n    scale-invariant loss functions such as SI-SDR.\n    Specifically, use:\n        std_ = std(mix)\n        mix = mix / std_\n        tgt = tgt / std_\n\n    Args:\n        n_srcs: number of output sources/speakers.\n        n_fft: stft window size.\n        stride: stft stride.\n        window: stft window type choose between 'hamming', 'hanning' or None.\n        n_imics: num of channels (only fixed-array geometry supported).\n        n_layers: number of TFGridNetV2 blocks.\n        lstm_hidden_units: number of hidden units in LSTM.\n        attn_n_head: number of heads in self-attention\n        attn_approx_qk_dim: approximate dim of frame-level key/value tensors\n        emb_dim: embedding dimension\n        emb_ks: kernel size for unfolding and deconv1D\n        emb_hs: hop size for unfolding and deconv1D\n        activation: activation function to use in the whole TFGridNetV2 model,\n            you can use any torch supported activation e.g. 'relu' or 'elu'.\n        eps: small epsilon for normalization layers.\n        spk_emb_dim: the dimension of target speaker embeddings.\n        use_spk_transform: whether use networks to transfer the speaker embeds.\n        spk_fuse_type: the fusion method of speaker embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_srcs=1,\n        sr=16000,\n        n_fft=128,\n        stride=64,\n        window=\"hann\",\n        n_imics=1,\n        n_layers=6,\n        lstm_hidden_units=192,\n        attn_n_head=4,\n        attn_approx_qk_dim=512,\n        emb_dim=48,\n        emb_ks=4,\n        emb_hs=1,\n        activation=\"prelu\",\n        eps=1.0e-5,\n        spk_emb_dim=256,\n        use_spk_transform=False,\n        spk_fuse_type=\"multiply\",\n        joint_training=True,\n        multi_task=False,\n        spksInTrain=251,\n        spk_model=None,\n        spk_model_init=None,\n        spk_model_freeze=False,\n        spk_args=None,\n        spk_feat=False,\n        feat_type=\"consistent\",\n    ):\n        super().__init__()\n        self.n_srcs = n_srcs\n        self.n_fft = n_fft\n        self.stride = stride\n        self.window = window\n        self.n_imics = n_imics\n        self.n_layers = n_layers\n        self.spk_emb_dim = spk_emb_dim\n        self.joint_training = joint_training\n        self.spk_feat = spk_feat\n        self.feat_type = feat_type\n        self.spk_model_freeze = spk_model_freeze\n        self.multi_task = multi_task\n\n        assert n_fft % 2 == 0\n        n_freqs = n_fft // 2 + 1\n\n        if use_spk_transform:\n            self.spk_transform = SpeakerTransform()\n        else:\n            self.spk_transform = nn.Identity()\n\n        if joint_training:\n            self.spk_model = get_speaker_model(spk_model)(**spk_args)\n            if spk_model_init:\n                pretrained_model = torch.load(spk_model_init)\n                state = self.spk_model.state_dict()\n                for key in state.keys():\n                    if key in pretrained_model.keys():\n                        state[key] = pretrained_model[key]\n                        # print(key)\n                    else:\n                        print(\"not %s loaded\" % key)\n                self.spk_model.load_state_dict(state)\n            if spk_model_freeze:\n                for param in self.spk_model.parameters():\n                    param.requires_grad = False\n            if not spk_feat:\n                if feat_type == \"consistent\":\n                    self.preEmphasis = PreEmphasis()\n                    self.spk_encoder = torchaudio.transforms.MelSpectrogram(\n                        sample_rate=sr,\n                        n_fft=n_fft,\n                        win_length=n_fft,\n                        hop_length=stride,\n                        f_min=20,\n                        window_fn=torch.hamming_window,\n                        n_mels=spk_args[\"feat_dim\"],\n                    )\n            else:\n                self.preEmphasis = nn.Identity()\n                self.spk_encoder = nn.Identity()\n\n            if multi_task:\n                self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)\n            else:\n                self.pred_linear = nn.Identity()\n\n        self.spk_fuse = SpeakerFuseLayer(\n            embed_dim=spk_emb_dim,\n            feat_dim=n_freqs,\n            fuse_type=spk_fuse_type,\n        )\n\n        t_ksize = 3\n        ks, padding = (t_ksize, 3), (t_ksize // 2, 1)\n        self.conv = nn.Sequential(\n            nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding),\n            nn.GroupNorm(1, emb_dim, eps=eps),\n        )\n\n        self.blocks = nn.ModuleList([])\n        for _ in range(n_layers):\n            self.blocks.append(\n                GridNetBlock(\n                    emb_dim,\n                    emb_ks,\n                    emb_hs,\n                    n_freqs,\n                    lstm_hidden_units,\n                    n_head=attn_n_head,\n                    approx_qk_dim=attn_approx_qk_dim,\n                    activation=activation,\n                    eps=eps,\n                ))\n\n        self.deconv = nn.ConvTranspose2d(emb_dim,\n                                         n_srcs * 2,\n                                         ks,\n                                         padding=padding)\n\n    def forward(\n        self,\n        input: torch.Tensor,\n        embeddings: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Forward.\n\n        Args:\n            input (torch.Tensor): batched multi-channel audio tensor with\n                    M audio channels and N samples [B, N, M]\n            embeddings (torch.Tensor): batched target speaker embeddings [B, D]\n\n        Returns:\n            enhanced (List[Union(torch.Tensor)]):\n                    [(B, T), ...] list of len n_srcs\n                    of mono audio tensors with T samples.\n        \"\"\"\n        batch_size, n_samples = input.shape[0], input.shape[1]\n        spk_emb_input = embeddings\n        if self.n_imics == 1:\n            assert len(input.shape) == 2\n            input = input[..., None]  # [B, N, M]\n\n        mix_std_ = torch.std(input, dim=(1, 2), keepdim=True)  # [B, 1, 1]\n        input = input / mix_std_  # RMS normalization\n\n        input = input.transpose(1, 2).reshape(\n            -1, input.size(1))  # [B, N, M] -> [B*M, N]\n        window_func = getattr(torch, f\"{self.window}_window\")\n        window = window_func(self.n_fft,\n                             dtype=input.dtype,\n                             device=input.device)\n\n        batch = torch.stft(\n            input,\n            n_fft=self.n_fft,\n            win_length=self.n_fft,\n            hop_length=self.stride,\n            window=window,\n            return_complex=True,\n            onesided=True,\n        )  # [B, F, T]\n        batch = batch.transpose(1, 2)  # [B, T, F]\n\n        batch0 = batch.view(batch_size, -1, batch.size(1),\n                            batch.size(2))  # [B, M, T, F]\n        # ilens = torch.full((batch_size,), n_samples, dtype=torch.long)\n        batch = torch.cat((batch0.real, batch0.imag), dim=1)  # [B, 2*M, T, F]\n        n_batch, _, n_frames, n_freqs = batch.shape\n\n        batch = self.conv(batch)  # [B, -1, T, F]\n\n        predict_speaker_label = torch.tensor(0.0).to(\n            spk_emb_input.device)  # dummy\n        if self.joint_training:\n            if not self.spk_feat:\n                if self.feat_type == \"consistent\":\n                    with torch.no_grad():\n                        spk_emb_input = self.preEmphasis(spk_emb_input)\n                        spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8\n                        spk_emb_input = spk_emb_input.log()\n                        spk_emb_input = spk_emb_input - torch.mean(\n                            spk_emb_input, dim=-1, keepdim=True)\n                        spk_emb_input = spk_emb_input.permute(0, 2, 1)\n\n            tmp_spk_emb_input = self.spk_model(spk_emb_input)\n            if isinstance(tmp_spk_emb_input, tuple):\n                spk_emb_input = tmp_spk_emb_input[-1]\n            else:\n                spk_emb_input = tmp_spk_emb_input\n            predict_speaker_label = self.pred_linear(spk_emb_input)\n\n        spk_embedding = self.spk_transform(spk_emb_input)  # [B, D]\n        spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)  # [B, 1, D, 1]\n\n        for ii in range(self.n_layers):\n            batch = torch.transpose(\n                self.spk_fuse(batch.transpose(2, 3), spk_embedding), 2,\n                3)  # [B, -1, T, F]\n            batch = self.blocks[ii](batch)  # [B, -1, T, F]\n\n        batch = self.deconv(batch)  # [B, n_srcs*2, T, F]\n\n        batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs])\n        assert is_torch_1_9_plus, \"Require torch 1.9.0+.\"\n        batch = torch.complex(batch[:, :, 0], batch[:, :, 1])\n\n        batch = torch.istft(\n            torch.transpose(batch.view(-1, n_frames, n_freqs), 1, 2),\n            n_fft=self.n_fft,\n            hop_length=self.stride,\n            win_length=self.n_fft,\n            window=window,\n            onesided=True,\n            length=n_samples,\n            return_complex=False,\n        )  # [B, n_srcs]\n\n        batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples)\n\n        batch = batch * mix_std_  # reverse the RMS normalization\n\n        # batch = [batch[:, src] for src in range(self.num_spk)]\n        batch = batch.squeeze(1)\n\n        return batch, predict_speaker_label\n\n    @property\n    def num_spk(self):\n        return self.n_srcs\n\n    @staticmethod\n    def pad2(input_tensor, target_len):\n        input_tensor = torch.nn.functional.pad(\n            input_tensor, (0, target_len - input_tensor.shape[-1]))\n        return input_tensor\n"
  },
  {
    "path": "wesep/modules/__init__.py",
    "content": ""
  },
  {
    "path": "wesep/modules/common/__init__.py",
    "content": "from wesep.modules.common.norm import ChannelWiseLayerNorm  # noqa\nfrom wesep.modules.common.norm import FiLM  # noqa\nfrom wesep.modules.common.norm import GlobalChannelLayerNorm  # noqa\nfrom wesep.modules.common.norm import select_norm  # noqa\n"
  },
  {
    "path": "wesep/modules/common/norm.py",
    "content": "import numbers\n\nimport torch\nimport torch.nn as nn\n\n\nclass GlobalChannelLayerNorm(nn.Module):\n    \"\"\"\n    Calculate Global Layer Normalization\n    dim: (int or list or torch.Size) –\n         input shape from an expected input of size\n    eps: a value added to the denominator for numerical stability.\n    elementwise_affine: a boolean value that when set to True,\n        this module has learnable per-element affine parameters\n        initialized to ones (for weights) and zeros (for biases).\n    \"\"\"\n\n    def __init__(self, dim, eps=1e-05, elementwise_affine=True):\n        super(GlobalChannelLayerNorm, self).__init__()\n        self.dim = dim\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n\n        if self.elementwise_affine:\n            self.weight = nn.Parameter(torch.ones(self.dim, 1))\n            self.bias = nn.Parameter(torch.zeros(self.dim, 1))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n\n    def forward(self, x):\n        # x = N x C x L\n        # N x 1 x 1\n        # cln: mean,var N x 1 x L\n        # gln: mean,var N x 1 x 1\n        if x.dim() != 3:\n            raise RuntimeError(\"{} accept 3D tensor as input\".format(\n                self.__name__))\n\n        mean = torch.mean(x, (1, 2), keepdim=True)\n        var = torch.mean((x - mean)**2, (1, 2), keepdim=True)\n        # N x C x L\n        if self.elementwise_affine:\n            x = (self.weight * (x - mean) / torch.sqrt(var + self.eps) +\n                 self.bias)\n        else:\n            x = (x - mean) / torch.sqrt(var + self.eps)\n        return x\n\n\nclass ChannelWiseLayerNorm(nn.LayerNorm):\n    \"\"\"\n    Channel wise layer normalization\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)\n\n    def forward(self, x):\n        \"\"\"\n        x: N x C x T\n        \"\"\"\n        x = torch.transpose(x, 1, 2)\n        x = super().forward(x)\n        x = torch.transpose(x, 1, 2)\n        return x\n\n\ndef select_norm(norm, dim):\n    \"\"\"\n    Build normalize layer\n    LN cost more memory than BN\n    \"\"\"\n    if norm not in [\"cLN\", \"gLN\", \"BN\"]:\n        raise RuntimeError(\"Unsupported normalize layer: {}\".format(norm))\n    if norm == \"cLN\":\n        return ChannelWiseLayerNorm(dim, elementwise_affine=True)\n    elif norm == \"BN\":\n        return nn.BatchNorm1d(dim)\n    else:\n        return GlobalChannelLayerNorm(dim, elementwise_affine=True)\n\n\nclass FiLM(nn.Module):\n    \"\"\"Feature-wise Linear Modulation (FiLM) layer\n    https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160  # noqa\n    \"\"\"\n\n    def __init__(self,\n                 feat_size,\n                 embed_size,\n                 num_film_layers=1,\n                 layer_norm=False):\n        super(FiLM, self).__init__()\n        self.feat_size = feat_size\n        self.embed_size = embed_size\n        self.num_film_layers = num_film_layers\n        self.layer_norm = nn.LayerNorm(embed_size) if layer_norm else None\n        gamma_fcs, beta_fcs = [], []\n        for i in range(num_film_layers):\n            if i == 0:\n                gamma_fcs.append(nn.Linear(embed_size, feat_size))\n                beta_fcs.append(nn.Linear(embed_size, feat_size))\n            else:\n                gamma_fcs.append(nn.Linear(feat_size, feat_size))\n                beta_fcs.append(nn.Linear(feat_size, feat_size))\n        self.gamma_fcs = nn.ModuleList(gamma_fcs)\n        self.beta_fcs = nn.ModuleList(beta_fcs)\n        self.init_weights()\n\n    def init_weights(self):\n        for i in range(self.num_film_layers):\n            nn.init.zeros_(self.gamma_fcs[i].weight)\n            nn.init.zeros_(self.gamma_fcs[i].bias)\n            nn.init.zeros_(self.beta_fcs[i].weight)\n            nn.init.zeros_(self.beta_fcs[i].bias)\n\n    def forward(self, embed, x):\n        gamma, beta = None, None\n        for i in range(len(self.gamma_fcs)):\n            if i == 0:\n                gamma = self.gamma_fcs[i](embed)\n                beta = self.beta_fcs[i](embed)\n            else:\n                gamma = self.gamma_fcs[i](gamma)\n                beta = self.beta_fcs[i](beta)\n\n        if len(gamma.shape) < len(x.shape):\n            gamma = gamma.unsqueeze(-1).expand_as(x)\n            beta = beta.unsqueeze(-1).expand_as(x)\n        else:\n            gamma = gamma.expand_as(x)\n            beta = beta.expand_as(x)\n\n        # print(gamma.size(), beta.size())\n        x = (1 + gamma) * x + beta\n        if self.layer_norm is not None:\n            x = self.layer_norm(x)\n        return x\n\n\nclass ConditionalLayerNorm(nn.Module):\n    \"\"\"\n    https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160\n    \"\"\"\n\n    def __init__(self,\n                 normalized_shape,\n                 embed_dim,\n                 modulate_bias=False,\n                 eps=1e-5):\n        super(ConditionalLayerNorm, self).__init__()\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape, )\n        self.normalized_shape = tuple(normalized_shape)\n\n        self.embed_dim = embed_dim\n        self.eps = eps\n\n        self.weight = nn.Parameter(torch.empty(*normalized_shape))\n        self.bias = nn.Parameter(torch.empty(*normalized_shape))\n        assert len(normalized_shape) == 1\n        self.ln_weight_modulation = FiLM(normalized_shape[0], embed_dim)\n        self.modulate_bias = modulate_bias\n        if self.modulate_bias:\n            self.ln_bias_modulation = FiLM(normalized_shape[0], embed_dim)\n        else:\n            self.ln_bias_modulation = None\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, input, embed):\n        mean = torch.mean(input, -1, keepdim=True)\n        var = torch.var(input, -1, unbiased=False, keepdim=True)\n        weight = self.ln_weight_modulation(\n            embed, self.weight.expand(embed.size(0), -1))\n        if self.ln_bias_modulation is None:\n            bias = self.bias\n        else:\n            bias = self.ln_bias_modulation(embed,\n                                           self.bias.expand(embed.size(0), -1))\n        res = (input - mean) / torch.sqrt(var + self.eps) * weight + bias\n        return res\n\n    def extra_repr(self):\n        return \"{normalized_shape}, {embed_dim}, \\\n            modulate_bias={modulate_bias}, eps={eps}\".format(**self.__dict__)\n"
  },
  {
    "path": "wesep/modules/common/speaker.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom wesep.modules.common import FiLM\r\n\r\n\r\nclass PreEmphasis(torch.nn.Module):\r\n\r\n    def __init__(self, coef: float = 0.97):\r\n        super().__init__()\r\n        self.coef = coef\r\n        self.register_buffer(\r\n            \"flipped_filter\",\r\n            torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),\r\n        )\r\n\r\n    def forward(self, input: torch.tensor) -> torch.tensor:\r\n        input = input.unsqueeze(1)\r\n        input = F.pad(input, (1, 0), \"reflect\")\r\n        return F.conv1d(input, self.flipped_filter).squeeze(1)\r\n\r\n\r\nclass SpeakerTransform(nn.Module):\r\n\r\n    def __init__(self, embed_dim=256, num_layers=3, hid_dim=128):\r\n        \"\"\"\r\n        Transform the pretrained speaker embeddings, keep the dimension\r\n        :param embed_dim:\r\n        :param num_layers:\r\n        :param hid_dim:\r\n        :return:\r\n        \"\"\"\r\n        super(SpeakerTransform, self).__init__()\r\n        self.transforms = []\r\n        self.transforms.append(nn.Conv1d(embed_dim, hid_dim, 1))\r\n        for _ in range(num_layers - 2):\r\n            self.transforms.append(nn.Conv1d(hid_dim, hid_dim, 1))\r\n            self.transforms.append(nn.Tanh())\r\n        self.transforms.append(nn.Conv1d(hid_dim, embed_dim, 1))\r\n        self.transforms = nn.Sequential(*self.transforms)\r\n\r\n    def forward(self, x):\r\n        if len(x.size()) == 2:\r\n            return self.transforms(x.unsqueeze(-1)).squeeze(-1)\r\n        else:\r\n            return self.transforms(x)\r\n\r\n\r\nclass LinearLayer(nn.Module):\r\n\r\n    def __init__(self, in_features, out_features, bias=True):\r\n        super(LinearLayer, self).__init__()\r\n\r\n        self.linear = nn.Linear(in_features, out_features, bias)\r\n\r\n    def forward(self, x, dummy: Optional[torch.Tensor] = None):\r\n        return self.linear(x)\r\n\r\n\r\nclass SpeakerFuseLayer(nn.Module):\r\n\r\n    def __init__(self, embed_dim=256, feat_dim=512, fuse_type=\"concat\"):\r\n        super(SpeakerFuseLayer, self).__init__()\r\n        assert fuse_type in [\"concat\", \"additive\", \"multiply\", \"FiLM\", \"None\"]\r\n\r\n        self.fuse_type = fuse_type\r\n        if fuse_type == \"concat\":\r\n            self.fc = LinearLayer(embed_dim + feat_dim, feat_dim)\r\n        elif fuse_type == \"additive\":\r\n            self.fc = LinearLayer(embed_dim, feat_dim)\r\n        elif fuse_type == \"multiply\":\r\n            self.fc = LinearLayer(embed_dim, feat_dim)\r\n        elif fuse_type == \"FiLM\":\r\n            self.fc = FiLM(feat_dim, embed_dim)\r\n        else:\r\n            raise ValueError(\"Fuse type not defined.\")\r\n\r\n    def forward(self, x, embed):\r\n        \"\"\"\r\n\r\n        :param x: batch x dimension x length\r\n        :param embed: batch x dimension x 1\r\n        :return:\r\n        \"\"\"\r\n        if self.fuse_type == \"concat\":\r\n            # For Conv\r\n            if len(x.size()) == 3:\r\n                embed_t = embed.expand(-1, -1, x.size(2))\r\n                y = torch.cat([x, embed_t], 1)\r\n                y = torch.transpose(y, 1, 2)\r\n                x = torch.transpose(self.fc(y), 1, 2)\r\n            else:\r\n                # len(x.size() == 4\r\n                embed_t = embed.expand(-1, x.size(1), -1, x.size(3))\r\n                y = torch.cat([x, embed_t], 2)\r\n                y = torch.transpose(y, 2, 3)\r\n                x = torch.transpose(self.fc(y), 2, 3).contiguous()\r\n                # print(x.size())\r\n        elif self.fuse_type == \"additive\":\r\n            if len(x.size()) == 3:\r\n                embed_t = embed.expand(-1, -1, x.size(2))\r\n                embed_t = torch.transpose(embed_t, 1, 2)\r\n                x = x + torch.transpose(self.fc(embed_t), 1, 2)\r\n            else:\r\n                # len(x.size() == 4\r\n                embed_t = embed.expand(-1, x.size(1), -1, x.size(3))\r\n                embed_t = torch.transpose(embed_t, 2, 3)\r\n                x = x + torch.transpose(self.fc(embed_t), 2, 3)\r\n        elif self.fuse_type == \"multiply\":\r\n            if len(x.size()) == 3:\r\n                embed_t = embed.expand(-1, -1, x.size(2))\r\n                embed_t = torch.transpose(embed_t, 1, 2)\r\n                x = x * torch.transpose(self.fc(embed_t), 1, 2)\r\n            else:\r\n                # len(x.size() == 4\r\n                embed_t = embed.expand(-1, x.size(1), -1, x.size(3))\r\n                embed_t = torch.transpose(embed_t, 2, 3)\r\n                x = x * torch.transpose(self.fc(embed_t), 2, 3)\r\n        else:\r\n            embed = embed.squeeze(-1)\r\n            x = self.fc(embed, x)\r\n        return x\r\n\r\n\r\ndef test_speaker_fuse():\r\n    st = SpeakerTransform(embed_dim=256, num_layers=3, hid_dim=128)\r\n    sfl = SpeakerFuseLayer(fuse_type=\"multiply\")\r\n\r\n    embeds = torch.rand(4, 256)\r\n    encoder_output = torch.rand(4, 512, 1000)\r\n\r\n    print(embeds.size())\r\n    embeds = st(embeds)\r\n    print(embeds.size())\r\n    output = sfl(encoder_output, embeds)\r\n    print(output.size())\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    test_speaker_fuse()\r\n"
  },
  {
    "path": "wesep/modules/dpccn/__init__.py",
    "content": ""
  },
  {
    "path": "wesep/modules/dpccn/convs.py",
    "content": "from typing import Tuple\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\n\r\nclass Conv1D(nn.Conv1d):\r\n    \"\"\"\r\n    1D conv in ConvTasNet\r\n    \"\"\"\r\n\r\n    def __init__(self, *args, **kwargs):\r\n        super(Conv1D, self).__init__(*args, **kwargs)\r\n\r\n    def forward(self, x, squeeze=False):\r\n        \"\"\"\r\n        x: N x L or N x C x L\r\n        \"\"\"\r\n        if x.dim() not in [2, 3]:\r\n            raise RuntimeError(\"{} accept 2/3D tensor as input\".format(\r\n                self.__name__))\r\n        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))\r\n        if squeeze:\r\n            x = torch.squeeze(x)\r\n        return x\r\n\r\n\r\nclass Conv2dBlock(nn.Module):\r\n\r\n    def __init__(\r\n            self,\r\n            in_dims: int = 16,\r\n            out_dims: int = 32,\r\n            kernel_size: Tuple[int] = (3, 3),\r\n            stride: Tuple[int] = (1, 1),\r\n            padding: Tuple[int] = (1, 1),\r\n    ) -> None:\r\n        super(Conv2dBlock, self).__init__()\r\n        self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride,\r\n                                padding)\r\n        self.elu = nn.ELU()\r\n        self.norm = nn.InstanceNorm2d(out_dims)\r\n\r\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\r\n        x = self.conv2d(x)\r\n        x = self.elu(x)\r\n        return self.norm(x)\r\n\r\n\r\nclass ConvTrans2dBlock(nn.Module):\r\n\r\n    def __init__(\r\n            self,\r\n            in_dims: int = 32,\r\n            out_dims: int = 16,\r\n            kernel_size: Tuple[int] = (3, 3),\r\n            stride: Tuple[int] = (1, 2),\r\n            padding: Tuple[int] = (1, 0),\r\n            output_padding: Tuple[int] = (0, 0),\r\n    ) -> None:\r\n        super(ConvTrans2dBlock, self).__init__()\r\n        self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size,\r\n                                              stride, padding, output_padding)\r\n        self.elu = nn.ELU()\r\n        self.norm = nn.InstanceNorm2d(out_dims)\r\n\r\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\r\n        x = self.convtrans2d(x)\r\n        x = self.elu(x)\r\n        return self.norm(x)\r\n\r\n\r\nclass DenseBlock(nn.Module):\r\n\r\n    def __init__(self, in_dims, out_dims, mode=\"enc\", **kargs):\r\n        super(DenseBlock, self).__init__()\r\n        if mode not in [\"enc\", \"dec\"]:\r\n            raise RuntimeError(\"The mode option must be 'enc' or 'dec'!\")\r\n\r\n        n = 1 if mode == \"enc\" else 2\r\n        self.conv1 = Conv2dBlock(in_dims=in_dims * n,\r\n                                 out_dims=in_dims,\r\n                                 **kargs)\r\n        self.conv2 = Conv2dBlock(in_dims=in_dims * (n + 1),\r\n                                 out_dims=in_dims,\r\n                                 **kargs)\r\n        self.conv3 = Conv2dBlock(in_dims=in_dims * (n + 2),\r\n                                 out_dims=in_dims,\r\n                                 **kargs)\r\n        self.conv4 = Conv2dBlock(in_dims=in_dims * (n + 3),\r\n                                 out_dims=in_dims,\r\n                                 **kargs)\r\n        self.conv5 = Conv2dBlock(in_dims=in_dims * (n + 4),\r\n                                 out_dims=out_dims,\r\n                                 **kargs)\r\n\r\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\r\n        y1 = self.conv1(x)\r\n        y2 = self.conv2(torch.cat([x, y1], 1))\r\n        y3 = self.conv3(torch.cat([x, y1, y2], 1))\r\n        y4 = self.conv4(torch.cat([x, y1, y2, y3], 1))\r\n        y5 = self.conv5(torch.cat([x, y1, y2, y3, y4], 1))\r\n        return y5\r\n\r\n\r\nclass TCNBlock(nn.Module):\r\n    \"\"\"\r\n    TCN block:\r\n        IN - ELU - Conv1D - IN - ELU - Conv1D\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        in_dims: int = 384,\r\n        out_dims: int = 384,\r\n        kernel_size: int = 3,\r\n        dilation: int = 1,\r\n        causal: bool = False,\r\n    ) -> None:\r\n        super(TCNBlock, self).__init__()\r\n        self.norm1 = nn.InstanceNorm1d(in_dims)\r\n        self.elu1 = nn.ELU()\r\n        dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else\r\n                     (dilation * (kernel_size - 1)))\r\n        # dilated conv\r\n        self.dconv1 = nn.Conv1d(\r\n            in_dims,\r\n            out_dims,\r\n            kernel_size,\r\n            padding=dconv_pad,\r\n            dilation=dilation,\r\n            groups=in_dims,\r\n            bias=True,\r\n        )\r\n\r\n        self.norm2 = nn.InstanceNorm1d(in_dims)\r\n        self.elu2 = nn.ELU()\r\n        self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True)\r\n\r\n        # different padding way\r\n        self.causal = causal\r\n        self.dconv_pad = dconv_pad\r\n\r\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\r\n        y = self.elu1(self.norm1(x))\r\n        y = self.dconv1(y)\r\n        if self.causal:\r\n            y = y[:, :, :-self.dconv_pad]\r\n        y = self.elu2(self.norm2(y))\r\n        y = self.dconv2(y)\r\n        x = x + y\r\n        return x\r\n"
  },
  {
    "path": "wesep/modules/metric_gan/__init__.py",
    "content": ""
  },
  {
    "path": "wesep/modules/metric_gan/discriminator.py",
    "content": "import torch\nimport torch.nn as nn\n\n\n# utility functions/classes used in the implementation of discriminators.\nclass LearnableSigmoid(nn.Module):\n\n    def __init__(self, in_features, beta=1):\n        super().__init__()\n        self.beta = beta\n        self.slope = nn.Parameter(torch.ones(in_features))\n        self.slope.requiresGrad = True\n\n    def forward(self, x):\n        return self.beta * torch.sigmoid(self.slope * x)\n\n\n# discriminators\nclass CMGAN_Discriminator(nn.Module):\n\n    def __init__(\n            self,\n            n_fft=400,\n            hop=100,\n            in_channels=2,\n            hid_chans=16,\n            ksz=(4, 4),\n            stride=(2, 2),\n            padding=(1, 1),\n            bias=False,\n            num_conv_blocks=4,\n            num_linear_layers=2,\n    ):\n        \"\"\"discriminator used in CMGAN (Interspeech 2022)\n            paper: https://arxiv.org/pdf/2203.15149.pdf\n            code: https://github.com/ruizhecao96/CMGAN\n\n        Args:\n        n_fft (int, optional): the windows length of stft. Defaults to 400.\n        hop (int, optional): the hop length of stft. Defaults to 100.\n        in_channels (int, optional): num of input channels. Defaults to 2.\n        hid_chans (int, optional): num of hidden channels. Defaults to 16.\n        ksz (tuple, optional): kernel size. Defaults to (4, 4).\n        stride (tuple, optional): stride. Defaults to (2, 2).\n        padding (tuple, optional): padding. Defaults to (1, 1).\n        bias (bool, optional): bias. Defaults to False.\n        num_conv_blocks (int, optional): num of conv blocks. Defaults to 4.\n        num_linear_layers (int, optional): num of linear layers. Defaults to 2.\n        \"\"\"\n        super(CMGAN_Discriminator, self).__init__()\n        assert num_conv_blocks >= num_linear_layers\n\n        self.n_fft = n_fft\n        self.hop = hop\n        self.num_conv_blocks = num_conv_blocks\n        self.num_linear_layers = num_linear_layers\n\n        self.conv = nn.ModuleList([])\n        in_chans = in_channels\n        out_chans = hid_chans\n        for i in range(num_conv_blocks):\n            self.conv.append(\n                nn.Sequential(\n                    nn.utils.spectral_norm(\n                        nn.Conv2d(\n                            in_chans,\n                            out_chans,\n                            ksz,\n                            stride,\n                            padding,\n                            bias=bias,\n                        )),\n                    nn.InstanceNorm2d(out_chans, affine=True),\n                    nn.PReLU(out_chans),\n                ))\n            in_chans = out_chans\n            out_chans = hid_chans * (2**(i + 1))\n\n        self.pooling = nn.Sequential(\n            nn.AdaptiveMaxPool2d(1),\n            nn.Flatten(),\n        )\n\n        self.fc = nn.ModuleList([])\n        for i in range(num_linear_layers - 1):\n            self.fc.append(\n                nn.Sequential(\n                    nn.utils.spectral_norm(\n                        nn.Linear(\n                            hid_chans * (2**(num_conv_blocks - 1 - i)),\n                            hid_chans * (2**(num_conv_blocks - 2 - i)),\n                        )),\n                    nn.Dropout(0.3),\n                    nn.PReLU(hid_chans * (2**(num_conv_blocks - 2 - i))),\n                ))\n        self.fc.append(\n            nn.Sequential(\n                nn.utils.spectral_norm(\n                    nn.Linear(\n                        hid_chans * (2**(num_conv_blocks - num_linear_layers)),\n                        1,\n                    )),\n                LearnableSigmoid(1),\n            ))\n\n    def forward(self, ref_wav, est_wav):\n        \"\"\"\n\n        Args:\n            ref_wav (torch.Tensor): the reference signal. [B, T]\n            est_wav (torch.Tensor): the estimated signal. [B, T]\n\n        Return:\n            estimated_scores (torch.Tensor): estimated scores, [B]\n        \"\"\"\n        ref_spec = torch.stft(\n            ref_wav,\n            self.n_fft,\n            self.hop,\n            window=torch.hann_window(self.n_fft).to(ref_wav.device).type(\n                ref_wav.type()),\n            return_complex=True,\n        ).transpose(-1, -2)\n        est_spec = torch.stft(\n            est_wav,\n            self.n_fft,\n            self.hop,\n            window=torch.hann_window(self.n_fft).to(est_wav.device).type(\n                est_wav.type()),\n            return_complex=True,\n        ).transpose(-1, -2)\n        # input shape: (B, 2, T, F)\n        input = torch.stack((abs(ref_spec), abs(est_spec)), dim=1)\n        for i in range(self.num_conv_blocks):\n            input = self.conv[i](input)\n\n        input = self.pooling(input)\n        for i in range(self.num_linear_layers):\n            input = self.fc[i](input)\n        return input\n\n\nif __name__ == \"__main__\":\n    # functions used to test discriminators\n    def test_CMGAN_Discriminator():\n        B, T = 2, 16000\n        ref_spec = torch.randn(B, T)\n        est_spec = torch.randn(B, T)\n        D = CMGAN_Discriminator()\n        metric = D(ref_spec, est_spec).detach()\n        print(f\"estimated metric score is {metric}\")\n\n    test_CMGAN_Discriminator()\n"
  },
  {
    "path": "wesep/modules/tasnet/__init__.py",
    "content": "from wesep.modules.tasnet.decoder import DeepDecoder  # noqa\r\nfrom wesep.modules.tasnet.decoder import MultiDecoder  # noqa\r\nfrom wesep.modules.tasnet.encoder import DeepEncoder  # noqa\r\nfrom wesep.modules.tasnet.encoder import MultiEncoder  # noqa\r\nfrom wesep.modules.tasnet.separation import Separation, FuseSeparation  # noqa\r\nfrom wesep.modules.tasnet.speaker import ResNet4SpExplus  # noqa\r\n"
  },
  {
    "path": "wesep/modules/tasnet/convs.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom wesep.modules.common import select_norm\n\n# from wesep.modules.common.spkadapt import SpeakerFuseLayer\n\n\nclass Conv1D(nn.Conv1d):\n\n    def __init__(self, *args, **kwargs):\n        super(Conv1D, self).__init__(*args, **kwargs)\n\n    def forward(self, x, squeeze=False):\n        # x: N x C x L\n        if x.dim() not in [2, 3]:\n            raise RuntimeError(\"{} accept 2/3D tensor as input\".format(\n                self.__name__))\n        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))\n        if squeeze:\n            x = torch.squeeze(x)\n        return x\n\n\nclass ConvTrans1D(nn.ConvTranspose1d):\n\n    def __init__(self, *args, **kwargs):\n        super(ConvTrans1D, self).__init__(*args, **kwargs)\n\n    def forward(self, x, squeeze=False):\n        \"\"\"\n        x: N x L or N x C x L\n        \"\"\"\n        if x.dim() not in [2, 3]:\n            raise RuntimeError(\"{} accept 2/3D tensor as input\".format(\n                self.__name__))\n        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))\n        if squeeze:\n            x = torch.squeeze(x)\n        return x\n\n\nclass Conv1DBlock(nn.Module):\n    \"\"\"\n    Consider only residual links\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels=256,\n        out_channels=512,\n        kernel_size=3,\n        dilation=1,\n        norm=\"gln\",\n        causal=False,\n        skip_con=True,\n    ):\n        super(Conv1DBlock, self).__init__()\n        # conv 1 x 1\n        self.conv1x1 = Conv1D(in_channels, out_channels, 1)\n        self.PReLU_1 = nn.PReLU()\n        self.norm_1 = select_norm(norm, out_channels)\n        # not causal don't need to padding, causal need to pad+1 = kernel_size\n        self.pad = ((dilation * (kernel_size - 1)) // 2 if not causal else\n                    (dilation * (kernel_size - 1)))\n        # depthwise convolution\n        # TODO: This is not depthwise seperable convolution\n        self.dwconv = Conv1D(\n            out_channels,\n            out_channels,\n            kernel_size,\n            groups=out_channels,\n            padding=self.pad,\n            dilation=dilation,\n        )\n        self.PReLU_2 = nn.PReLU()\n        self.norm_2 = select_norm(norm, out_channels)\n        if skip_con:\n            self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True)\n        self.Output = nn.Conv1d(out_channels, in_channels, 1, bias=True)\n        self.causal = causal\n        self.skip_con = skip_con\n\n    def forward(self, x):\n        # x: N x C x L\n        # N x O_C x L\n        c = self.conv1x1(x)\n        # N x O_C x L\n        c = self.PReLU_1(c)\n        c = self.norm_1(c)\n        # causal: N x O_C x (L+pad)\n        # noncausal: N x O_C x L\n        c = self.dwconv(c)\n        if self.causal:\n            c = c[:, :, :-self.pad]\n        c = self.PReLU_2(c)\n        c = self.norm_2(c)\n        # N x O_C x L\n        if self.skip_con:\n            Sc = self.Sc_conv(c)\n            c = self.Output(c)\n            return Sc, c + x\n        c = self.Output(c)\n        return x + c\n\n\nclass Conv1DBlock4Fuse(nn.Module):\n    \"\"\"\n    1D convolutional block:\n        Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels=256,\n        spk_embed_dim=100,\n        conv_channels=512,\n        kernel_size=3,\n        dilation=1,\n        norm=\"cLN\",\n        causal=False,\n    ):\n        super(Conv1DBlock4Fuse, self).__init__()\n        # 1x1 conv\n        self.conv1x1 = Conv1D(in_channels + spk_embed_dim, conv_channels, 1)\n        self.prelu1 = nn.PReLU()\n        self.lnorm1 = select_norm(norm, conv_channels)\n        dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else\n                     (dilation * (kernel_size - 1)))\n        # depthwise conv\n        self.dconv = nn.Conv1d(\n            conv_channels,\n            conv_channels,\n            kernel_size,\n            groups=conv_channels,\n            padding=dconv_pad,\n            dilation=dilation,\n            bias=True,\n        )\n        self.prelu2 = nn.PReLU()\n        self.lnorm2 = select_norm(norm, conv_channels)\n        # 1x1 conv cross channel\n        self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)\n        # different padding way\n        self.causal = causal\n        self.dconv_pad = dconv_pad\n\n    def forward(self, x, aux):\n        T = x.shape[-1]\n        aux = aux.repeat(1, 1, T)\n        y = torch.cat([x, aux], 1)\n        y = self.conv1x1(y)\n        y = self.lnorm1(self.prelu1(y))\n        y = self.dconv(y)\n        if self.causal:\n            y = y[:, :, :-self.dconv_pad]\n        y = self.lnorm2(self.prelu2(y))\n        y = self.sconv(y)\n        x = x + y\n        return x\n"
  },
  {
    "path": "wesep/modules/tasnet/decoder.py",
    "content": "import torch\r\nimport torch.nn as nn\r\n\r\nfrom wesep.modules.tasnet.convs import Conv1D, ConvTrans1D\r\n\r\n\r\nclass DeepDecoder(nn.Module):\r\n\r\n    def __init__(self, N, kernel_size=16, stride=16 // 2):\r\n        super(DeepDecoder, self).__init__()\r\n        self.sequential = nn.Sequential(\r\n            nn.ConvTranspose1d(N,\r\n                               N,\r\n                               kernel_size=3,\r\n                               stride=1,\r\n                               dilation=8,\r\n                               padding=8),\r\n            nn.PReLU(),\r\n            nn.ConvTranspose1d(N,\r\n                               N,\r\n                               kernel_size=3,\r\n                               stride=1,\r\n                               dilation=4,\r\n                               padding=4),\r\n            nn.PReLU(),\r\n            nn.ConvTranspose1d(N,\r\n                               N,\r\n                               kernel_size=3,\r\n                               stride=1,\r\n                               dilation=2,\r\n                               padding=2),\r\n            nn.PReLU(),\r\n            nn.ConvTranspose1d(N,\r\n                               N,\r\n                               kernel_size=3,\r\n                               stride=1,\r\n                               dilation=1,\r\n                               padding=1),\r\n            nn.PReLU(),\r\n            nn.ConvTranspose1d(N,\r\n                               1,\r\n                               kernel_size=kernel_size,\r\n                               stride=stride,\r\n                               bias=True),\r\n        )\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        x: N x L or N x C x L\r\n        \"\"\"\r\n        x = self.sequential(x)\r\n        if torch.squeeze(x).dim() == 1:\r\n            x = torch.squeeze(x, dim=1)\r\n        else:\r\n            x = torch.squeeze(x)\r\n\r\n        return x\r\n\r\n\r\nclass MultiDecoder(nn.Module):\r\n\r\n    def __init__(self, in_channels, middle_channels, out_channels, kernel_size,\r\n                 stride):\r\n        super(MultiDecoder, self).__init__()\r\n\r\n        B = in_channels\r\n        N = middle_channels\r\n        L = kernel_size\r\n        # n x B x T => n x 2N x T\r\n        self.mask1 = Conv1D(B, N, 1)\r\n        self.mask2 = Conv1D(B, N, 1)\r\n        self.mask3 = Conv1D(B, N, 1)\r\n\r\n        # using ConvTrans1D: n x N x T => n x 1 x To\r\n        # To = (T - 1) * L // 2 + L\r\n        self.decoder_1d_1 = ConvTrans1D(N,\r\n                                        out_channels,\r\n                                        kernel_size=L,\r\n                                        stride=stride,\r\n                                        bias=True)\r\n        self.decoder_1d_2 = ConvTrans1D(N,\r\n                                        out_channels,\r\n                                        kernel_size=80,\r\n                                        stride=stride,\r\n                                        bias=True)\r\n        self.decoder_1d_3 = ConvTrans1D(N,\r\n                                        out_channels,\r\n                                        kernel_size=160,\r\n                                        stride=stride,\r\n                                        bias=True)\r\n\r\n    def forward(self, x, w1, w2, w3, actLayer):\r\n        \"\"\"\r\n        x: N x L or N x C x L\r\n        \"\"\"\r\n        m1 = actLayer(self.mask1(x))\r\n        m2 = actLayer(self.mask2(x))\r\n        m3 = actLayer(self.mask3(x))\r\n\r\n        s1 = w1 * m1\r\n        s2 = w2 * m2\r\n        s3 = w3 * m3\r\n\r\n        est1 = self.decoder_1d_1(s1, squeeze=True)\r\n        xlen = est1.shape[-1]\r\n        if est1.dim() > 1:\r\n            est2 = self.decoder_1d_2(s2, squeeze=True)[:, :xlen]\r\n            est3 = self.decoder_1d_3(s3, squeeze=True)[:, :xlen]\r\n        else:\r\n            est1 = est1.unsqueeze(0)\r\n            est2 = self.decoder_1d_2(s2, squeeze=True).unsqueeze(0)[:, :xlen]\r\n            est3 = self.decoder_1d_3(s3, squeeze=True).unsqueeze(0)[:, :xlen]\r\n        s = [est1, est2, est3]\r\n        return s\r\n"
  },
  {
    "path": "wesep/modules/tasnet/encoder.py",
    "content": "import torch as th\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom wesep.modules.common import select_norm\r\nfrom wesep.modules.tasnet.convs import Conv1D\r\n\r\n\r\nclass DeepEncoder(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, kernel_size, stride):\r\n        super(DeepEncoder, self).__init__()\r\n        self.sequential = nn.Sequential(\r\n            Conv1D(in_channels, out_channels, kernel_size, stride=stride),\r\n            Conv1D(\r\n                out_channels,\r\n                out_channels,\r\n                kernel_size=3,\r\n                stride=1,\r\n                dilation=1,\r\n                padding=1,\r\n            ),\r\n            nn.PReLU(),\r\n            Conv1D(\r\n                out_channels,\r\n                out_channels,\r\n                kernel_size=3,\r\n                stride=1,\r\n                dilation=2,\r\n                padding=2,\r\n            ),\r\n            nn.PReLU(),\r\n            Conv1D(\r\n                out_channels,\r\n                out_channels,\r\n                kernel_size=3,\r\n                stride=1,\r\n                dilation=4,\r\n                padding=4,\r\n            ),\r\n            nn.PReLU(),\r\n            Conv1D(\r\n                out_channels,\r\n                out_channels,\r\n                kernel_size=3,\r\n                stride=1,\r\n                dilation=8,\r\n                padding=8,\r\n            ),\r\n            nn.PReLU(),\r\n        )\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        :param  x: [B, T]\r\n        :return: out: [B, N, T]\r\n        \"\"\"\r\n\r\n        x = self.sequential(x)\r\n        return x\r\n\r\n\r\nclass MultiEncoder(nn.Module):\r\n\r\n    def __init__(self, in_channels, middle_channels, out_channels, kernel_size,\r\n                 stride):\r\n        super(MultiEncoder, self).__init__()\r\n        self.L1 = kernel_size\r\n        self.L2 = 80\r\n        self.L3 = 160\r\n        self.encoder_1d_short = Conv1D(in_channels,\r\n                                       middle_channels,\r\n                                       self.L1,\r\n                                       stride=stride,\r\n                                       padding=0)\r\n        self.encoder_1d_middle = Conv1D(in_channels,\r\n                                        middle_channels,\r\n                                        self.L2,\r\n                                        stride=stride,\r\n                                        padding=0)\r\n        self.encoder_1d_long = Conv1D(in_channels,\r\n                                      middle_channels,\r\n                                      self.L3,\r\n                                      stride=stride,\r\n                                      padding=0)\r\n        # keep T not change\r\n        # T = int((xlen - L) / (L // 2)) + 1\r\n        # before repeat blocks, always cLN\r\n        self.ln = select_norm(\r\n            \"cLN\",\r\n            3 * middle_channels)  # ChannelWiseLayerNorm(3 * middle_channels)\r\n        # n x N x T => n x B x T\r\n        self.proj = Conv1D(3 * middle_channels, out_channels, 1)\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        :param  x: [B, T]\r\n        :return: out: [B, N, T]\r\n        \"\"\"\r\n        w1 = F.relu(self.encoder_1d_short(x))\r\n        T = w1.shape[-1]\r\n        xlen1 = x.shape[-1]\r\n        xlen2 = (T - 1) * (self.L1 // 2) + self.L2\r\n        xlen3 = (T - 1) * (self.L1 // 2) + self.L3\r\n        w2 = F.relu(\r\n            self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), \"constant\",\r\n                                         0)))\r\n        w3 = F.relu(\r\n            self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), \"constant\", 0)))\r\n        # n x 3N x T\r\n        x = self.ln(th.cat([w1, w2, w3], 1))\r\n        # n x B x T\r\n        x = self.proj(x)\r\n        return x, w1, w2, w3\r\n"
  },
  {
    "path": "wesep/modules/tasnet/separation.py",
    "content": "import torch.nn as nn\r\n\r\nfrom wesep.modules.common import select_norm\r\nfrom wesep.modules.common.speaker import SpeakerFuseLayer\r\nfrom wesep.modules.tasnet.convs import Conv1DBlock, Conv1DBlock4Fuse\r\n\r\n\r\nclass Separation(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        R,\r\n        X,\r\n        B,\r\n        H,\r\n        P,\r\n        norm=\"gLN\",\r\n        causal=False,\r\n        skip_con=True,\r\n        start_dilation=0,\r\n    ):\r\n        \"\"\"\r\n        Args\r\n        :param R: Number of repeats\r\n        :param X: Number of convolutional blocks in each repeat\r\n        :param B: Number of channels in bottleneck and the residual paths\r\n        :param H: Number of channels in convolutional blocks\r\n        :param P: Kernel size in convolutional blocks\r\n        :param norm: The type of normalization(gln, cln, bn)\r\n        :param causal: Two choice(causal or noncausal)\r\n        :param skip_con: Whether to use skip connection\r\n        \"\"\"\r\n        super(Separation, self).__init__()\r\n        self.separation = nn.ModuleList([])\r\n        for _ in range(R):\r\n            for x in range(start_dilation, X):\r\n                self.separation.append(\r\n                    Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con))\r\n        self.skip_con = skip_con\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        x: [B, N, L]\r\n        out: [B, N, L]\r\n        \"\"\"\r\n        if self.skip_con:\r\n            skip_connection = 0\r\n            for i in range(len(self.separation)):\r\n                skip, out = self.separation[i](x)\r\n                skip_connection = skip_connection + skip\r\n                x = out\r\n            return skip_connection\r\n        else:\r\n            for i in range(len(self.separation)):\r\n                out = self.separation[i](x)\r\n                x = out\r\n            return x\r\n\r\n\r\nclass FuseSeparation(nn.Module):\r\n\r\n    def __init__(\r\n        self,\r\n        R,\r\n        X,\r\n        B,\r\n        H,\r\n        P,\r\n        norm=\"gLN\",\r\n        causal=False,\r\n        skip_con=False,\r\n        C_embedding=256,\r\n        spk_fuse_type=\"concatConv\",\r\n        multi_fuse=True,\r\n    ):\r\n        \"\"\"\r\n\r\n        :param R: Number of repeats\r\n        :param X: Number of convolutional blocks in each repeat\r\n        :param B: Number of channels in bottleneck and the residual paths\r\n        :param H: Number of channels in convolutional blocks\r\n        :param P: Kernel size in convolutional blocks\r\n        :param norm: The type of normalization(gln, cln, bn)\r\n        :param causal: Two choice(causal or noncausal)\r\n        :param skip_con: Whether to use skip connection\r\n        \"\"\"\r\n        super(FuseSeparation, self).__init__()\r\n        self.multi_fuse = multi_fuse\r\n        self.spk_fuse_type = spk_fuse_type\r\n        self.separation = nn.ModuleList([])\r\n        if self.multi_fuse:\r\n            for _ in range(R):\r\n                if spk_fuse_type == \"concatConv\":\r\n                    self.separation.append(\r\n                        Conv1DBlock4Fuse(\r\n                            spk_embed_dim=C_embedding,\r\n                            in_channels=B,\r\n                            conv_channels=H,\r\n                            kernel_size=P,\r\n                            norm=norm,\r\n                            causal=causal,\r\n                            dilation=1,\r\n                        ))\r\n                    self.separation.append(\r\n                        Separation(\r\n                            1,\r\n                            X,\r\n                            B,\r\n                            H,\r\n                            P,\r\n                            norm=norm,\r\n                            causal=causal,\r\n                            skip_con=skip_con,\r\n                            start_dilation=1,\r\n                        ))\r\n                else:\r\n                    self.separation.append(\r\n                        SpeakerFuseLayer(\r\n                            embed_dim=C_embedding,\r\n                            feat_dim=B,\r\n                            fuse_type=spk_fuse_type,\r\n                        ))\r\n                    self.separation.append(nn.PReLU())\r\n                    self.separation.append(select_norm(norm, B))\r\n                    self.separation.append(\r\n                        Separation(\r\n                            1,\r\n                            X,\r\n                            B,\r\n                            H,\r\n                            P,\r\n                            norm=norm,\r\n                            causal=causal,\r\n                            skip_con=skip_con,\r\n                        ))\r\n        else:\r\n            if spk_fuse_type == \"concatConv\":\r\n                self.separation.append(\r\n                    Conv1DBlock4Fuse(\r\n                        spk_embed_dim=C_embedding,\r\n                        in_channels=B,\r\n                        conv_channels=H,\r\n                        kernel_size=P,\r\n                        norm=norm,\r\n                        causal=causal,\r\n                        dilation=1,\r\n                    ))\r\n            else:\r\n                self.separation.append(\r\n                    SpeakerFuseLayer(\r\n                        embed_dim=C_embedding,\r\n                        feat_dim=B,\r\n                        fuse_type=spk_fuse_type,\r\n                    ))\r\n                self.separation.append(nn.PReLU())\r\n                self.separation.append(select_norm(norm, B))\r\n            self.separation = Separation(R,\r\n                                         X,\r\n                                         B,\r\n                                         H,\r\n                                         P,\r\n                                         norm=norm,\r\n                                         causal=causal,\r\n                                         skip_con=skip_con)\r\n\r\n    def forward(self, x, spk_embedding):\r\n        \"\"\"\r\n        x: [B, N, L]\r\n        out: [B, N, L]\r\n        \"\"\"\r\n\r\n        if self.multi_fuse:\r\n            if self.spk_fuse_type == \"concatConv\":\r\n                round_num = 2\r\n            else:\r\n                round_num = 4\r\n            for i in range(len(self.separation)):\r\n                if i % round_num == 0:\r\n                    x = self.separation[i](x, spk_embedding)\r\n                else:\r\n                    x = self.separation[i](x)\r\n        else:\r\n            x = self.separation[0](x, spk_embedding)\r\n            for i in range(1, len(self.separation)):\r\n                x = self.separation[i](x)\r\n        return x\r\n"
  },
  {
    "path": "wesep/modules/tasnet/separator.py",
    "content": "import torch.nn as nn\r\n\r\nfrom wesep.modules.tasnet.convs import Conv1DBlock\r\n\r\n\r\nclass Separation(nn.Module):\r\n    \"\"\"\r\n    R    Number of repeats\r\n    X    Number of convolutional blocks in each repeat\r\n    B    Number of channels in bottleneck and the residual paths\r\n    H    Number of channels in convolutional blocks\r\n    P    Kernel size in convolutional blocks\r\n    norm The type of normalization(gln, cl, bn)\r\n    causal  Two choice(causal or noncausal)\r\n    skip_con Whether to use skip connection\r\n    \"\"\"\r\n\r\n    def __init__(self, R, X, B, H, P, norm=\"gln\", causal=False, skip_con=True):\r\n        super(Separation, self).__init__()\r\n        self.separation = nn.ModuleList([])\r\n        for _ in range(R):\r\n            for x in range(X):\r\n                self.separation.append(\r\n                    Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con))\r\n        self.skip_con = skip_con\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        x: [B, N, L]\r\n        out: [B, N, L]\r\n        \"\"\"\r\n        if self.skip_con:\r\n            skip_connection = 0\r\n            for i in range(len(self.separation)):\r\n                skip, out = self.separation[i](x)\r\n                skip_connection = skip_connection + skip\r\n                x = out\r\n            return skip_connection\r\n        else:\r\n            for i in range(len(self.separation)):\r\n                out = self.separation[i](x)\r\n                x = out\r\n            return x\r\n"
  },
  {
    "path": "wesep/modules/tasnet/speaker.py",
    "content": "import torch.nn as nn\n\nfrom wesep.modules.common.norm import ChannelWiseLayerNorm\nfrom wesep.modules.tasnet.convs import Conv1D\n\n\nclass ResBlock(nn.Module):\n    \"\"\"\n    ref to\n        https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py\n        and\n        https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py\n    \"\"\"\n\n    def __init__(self, in_dims, out_dims):\n        super().__init__()\n        self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)\n        self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)\n        self.batch_norm1 = nn.BatchNorm1d(out_dims)\n        self.batch_norm2 = nn.BatchNorm1d(out_dims)\n        self.prelu1 = nn.PReLU()\n        self.prelu2 = nn.PReLU()\n\n        self.mp = nn.MaxPool1d(3)\n        if in_dims != out_dims:\n            self.downsample = True\n            self.conv_downsample = nn.Conv1d(in_dims,\n                                             out_dims,\n                                             kernel_size=1,\n                                             bias=False)\n        else:\n            self.downsample = False\n\n    def forward(self, x):\n        residual = x\n        x = self.conv1(x)\n        x = self.batch_norm1(x)\n        x = self.prelu1(x)\n        x = self.conv2(x)\n        x = self.batch_norm2(x)\n        if self.downsample:\n            residual = self.conv_downsample(residual)\n        x = x + residual\n        x = self.prelu2(x)\n        return self.mp(x)\n\n\nclass ResNet4SpExplus(nn.Module):\n\n    def __init__(self, in_channel=256, C_embedding=256):\n        super().__init__()\n        self.aux_enc3 = nn.Sequential(\n            ChannelWiseLayerNorm(3 * in_channel),\n            Conv1D(3 * 256, 256, 1),\n            ResBlock(256, 256),\n            ResBlock(256, 512),\n            ResBlock(512, 512),\n            Conv1D(512, C_embedding, 1),\n        )\n\n    def forward(self, x):\n        aux = self.aux_enc3(x)\n        aux = aux.mean(dim=-1)\n        return aux\n"
  },
  {
    "path": "wesep/modules/tfgridnet/__init__.py",
    "content": ""
  },
  {
    "path": "wesep/modules/tfgridnet/gridnet_block.py",
    "content": "# The implementation is based on:\n# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nfrom torch.nn.parameter import Parameter\n\nfrom wesep.utils.utils import get_layer\n\n\nclass GridNetBlock(nn.Module):\n\n    def __getitem__(self, key):\n        return getattr(self, key)\n\n    def __init__(\n        self,\n        emb_dim,\n        emb_ks,\n        emb_hs,\n        n_freqs,\n        hidden_channels,\n        n_head=4,\n        approx_qk_dim=512,\n        activation=\"prelu\",\n        eps=1e-5,\n    ):\n        super().__init__()\n        assert activation == \"prelu\"\n\n        in_channels = emb_dim * emb_ks\n\n        self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)\n        self.intra_rnn = nn.LSTM(\n            in_channels,\n            hidden_channels,\n            1,\n            batch_first=True,\n            bidirectional=True,\n        )\n        if emb_ks == emb_hs:\n            self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)\n        else:\n            self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2,\n                                                   emb_dim,\n                                                   emb_ks,\n                                                   stride=emb_hs)\n\n        self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)\n        self.inter_rnn = nn.LSTM(\n            in_channels,\n            hidden_channels,\n            1,\n            batch_first=True,\n            bidirectional=True,\n        )\n        if emb_ks == emb_hs:\n            self.inter_linear = nn.Linear(hidden_channels * 2, in_channels)\n        else:\n            self.inter_linear = nn.ConvTranspose1d(hidden_channels * 2,\n                                                   emb_dim,\n                                                   emb_ks,\n                                                   stride=emb_hs)\n\n        E = math.ceil(approx_qk_dim * 1.0 /\n                      n_freqs)  # approx_qk_dim is only approximate\n        assert emb_dim % n_head == 0\n\n        self.add_module(\"attn_conv_Q\", nn.Conv2d(emb_dim, n_head * E, 1))\n        self.add_module(\n            \"attn_norm_Q\",\n            AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),\n        )\n\n        self.add_module(\"attn_conv_K\", nn.Conv2d(emb_dim, n_head * E, 1))\n        self.add_module(\n            \"attn_norm_K\",\n            AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),\n        )\n\n        self.add_module(\"attn_conv_V\",\n                        nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1))\n        self.add_module(\n            \"attn_norm_V\",\n            AllHeadPReLULayerNormalization4DCF(\n                (n_head, emb_dim // n_head, n_freqs), eps=eps),\n        )\n\n        self.add_module(\n            \"attn_concat_proj\",\n            nn.Sequential(\n                nn.Conv2d(emb_dim, emb_dim, 1),\n                get_layer(activation)(),\n                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),\n            ),\n        )\n\n        self.emb_dim = emb_dim\n        self.emb_ks = emb_ks\n        self.emb_hs = emb_hs\n        self.n_head = n_head\n\n    def forward(self, x):\n        \"\"\"GridNetBlock Forward.\n\n        Args:\n            x: [B, C, T, Q]\n            out: [B, C, T, Q]\n        \"\"\"\n        B, C, old_T, old_Q = x.shape\n\n        olp = self.emb_ks - self.emb_hs\n        T = math.ceil((old_T + 2 * olp - self.emb_ks) /\n                      self.emb_hs) * self.emb_hs + self.emb_ks\n        Q = math.ceil((old_Q + 2 * olp - self.emb_ks) /\n                      self.emb_hs) * self.emb_hs + self.emb_ks\n\n        x = x.permute(0, 2, 3, 1)  # [B, old_T, old_Q, C]\n        x = F.pad(\n            x,\n            (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp))  # [B, T, Q, C]\n\n        # intra RNN\n        input_ = x\n        intra_rnn = self.intra_norm(input_)  # [B, T, Q, C]\n        if self.emb_ks == self.emb_hs:\n            intra_rnn = intra_rnn.view([B * T, -1,\n                                        self.emb_ks * C])  # [BT, Q//I, I*C]\n            intra_rnn, _ = self.intra_rnn(intra_rnn)  # [BT, Q//I, H]\n            intra_rnn = self.intra_linear(intra_rnn)  # [BT, Q//I, I*C]\n            intra_rnn = intra_rnn.view([B, T, Q, C])\n        else:\n            intra_rnn = intra_rnn.view([B * T, Q, C])  # [BT, Q, C]\n            intra_rnn = intra_rnn.transpose(1, 2)  # [BT, C, Q]\n            intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1),\n                                 stride=(self.emb_hs, 1))  # [BT, C*I, -1]\n            intra_rnn = intra_rnn.transpose(1, 2)  # [BT, -1, C*I]\n\n            intra_rnn, _ = self.intra_rnn(intra_rnn)  # [BT, -1, H]\n\n            intra_rnn = intra_rnn.transpose(1, 2)  # [BT, H, -1]\n            intra_rnn = self.intra_linear(intra_rnn)  # [BT, C, Q]\n            intra_rnn = intra_rnn.view([B, T, C, Q])\n            intra_rnn = intra_rnn.transpose(-2, -1)  # [B, T, Q, C]\n        intra_rnn = intra_rnn + input_  # [B, T, Q, C]\n\n        intra_rnn = intra_rnn.transpose(1, 2)  # [B, Q, T, C]\n\n        # inter RNN\n        input_ = intra_rnn\n        inter_rnn = self.inter_norm(input_)  # [B, Q, T, C]\n        if self.emb_ks == self.emb_hs:\n            inter_rnn = inter_rnn.view([B * Q, -1,\n                                        self.emb_ks * C])  # [BQ, T//I, I*C]\n            inter_rnn, _ = self.inter_rnn(inter_rnn)  # [BQ, T//I, H]\n            inter_rnn = self.inter_linear(inter_rnn)  # [BQ, T//I, I*C]\n            inter_rnn = inter_rnn.view([B, Q, T, C])\n        else:\n            inter_rnn = inter_rnn.view(B * Q, T, C)  # [BQ, T, C]\n            inter_rnn = inter_rnn.transpose(1, 2)  # [BQ, C, T]\n            inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1),\n                                 stride=(self.emb_hs, 1))  # [BQ, C*I, -1]\n            inter_rnn = inter_rnn.transpose(1, 2)  # [BQ, -1, C*I]\n\n            inter_rnn, _ = self.inter_rnn(inter_rnn)  # [BQ, -1, H]\n\n            inter_rnn = inter_rnn.transpose(1, 2)  # [BQ, H, -1]\n            inter_rnn = self.inter_linear(inter_rnn)  # [BQ, C, T]\n            inter_rnn = inter_rnn.view([B, Q, C, T])\n            inter_rnn = inter_rnn.transpose(-2, -1)  # [B, Q, T, C]\n        inter_rnn = inter_rnn + input_  # [B, Q, T, C]\n\n        inter_rnn = inter_rnn.permute(0, 3, 2, 1)  # [B, C, T, Q]\n\n        inter_rnn = inter_rnn[..., olp:olp + old_T, olp:olp + old_Q]\n        batch = inter_rnn\n\n        Q = self[\"attn_norm_Q\"](\n            self[\"attn_conv_Q\"](batch))  # [B, n_head, C, T, Q]\n        K = self[\"attn_norm_K\"](\n            self[\"attn_conv_K\"](batch))  # [B, n_head, C, T, Q]\n        V = self[\"attn_norm_V\"](\n            self[\"attn_conv_V\"](batch))  # [B, n_head, C, T, Q]\n        Q = Q.view(-1, *Q.shape[2:])  # [B*n_head, C, T, Q]\n        K = K.view(-1, *K.shape[2:])  # [B*n_head, C, T, Q]\n        V = V.view(-1, *V.shape[2:])  # [B*n_head, C, T, Q]\n\n        Q = Q.transpose(1, 2)\n        Q = Q.flatten(start_dim=2)  # [B', T, C*Q]\n\n        K = K.transpose(2, 3)\n        K = K.contiguous().view([B * self.n_head, -1, old_T])  # [B', C*Q, T]\n\n        V = V.transpose(1, 2)  # [B', T, C, Q]\n        old_shape = V.shape\n        V = V.flatten(start_dim=2)  # [B', T, C*Q]\n        emb_dim = Q.shape[-1]\n\n        attn_mat = torch.matmul(Q, K) / (emb_dim**0.5)  # [B', T, T]\n        attn_mat = F.softmax(attn_mat, dim=2)  # [B', T, T]\n        V = torch.matmul(attn_mat, V)  # [B', T, C*Q]\n\n        V = V.reshape(old_shape)  # [B', T, C, Q]\n        V = V.transpose(1, 2)  # [B', C, T, Q]\n        emb_dim = V.shape[1]\n\n        batch = V.contiguous().view([B, self.n_head * emb_dim, old_T,\n                                     old_Q])  # [B, C, T, Q])\n        batch = self[\"attn_concat_proj\"](batch)  # [B, C, T, Q])\n\n        out = batch + inter_rnn\n        return out\n\n\nclass LayerNormalization4DCF(nn.Module):\n\n    def __init__(self, input_dimension, eps=1e-5):\n        super().__init__()\n        assert len(input_dimension) == 2\n        param_size = [1, input_dimension[0], 1, input_dimension[1]]\n        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))\n        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))\n        init.ones_(self.gamma)\n        init.zeros_(self.beta)\n        self.eps = eps\n\n    def forward(self, x):\n        if x.ndim == 4:\n            stat_dim = (1, 3)\n        else:\n            raise ValueError(\n                \"Expect x to have 4 dimensions, but got {}\".format(x.ndim))\n        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,1]\n        std_ = torch.sqrt(\n            x.var(dim=stat_dim, unbiased=False, keepdim=True) +\n            self.eps)  # [B,1,T,F]\n        x_hat = ((x - mu_) / std_) * self.gamma + self.beta\n        return x_hat\n\n\nclass AllHeadPReLULayerNormalization4DCF(nn.Module):\n\n    def __init__(self, input_dimension, eps=1e-5):\n        super().__init__()\n        assert len(input_dimension) == 3\n        H, E, n_freqs = input_dimension\n        param_size = [1, H, E, 1, n_freqs]\n        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))\n        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))\n        init.ones_(self.gamma)\n        init.zeros_(self.beta)\n        self.act = nn.PReLU(num_parameters=H, init=0.25)\n        self.eps = eps\n        self.H = H\n        self.E = E\n        self.n_freqs = n_freqs\n\n    def forward(self, x):\n        assert x.ndim == 4\n        B, _, T, _ = x.shape\n        x = x.view([B, self.H, self.E, T, self.n_freqs])\n        x = self.act(x)  # [B,H,E,T,F]\n        stat_dim = (2, 4)\n        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,H,1,T,1]\n        std_ = torch.sqrt(\n            x.var(dim=stat_dim, unbiased=False, keepdim=True) +\n            self.eps)  # [B,H,1,T,1]\n        x = ((x - mu_) / std_) * self.gamma + self.beta  # [B,H,E,T,F]\n        return x\n"
  },
  {
    "path": "wesep/utils/abs_loss.py",
    "content": "from abc import ABC, abstractmethod\n\nimport torch\n\nEPS = torch.finfo(torch.get_default_dtype()).eps\n\n\nclass AbsEnhLoss(torch.nn.Module, ABC):\n    \"\"\"Base class for all Enhancement loss modules.\"\"\"\n\n    # the name will be the key that appears in the reporter\n    @property\n    def name(self) -> str:\n        return NotImplementedError\n\n    # This property specifies whether the criterion will only\n    # be evaluated during the inference stage\n    @property\n    def only_for_test(self) -> bool:\n        return False\n\n    @abstractmethod\n    def forward(\n        self,\n        ref,\n        inf,\n    ) -> torch.Tensor:\n        # the return tensor should be shape of (batch)\n        raise NotImplementedError\n"
  },
  {
    "path": "wesep/utils/checkpoint.py",
    "content": "from typing import List, Optional\r\n\r\nimport torch\r\n\r\nfrom wesep.utils.schedulers import BaseClass\r\n\r\n\r\ndef load_pretrained_model(model: torch.nn.Module,\r\n                          path: str,\r\n                          type: str = \"generator\"):\r\n    assert type in [\"generator\", \"discriminator\"]\r\n    states = torch.load(\r\n        path,\r\n        map_location=\"cpu\",\r\n    )\r\n    if type == \"generator\":\r\n        state = states[\"models\"][0]\r\n    else:\r\n        assert len(states[\"models\"]) == 2\r\n        state = states[\"models\"][1]\r\n\r\n    if isinstance(model, torch.nn.DataParallel):\r\n        model.module.load_state_dict(state)\r\n    elif isinstance(model, torch.nn.parallel.DistributedDataParallel):\r\n        model.module.load_state_dict(state)\r\n    else:\r\n        model.load_state_dict(state)\r\n\r\n\r\ndef load_checkpoint(\r\n    models: List[torch.nn.Module],\r\n    optimizers: List[torch.optim.Optimizer],\r\n    schedulers: List[BaseClass],\r\n    scaler: Optional[torch.cuda.amp.GradScaler],\r\n    path: str,\r\n    only_model: bool = False,\r\n    mode: str = \"all\",\r\n):\r\n    assert mode in [\"all\", \"generator\", \"discriminator\"]\r\n    states = torch.load(\r\n        path,\r\n        map_location=\"cpu\",\r\n    )\r\n    if mode == \"generator\":\r\n        model_state, optimizer_state, scheduler_state = (\r\n            [states[\"models\"][0]],\r\n            [states[\"optimizers\"][0]],\r\n            [states[\"schedulers\"][0]],\r\n        )\r\n    elif mode == \"discriminator\":\r\n        model_state, optimizer_state, scheduler_state = (\r\n            [states[\"models\"][1]],\r\n            [states[\"optimizers\"][1]],\r\n            [states[\"schedulers\"][1]],\r\n        )\r\n    else:\r\n        model_state, optimizer_state, scheduler_state = (\r\n            states[\"models\"],\r\n            states[\"optimizers\"],\r\n            states[\"schedulers\"],\r\n        )\r\n\r\n    for model, state in zip(models, model_state):\r\n        if isinstance(model, torch.nn.DataParallel):\r\n            model.module.load_state_dict(state, strict=False)\r\n        elif isinstance(model, torch.nn.parallel.DistributedDataParallel):\r\n            model.module.load_state_dict(state, strict=False)\r\n        else:\r\n            model.load_state_dict(state, strict=False)\r\n    if not only_model:\r\n        for optimizer, state in zip(optimizers, optimizer_state):\r\n            optimizer.load_state_dict(state)\r\n        for scheduler, state in zip(schedulers, scheduler_state):\r\n            if scheduler is not None:\r\n                scheduler.load_state_dict(state)\r\n        if scaler is not None:\r\n            if states[\"scaler\"] is not None:\r\n                scaler.load_state_dict(states[\"scaler\"])\r\n\r\n\r\ndef save_checkpoint(\r\n    models: List[torch.nn.Module],\r\n    optimizers: List[torch.optim.Optimizer],\r\n    schedulers: List[BaseClass],\r\n    scaler: Optional[torch.cuda.amp.GradScaler],\r\n    path: str,\r\n):\r\n    if isinstance(models[0], torch.nn.DataParallel):\r\n        state_dict = [model.module.state_dict() for model in models]\r\n    elif isinstance(models[0], torch.nn.parallel.DistributedDataParallel):\r\n        state_dict = [model.module.state_dict() for model in models]\r\n    else:\r\n        state_dict = [model.state_dict() for model in models]\r\n    torch.save(\r\n        {\r\n            \"models\":\r\n            state_dict,\r\n            \"optimizers\": [o.state_dict() for o in optimizers],\r\n            \"schedulers\":\r\n            [s.state_dict() if s is not None else None for s in schedulers],\r\n            \"scaler\":\r\n            scaler.state_dict() if scaler is not None else None,\r\n        },\r\n        path,\r\n    )\r\n"
  },
  {
    "path": "wesep/utils/datadir_writer.py",
    "content": "import warnings\nfrom pathlib import Path\nfrom typing import Union\n\n\n# ported from\n# https://github.com/espnet/espnet/blob/master/espnet2/fileio/datadir_writer.py\nclass DatadirWriter:\n    \"\"\"Writer class to create kaldi like data directory.\n\n    Examples:\n        >>> with DatadirWriter(\"output\") as writer:\n        ...     # output/sub.txt is created here\n        ...     subwriter = writer[\"sub.txt\"]\n        ...     # Write \"uttidA some/where/a.wav\"\n        ...     subwriter[\"uttidA\"] = \"some/where/a.wav\"\n        ...     subwriter[\"uttidB\"] = \"some/where/b.wav\"\n\n    \"\"\"\n\n    def __init__(self, p: Union[Path, str]):\n        self.path = Path(p)\n        self.chilidren = {}\n        self.fd = None\n        self.has_children = False\n        self.keys = set()\n\n    def __enter__(self):\n        return self\n\n    def __getitem__(self, key: str) -> \"DatadirWriter\":\n        if self.fd is not None:\n            raise RuntimeError(\"This writer points out a file\")\n\n        if key not in self.chilidren:\n            w = DatadirWriter((self.path / key))\n            self.chilidren[key] = w\n            self.has_children = True\n\n        retval = self.chilidren[key]\n        return retval\n\n    def __setitem__(self, key: str, value: str):\n        if self.has_children:\n            raise RuntimeError(\"This writer points out a directory\")\n        if key in self.keys:\n            warnings.warn(f\"Duplicated: {key}\", stacklevel=1)\n\n        if self.fd is None:\n            self.path.parent.mkdir(parents=True, exist_ok=True)\n            self.fd = self.path.open(\"w\", encoding=\"utf-8\")\n\n        self.keys.add(key)\n        self.fd.write(f\"{key} {value}\\n\")\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.close()\n\n    def close(self):\n        if self.has_children:\n            prev_child = None\n            for child in self.chilidren.values():\n                child.close()\n                if prev_child is not None and prev_child.keys != child.keys:\n                    warnings.warn(\n                        f\"Ids are mismatching between \"\n                        f\"{prev_child.path} and {child.path}\",\n                        stacklevel=1)\n                prev_child = child\n\n        elif self.fd is not None:\n            self.fd.close()\n"
  },
  {
    "path": "wesep/utils/dnsmos.py",
    "content": "import json\nimport math\n\nimport librosa\nimport numpy as np\nimport requests\nimport torch\nimport torchaudio\n\nSAMPLING_RATE = 16000\nINPUT_LENGTH = 9.01\n# URL for the web service\nSCORING_URI_DNSMOS = \"https://dnsmos.azurewebsites.net/score\"\nSCORING_URI_DNSMOS_P835 = (\n    \"https://dnsmos.azurewebsites.net/v1/dnsmosp835/score\")\n\n\ndef poly1d(coefficients, use_numpy=False):\n    if use_numpy:\n        return np.poly1d(coefficients)\n    coefficients = tuple(reversed(coefficients))\n\n    def func(p):\n        return sum(coef * p**i for i, coef in enumerate(coefficients))\n\n    return func\n\n\nclass DNSMOS_web:\n    # ported from\n    # https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos.py\n    def __init__(self, auth_key):\n        self.auth_key = auth_key\n\n    def __call__(self, aud, input_fs, fname=\"\", method=\"p808\"):\n        if input_fs != SAMPLING_RATE:\n            audio = librosa.resample(aud,\n                                     orig_sr=input_fs,\n                                     target_sr=SAMPLING_RATE)\n        else:\n            audio = aud\n\n        # Set the content type\n        headers = {\"Content-Type\": \"application/json\"}\n        # If authentication is enabled, set the authorization header\n        headers[\"Authorization\"] = f\"Basic {self.auth_key}\"\n        fname = fname + \".wav\" if fname else \"audio.wav\"\n        data = {\"data\": audio.tolist(), \"filename\": fname}\n        input_data = json.dumps(data)\n        # Make the request and display the response\n        if method == \"p808\":\n            u = SCORING_URI_DNSMOS\n        else:\n            u = SCORING_URI_DNSMOS_P835\n        resp = requests.post(u, data=input_data, headers=headers)\n        score_dict = resp.json()\n        return score_dict\n\n\nclass DNSMOS_local:\n    # ported from\n    # https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos_local.py\n    def __init__(\n        self,\n        primary_model_path,\n        p808_model_path,\n        use_gpu=False,\n        convert_to_torch=False,\n        gpu_device=None,\n    ):\n        self.convert_to_torch = convert_to_torch\n        self.use_gpu = use_gpu\n        self.gpu_device = gpu_device\n\n        if convert_to_torch:\n            try:\n                from onnx2torch import convert\n            except ModuleNotFoundError:\n                raise RuntimeError(\n                    \"Please install onnx2torch manually and retry!\") from None\n\n            if primary_model_path is not None:\n                self.primary_model = convert(primary_model_path).eval()\n                self.p808_model = convert(p808_model_path).eval()\n            self.spectrogram = torchaudio.transforms.Spectrogram(\n                n_fft=321, hop_length=160, pad_mode=\"constant\")\n\n            self.to_db = torchaudio.transforms.AmplitudeToDB(\"power\",\n                                                             top_db=80.0)\n            if use_gpu:\n                if gpu_device is not None:\n                    torch.cuda.set_device(gpu_device)\n                if primary_model_path is not None:\n                    self.primary_model = self.primary_model.cuda()\n                    self.p808_model = self.p808_model.cuda()\n                self.spectrogram = self.spectrogram.cuda()\n        else:\n            try:\n                import onnxruntime as ort\n            except ModuleNotFoundError:\n                raise RuntimeError(\n                    \"Please install onnxruntime manually and retry!\") from None\n\n            prvd = (\"CUDAExecutionProvider\"\n                    if use_gpu else \"CPUExecutionProvider\")\n            if primary_model_path is not None:\n                self.onnx_sess = ort.InferenceSession(primary_model_path,\n                                                      providers=[prvd])\n                self.p808_onnx_sess = ort.InferenceSession(p808_model_path,\n                                                           providers=[prvd])\n                if self.gpu_device is not None:\n                    self.onnx_sess.set_providers([prvd],\n                                                 [{\n                                                     \"device_id\": gpu_device\n                                                 }])\n                    self.p808_onnx_sess.set_providers(\n                        [prvd], [{\n                            \"device_id\": gpu_device\n                        }])\n\n    def audio_melspec(\n        self,\n        audio,\n        n_mels=120,\n        frame_size=320,\n        hop_length=160,\n        sr=16000,\n        to_db=True,\n    ):\n        if self.convert_to_torch:\n            specgram = self.spectrogram(audio)\n            fb = torch.as_tensor(\n                librosa.filters.mel(sr=sr, n_fft=frame_size + 1,\n                                    n_mels=n_mels).T,\n                dtype=audio.dtype,\n                device=audio.device,\n            )\n            mel_spec = torch.matmul(specgram.transpose(-1, -2),\n                                    fb).transpose(-1, -2)\n            if to_db:\n                self.to_db.db_multiplier = math.log10(\n                    max(self.to_db.amin, torch.max(mel_spec)))\n                mel_spec = (self.to_db(mel_spec) + 40) / 40\n        else:\n            mel_spec = librosa.feature.melspectrogram(\n                y=audio,\n                sr=sr,\n                n_fft=frame_size + 1,\n                hop_length=hop_length,\n                n_mels=n_mels,\n            )\n            if to_db:\n                mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) +\n                            40) / 40\n        return mel_spec.T\n\n    def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):\n        flag = not self.convert_to_torch\n        if is_personalized_MOS:\n            p_ovr = poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046],\n                           flag)\n            p_sig = poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726],\n                           flag)\n            p_bak = poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132],\n                           flag)\n        else:\n            p_ovr = poly1d([-0.06766283, 1.11546468, 0.04602535], flag)\n            p_sig = poly1d([-0.08397278, 1.22083953, 0.0052439], flag)\n            p_bak = poly1d([-0.13166888, 1.60915514, -0.39604546], flag)\n\n        sig_poly = p_sig(sig)\n        bak_poly = p_bak(bak)\n        ovr_poly = p_ovr(ovr)\n\n        return sig_poly, bak_poly, ovr_poly\n\n    def __call__(self, aud, input_fs, is_personalized_MOS=False):\n        if self.convert_to_torch:\n            if self.use_gpu:\n                if self.gpu_device is not None:\n                    device = f\"cuda:{self.gpu_device}\"\n                else:\n                    device = \"cuda\"\n            else:\n                device = \"cpu\"\n            if isinstance(aud, torch.Tensor):\n                aud = aud.to(device=device)\n            else:\n                aud = torch.as_tensor(aud, dtype=torch.float32, device=device)\n        else:\n            aud = (aud.cpu().detach().numpy()\n                   if isinstance(aud, torch.Tensor) else aud)\n        if input_fs != SAMPLING_RATE:\n            if self.convert_to_torch:\n                audio = torch.as_tensor(\n                    librosa.resample(\n                        aud.detach().cpu().numpy(),\n                        orig_sr=input_fs,\n                        target_sr=SAMPLING_RATE,\n                    ),\n                    dtype=aud.dtype,\n                    device=aud.device,\n                )\n            else:\n                audio = librosa.resample(aud,\n                                         orig_sr=input_fs,\n                                         target_sr=SAMPLING_RATE)\n        else:\n            audio = aud\n        len_samples = int(INPUT_LENGTH * SAMPLING_RATE)\n        while len(audio) < len_samples:\n            if self.convert_to_torch:\n                audio = torch.cat((audio, audio))\n            else:\n                audio = np.append(audio, audio)\n\n        num_hops = int(np.floor(len(audio) / SAMPLING_RATE) - INPUT_LENGTH) + 1\n        hop_len_samples = SAMPLING_RATE\n        predicted_mos_sig_seg_raw = []\n        predicted_mos_bak_seg_raw = []\n        predicted_mos_ovr_seg_raw = []\n        predicted_mos_sig_seg = []\n        predicted_mos_bak_seg = []\n        predicted_mos_ovr_seg = []\n        predicted_p808_mos = []\n\n        for idx in range(num_hops):\n            audio_seg = audio[int(idx *\n                                  hop_len_samples):int((idx + INPUT_LENGTH) *\n                                                       hop_len_samples)]\n            if len(audio_seg) < len_samples:\n                continue\n\n            if self.convert_to_torch:\n                input_features = audio_seg.float()[None, :]\n                p808_input_features = self.audio_melspec(\n                    audio=audio_seg[:-160]).float()[None, :, :]\n                p808_mos = self.p808_model(p808_input_features)\n                mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.primary_model(\n                    input_features)[0]\n            else:\n                input_features = np.array(audio_seg).astype(\"float32\")[\n                    np.newaxis, :]\n                p808_input_features = np.array(\n                    self.audio_melspec(audio=audio_seg[:-160])).astype(\n                        \"float32\")[np.newaxis, :, :]\n                p808_mos = self.p808_onnx_sess.run(\n                    None, {\"input_1\": p808_input_features})[0][0][0]\n                mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(\n                    None, {\"input_1\": input_features})[0][0]\n            mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(\n                mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS)\n            predicted_mos_sig_seg_raw.append(mos_sig_raw)\n            predicted_mos_bak_seg_raw.append(mos_bak_raw)\n            predicted_mos_ovr_seg_raw.append(mos_ovr_raw)\n            predicted_mos_sig_seg.append(mos_sig)\n            predicted_mos_bak_seg.append(mos_bak)\n            predicted_mos_ovr_seg.append(mos_ovr)\n            predicted_p808_mos.append(p808_mos)\n\n        to_array = torch.stack if self.convert_to_torch else np.array\n        return {\n            \"OVRL_raw\": to_array(predicted_mos_ovr_seg_raw).mean(),\n            \"SIG_raw\": to_array(predicted_mos_sig_seg_raw).mean(),\n            \"BAK_raw\": to_array(predicted_mos_bak_seg_raw).mean(),\n            \"OVRL\": to_array(predicted_mos_ovr_seg).mean(),\n            \"SIG\": to_array(predicted_mos_sig_seg).mean(),\n            \"BAK\": to_array(predicted_mos_bak_seg).mean(),\n            \"P808_MOS\": to_array(predicted_p808_mos).mean(),\n        }\n"
  },
  {
    "path": "wesep/utils/executor.py",
    "content": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\r\n#               2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#   http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nfrom contextlib import nullcontext\r\n\r\nimport tableprint as tp\r\n\r\n# if your python version < 3.7 use the below one\r\nimport torch\r\n\r\nfrom wesep.utils.funcs import clip_gradients, compute_fbank, apply_cmvn\r\nimport random\r\n\r\n\r\nclass Executor:\r\n\r\n    def __init__(self):\r\n        self.step = 0\r\n\r\n    def train(\r\n            self,\r\n            dataloader,\r\n            models,\r\n            epoch_iter,\r\n            optimizers,\r\n            criterion,\r\n            schedulers,\r\n            scaler,\r\n            epoch,\r\n            enable_amp,\r\n            logger,\r\n            clip_grad=5.0,\r\n            log_batch_interval=100,\r\n            device=torch.device(\"cuda\"),\r\n            se_loss_weight=1.0,\r\n            multi_task=False,\r\n            SSA_enroll_prob=0,\r\n            fbank_args=None,\r\n            sample_rate=16000,\r\n            speaker_feat=True\r\n    ):\r\n        \"\"\"Train one epoch\"\"\"\r\n        model = models[0]\r\n        optimizer = optimizers[0]\r\n        scheduler = schedulers[0]\r\n\r\n        model.train()\r\n        log_interval = log_batch_interval\r\n        accum_grad = 1\r\n        losses = []\r\n\r\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel):\r\n            model_context = model.join\r\n        else:\r\n            model_context = nullcontext\r\n\r\n        with model_context():\r\n            for i, batch in enumerate(dataloader):\r\n                features = batch[\"wav_mix\"]\r\n                targets = batch[\"wav_targets\"]\r\n                # embeddings when not joint training, enrollment wavforms\r\n                # when joint training\r\n                enroll = batch[\"spk_embeds\"]\r\n                # spk_lable is an empty list when not joint training\r\n                # and multi-task\r\n                spk_label = batch[\"spk_label\"]\r\n\r\n                cur_iter = (epoch - 1) * epoch_iter + i\r\n                scheduler.step(cur_iter)\r\n\r\n                features = features.float().to(device)  # (B,T,F)\r\n                targets = targets.float().to(device)\r\n                enroll = enroll.float().to(device)\r\n                spk_label = spk_label.to(device)\r\n\r\n                with torch.cuda.amp.autocast(enabled=enable_amp):\r\n                    if SSA_enroll_prob > 0:\r\n                        if SSA_enroll_prob > random.random():\r\n                            with torch.no_grad():\r\n                                outputs = model(features, enroll)\r\n                                est_speech = outputs[0]\r\n                                self_fbank = est_speech\r\n                                if fbank_args is not None and speaker_feat:\r\n                                    self_fbank = compute_fbank(\r\n                                        est_speech, **fbank_args,\r\n                                        sample_rate=sample_rate)\r\n                                    self_fbank = apply_cmvn(self_fbank)\r\n                            outputs = model(features, self_fbank)\r\n                        else:\r\n                            outputs = model(features, enroll)\r\n                    else:\r\n                        outputs = model(features, enroll)\r\n                    if not isinstance(outputs, (list, tuple)):\r\n                        outputs = [outputs]\r\n                    loss = 0\r\n                    for ii in range(len(criterion)):\r\n                        # se_loss_weight: ([position in outputs[0], [1]],\r\n                        #                 [weights:[1.0], [0.5]])\r\n                        for ji in range(len(se_loss_weight[0][ii])):\r\n                            if (multi_task and criterion[ii].__class__.__name__\r\n                                    == \"CrossEntropyLoss\"):\r\n                                loss += se_loss_weight[1][ii][ji] * (\r\n                                    criterion[ii](\r\n                                        outputs[se_loss_weight[0][ii][ji]],\r\n                                        spk_label,\r\n                                    ).mean() / accum_grad)\r\n                                continue\r\n                            loss += se_loss_weight[1][ii][ji] * (criterion[ii](\r\n                                outputs[se_loss_weight[0][ii][ji]],\r\n                                targets).mean() / accum_grad)\r\n\r\n                losses.append(loss.item())\r\n                total_loss_avg = sum(losses) / len(losses)\r\n\r\n                # updata the model\r\n                optimizer.zero_grad()\r\n                # scaler does nothing here if enable_amp=False\r\n                scaler.scale(loss).backward()\r\n                scaler.unscale_(optimizer)\r\n                clip_gradients(model, clip_grad)\r\n                scaler.step(optimizer)\r\n                scaler.update()\r\n\r\n                if (i + 1) % log_interval == 0:\r\n                    logger.info(\r\n                        tp.row(\r\n                            (\r\n                                \"TRAIN\",\r\n                                epoch,\r\n                                i + 1,\r\n                                total_loss_avg * accum_grad,\r\n                                optimizer.param_groups[0][\"lr\"],\r\n                            ),\r\n                            width=10,\r\n                            style=\"grid\",\r\n                        ))\r\n                if (i + 1) == epoch_iter:\r\n                    break\r\n            total_loss_avg = sum(losses) / len(losses)\r\n            return total_loss_avg, 0\r\n\r\n    def cv(\r\n            self,\r\n            dataloader,\r\n            models,\r\n            val_iter,\r\n            criterion,\r\n            epoch,\r\n            enable_amp,\r\n            logger,\r\n            log_batch_interval=100,\r\n            device=torch.device(\"cuda\"),\r\n    ):\r\n        \"\"\"Cross validation on\"\"\"\r\n        model = models[0]\r\n\r\n        model.eval()\r\n        log_interval = log_batch_interval\r\n        losses = []\r\n\r\n        with torch.no_grad():\r\n            for i, batch in enumerate(dataloader):\r\n                features = batch[\"wav_mix\"]\r\n                targets = batch[\"wav_targets\"]\r\n                enroll = batch[\"spk_embeds\"]\r\n\r\n                features = features.float().to(device)  # (B,T,F)\r\n                targets = targets.float().to(device)\r\n                enroll = enroll.float().to(device)\r\n\r\n                with torch.cuda.amp.autocast(enabled=enable_amp):\r\n                    outputs = model(features, enroll)\r\n                    if not isinstance(outputs, (list, tuple)):\r\n                        outputs = [outputs]\r\n                    # By default, the first loss is used as the indicator\r\n                    # of the validation set.\r\n                    loss = criterion[0](outputs[0], targets).mean()\r\n\r\n                losses.append(loss.item())\r\n                total_loss_avg = sum(losses) / len(losses)\r\n\r\n                if (i + 1) % log_interval == 0:\r\n                    logger.info(\r\n                        tp.row(\r\n                            (\"VAL\", epoch, i + 1, total_loss_avg, \"-\"),\r\n                            width=10,\r\n                            style=\"grid\",\r\n                        ))\r\n                if (i + 1) == val_iter:\r\n                    break\r\n        return total_loss_avg, 0\r\n"
  },
  {
    "path": "wesep/utils/executor_gan.py",
    "content": "# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)\n#               2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom contextlib import nullcontext\n\nimport tableprint as tp\n\n# if your python version < 3.7 use the below one\nimport torch\nimport torch.nn.functional as F\n\nfrom wesep.utils.funcs import clip_gradients\nfrom wesep.utils.score import batch_evaluation, cal_PESQ_norm\n\n\nclass ExecutorGAN:\n\n    def __init__(self):\n        self.step = 0\n\n    def train(\n        self,\n        dataloader,\n        models,\n        epoch_iter,\n        optimizers,\n        criterion,\n        schedulers,\n        scaler,\n        epoch,\n        enable_amp,\n        logger,\n        clip_grad=5.0,\n        log_batch_interval=100,\n        device=torch.device(\"cuda\"),\n        se_loss_weight=0.95,\n        gan_loss_weight=0.05,\n        multi_task=False,\n    ):\n        \"\"\"Train one epoch\"\"\"\n        assert (len(models) == len(optimizers) == len(schedulers) ==\n                2), \"Currently only support one discriminator\"\n        model, discriminator = models\n        optimizer, optimizer_dis = optimizers\n        scheduler, scheduler_dis = schedulers\n\n        model.train()\n        discriminator.train()\n        log_interval = log_batch_interval\n        accum_grad = 1\n        losses = []\n        se_losses = []\n        dis_losses = []\n\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel):\n            model_context = model.join\n        else:\n            model_context = nullcontext\n\n        with model_context():\n            for i, batch in enumerate(dataloader):\n                features = batch[\"wav_mix\"]\n                targets = batch[\"wav_targets\"]\n                # embeddings when when not joint training, enrollment\n                # wavforms when joint training\n                enroll = batch[\"spk_embeds\"]\n                # spk_lable is an empty list when not joint training\n                # and multi-task\n                spk_label = batch[\"spk_label\"]\n                one_labels = torch.ones(features.size(0))\n\n                cur_iter = (epoch - 1) * epoch_iter + i\n                scheduler.step(cur_iter)\n                scheduler_dis.step(cur_iter)\n\n                features = features.float().to(device)\n                targets = targets.float().to(device)\n                enroll = enroll.float().to(device)\n                spk_label = spk_label.to(device)\n                one_labels = one_labels.float().to(device)\n\n                # calculate discriminator loss\n                with torch.cuda.amp.autocast(enabled=enable_amp):\n                    outputs = model(features, enroll)\n                    if not isinstance(outputs, (list, tuple)):\n                        outputs = [outputs]\n                    # outputs is a list of tensors, each tensor has shape\n                    # (Batch, samples)\n                    if multi_task:\n                        # remove the predicted spk_label from the outputs list\n                        enhanced_wavs = torch.stack(outputs[:-1], dim=0)\n                    else:\n                        # enhanced_wavs: [N, Batch, samples], N is the number\n                        # of output of the model\n                        enhanced_wavs = torch.stack(outputs, dim=0)\n                    d_loss = self._calculate_discriminator_loss(\n                        discriminator,\n                        targets,\n                        enhanced_wavs.detach(),\n                        features.detach(),\n                    )\n\n                dis_losses.append(d_loss.item())\n                total_dis_loss_avg = sum(dis_losses) / len(dis_losses)\n                # updata discriminator\n                optimizer_dis.zero_grad()\n                # scaler does nothing here if enable_amp=False\n                scaler.scale(d_loss).backward()\n                scaler.unscale_(optimizer_dis)\n                clip_gradients(discriminator, clip_grad)\n                scaler.step(optimizer_dis)\n                scaler.update()\n\n                # calculate generator loss\n                with torch.cuda.amp.autocast(enabled=enable_amp):\n                    se_loss = 0\n                    for ii in range(len(criterion)):\n                        # se_loss_weight[0]: 2-D array,loss_posi;\n                        # se_loss_weight[1]: 2-D array,loss_weight.\n                        for ji in range(len(se_loss_weight[0][ii])):\n                            if multi_task and ii == (len(criterion) - 1):\n                                se_loss += se_loss_weight[1][ii][ji] * (\n                                    criterion[ii](\n                                        outputs[se_loss_weight[0][ii][ji]],\n                                        spk_label,\n                                    ).mean() / accum_grad)\n                                continue\n                            se_loss += se_loss_weight[1][ii][ji] * (\n                                criterion[ii]\n                                (outputs[se_loss_weight[0][ii][ji]],\n                                 targets).mean() / accum_grad)\n                    gan_loss = 0\n                    len_output = (len(outputs) -\n                                  1 if multi_task else len(outputs))\n                    for j in range(len_output):\n                        enhanced_fake_metric = discriminator(\n                            targets, outputs[j])\n                        gan_loss += F.mse_loss(\n                            enhanced_fake_metric.flatten(),\n                            one_labels,\n                        )\n                    g_loss = se_loss + gan_loss_weight * gan_loss\n\n                losses.append(g_loss.item())\n                se_losses.append(se_loss.item())\n                total_loss_avg = sum(losses) / len(losses)\n                total_se_loss_avg = sum(se_losses) / len(se_losses)\n\n                # updata the generator\n                optimizer.zero_grad()\n                # scaler does nothing here if enable_amp=False\n                scaler.scale(g_loss).backward()\n                scaler.unscale_(optimizer)\n                clip_gradients(model, clip_grad)\n                scaler.step(optimizer)\n                scaler.update()\n\n                if (i + 1) % log_interval == 0:\n                    logger.info(\n                        tp.row(\n                            (\n                                \"TRAIN\",\n                                epoch,\n                                i + 1,\n                                total_se_loss_avg,\n                                total_loss_avg * accum_grad,\n                                total_dis_loss_avg * accum_grad,\n                                optimizer.param_groups[0][\"lr\"],\n                            ),\n                            width=10,\n                            style=\"grid\",\n                        ))\n                if (i + 1) == epoch_iter:\n                    break\n            total_loss_avg = sum(losses) / len(losses)\n            total_dis_loss_avg = sum(dis_losses) / len(dis_losses)\n            return total_loss_avg, total_dis_loss_avg\n\n    def cv(\n            self,\n            dataloader,\n            models,\n            val_iter,\n            criterion,\n            epoch,\n            enable_amp,\n            logger,\n            log_batch_interval=100,\n            device=torch.device(\"cuda\"),\n    ):\n        \"\"\"Cross validation on\"\"\"\n        assert len(models) == 2, \"Currently only support one discriminator\"\n        model, discriminator = models\n        model.eval()\n        discriminator.eval()\n        log_interval = log_batch_interval\n        losses = []\n        se_losses = []\n        dis_losses = []\n\n        with torch.no_grad():\n            for i, batch in enumerate(dataloader):\n                features = batch[\"wav_mix\"]\n                targets = batch[\"wav_targets\"]\n                enroll = batch[\"spk_embeds\"]\n                one_labels = torch.ones(features.size(0))\n\n                features = features.float().to(device)  # (B,T,F)\n                targets = targets.float().to(device)\n                enroll = enroll.float().to(device)\n                one_labels = one_labels.float().to(device)\n\n                with torch.cuda.amp.autocast(enabled=enable_amp):\n                    outputs = model(features, enroll)\n                    if not isinstance(outputs, (list, tuple)):\n                        outputs = [outputs]\n                    # calculate discriminator loss\n                    d_loss = self._calculate_discriminator_loss(\n                        discriminator,\n                        targets,\n                        outputs[0].unsqueeze(0),\n                        features,\n                    )\n\n                dis_losses.append(d_loss.item())\n                total_dis_loss_avg = sum(dis_losses) / len(dis_losses)\n\n                # calculate generator loss\n                with torch.cuda.amp.autocast(enabled=enable_amp):\n                    se_loss = criterion[0](outputs[0], targets).mean()\n                    enhanced_fake_metric = discriminator(targets, outputs[0])\n                    gan_loss = F.mse_loss(\n                        enhanced_fake_metric.flatten(),\n                        one_labels,\n                    )\n                    g_loss = se_loss + gan_loss\n\n                losses.append(g_loss.item())\n                se_losses.append(se_loss.item())\n                total_loss_avg = sum(losses) / len(losses)\n                total_se_loss_avg = sum(se_losses) / len(se_losses)\n\n                if (i + 1) % log_interval == 0:\n                    logger.info(\n                        tp.row(\n                            (\n                                \"VAL\",\n                                epoch,\n                                i + 1,\n                                total_se_loss_avg,\n                                total_loss_avg,\n                                total_dis_loss_avg,\n                                \"-\",\n                            ),\n                            width=10,\n                            style=\"grid\",\n                        ))\n                if (i + 1) == val_iter:\n                    break\n        return total_loss_avg, total_dis_loss_avg\n\n    def mse_loss(self, output, target):\n        return F.mse_loss(output.flatten(), target)\n\n    def _calculate_discriminator_loss(\n        self,\n        discriminator,\n        clean_wavs,\n        enhanced_wavs,\n        noisy_wavs,\n    ):\n        \"\"\"Calculate the discriminator loss\n\n        Args:\n            discriminator (torch.nn.Module): the discriminator model\n            clean_wavs (torch.Tensor): the clean waveforms, [Batch, samples]\n            enhanced_wavs (torch.Tensor): the predicted waveforms,\n                                          [N, Batch, samples]\n            noisy_wavs (torch.Tensor): the noisy waveforms, [Batch, samples]\n\n        Returns:\n            torch.Tensor: the discriminator loss\n        \"\"\"\n\n        def calculate_mse_loss(output, target):\n            if target is not None:\n                target = torch.FloatTensor(target).to(device)\n                return self.mse_loss(output, target)\n            return 0\n\n        device = clean_wavs.device\n        one_labels = torch.ones(clean_wavs.size(0)).float().to(device)\n\n        noisy_fake_metric = discriminator(clean_wavs, noisy_wavs)\n        clean_fake_metric = discriminator(clean_wavs, clean_wavs)\n\n        audio_ref = clean_wavs.detach().cpu().numpy()\n        audio_noisy = noisy_wavs.detach().cpu().numpy()\n\n        noisy_real_metric = batch_evaluation(cal_PESQ_norm,\n                                             audio_noisy,\n                                             audio_ref,\n                                             parallel=False)\n\n        loss_d_clean = self.mse_loss(clean_fake_metric, one_labels)\n        loss_d_noisy = calculate_mse_loss(noisy_fake_metric, noisy_real_metric)\n        d_loss = loss_d_clean + loss_d_noisy\n\n        # unbind enhanced_wavs to get a list of tensors,\n        # each tensor has shape (Batch, samples)\n        enhanced_wavs = torch.unbind(enhanced_wavs, dim=0)\n\n        for enhanced_wav in enhanced_wavs:\n            enhanced_fake_metric = discriminator(clean_wavs, enhanced_wav)\n            audio_est = enhanced_wav.detach().cpu().numpy()\n\n            enhanced_real_metric = batch_evaluation(cal_PESQ_norm,\n                                                    audio_est,\n                                                    audio_ref,\n                                                    parallel=False)\n\n            loss_d_enhanced = calculate_mse_loss(enhanced_fake_metric,\n                                                 enhanced_real_metric)\n\n            d_loss += loss_d_enhanced\n\n        return d_loss\n"
  },
  {
    "path": "wesep/utils/file_utils.py",
    "content": "import collections\r\nimport math\r\nfrom pathlib import Path\r\nfrom typing import Dict, List, Optional, Tuple, Union\r\n\r\nimport kaldiio\r\nimport numpy as np\r\nimport soundfile\r\n\r\n\r\ndef read_lists(list_file):\r\n    \"\"\"list_file: only 1 column\"\"\"\r\n    lists = []\r\n    with open(list_file, \"r\", encoding=\"utf8\") as fin:\r\n        for line in fin:\r\n            lists.append(line.strip())\r\n    return lists\r\n\r\n\r\ndef read_vec_scp_file(scp_file):\r\n    \"\"\"\r\n    Read the pre-extracted kaldi-format speaker embeddings.\r\n    :param scp_file: path to xvector.scp\r\n    :return: dict {wav_name: embedding}\r\n    \"\"\"\r\n    samples_dict = {}\r\n    for key, vec in kaldiio.load_scp_sequential(scp_file):\r\n        if len(vec.shape) == 1:\r\n            vec = np.expand_dims(vec, 0)\r\n        samples_dict[key] = vec\r\n\r\n    return samples_dict\r\n\r\n\r\ndef norm_embeddings(embeddings, kaldi_style=True):\r\n    \"\"\"\r\n    Norm embeddings to unit length\r\n    :param embeddings: input embeddings\r\n    :param kaldi_style: if true, the norm should be embedding dimension\r\n    :return:\r\n    \"\"\"\r\n    scale = math.sqrt(embeddings.shape[-1]) if kaldi_style else 1.0\r\n    if len(embeddings.shape) == 2:\r\n        return (scale * embeddings.transpose() /\r\n                np.linalg.norm(embeddings, axis=1)).transpose()\r\n    elif len(embeddings.shape) == 1:\r\n        return scale * embeddings / np.linalg.norm(embeddings)\r\n\r\n\r\ndef read_label_file(label_file):\r\n    \"\"\"\r\n    Read the utt2spk file\r\n    :param label_file: the path to utt2spk\r\n    :return: dict {wav_name: spk_id}\r\n    \"\"\"\r\n    labels_dict = {}\r\n    with open(label_file, \"r\") as fin:\r\n        for line in fin:\r\n            tokens = line.strip().split()\r\n            labels_dict[tokens[0]] = tokens[1]\r\n    return labels_dict\r\n\r\n\r\ndef load_speaker_embeddings(scp_file, utt2spk_file):\r\n    \"\"\"\r\n    :param scp_file:\r\n    :param utt2spk_file:\r\n    :return: {spk1: [emb1, emb2 ...], spk2: [emb1, emb2...]}\r\n    \"\"\"\r\n    samples_dict = read_vec_scp_file(scp_file)\r\n    labels_dict = read_label_file(utt2spk_file)\r\n    spk2embeds = {}\r\n    for key, vec in samples_dict.items():\r\n        if len(vec.shape) == 1:\r\n            vec = np.expand_dims(vec, 0)\r\n        label = labels_dict[key]\r\n        if label in spk2embeds.keys():\r\n            spk2embeds[label].append(vec)\r\n        else:\r\n            spk2embeds[label] = [vec]\r\n    return spk2embeds\r\n\r\n\r\n# ported from\r\n# https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py\r\ndef read_2columns_text(path: Union[Path, str]) -> Dict[str, str]:\r\n    \"\"\"Read a text file having 2 columns as dict object.\r\n\r\n    Examples:\r\n        wav.scp:\r\n            key1 /some/path/a.wav\r\n            key2 /some/path/b.wav\r\n\r\n        >>> read_2columns_text('wav.scp')\r\n        {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}\r\n\r\n    \"\"\"\r\n\r\n    data = {}\r\n    with Path(path).open(\"r\", encoding=\"utf-8\") as f:\r\n        for linenum, line in enumerate(f, 1):\r\n            sps = line.rstrip().split(maxsplit=1)\r\n            if len(sps) == 1:\r\n                k, v = sps[0], \"\"\r\n            else:\r\n                k, v = sps\r\n\r\n            if k in data:\r\n                raise RuntimeError(f\"{k} is duplicated ({path}:{linenum})\")\r\n            data[k] = v\r\n    return data\r\n\r\n\r\n# ported from\r\n# https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py\r\ndef read_multi_columns_text(\r\n    path: Union[Path, str],\r\n    return_unsplit: bool = False\r\n) -> Tuple[Dict[str, List[str]], Optional[Dict[str, str]]]:\r\n    \"\"\"Read a text file having 2 or more columns as dict object.\r\n\r\n    Examples:\r\n        wav.scp:\r\n            key1 /some/path/a1.wav /some/path/a2.wav\r\n            key2 /some/path/b1.wav /some/path/b2.wav  /some/path/b3.wav\r\n            key3 /some/path/c1.wav\r\n            ...\r\n\r\n        >>> read_multi_columns_text('wav.scp')\r\n        {'key1': ['/some/path/a1.wav', '/some/path/a2.wav'],\r\n         'key2': ['/some/path/b1.wav', '/some/path/b2.wav',\r\n                  '/some/path/b3.wav'],\r\n         'key3': ['/some/path/c1.wav']}\r\n\r\n    \"\"\"\r\n\r\n    data = {}\r\n\r\n    if return_unsplit:\r\n        unsplit_data = {}\r\n    else:\r\n        unsplit_data = None\r\n\r\n    with Path(path).open(\"r\", encoding=\"utf-8\") as f:\r\n        for linenum, line in enumerate(f, 1):\r\n            sps = line.rstrip().split(maxsplit=1)\r\n            if len(sps) == 1:\r\n                k, v = sps[0], \"\"\r\n            else:\r\n                k, v = sps\r\n\r\n            if k in data:\r\n                raise RuntimeError(f\"{k} is duplicated ({path}:{linenum})\")\r\n\r\n            data[k] = v.split() if v != \"\" else [\"\"]\r\n            if return_unsplit:\r\n                unsplit_data[k] = v\r\n\r\n    return data, unsplit_data\r\n\r\n\r\n# ported from\r\n# https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py\r\ndef soundfile_read(\r\n    wavs: Union[str, List[str]],\r\n    dtype=None,\r\n    always_2d: bool = False,\r\n    concat_axis: int = 1,\r\n    start: int = 0,\r\n    end: int = None,\r\n    return_subtype: bool = False,\r\n) -> Tuple[np.array, int]:\r\n    if isinstance(wavs, str):\r\n        wavs = [wavs]\r\n\r\n    arrays = []\r\n    subtypes = []\r\n    prev_rate = None\r\n    prev_wav = None\r\n    for wav in wavs:\r\n        with soundfile.SoundFile(wav) as f:\r\n            f.seek(start)\r\n            if end is not None:\r\n                frames = end - start\r\n            else:\r\n                frames = -1\r\n            if dtype == \"float16\":\r\n                array = f.read(\r\n                    frames,\r\n                    dtype=\"float32\",\r\n                    always_2d=always_2d,\r\n                ).astype(dtype)\r\n            else:\r\n                array = f.read(frames, dtype=dtype, always_2d=always_2d)\r\n            rate = f.samplerate\r\n            subtype = f.subtype\r\n            subtypes.append(subtype)\r\n\r\n        if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:\r\n            # array: (Time, Channel)\r\n            array = array[:, None]\r\n\r\n        if prev_wav is not None:\r\n            if prev_rate != rate:\r\n                raise RuntimeError(\r\n                    f\"{prev_wav} and {wav} have mismatched sampling rate: \"\r\n                    f\"{prev_rate} != {rate}\")\r\n\r\n            dim1 = arrays[0].shape[1 - concat_axis]\r\n            dim2 = array.shape[1 - concat_axis]\r\n            if dim1 != dim2:\r\n                raise RuntimeError(\r\n                    \"Shapes must match with \"\r\n                    f\"{1 - concat_axis} axis, but gut {dim1} and {dim2}\")\r\n\r\n        prev_rate = rate\r\n        prev_wav = wav\r\n        arrays.append(array)\r\n\r\n    if len(arrays) == 1:\r\n        array = arrays[0]\r\n    else:\r\n        array = np.concatenate(arrays, axis=concat_axis)\r\n\r\n    if return_subtype:\r\n        return array, rate, subtypes\r\n    else:\r\n        return array, rate\r\n\r\n\r\n# ported from\r\n# https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py\r\nclass SoundScpReader(collections.abc.Mapping):\r\n    \"\"\"Reader class for 'wav.scp'.\r\n\r\n    Examples:\r\n        wav.scp is a text file that looks like the following:\r\n\r\n        key1 /some/path/a.wav\r\n        key2 /some/path/b.wav\r\n        key3 /some/path/c.wav\r\n        key4 /some/path/d.wav\r\n        ...\r\n\r\n        >>> reader = SoundScpReader('wav.scp')\r\n        >>> rate, array = reader['key1']\r\n\r\n        If multi_columns=True is given and\r\n        multiple files are given in one line\r\n        with space delimiter, and  the output array are concatenated\r\n        along channel direction\r\n\r\n        key1 /some/path/a.wav /some/path/a2.wav\r\n        key2 /some/path/b.wav /some/path/b2.wav\r\n        ...\r\n\r\n        >>> reader = SoundScpReader('wav.scp', multi_columns=True)\r\n        >>> rate, array = reader['key1']\r\n\r\n        In the above case, a.wav and a2.wav are concatenated.\r\n\r\n        Note that even if multi_columns=True is given,\r\n        SoundScpReader still supports a normal wav.scp,\r\n        i.e., a wav file is given per line,\r\n        but this option is disable by default\r\n        because dict[str, list[str]] object is needed to be kept,\r\n        but it increases the required amount of memory.\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        fname,\r\n        dtype=None,\r\n        always_2d: bool = False,\r\n        multi_columns: bool = False,\r\n        concat_axis=1,\r\n    ):\r\n        self.fname = fname\r\n        self.dtype = dtype\r\n        self.always_2d = always_2d\r\n\r\n        if multi_columns:\r\n            self.data, _ = read_multi_columns_text(fname)\r\n        else:\r\n            self.data = read_2columns_text(fname)\r\n        self.multi_columns = multi_columns\r\n        self.concat_axis = concat_axis\r\n\r\n    def __getitem__(self, key) -> Tuple[int, np.ndarray]:\r\n        wavs = self.data[key]\r\n\r\n        array, rate = soundfile_read(\r\n            wavs,\r\n            dtype=self.dtype,\r\n            always_2d=self.always_2d,\r\n            concat_axis=self.concat_axis,\r\n        )\r\n        # Returned as scipy.io.wavread's order\r\n        return rate, array\r\n\r\n    def get_path(self, key):\r\n        return self.data[key]\r\n\r\n    def __contains__(self, item):\r\n        return item\r\n\r\n    def __len__(self):\r\n        return len(self.data)\r\n\r\n    def __iter__(self):\r\n        return iter(self.data)\r\n\r\n    def keys(self):\r\n        return self.data.keys()\r\n"
  },
  {
    "path": "wesep/utils/funcs.py",
    "content": "# Created on 2018/12\r\n# Author: Kaituo XU\r\n\r\nimport math\r\n\r\nimport torch\r\nimport torchaudio.compliance.kaldi as kaldi\r\n\r\n\r\ndef overlap_and_add(signal, frame_step):\r\n    \"\"\"Reconstructs a signal from a framed representation.\r\n\r\n    Adds potentially overlapping frames of a signal with shape\r\n    `[..., frames, frame_length]`, offsetting subsequent frames\r\n    by `frame_step`.\r\n    The resulting tensor has shape `[..., output_size]` where\r\n\r\n        output_size = (frames - 1) * frame_step + frame_length\r\n\r\n    Args:\r\n        signal: A [..., frames, frame_length] Tensor. All dimensions\r\n                may be unknown, and rank must be at least 2.\r\n        frame_step: An integer denoting overlap offsets. Must be\r\n                    less than or equal to frame_length.\r\n\r\n    Returns:\r\n        A Tensor with shape [..., output_size] containing the overlap-added\r\n        frames of signal's inner-most two dimensions.\r\n\r\n        output_size = (frames - 1) * frame_step + frame_length\r\n\r\n    Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/\r\n             contrib/signal/python/ops/reconstruction_ops.py\r\n    \"\"\"\r\n    outer_dimensions = signal.size()[:-2]\r\n    frames, frame_length = signal.size()[-2:]\r\n\r\n    subframe_length = math.gcd(frame_length,\r\n                               frame_step)  # gcd=Greatest Common Divisor\r\n    subframe_step = frame_step // subframe_length\r\n    subframes_per_frame = frame_length // subframe_length\r\n    output_size = frame_step * (frames - 1) + frame_length\r\n    output_subframes = output_size // subframe_length\r\n\r\n    subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)\r\n\r\n    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame,\r\n                                                     subframe_step)\r\n    frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU\r\n    frame = frame.contiguous().view(-1)\r\n\r\n    result = signal.new_zeros(*outer_dimensions, output_subframes,\r\n                              subframe_length)\r\n    result.index_add_(-2, frame, subframe_signal)\r\n    result = result.view(*outer_dimensions, -1)\r\n    return result\r\n\r\n\r\ndef remove_pad(inputs, inputs_lengths):\r\n    \"\"\"\r\n    Args:\r\n        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size\r\n        inputs_lengths: torch.Tensor, [B]\r\n    Returns:\r\n        results: a list containing B items, each item is [C, T], T varies\r\n    \"\"\"\r\n    results = []\r\n    dim = inputs.dim()\r\n    if dim == 3:\r\n        C = inputs.size(1)\r\n    for input, length in zip(inputs, inputs_lengths):\r\n        if dim == 3:  # [B, C, T]\r\n            results.append(input[:, :length].view(C, -1).cpu().numpy())\r\n        elif dim == 2:  # [B, T]\r\n            results.append(input[:length].view(-1).cpu().numpy())\r\n    return results\r\n\r\n\r\ndef clip_gradients(model, clip):\r\n    norms = []\r\n    for _, p in model.named_parameters():\r\n        if p.grad is not None:\r\n            param_norm = p.grad.data.norm(2)\r\n            norms.append(param_norm.item())\r\n            clip_coef = clip / (param_norm + 1e-6)\r\n            if clip_coef < 1:\r\n                p.grad.data.mul_(clip_coef)\r\n    return norms\r\n\r\n\r\ndef compute_fbank(\r\n    data,\r\n    num_mel_bins=80,\r\n    frame_length=25,\r\n    frame_shift=10,\r\n    dither=1.0,\r\n    sample_rate=16000,\r\n):\r\n    \"\"\"Extract fbank\"\"\"\r\n    fbank_list = []\r\n    for index_ in range(data.shape[0]):\r\n        waveform = data[index_, :].unsqueeze(0)\r\n        waveform = waveform * (1 << 15)\r\n        mat = kaldi.fbank(\r\n            waveform,\r\n            num_mel_bins=num_mel_bins,\r\n            frame_length=frame_length,\r\n            frame_shift=frame_shift,\r\n            dither=dither,\r\n            sample_frequency=sample_rate,\r\n            window_type=\"hamming\",\r\n            use_energy=False,\r\n        )\r\n        fbank_list.append(mat.unsqueeze(0))\r\n    np_fbank = torch.cat(fbank_list, 0)\r\n    return np_fbank\r\n\r\n\r\ndef apply_cmvn(data, norm_mean=True, norm_var=False):\r\n    \"\"\"Apply CMVN\r\n\r\n    Args:\r\n        data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1',\r\n        'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']\r\n\r\n    Returns:\r\n        Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1',\r\n        'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']\r\n    \"\"\"\r\n    mat_list = []\r\n    for index_ in range(data.shape[0]):\r\n        mat = data[index_, :, :]\r\n        if norm_mean:\r\n            mat = mat - torch.mean(mat, dim=0)\r\n        if norm_var:\r\n            mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8)\r\n        mat = mat.unsqueeze(0)\r\n        mat_list.append(mat)\r\n    np_mat = torch.cat(mat_list, 0)\r\n    return np_mat\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    torch.manual_seed(123)\r\n    M, C, K, N = 2, 2, 3, 4\r\n    frame_step = 2\r\n    signal = torch.randint(5, (M, C, K, N))\r\n    result = overlap_and_add(signal, frame_step)\r\n    print(signal)\r\n    print(result)\r\n"
  },
  {
    "path": "wesep/utils/losses.py",
    "content": "import auraloss\nimport torch.nn as nn\nimport torchmetrics.audio as audio_metrics\nfrom torchmetrics.functional.audio import scale_invariant_signal_noise_ratio\n\"\"\"Get a loss function with its name from the configuration file.\"\"\"\nvalid_losses = {}\n\ntorch_losses = {\n    \"L1\": nn.L1Loss(),\n    \"L2\": nn.MSELoss(),\n    \"CE\": nn.CrossEntropyLoss(),\n}\n\ntorchmetrics_losses = {\n    # Not tested\n    \"PIT\":\n    audio_metrics.PermutationInvariantTraining(\n        scale_invariant_signal_noise_ratio),\n}\n\nauraloss_losses = {\n    \"STFT\": auraloss.freq.STFTLoss(),\n    \"MultiResolutionSTFT\": auraloss.freq.MultiResolutionSTFTLoss(),\n    \"SISDR\": auraloss.time.SISDRLoss(),\n    \"SISNR\": auraloss.time.SISDRLoss(),\n    \"SNR\": auraloss.time.SNRLoss(),\n}\n\nvalid_losses.update(torch_losses)\nvalid_losses.update(auraloss_losses)\nvalid_losses.update(torchmetrics_losses)\n\n\ndef parse_loss(loss):\n    loss_functions = []\n    if not isinstance(loss, list):\n        loss = [loss]\n    for i in range(len(loss)):\n        loss_name = loss[i]\n        loss_functions.append(valid_losses.get(loss_name))\n    return loss_functions\n"
  },
  {
    "path": "wesep/utils/schedulers.py",
    "content": "# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)\n#               2021 Zhengyang Chen (chenzhengyang117@gmail.com)\n#               2022 Hongji Wang (jijijiang77@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\n\nclass MarginScheduler:\n\n    def __init__(\n        self,\n        model,\n        epoch_iter,\n        increase_start_epoch,\n        fix_start_epoch,\n        initial_margin,\n        final_margin,\n        update_margin,\n        increase_type=\"exp\",\n    ):\n        \"\"\"\n        The margin is fixed as initial_margin before increase_start_epoch,\n        between increase_start_epoch and fix_start_epoch, the margin is\n        exponentially increasing from initial_margin to final_margin\n        after fix_start_epoch, the margin is fixed as final_margin.\n        \"\"\"\n        self.model = model\n        self.increase_start_iter = (increase_start_epoch - 1) * epoch_iter\n        self.fix_start_iter = (fix_start_epoch - 1) * epoch_iter\n        self.initial_margin = initial_margin\n        self.final_margin = final_margin\n        self.increase_type = increase_type\n\n        self.fix_already = False\n        self.current_iter = 0\n        self.update_margin = update_margin and hasattr(self.model.projection,\n                                                       \"update\")\n        self.increase_iter = self.fix_start_iter - self.increase_start_iter\n\n        self.init_margin()\n\n    def init_margin(self):\n        if hasattr(self.model.projection, \"update\"):\n            self.model.projection.update(margin=self.initial_margin)\n\n    def get_increase_margin(self):\n        initial_val = 1.0\n        final_val = 1e-3\n\n        current_iter = self.current_iter - self.increase_start_iter\n\n        if self.increase_type == \"exp\":  # exponentially increase the margin\n            ratio = (1.0 - math.exp(\n                (current_iter / self.increase_iter) *\n                math.log(final_val / (initial_val + 1e-6))) * initial_val)\n        else:  # linearly increase the margin\n            ratio = 1.0 * current_iter / self.increase_iter\n        return (self.initial_margin +\n                (self.final_margin - self.initial_margin) * ratio)\n\n    def step(self, current_iter=None):\n        if not self.update_margin or self.fix_already:\n            return\n\n        if current_iter is not None:\n            self.current_iter = current_iter\n\n        if self.current_iter >= self.fix_start_iter:\n            self.fix_already = True\n            if hasattr(self.model.projection, \"update\"):\n                self.model.projection.update(margin=self.final_margin)\n        elif self.current_iter >= self.increase_start_iter:\n            if hasattr(self.model.projection, \"update\"):\n                self.model.projection.update(margin=self.get_increase_margin())\n\n        self.current_iter += 1\n\n    def get_margin(self):\n        try:\n            margin = self.model.projection.margin\n        except Exception:\n            margin = 0.0\n\n        return margin\n\n\nclass BaseClass:\n    \"\"\"\n    Base Class for learning rate scheduler\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer,\n        num_epochs,\n        epoch_iter,\n        initial_lr,\n        final_lr,\n        warm_up_epoch=6,\n        scale_ratio=1.0,\n        warm_from_zero=False,\n    ):\n        \"\"\"\n        warm_up_epoch: the first warm_up_epoch is the multiprocess\n                       warm-up stage\n        scale_ratio: multiplied to the current lr in the multiprocess\n                     training process\n        \"\"\"\n        self.optimizer = optimizer\n        self.max_iter = num_epochs * epoch_iter\n        self.initial_lr = initial_lr\n        self.final_lr = final_lr\n        self.scale_ratio = scale_ratio\n        self.current_iter = 0\n        self.warm_up_iter = warm_up_epoch * epoch_iter\n        self.warm_from_zero = warm_from_zero\n\n    def get_multi_process_coeff(self):\n        lr_coeff = 1.0 * self.scale_ratio\n        if self.current_iter < self.warm_up_iter:\n            if self.warm_from_zero:\n                lr_coeff = (self.scale_ratio * self.current_iter /\n                            self.warm_up_iter)\n            elif self.scale_ratio > 1:\n                lr_coeff = (self.scale_ratio -\n                            1) * self.current_iter / self.warm_up_iter + 1.0\n\n        return lr_coeff\n\n    def get_current_lr(self):\n        \"\"\"\n        This function should be implemented in the child class\n        \"\"\"\n        return 0.0\n\n    def get_lr(self):\n        return self.optimizer.param_groups[0][\"lr\"]\n\n    def set_lr(self):\n        current_lr = self.get_current_lr()\n        for param_group in self.optimizer.param_groups:\n            param_group[\"lr\"] = current_lr\n\n    def step(self, current_iter=None):\n        if current_iter is not None:\n            self.current_iter = current_iter\n\n        self.set_lr()\n        self.current_iter += 1\n\n    def step_return_lr(self, current_iter=None):\n        if current_iter is not None:\n            self.current_iter = current_iter\n\n        current_lr = self.get_current_lr()\n        self.current_iter += 1\n\n        return current_lr\n\n    def state_dict(self):\n        \"\"\"Returns the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        \"\"\"\n        return {\n            key: value\n            for key, value in self.__dict__.items() if key != \"optimizer\"\n        }\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Loads the schedulers state.\n\n        Args:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        self.__dict__.update(state_dict)\n\n\nclass ExponentialDecrease(BaseClass):\n\n    def __init__(\n        self,\n        optimizer,\n        num_epochs,\n        epoch_iter,\n        initial_lr,\n        final_lr,\n        warm_up_epoch=6,\n        scale_ratio=1.0,\n        warm_from_zero=False,\n    ):\n        super().__init__(\n            optimizer,\n            num_epochs,\n            epoch_iter,\n            initial_lr,\n            final_lr,\n            warm_up_epoch,\n            scale_ratio,\n            warm_from_zero,\n        )\n\n    def get_current_lr(self):\n        lr_coeff = self.get_multi_process_coeff()\n        current_lr = (lr_coeff * self.initial_lr * math.exp(\n            (self.current_iter / self.max_iter) *\n            math.log(self.final_lr / self.initial_lr)))\n        return current_lr\n\n\nclass TriAngular2(BaseClass):\n    \"\"\"\n    The implementation of https://arxiv.org/pdf/1506.01186.pdf\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer,\n        num_epochs,\n        epoch_iter,\n        initial_lr,\n        final_lr,\n        warm_up_epoch=6,\n        scale_ratio=1.0,\n        cycle_step=2,\n        reduce_lr_diff_ratio=0.5,\n    ):\n        super().__init__(\n            optimizer,\n            num_epochs,\n            epoch_iter,\n            initial_lr,\n            final_lr,\n            warm_up_epoch,\n            scale_ratio,\n        )\n\n        self.reduce_lr_diff_ratio = reduce_lr_diff_ratio\n        self.cycle_iter = cycle_step * epoch_iter\n        self.step_size = self.cycle_iter // 2\n\n        self.max_lr = initial_lr\n        self.min_lr = final_lr\n        self.gap = self.max_lr - self.min_lr\n\n    def get_current_lr(self):\n        lr_coeff = self.get_multi_process_coeff()\n        point = self.current_iter % self.cycle_iter\n        cycle_index = self.current_iter // self.cycle_iter\n\n        self.max_lr = (self.min_lr +\n                       self.gap * self.reduce_lr_diff_ratio**cycle_index)\n\n        if point <= self.step_size:\n            current_lr = (self.min_lr +\n                          (self.max_lr - self.min_lr) * point / self.step_size)\n        else:\n            current_lr = (self.max_lr - (self.max_lr - self.min_lr) *\n                          (point - self.step_size) / self.step_size)\n\n        current_lr = lr_coeff * current_lr\n\n        return current_lr\n\n\ndef show_lr_curve(scheduler):\n    import matplotlib.pyplot as plt\n\n    lr_list = []\n    for current_lr in range(0, scheduler.max_iter):\n        lr_list.append(scheduler.step_return_lr(current_lr))\n    data_index = list(range(1, len(lr_list) + 1))\n\n    plt.plot(data_index, lr_list, \"-o\", markersize=1)\n    plt.legend(loc=\"best\")\n    plt.xlabel(\"Iteration\")\n    plt.ylabel(\"LR\")\n\n    plt.show()\n\n\nif __name__ == \"__main__\":\n    optimizer = None\n    num_epochs = 6\n    epoch_iter = 500\n    initial_lr = 0.6\n    final_lr = 0.1\n    warm_up_epoch = 2\n    scale_ratio = 4\n    scheduler = ExponentialDecrease(\n        optimizer,\n        num_epochs,\n        epoch_iter,\n        initial_lr,\n        final_lr,\n        warm_up_epoch,\n        scale_ratio,\n    )\n    # scheduler = TriAngular2(optimizer,\n    #                         num_epochs,\n    #                         epoch_iter,\n    #                         initial_lr,\n    #                         final_lr,\n    #                         warm_up_epoch,\n    #                         scale_ratio,\n    #                         cycle_step=2,\n    #                         reduce_lr_diff_ratio=0.5)\n\n    show_lr_curve(scheduler)\n"
  },
  {
    "path": "wesep/utils/score.py",
    "content": "import numpy as np\nfrom joblib import Parallel, delayed\nfrom pesq import pesq\nfrom pystoi.stoi import stoi\n\n\ndef cal_SISNR(est, ref, eps=1e-8):\n    \"\"\"Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)\n    Args:\n        est: separated signal, numpy.ndarray, [T]\n        ref: reference signal, numpy.ndarray, [T]\n    Returns:\n        SISNR\n    \"\"\"\n    assert len(est) == len(ref)\n    est_zm = est - np.mean(est)\n    ref_zm = ref - np.mean(ref)\n\n    t = np.sum(est_zm * ref_zm) * ref_zm / (np.linalg.norm(ref_zm)**2 + eps)\n    return 20 * np.log10(eps + np.linalg.norm(t) /\n                         (np.linalg.norm(est_zm - t) + eps))\n\n\ndef cal_SISNRi(est, ref, mix, eps=1e-8):\n    \"\"\"Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)\n    Args:\n        est: separated signal, numpy.ndarray, [T]\n        ref: reference signal, numpy.ndarray, [T]\n    Returns:\n        SISNR\n    \"\"\"\n    assert len(est) == len(ref) == len(mix)\n    sisnr1 = cal_SISNR(est, ref)\n    sisnr2 = cal_SISNR(mix, ref)\n\n    return sisnr1, sisnr1 - sisnr2\n\n\ndef cal_PESQ(est, ref):\n    assert len(est) == len(ref)\n    mode = \"wb\"\n    p = pesq(16000, ref, est, mode)\n    return p\n\n\ndef cal_PESQ_norm(est, ref):\n    assert len(est) == len(ref)\n    mode = \"wb\"\n    try:\n        # normalize PESQ to (0, 1)\n        p = (pesq(16000, ref, est, mode) + 0.5) / 5\n    except Exception:\n        # error can happen due to silent estimated signal\n        p = None\n    return p\n\n\ndef cal_PESQi(est, ref, mix):\n    \"\"\"Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)\n    Args:\n        est: separated signal, numpy.ndarray, [T]\n        ref: reference signal, numpy.ndarray, [T]\n    Returns:\n        SISNR\n    \"\"\"\n    assert len(est) == len(ref) == len(mix)\n    pesq1 = cal_PESQ(est, ref)\n    pesq2 = cal_PESQ(mix, ref)\n\n    return pesq1, pesq1 - pesq2\n\n\ndef cal_STOI(est, ref):\n    assert len(est) == len(ref)\n    p = stoi(ref, est, 16000)\n    return p\n\n\ndef cal_STOIi(est, ref, mix):\n    \"\"\"Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)\n    Args:\n        est: separated signal, numpy.ndarray, [T]\n        ref: reference signal, numpy.ndarray, [T]\n    Returns:\n        SISNR\n    \"\"\"\n    assert len(est) == len(ref) == len(mix)\n    stoi1 = cal_STOI(est, ref)\n    stoi2 = cal_STOI(mix, ref)\n\n    return stoi1, stoi1 - stoi2\n\n\ndef batch_evaluation(metric, est, ref, lengths=None, parallel=False, n_jobs=8):\n    \"\"\"Calculate specified evaluation metrics in batches\n\n    Args:\n        metric (Callable): the function to calculate metric\n        est (np.ndarray): separated signal, numpy.ndarray, [B, T]\n        ref (np.ndarray): reference signal, numpy.ndarray, [B, T]\n        lengths (np.ndarray, optional): specify the length of each signal.\n                                        Defaults to None.\n        parallel (bool, optional): whether to calculate metric in parallel.\n                                   Default to False.\n        n_jobs (int, optional): number of jobs, used when `parallel` is True.\n                                Defaults to 8.\n\n    Returns:\n        scores (np.ndarray): batched metrics, [B]\n    \"\"\"\n    assert callable(metric)\n    if lengths is not None:\n        assert ((0 < lengths) & (lengths <= 1)).all()\n        lengths = (lengths * est.size(1)).round().int().cpu()\n        est = [p[:length].cpu() for p, length in zip(est, lengths)]\n        ref = [t[:length].cpu() for t, length in zip(ref, lengths)]\n\n    if parallel:\n        while True:\n            try:\n                scores = Parallel(n_jobs=n_jobs,\n                                  timeout=30)(delayed(metric)(p, t)\n                                              for p, t in zip(est, ref))\n                break\n            except Exception as e:\n                print(e)\n                print(\"Evaluation timeout...... (will try again)\")\n    else:\n        scores = []\n        for p, t in zip(est, ref):\n            score = metric(p, t)\n            scores.append(score)\n\n    if None in scores:\n        return None\n\n    return np.array(scores)\n"
  },
  {
    "path": "wesep/utils/signal.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scipy.signal import get_window\n\n\ndef init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):\n    \"\"\"\n    Return window coefficient\n    \"\"\"\n\n    def sqrthann(win_len):\n        return get_window(\"hann\", win_len, fftbins=True)**0.5\n\n    if win_type == \"None\" or win_type is None:\n        window = np.ones(win_len)\n    elif win_type == \"sqrthann\":\n        window = sqrthann(win_len)\n    else:\n        window = get_window(win_type, win_len, fftbins=True)  # **0.5\n\n    N = fft_len\n    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]\n    real_kernel = np.real(fourier_basis)\n    imag_kernel = np.imag(fourier_basis)\n    kernel = np.concatenate([real_kernel, imag_kernel], 1).T\n\n    if invers:\n        kernel = np.linalg.pinv(kernel).T\n\n    kernel = kernel * window\n    kernel = kernel[:, None, :]\n    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(\n        window[None, :, None].astype(np.float32))\n\n\nclass ConvSTFT(nn.Module):\n\n    def __init__(\n        self,\n        win_len,\n        win_inc,\n        fft_len=None,\n        win_type=\"hamming\",\n        feature_type=\"real\",\n    ):\n        super(ConvSTFT, self).__init__()\n\n        if fft_len is None:\n            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))\n        else:\n            self.fft_len = fft_len\n\n        kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)\n        self.register_buffer(\"weight\", kernel)\n        self.feature_type = feature_type\n        self.stride = win_inc\n        self.win_len = win_len\n        self.dim = self.fft_len\n\n    def forward(self, inputs):\n        if inputs.dim() == 2:\n            inputs = torch.unsqueeze(inputs, 1)\n        inputs = F.pad(\n            inputs, [self.win_len - self.stride, self.win_len - self.stride])\n        outputs = F.conv1d(inputs, self.weight, stride=self.stride)\n\n        if self.feature_type == \"complex\":\n            return outputs\n        else:\n            dim = self.dim // 2 + 1\n            real = outputs[:, :dim, :]\n            imag = outputs[:, dim:, :]\n            mags = torch.sqrt(real**2 + imag**2)\n            phase = torch.atan2(imag, real)\n            return mags, phase\n\n\nclass ConviSTFT(nn.Module):\n\n    def __init__(\n        self,\n        win_len,\n        win_inc,\n        fft_len=None,\n        win_type=\"hamming\",\n        feature_type=\"real\",\n    ):\n        super(ConviSTFT, self).__init__()\n        if fft_len is None:\n            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))\n        else:\n            self.fft_len = fft_len\n        kernel, window = init_kernels(win_len,\n                                      win_inc,\n                                      self.fft_len,\n                                      win_type,\n                                      invers=True)\n        self.register_buffer(\"weight\", kernel)\n        self.feature_type = feature_type\n        self.win_type = win_type\n        self.win_len = win_len\n        self.stride = win_inc\n        self.stride = win_inc\n        self.dim = self.fft_len\n        self.register_buffer(\"window\", window)\n        self.register_buffer(\"enframe\", torch.eye(win_len)[:, None, :])\n\n    def forward(self, inputs, phase=None):\n        \"\"\"\n        inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)\n        phase: [B, N//2+1, T] (if not none)\n        \"\"\"\n\n        if phase is not None:\n            real = inputs * torch.cos(phase)\n            imag = inputs * torch.sin(phase)\n            inputs = torch.cat([real, imag], 1)\n        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)\n\n        # this is from torch-stft: https://github.com/pseeth/torch-stft\n        t = self.window.repeat(1, 1, inputs.size(-1))**2\n        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)\n        outputs = outputs / (coff + 1e-8)\n        # outputs = torch.where(coff == 0, outputs, outputs/coff)\n        outputs = outputs[..., self.win_len -\n                          self.stride:-(self.win_len - self.stride)]\n\n        return outputs\n"
  },
  {
    "path": "wesep/utils/utils.py",
    "content": "# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport difflib\nimport logging\nimport os\nimport random\nimport shutil\nimport sys\nfrom distutils.util import strtobool\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport yaml\n\n\ndef str2bool(value: str) -> bool:\n    return bool(strtobool(value))\n\n\ndef get_logger(outdir, fname):\n    formatter = logging.Formatter(\n        \"[ %(levelname)s : %(asctime)s ] - %(message)s\")\n    logging.basicConfig(\n        level=logging.DEBUG,\n        format=\"[ %(levelname)s : %(asctime)s ] - %(message)s\",\n    )\n    logger = logging.getLogger(\"Pyobj, f\")\n    # Dump log to file\n    fh = logging.FileHandler(os.path.join(outdir, fname))\n    fh.setFormatter(formatter)\n    logger.addHandler(fh)\n    return logger\n\n\ndef setup_logger(rank, exp_dir, device_ids, MAX_NUM_LOG_FILES: int = 100):\n    model_dir = os.path.join(exp_dir, \"models\")\n    file_name = \"train.log\"\n    if rank == 0:\n        os.makedirs(model_dir, exist_ok=True)\n        for i in range(MAX_NUM_LOG_FILES - 1, -1, -1):\n            if i == 0:\n                p = Path(os.path.join(exp_dir, file_name))\n                pn = p.parent / (p.stem + \".1\" + p.suffix)\n            else:\n                _p = Path(os.path.join(exp_dir, file_name))\n                p = _p.parent / (_p.stem + f\".{i}\" + _p.suffix)\n                pn = _p.parent / (_p.stem + f\".{i + 1}\" + _p.suffix)\n\n            if p.exists():\n                if i == MAX_NUM_LOG_FILES - 1:\n                    p.unlink()\n                else:\n                    shutil.move(p, pn)\n    dist.barrier(device_ids=[device_ids])  # let the rank 0 mkdir first\n    return get_logger(exp_dir, file_name)\n\n\ndef parse_config_or_kwargs(config_file, **kwargs):\n    \"\"\"parse_config_or_kwargs\n\n    :param config_file: Config file that has parameters, yaml format\n    :param **kwargs: Other alternative parameters or overwrites for conf\n    \"\"\"\n    with open(config_file) as con_read:\n        yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)\n    # values from conf file are all possible params\n    help_str = \"Valid Parameters are:\\n\"\n    help_str += \"\\n\".join(list(yaml_config.keys()))\n    # passed kwargs will override yaml conf\n    # for key in kwargs.keys():\n    #    assert key in yaml_config, \"Parameter {} invalid!\\n\".format(key)\n    # add the path of config file to dict\n    if \"config\" not in kwargs:\n        kwargs[\"config\"] = config_file\n    return dict(yaml_config, **kwargs)\n\n\ndef validate_path(dir_name):\n    \"\"\"Create the directory if it doesn't exist\n    :param dir_name\n    :return: None\n    \"\"\"\n    dir_name = os.path.dirname(dir_name)  # get the path\n    if not os.path.exists(dir_name) and (dir_name != \"\"):\n        os.makedirs(dir_name)\n\n\ndef set_seed(seed=42):\n    np.random.seed(seed)\n    random.seed(seed)\n\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n    # torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = True\n\n\ndef generate_enahnced_scp(directory: str, extension: str = \"wav\"):\n    source_dir = Path(directory)\n    spk_scp = source_dir.joinpath(\"spk1.scp\")\n    audio_list = []\n\n    for file_path in source_dir.rglob(f\"*.{extension}\"):\n        audio_list.append(file_path)\n\n    with open(spk_scp, \"w\") as f:\n        for audio in audio_list:\n            path = str(audio.resolve())\n            ori_filename = audio.stem\n            spk1_id = ori_filename.split(\"-\")[1]\n            # spk2_id = ori_filename.split(\"_\")[1].split(\"-\")[0]\n            curr_spk = ori_filename.split(\"T\")[1]\n            prefix = \"s1\" if curr_spk == spk1_id else \"s2\"\n            f_dash_index = ori_filename.find(\"-\")\n            l_dash_index = ori_filename.rfind(\"-\")\n            filename = ori_filename[f_dash_index + 1:l_dash_index]\n            final_filename = prefix + \"/\" + filename + \".wav\"\n            line = final_filename + \" \" + path\n            f.write(line + \"\\n\")\n\n\ndef get_commandline_args():\n    # ported from\n    # https://github.com/espnet/espnet/blob/master/espnet/utils/cli_utils.py\n    extra_chars = [\n        \" \",\n        \";\",\n        \"&\",\n        \"(\",\n        \")\",\n        \"|\",\n        \"^\",\n        \"<\",\n        \">\",\n        \"?\",\n        \"*\",\n        \"[\",\n        \"]\",\n        \"$\",\n        \"`\",\n        '\"',\n        \"\\\\\",\n        \"!\",\n        \"{\",\n        \"}\",\n    ]\n\n    # Escape the extra characters for shell\n    argv = [(arg.replace(\"'\", \"'\\\\''\") if all(\n        char not in arg\n        for char in extra_chars) else \"'\" + arg.replace(\"'\", \"'\\\\''\") + \"'\")\n        for arg in sys.argv]\n\n    return sys.executable + \" \" + \" \".join(argv)\n\n\n# ported from\n# https://github.com/espnet/espnet/blob/master/espnet2/utils/config_argparse.py\nclass ArgumentParser(argparse.ArgumentParser):\n    \"\"\"Simple implementation of ArgumentParser supporting config file\n\n    This class is originated from https://github.com/bw2/ConfigArgParse,\n    but this class is lack of some features that it has.\n\n    - Not supporting multiple config files\n    - Automatically adding \"--config\" as an option.\n    - Not supporting any formats other than yaml\n    - Not checking argument type\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.add_argument(\"--config\", help=\"Give config file in yaml format\")\n\n    def parse_known_args(self, args=None, namespace=None):\n        # Once parsing for setting from \"--config\"\n        _args, _ = super().parse_known_args(args, namespace)\n        if _args.config is not None:\n            if not Path(_args.config).exists():\n                self.error(f\"No such file: {_args.config}\")\n\n            with open(_args.config, \"r\", encoding=\"utf-8\") as f:\n                d = yaml.safe_load(f)\n            if not isinstance(d, dict):\n                self.error(\"Config file has non dict value: {_args.config}\")\n\n            for key in d:\n                for action in self._actions:\n                    if key == action.dest:\n                        break\n                else:\n                    self.error(\n                        f\"unrecognized arguments: {key} (from {_args.config})\")\n\n            # NOTE(kamo): Ignore \"--config\" from a config file\n            # NOTE(kamo): Unlike \"configargparse\", this module doesn't\n            #             check type. i.e. We can set any type value\n            #             regardless of argument type.\n            self.set_defaults(**d)\n        return super().parse_known_args(args, namespace)\n\n\ndef get_layer(l_name, library=torch.nn):\n    \"\"\"Return layer object handler from library e.g. from torch.nn\n\n    E.g. if l_name==\"elu\", returns torch.nn.ELU.\n\n    Args:\n        l_name (string): Case insensitive name for layer in library\n                        (e.g. .'elu').\n        library (module): Name of library/module where to search for\n                          object handler with l_name e.g. \"torch.nn\".\n\n    Returns:\n        layer_handler (object): handler for the requested layer\n                                e.g. (torch.nn.ELU)\n\n    \"\"\"\n\n    all_torch_layers = list(dir(torch.nn))\n    match = [x for x in all_torch_layers if l_name.lower() == x.lower()]\n    if len(match) == 0:\n        close_matches = difflib.get_close_matches(\n            l_name, [x.lower() for x in all_torch_layers])\n        raise NotImplementedError(\n            \"Layer with name {} not found in {}.\\n Closest matches: {}\".format(\n                l_name, str(library), close_matches))\n    elif len(match) > 1:\n        close_matches = difflib.get_close_matches(\n            l_name, [x.lower() for x in all_torch_layers])\n        raise NotImplementedError(\n            \"Multiple matchs for layer with name {} not found in {}.\\n \"\n            \"All matches: {}\".format(l_name, str(library), close_matches))\n    else:\n        # valid\n        layer_handler = getattr(library, match[0])\n        return layer_handler\n\n\n# def spk2id(utt_spk_list):\n#     _, spk_list = zip(*utt_spk_list)\n#     spk_list = sorted(list(set(spk_list)))  # remove overlap and sort\n\n#     spk2id_dict = {}\n#     spk_list.sort()\n#     for i, spk in enumerate(spk_list):\n#         spk2id_dict[spk] = i\n#     return spk2id_dict\n"
  }
]