* out);
#ifdef _MSC_VER
std::wstring ToWString(const std::string& str);
#endif
} // namespace wesep
#endif // UTILS_UTILS_H_
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
requirements = [
"tqdm",
"kaldiio",
"torch>=1.12.0",
"torchaudio>=0.12.0",
"silero-vad",
]
setup(
name="wesep",
install_requires=requirements,
packages=find_packages(),
entry_points={
"console_scripts": [
"wesep = wesep.cli.extractor:main",
],
},
)
================================================
FILE: tools/extract_embed_depreciated.py
================================================
# Copyright (c) 2022, Shuai Wang (wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import kaldiio
import onnxruntime as ort
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="infer example using onnx")
parser.add_argument("--onnx_path", required=True, help="onnx path")
parser.add_argument("--wav_scp", required=True, help="wav path")
parser.add_argument("--out_path",
required=True,
help="output path of the embeddings")
args = parser.parse_args()
return args
def compute_fbank(wav_path,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0):
"""Extract fbank, simlilar to the one in wespeaker.dataset.processor,
While integrating the wave reading and CMN.
"""
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
sample_frequency=sample_rate,
window_type="hamming",
use_energy=False,
)
# CMN, without CVN
mat = mat - torch.mean(mat, dim=0)
return mat
def main():
args = get_args()
so = ort.SessionOptions()
so.inter_op_num_threads = 1
so.intra_op_num_threads = 1
session = ort.InferenceSession(args.onnx_path, sess_options=so)
embed_ark = os.path.join(args.out_path, "embed.ark")
embed_scp = os.path.join(args.out_path, "embed.scp")
with kaldiio.WriteHelper("ark,scp:" + embed_ark + "," +
embed_scp) as writer:
with open(args.wav_scp, "r") as read_scp:
for line in tqdm(read_scp):
tokens = line.strip().split(" ")
name, wav_path = tokens[0], tokens[1]
feats = compute_fbank(wav_path)
feats = feats.unsqueeze(0).numpy() # add batch dimension
embed = session.run(output_names=["embs"],
input_feed={"feats": feats})
writer(name, embed[0])
if __name__ == "__main__":
main()
================================================
FILE: tools/make_lmdb.py
================================================
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import pickle
import lmdb
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument("in_scp_file", help="input scp file")
parser.add_argument("out_lmdb", help="output lmdb")
args = parser.parse_args()
return args
def main():
args = get_args()
db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB
# txn is for Transaciton
txn = db.begin(write=True)
keys = []
with open(args.in_scp_file, "r", encoding="utf8") as fin:
lines = fin.readlines()
for i, line in enumerate(tqdm(lines)):
arr = line.strip().split()
assert len(arr) == 2
key, wav = arr[0], arr[1]
keys.append(key)
with open(wav, "rb") as fin:
data = fin.read()
txn.put(key.encode(), data)
# Write flush to disk
if i % 100 == 0:
txn.commit()
txn = db.begin(write=True)
txn.commit()
with db.begin(write=True) as txn:
txn.put(b"__keys__", pickle.dumps(keys))
db.sync()
db.close()
if __name__ == "__main__":
main()
================================================
FILE: tools/make_shard_list_premix.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang
# 2023 SRIBD Shuai Wang )
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import logging
import multiprocessing
import os
import random
import tarfile
import time
import sys
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def write_tar_file(data_list, tar_file, index=0, total=1):
logging.info('Processing {} {}/{}'.format(tar_file, index, total))
read_time = 0.0
write_time = 0.0
with tarfile.open(tar_file, "w") as tar:
for item in data_list:
assert len(
item) == 3, 'item should have 3 elements: Key, Speaker, Wav'
key, spks, wavs = item
spk_idx = 1
for spk in spks:
assert isinstance(spk, str)
spk_file = key + '.spk' + str(spk_idx)
spk = spk.encode('utf8')
spk_data = io.BytesIO(spk)
spk_info = tarfile.TarInfo(spk_file)
spk_info.size = len(spk)
tar.addfile(spk_info, spk_data)
spk_idx = spk_idx + 1
spk_idx = 0
for wav in wavs:
suffix = wav.split('.')[-1]
assert suffix in AUDIO_FORMAT_SETS
ts = time.time()
try:
with open(wav, 'rb') as fin:
data = fin.read()
except FileNotFoundError as e:
print(e)
sys.exit()
read_time += (time.time() - ts)
ts = time.time()
if spk_idx > 0:
wav_file = key + '_spk' + str(spk_idx) + '.' + suffix
else:
wav_file = key + '.' + suffix
wav_data = io.BytesIO(data)
wav_info = tarfile.TarInfo(wav_file)
wav_info.size = len(data)
tar.addfile(wav_info, wav_data)
write_time += (time.time() - ts)
spk_idx = spk_idx + 1
logging.info('read {} write {}'.format(read_time, write_time))
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_utts_per_shard',
type=int,
default=1000,
help='num utts per shard')
parser.add_argument('--num_threads',
type=int,
default=1,
help='num threads for make shards')
parser.add_argument('--prefix',
default='shards',
help='prefix of shards tar file')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--shuffle',
action='store_true',
help='whether to shuffle data')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('utt2spk_file', help='utt2spk file')
parser.add_argument('shards_dir', help='output shards dir')
parser.add_argument('shards_list', help='output shards list file')
args = parser.parse_args()
return args
def main():
args = get_args()
random.seed(args.seed)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0] # key = os.path.splitext(arr[0])[0]
wav_table[key] = [arr[i + 1] for i in range(len(arr) - 1)]
data = []
with open(args.utt2spk_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0] # key = os.path.splitext(arr[0])[0]
spks = [arr[i + 1] for i in range(len(arr) - 1)]
assert key in wav_table
wavs = wav_table[key]
data.append((key, spks, wavs))
if args.shuffle:
random.shuffle(data)
num = args.num_utts_per_shard
chunks = [data[i:i + num] for i in range(0, len(data), num)]
os.makedirs(args.shards_dir, exist_ok=True)
# Using thread pool to speedup
pool = multiprocessing.Pool(processes=args.num_threads)
shards_list = []
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
tar_file = os.path.join(args.shards_dir,
'{}_{:09d}.tar'.format(args.prefix, i))
shards_list.append(tar_file)
pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks))
pool.close()
pool.join()
with open(args.shards_list, 'w', encoding='utf8') as fout:
for name in shards_list:
fout.write(name + '\n')
if __name__ == '__main__':
main()
================================================
FILE: tools/make_shard_online.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang
# 2023 Shuai Wang )
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import logging
import multiprocessing
import os
import random
import tarfile
import time
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def write_tar_file(data_list, tar_file, index=0, total=1):
logging.info('Processing {} {}/{}'.format(tar_file, index, total))
read_time = 0.0
write_time = 0.0
with tarfile.open(tar_file, "w") as tar:
for item in data_list:
assert len(
item) == 3, 'item should have 3 elements: Key, Speaker, Wav'
key, spk, wav = item
suffix = wav.split('.')[-1]
assert suffix in AUDIO_FORMAT_SETS
ts = time.time()
with open(wav, 'rb') as fin:
data = fin.read()
read_time += (time.time() - ts)
assert isinstance(spk, str)
ts = time.time()
spk_file = key + '.spk'
spk = spk.encode('utf8')
spk_data = io.BytesIO(spk)
spk_info = tarfile.TarInfo(spk_file)
spk_info.size = len(spk)
tar.addfile(spk_info, spk_data)
wav_file = key + '.' + suffix
wav_data = io.BytesIO(data)
wav_info = tarfile.TarInfo(wav_file)
wav_info.size = len(data)
tar.addfile(wav_info, wav_data)
write_time += (time.time() - ts)
logging.info('read {} write {}'.format(read_time, write_time))
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_utts_per_shard',
type=int,
default=1000,
help='num utts per shard')
parser.add_argument('--num_threads',
type=int,
default=1,
help='num threads for make shards')
parser.add_argument('--prefix',
default='shards',
help='prefix of shards tar file')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--shuffle',
action='store_true',
help='whether to shuffle data')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('utt2spk_file', help='utt2spk file')
parser.add_argument('shards_dir', help='output shards dir')
parser.add_argument('shards_list', help='output shards list file')
args = parser.parse_args()
return args
def main():
args = get_args()
random.seed(args.seed)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0] # key = os.path.splitext(arr[0])[0]
wav_table[key] = ' '.join(arr[1:])
data = []
with open(args.utt2spk_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0] # key = os.path.splitext(arr[0])[0]
spk = arr[1]
assert key in wav_table
wav = wav_table[key]
data.append((key, spk, wav))
if args.shuffle:
random.shuffle(data)
num = args.num_utts_per_shard
chunks = [data[i:i + num] for i in range(0, len(data), num)]
os.makedirs(args.shards_dir, exist_ok=True)
# Using thread pool to speedup
pool = multiprocessing.Pool(processes=args.num_threads)
shards_list = []
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
tar_file = os.path.join(args.shards_dir,
'{}_{:09d}.tar'.format(args.prefix, i))
shards_list.append(tar_file)
pool.apply_async(write_tar_file, (chunk, tar_file, i, num_chunks))
pool.close()
pool.join()
with open(args.shards_list, 'w', encoding='utf8') as fout:
for name in shards_list:
fout.write(name + '\n')
if __name__ == '__main__':
main()
================================================
FILE: tools/parse_options.sh
================================================
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# 2022 Hongji Wang (jijijiang77@gmail.com)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --conf file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the confs specified by command-line, in left-to-right order
for ((argpos = 1; argpos < $#; argpos++)); do
if [ "${!argpos}" == "--conf" ]; then
argpos_plus1=$((argpos + 1))
conf=${!argpos_plus1}
[ ! -r $conf ] && echo "$0: missing conf '$conf'" && exit 1
. $conf # source the conf file.
fi
done
###
### No we process the command line options
###
while true; do
[ -z "${1:-}" ] && break # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help | -h)
if [ -z "$help_message" ]; then
echo "No help found." 1>&2
else printf "$help_message\n" 1>&2; fi
exit 0
;;
--*=*)
echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1
;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*)
name=$(echo "$1" | sed s/^--// | sed s/-/_/g)
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1
oldval="$(eval echo \$$name)"
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true
else
was_bool=false
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\"
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1
fi
shift 2
;;
*) break ;;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1
true # so this script returns exit code 0.
================================================
FILE: tools/print_train_val_curve.py
================================================
import re
import matplotlib.pyplot as plt
# Initialize lists to store epochs, train losses and validation losses
epochs = []
train_loss = []
val_loss = []
# Open the log file
prev_epoch = 0
with open("train.log", "r") as f:
for line in f:
# Find lines with epoch info
if "info" in line:
# Extract epoch number
epoch = int(re.search(r"Epoch (\d+)", line).group(1))
if epoch != prev_epoch:
print(prev_epoch, epoch)
# Extract loss values
# pattern = r'loss (.*?)\n'
pattern = r"[-+]?\d*\.\d+"
loss = float(re.search(pattern, line).group())
if "Train" in line:
epochs.append(epoch)
train_loss.append(loss)
elif "Val" in line:
val_loss.append(loss)
prev_epoch = epoch
# Create the plot
plt.figure(figsize=(10, 5))
# Plot training and validation loss
plt.plot(epochs, train_loss, label="Training Loss", color="blue")
plt.plot(epochs, val_loss, label="Validation Loss", color="red")
# Add horizontal lines at the minimum values
plt.axhline(min(train_loss),
color="blue",
linestyle="--",
label="Min Training Loss")
plt.axhline(min(val_loss),
color="red",
linestyle="--",
label="Min Validation Loss")
# Annotate the minimum values on the y-axis
plt.text(
0,
min(train_loss),
"{:.2f}".format(min(train_loss)),
va="center",
ha="left",
backgroundcolor="w",
)
plt.text(
0,
min(val_loss),
"{:.2f}".format(min(val_loss)),
va="center",
ha="left",
backgroundcolor="w",
)
# Add legend, title, and x, y labels
plt.legend(loc="upper right")
plt.title("Training and Validation Loss Over Epochs")
plt.ylabel("Loss Value")
plt.xlabel("Epochs")
# Save the plot as a .png file
plt.savefig("train_val_loss.png")
# Show the plot
# plt.show()
================================================
FILE: tools/run.pl
================================================
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while ();
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, ") { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (
) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);
================================================
FILE: tools/score.sh
================================================
#!/bin/bash
min() {
local a b
a=$1
for b in "$@"; do
if [ "${b}" -le "${a}" ]; then
a="${b}"
fi
done
echo "${a}"
}
# Set default values
dset=
exp_dir=
scoring_opts=
n_gpu=1
score_nj=16
ref_channel=0
use_pesq=false
use_dnsmos=false
dnsmos_use_gpu=true
fs=16k
scoring_protocol="STOI SDR SAR SIR SI_SNR"
# Parse command line options
. tools/parse_options.sh || exit 1
if [ ! ${fs} = 16k ] && ${use_dnsmos}; then
echo "Warning: DNSMOS only supports 16k sampling rate."
echo "--use_dnsmos will be set to false automatically."
use_dnsmos=false
fi
# Set scoring options
scoring_opts=""
if ${use_dnsmos}; then
# Set model path
primary_model_path=DNSMOS/sig_bak_ovr.onnx
p808_model_path=DNSMOS/model_v8.onnx
if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then
echo "=========================================="
echo "Warning: DNSMOS model files are not found."
echo "Trying to download them from the official repository."
echo "If this takes too long,"
echo "please manually download the model files"
echo "and put them in the DNSMOS directory."
echo "=========================================="
# creat directory for DNSMOS model files
mkdir -p DNSMOS
# download DNSMOS model files and save them to the directory
wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx
wget -P DNSMOS https://github.com/microsoft/DNS-Challenge/raw/master/DNSMOS/DNSMOS/model_v8.onnx
# check if the model files are downloaded successfully
if [ ! -f ${primary_model_path} ] || [ ! -f ${p808_model_path} ]; then
echo "Error: DNSMOS model files are not downloaded successfully."
exit 1
fi
fi
scoring_opts+="--dnsmos_mode local "
scoring_opts+="--dnsmos_primary_model ${primary_model_path} "
scoring_opts+="--dnsmos_p808_model ${p808_model_path} "
if ${dnsmos_use_gpu}; then
score_nj=$(min "${score_nj}" "${n_gpu}")
scoring_opts+="--dnsmos_use_gpu ${dnsmos_use_gpu} "
fi
fi
# Set directories and log directory
_dir="${exp_dir}/scoring"
_logdir="${_dir}/logdir"
mkdir -p "${_logdir}"
# 0. Check the inference file
inf_scp=${exp_dir}/audio/spk1.scp
if [ ! -s "${inf_scp}" ] || [ -z "$(cat "${inf_scp}")" ]; then
echo "Error: ${inf_scp} does not exist or is empty!"
exit 1
fi
# 1. Split the key file
key_file=${dset}/single.wav.scp
split_scps=""
_nj=$(min "${score_nj}" "$(wc <${key_file} -l)")
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
./tools/split_scp.pl "${key_file}" ${split_scps}
_ref_scp="--ref_scp ${dset}/single.wav.scp "
_inf_scp="--inf_scp ${exp_dir}/audio/spk1.scp "
# 2. Submit scoring jobs
echo "log: '${_logdir}/tse_scoring.*.log'"
if ${use_dnsmos} && ${dnsmos_use_gpu}; then
cmd="./tools/run.pl --gpu ${n_gpu}"
else
cmd="./tools/run.pl"
fi
# shellcheck disable=SC2086
${cmd} JOB=1:"${_nj}" "${_logdir}"/tse_scoring.JOB.log \
python -m wesep.bin.score \
--key_file "${_logdir}"/keys.JOB.scp \
--output_dir "${_logdir}"/output.JOB \
${_ref_scp} \
${_inf_scp} \
--ref_channel ${ref_channel} \
--use_pesq ${use_pesq} \
--use_dnsmos ${use_dnsmos} \
--dnsmos_gpu_device JOB \
${scoring_opts}
# Check if PESQ is used
if "${use_pesq}"; then
if [ ${fs} = 16k ]; then
scoring_protocol+=" PESQ_WB"
else
scoring_protocol+=" PESQ_NB"
fi
fi
# Check if dnsmos is used
if "${use_dnsmos}"; then
scoring_protocol+=" BAK SIG OVRL P808_MOS"
fi
# Merge and sort result files
for protocol in ${scoring_protocol} wav; do
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/${protocol}_spk1"
done | LC_ALL=C sort -k1 >"${_dir}/${protocol}_spk1"
done
# Calculate and save results
for protocol in ${scoring_protocol}; do
# shellcheck disable=SC2046
paste $(printf "%s/%s_spk1 " "${_dir}" "${protocol}") |
awk 'BEGIN{sum=0}
{n=0;score=0;for (i=2; i<=NF; i+=2){n+=1;score+=$i}; sum+=score/n}
END{printf ("%.2f\n",sum/NR)}' >"${_dir}/result_${protocol,,}.txt"
done
# show the result
./tools/show_enh_score.sh "${_dir}/../.." > \
"${_dir}/../../RESULTS.md"
================================================
FILE: tools/show_enh_score.sh
================================================
#!/usr/bin/env bash
mindepth=0
maxdepth=1
. tools/parse_options.sh
if [ $# -gt 1 ]; then
echo "Usage: $0 --mindepth 0 --maxdepth 1 [exp]" 1>&2
echo ""
echo "Show the system environments and the evaluation results in Markdown format."
echo 'The default of is "exp/".'
exit 1
fi
[ -f ./path.sh ] && . ./path.sh
set -euo pipefail
if [ $# -eq 1 ]; then
exp=$(realpath "$1")
else
exp=exp
fi
cat <
# RESULTS
## Environments
- date: \`$(LC_ALL=C date)\`
EOF
cat </dev/null; then
echo -e "\n## $(basename ${expdir})\n"
[ -e "${expdir}"/config.yaml ] && grep ^config "${expdir}"/config.yaml
metrics=()
heading="\n|dataset|"
sep="|---|"
for type in pesq pesq_wb pesq_nb estoi stoi sar sdr sir si_snr ovrl sig bak p808_mos; do
if ls "${expdir}"/*/scoring/result_${type}.txt &>/dev/null; then
metrics+=("$type")
heading+="${type^^}|"
sep+="---|"
fi
done
echo -e "${heading}\n${sep}"
setnames=()
for dirname in "${expdir}"/*/scoring/result_stoi.txt; do
dset=$(echo $dirname | sed -e "s#${expdir}/\([^/]*\)/scoring/result_stoi.txt#\1#g")
setnames+=("$dset")
done
for dset in "${setnames[@]}"; do
line="|${dset}|"
for ((i = 0; i < ${#metrics[@]}; i++)); do
type=${metrics[$i]}
if [ -f "${expdir}"/${dset}/scoring/result_${type}.txt ]; then
score=$(head -n1 "${expdir}"/${dset}/scoring/result_${type}.txt)
else
score=""
fi
line+="${score}|"
done
echo $line
done
echo ""
fi
done < <(find ${exp} -mindepth ${mindepth} -maxdepth ${maxdepth} -type d)
================================================
FILE: tools/split_scp.pl
================================================
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);
================================================
FILE: tools/test_dataset.py
================================================
from torch.utils.data import DataLoader
from wesep.dataset.dataset import Dataset
from wesep.dataset.dataset import tse_collate_fn
from wesep.utils.file_utils import load_speaker_embeddings
def test_premixed_dataset():
configs = {
"shuffle": False,
"shuffle_args": {
"shuffle_size": 2500
},
"resample_rate": 16000,
"chunk_len": 32000,
}
spk2embed_dict = load_speaker_embeddings("data/clean/test/embed.scp",
"data/clean/test/single.utt2spk")
dataset = Dataset(
"shard",
"data/clean/test/shard.list",
configs=configs,
spk2embed_dict=spk2embed_dict,
whole_utt=False,
)
return dataset
def test_online_dataset():
# Implementation to test the online speaker mixing dataloader
configs = {
"shuffle": True,
"resample_rate": 16000,
"chunk_len": 64000,
"num_speakers": 2,
"online_mix": True,
"reverb": False,
}
spk2embed_dict = load_speaker_embeddings("mydata/clean/test/embed.scp",
"mydata/clean/test/utt2spk")
dataset = Dataset(
"shard",
"mydata/clean/test/shard.list",
configs=configs,
spk2embed_dict=spk2embed_dict,
whole_utt=False,
)
return dataset
if __name__ == "__main__":
dataset = test_online_dataset()
dataloader = DataLoader(dataset,
batch_size=4,
num_workers=1,
collate_fn=tse_collate_fn)
for i, batch in enumerate(dataloader):
print(
batch["wav_mix"].size(),
batch["wav_targets"].size(),
batch["spk_embeds"].size(),
)
if i == 0:
break
================================================
FILE: wesep/__init__.py
================================================
from wesep.cli.extractor import load_model # noqa
from wesep.cli.extractor import load_model_local # noqa
================================================
FILE: wesep/bin/average_model.py
================================================
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# 2021 Hongji Wang (jijijiang77@gmail.com)
# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os.path
import re
import torch
def get_args():
parser = argparse.ArgumentParser(description="average model")
parser.add_argument("--dst_model", required=True, help="averaged model")
parser.add_argument("--src_path",
required=True,
help="src model path for average")
parser.add_argument("--num",
default=5,
type=int,
help="nums for averaged model")
parser.add_argument(
"--min_epoch",
default=0,
type=int,
help="min epoch used for averaging model",
)
parser.add_argument(
"--max_epoch",
default=65536, # Big enough
type=int,
help="max epoch used for averaging model",
)
parser.add_argument(
"--mode",
default="final",
type=str,
help="use last epochs for average or best epochs",
)
parser.add_argument(
"--epochs",
default="1,2,3,4,5",
type=str,
help="use last epochs for average or best epochs",
)
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
if args.mode == "final":
path_list = glob.glob("{}/*[!avg][!final][!latest].pt".format(
args.src_path))
path_list = sorted(
path_list,
key=lambda p: int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", p)[0]),
)
path_list = path_list[-args.num:]
else:
epoch_indexes = list(args.epochs.split(","))
path_list = [
os.path.join(args.src_path, "checkpoint_" + x + ".pt")
for x in epoch_indexes
]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print("Processing {}".format(path))
states = torch.load(path, map_location=torch.device("cpu"))
states = states["models"][0] if "models" in states else states
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
avg = {"models": [avg]}
print("Saving to {}".format(args.dst_model))
torch.save(avg, args.dst_model)
if __name__ == "__main__":
main()
================================================
FILE: wesep/bin/export_jit.py
================================================
from __future__ import print_function
import argparse
import os
import torch
import yaml
from wesep.models import get_model
from wesep.utils.checkpoint import load_pretrained_model
def get_args():
parser = argparse.ArgumentParser(description="export your script model")
parser.add_argument("--config", required=True, help="config file")
parser.add_argument("--checkpoint", required=True, help="checkpoint model")
parser.add_argument("--output_model", required=True, help="output file")
args = parser.parse_args()
return args
def main():
args = get_args()
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
print(configs)
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
print(model)
load_pretrained_model(model, args.checkpoint)
model.eval()
speaker_feat_dim = configs["dataset_args"]["fbank_args"].get(
"num_mel_bins", 80)
speaker_dummy_input = torch.ones(2, 300, speaker_feat_dim)
mix_dummy_input = torch.ones(2, 81280)
script_model = torch.jit.script(model,
(mix_dummy_input, speaker_dummy_input))
script_model.save(args.output_model)
print("Export model successfully, see {}".format(args.output_model))
if __name__ == "__main__":
main()
================================================
FILE: wesep/bin/infer.py
================================================
from __future__ import print_function
import os
import time
import fire
import soundfile
import torch
from torch.utils.data import DataLoader
from wesep.dataset.dataset import Dataset, tse_collate_fn_2spk
from wesep.models import get_model
from wesep.utils.checkpoint import load_pretrained_model
from wesep.utils.file_utils import read_label_file, read_vec_scp_file
from wesep.utils.score import cal_SISNRi
from wesep.utils.utils import (
generate_enahnced_scp,
get_logger,
parse_config_or_kwargs,
set_seed,
)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
def infer(config="confs/conf.yaml", **kwargs):
start = time.time()
total_SISNR = 0
total_SISNRi = 0
total_cnt = 0
accept_cnt = 0
configs = parse_config_or_kwargs(config, **kwargs)
sign_save_wav = configs.get(
"save_wav", True) # Control if save the extracted speech as .wav
rank = 0
set_seed(configs["seed"] + rank)
gpu = configs["gpus"]
device = (torch.device("cuda:{}".format(gpu))
if gpu >= 0 else torch.device("cpu"))
sample_rate = configs.get("fs", None)
if sample_rate is None or sample_rate == "16k":
sample_rate = 16000
else:
sample_rate = 8000
if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
model_path = os.path.join(configs["checkpoint"])
load_pretrained_model(model, model_path)
logger = get_logger(configs["exp_dir"], "infer.log")
logger.info("Load checkpoint from {}".format(model_path))
save_audio_dir = os.path.join(configs["exp_dir"], "audio")
if sign_save_wav:
if not os.path.exists(save_audio_dir):
try:
os.makedirs(save_audio_dir)
print(f"Directory {save_audio_dir} created successfully.")
except OSError as e:
print(f"Error creating directory {save_audio_dir}: {e}")
else:
print(f"Directory {save_audio_dir} already exists.")
else:
print("Do NOT save the results in wav.")
model = model.to(device)
model.eval()
test_spk_embeds = configs.get("test_spk_embeds", None)
test_spk1_embed_scp = configs["test_spk1_enroll"]
test_spk2_embed_scp = configs["test_spk2_enroll"]
joint_training = configs["model_args"]["tse_model"].get(
"joint_training", None)
if not joint_training and test_spk_embeds:
test_spk2embed_dict = read_vec_scp_file(test_spk_embeds)
else:
test_spk2embed_dict = read_label_file(configs["test_spk2utt"])
test_spk1_embed = read_label_file(test_spk1_embed_scp)
test_spk2_embed = read_label_file(test_spk2_embed_scp)
lines = len(test_spk2embed_dict)
test_dataset = Dataset(
configs["data_type"],
configs["test_data"],
configs["dataset_args"],
test_spk2embed_dict,
test_spk1_embed,
test_spk2_embed,
state="test",
joint_training=joint_training,
whole_utt=configs.get("whole_utt", True),
repeat_dataset=configs.get("repeat_dataset", False),
)
test_dataloader = DataLoader(test_dataset,
batch_size=1,
collate_fn=tse_collate_fn_2spk)
test_iter = lines // 2
logger.info("test number: {}".format(test_iter))
with torch.no_grad():
for i, batch in enumerate(test_dataloader):
features = batch["wav_mix"]
targets = batch["wav_targets"]
enroll = batch["spk_embeds"]
spk = batch["spk"]
key = batch["key"]
features = features.float().to(device) # (B,T,F)
targets = targets.float().to(device)
enroll = enroll.float().to(device)
outputs = model(features, enroll)
if isinstance(outputs, (list, tuple)):
outputs = outputs[0]
if torch.min(outputs.max(dim=1).values) > 0:
outputs = ((outputs /
abs(outputs).max(dim=1, keepdim=True)[0] *
0.9).cpu().numpy())
else:
outputs = outputs.cpu().numpy()
if sign_save_wav:
file1 = os.path.join(
save_audio_dir,
f"Utt{total_cnt + 1}-{key[0]}-T{spk[0]}.wav",
)
soundfile.write(file1, outputs[0], sample_rate)
file2 = os.path.join(
save_audio_dir,
f"Utt{total_cnt + 1}-{key[1]}-T{spk[1]}.wav",
)
soundfile.write(file2, outputs[1], sample_rate)
ref = targets.cpu().numpy()
ests = outputs
mix = features.cpu().numpy()
if ests[0].size != ref[0].size:
end = min(ests[0].size, ref[0].size, mix[0].size)
ests_1 = ests[0][:end]
ref_1 = ref[0][:end]
mix_1 = mix[0][:end]
SISNR1, delta1 = cal_SISNRi(ests_1, ref_1, mix_1)
else:
SISNR1, delta1 = cal_SISNRi(ests[0], ref[0], mix[0])
logger.info(
"Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}"
.format(total_cnt + 1, key[0], spk[0], SISNR1, delta1))
total_SISNR += SISNR1
total_SISNRi += delta1
total_cnt += 1
if delta1 > 1:
accept_cnt += 1
if ests[1].size != ref[1].size:
end = min(ests[1].size, ref[1].size, mix[1].size)
ests_2 = ests[1][:end]
ref_2 = ref[1][:end]
mix_2 = mix[1][:end]
SISNR2, delta2 = cal_SISNRi(ests_2, ref_2, mix_2)
else:
SISNR2, delta2 = cal_SISNRi(ests[1], ref[1], mix[1])
logger.info(
"Num={} | Utt={} | Target speaker={} | SI-SNR={:.2f} | SI-SNRi={:.2f}"
.format(total_cnt + 1, key[1], spk[1], SISNR2, delta2))
total_SISNR += SISNR2
total_SISNRi += delta2
total_cnt += 1
if delta2 > 1:
accept_cnt += 1
# if (i + 1) == test_iter:
# break
end = time.time()
# generate the scp file of the enhanced speech for scoring
if sign_save_wav:
generate_enahnced_scp(os.path.abspath(save_audio_dir), extension="wav")
logger.info("Time Elapsed: {:.1f}s".format(end - start))
logger.info("Average SI-SNR: {:.2f}".format(total_SISNR / total_cnt))
logger.info("Average SI-SNRi: {:.2f}".format(total_SISNRi / total_cnt))
logger.info(
"Acceptance rate of Utterances with SI-SDRi > 1 dB: {:.2f}".format(
accept_cnt / total_cnt * 100))
if __name__ == "__main__":
fire.Fire(infer)
================================================
FILE: wesep/bin/score.py
================================================
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/bin/enh_scoring.py
import argparse
import logging
import sys
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
from mir_eval.separation import bss_eval_sources
from pystoi import stoi
from wesep.utils.datadir_writer import DatadirWriter
from wesep.utils.file_utils import SoundScpReader
from wesep.utils.score import cal_SISNR
from wesep.utils.utils import ArgumentParser, get_commandline_args, str2bool
def get_readers(scps: List[str], dtype: str):
readers = [SoundScpReader(f, dtype=dtype) for f in scps]
audio_format = "sound"
return readers, audio_format
def read_audio(reader, key, audio_format="sound"):
if audio_format == "sound":
return reader[key][1]
else:
raise ValueError(f"Unknown audio format: {audio_format}")
def scoring(
output_dir: str,
dtype: str,
log_level: Union[int, str],
key_file: str,
ref_scp: List[str],
inf_scp: List[str],
ref_channel: int,
use_dnsmos: bool,
dnsmos_args: Dict,
use_pesq: bool,
):
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if use_dnsmos:
if dnsmos_args["mode"] == "local":
from wesep.utils.dnsmos import DNSMOS_local
if not Path(dnsmos_args["primary_model"]).exists():
raise ValueError(
f"The primary model {dnsmos_args['primary_model']} doesn't exist."
" You can download the model from https://github.com/microsoft/"
"DNS-Challenge/tree/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx")
if not Path(dnsmos_args["p808_model"]).exists():
raise ValueError(
f"The P808 model {dnsmos_args['p808_model']} doesn't exist."
" You can download the model from https://github.com/microsoft/"
"DNS-Challenge/tree/master/DNSMOS/DNSMOS/model_v8.onnx")
dnsmos = DNSMOS_local(
dnsmos_args["primary_model"],
dnsmos_args["p808_model"],
use_gpu=dnsmos_args["use_gpu"],
convert_to_torch=dnsmos_args["convert_to_torch"],
gpu_device=dnsmos_args["gpu_device"] - 1,
)
logging.warning("Using local DNSMOS models for evaluation")
elif dnsmos_args["mode"] == "web":
from wesep.utils.dnsmos import DNSMOS_web
if not dnsmos_args["auth_key"]:
raise ValueError(
"Please specify the authentication key for access to the Web-API. "
"You can apply for the AUTH_KEY at https://github.com/microsoft/"
"DNS-Challenge/blob/master/DNSMOS/README.md#to-use-the-web-api"
)
dnsmos = DNSMOS_web(dnsmos_args["auth_key"])
logging.warning("Using the DNSMOS Web-API for evaluation")
else:
dnsmos = None
if use_pesq:
try:
from pesq import PesqError, pesq
logging.warning("Using the PESQ package for evaluation")
except ImportError:
raise ImportError(
"Please install pesq and retry: pip install pesq") from None
else:
pesq = None
assert len(ref_scp) == len(inf_scp), "len(ref_scp) != len(inf_scp)"
num_spk = len(ref_scp)
keys = [
line.rstrip().split(maxsplit=1)[0]
for line in open(key_file, encoding="utf-8")
]
ref_readers, ref_audio_format = get_readers(ref_scp, dtype)
inf_readers, inf_audio_format = get_readers(inf_scp, dtype)
# get sample rate
retval = ref_readers[0][keys[0]]
if ref_audio_format == "kaldi_ark":
sample_rate = ref_readers[0].rate
elif ref_audio_format == "sound":
sample_rate = retval[0]
else:
raise NotImplementedError(ref_audio_format)
assert sample_rate is not None, (sample_rate, ref_audio_format)
# check keys
for inf_reader, ref_reader in zip(inf_readers, ref_readers):
assert inf_reader.keys() == ref_reader.keys()
with DatadirWriter(output_dir) as writer:
for n, key in enumerate(keys):
logging.info(f"[{n}] Scoring {key}")
ref_audios = [
read_audio(ref_reader, key, audio_format=ref_audio_format)
for ref_reader in ref_readers
]
inf_audios = [
read_audio(inf_reader, key, audio_format=inf_audio_format)
for inf_reader in inf_readers
]
ref = np.array(ref_audios)
inf = np.array(inf_audios)
if ref.ndim > inf.ndim:
# multi-channel reference and single-channel output
ref = ref[..., ref_channel]
elif ref.ndim < inf.ndim:
# single-channel reference and multi-channel output
inf = inf[..., ref_channel]
elif ref.ndim == inf.ndim == 3:
# multi-channel reference and output
ref = ref[..., ref_channel]
inf = inf[..., ref_channel]
assert ref.shape == inf.shape, (ref.shape, inf.shape)
sdr, sir, sar, perm = bss_eval_sources(ref,
inf,
compute_permutation=True)
for i in range(num_spk):
stoi_score = stoi(ref[i],
inf[int(perm[i])],
fs_sig=sample_rate)
estoi_score = stoi(
ref[i],
inf[int(perm[i])],
fs_sig=sample_rate,
extended=True,
)
si_snr_score = cal_SISNR(
ref[i],
inf[int(perm[i])],
)
if dnsmos:
dnsmos_score = dnsmos(inf[int(perm[i])], sample_rate)
writer[f"OVRL_spk{i + 1}"][key] = str(dnsmos_score["OVRL"])
writer[f"SIG_spk{i + 1}"][key] = str(dnsmos_score["SIG"])
writer[f"BAK_spk{i + 1}"][key] = str(dnsmos_score["BAK"])
writer[f"P808_MOS_spk{i + 1}"][key] = str(
dnsmos_score["P808_MOS"])
if pesq:
if sample_rate == 8000:
mode = "nb"
elif sample_rate == 16000:
mode = "wb"
else:
raise ValueError(
"sample rate must be 8000 or 16000 for PESQ evaluation, "
f"but got {sample_rate}")
pesq_score = pesq(
sample_rate,
ref[i],
inf[int(perm[i])],
mode=mode,
on_error=PesqError.RETURN_VALUES,
)
if pesq_score == PesqError.NO_UTTERANCES_DETECTED:
logging.warning(
f"[PESQ] Error: No utterances detected for {key}. "
"Skipping this utterance.")
else:
writer[f"PESQ_{mode.upper()}_spk{i + 1}"][key] = str(
pesq_score)
writer[f"STOI_spk{i + 1}"][key] = str(stoi_score *
100) # in percentage
writer[f"ESTOI_spk{i + 1}"][key] = str(estoi_score * 100)
writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score)
writer[f"SDR_spk{i + 1}"][key] = str(sdr[i])
writer[f"SAR_spk{i + 1}"][key] = str(sar[i])
writer[f"SIR_spk{i + 1}"][key] = str(sir[i])
# save permutation assigned script file
if i < len(ref_scp):
if inf_audio_format == "sound":
writer[f"wav_spk{i + 1}"][key] = inf_readers[
perm[i]].data[key]
elif inf_audio_format == "kaldi_ark":
# NOTE: SegmentsExtractor is not supported
writer[f"wav_spk{i + 1}"][key] = inf_readers[
perm[i]].loader._dict[key]
else:
raise ValueError(
f"Unknown audio format: {inf_audio_format}")
def get_parser():
parser = ArgumentParser(
description="Frontend inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--ref_scp",
type=str,
required=True,
action="append",
)
group.add_argument(
"--inf_scp",
type=str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str)
group.add_argument("--ref_channel", type=int, default=0)
group = parser.add_argument_group("DNSMOS related")
group.add_argument("--use_dnsmos", type=str2bool, default=False)
group.add_argument(
"--dnsmos_mode",
type=str,
choices=("local", "web"),
default="local",
help="Use local DNSMOS model or web API for DNSMOS calculation",
)
group.add_argument(
"--dnsmos_auth_key",
type=str,
default="",
help="Required if dnsmsos_mode='web'",
)
group.add_argument(
"--dnsmos_use_gpu",
type=str2bool,
default=False,
help="used when dnsmsos_mode='local'",
)
group.add_argument(
"--dnsmos_convert_to_torch",
type=str2bool,
default=False,
help="used when dnsmsos_mode='local'",
)
group.add_argument("--dnsmos_primary_model",
type=str,
default="./DNSMOS/sig_bak_ovr.onnx",
help="Path to the primary DNSMOS model. "
"Required if dnsmsos_mode='local'")
group.add_argument(
"--dnsmos_p808_model",
type=str,
default="./DNSMOS/model_v8.onnx",
help="Path to the p808 model. Required if dnsmsos_mode='local'",
)
group.add_argument("--dnsmos_gpu_device",
type=int,
default=None,
help="gpu device to use for DNSMOS evaluation. "
"Used when dnsmsos_mode='local'")
group = parser.add_argument_group("PESQ related")
group.add_argument(
"--use_pesq",
type=str2bool,
default=False,
help="Bebore setting this to True, please make sure that you or "
"your institution have the license "
"(check https://www.itu.int/rec/T-REC-P.862-200511-I!Amd2/en) to report PESQ",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
dnsmos_args = {
"mode": kwargs.pop("dnsmos_mode"),
"auth_key": kwargs.pop("dnsmos_auth_key"),
"primary_model": kwargs.pop("dnsmos_primary_model"),
"p808_model": kwargs.pop("dnsmos_p808_model"),
"use_gpu": kwargs.pop("dnsmos_use_gpu"),
"convert_to_torch": kwargs.pop("dnsmos_convert_to_torch"),
"gpu_device": kwargs.pop("dnsmos_gpu_device"),
}
kwargs["dnsmos_args"] = dnsmos_args
scoring(**kwargs)
if __name__ == "__main__":
main()
================================================
FILE: wesep/bin/train.py
================================================
# Copyright (c) 2023 Shuai Wang (wsstriving@gmail.com)
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import re
from pprint import pformat
import fire
import matplotlib.pyplot as plt
import tableprint as tp
import torch
import torch.distributed as dist
import yaml
from torch.utils.data import DataLoader
import wesep.utils.schedulers as schedulers
from wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk
from wesep.models import get_model
from wesep.utils.checkpoint import (
load_checkpoint,
load_pretrained_model,
save_checkpoint,
)
from wesep.utils.executor import Executor
from wesep.utils.file_utils import (
load_speaker_embeddings,
read_label_file,
read_vec_scp_file,
)
from wesep.utils.losses import parse_loss
from wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger
MAX_NUM_log_files = 100 # The maximum number of log-files to be kept
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)
def train(config="conf/config.yaml", **kwargs):
"""Trains a model on the given features and spk labels.
:config: A training configuration. Note that all parameters in the
config can also be manually adjusted with --ARG VALUE
:returns: None
"""
# print(kwargs)
configs = parse_config_or_kwargs(config, **kwargs)
checkpoint = configs.get("checkpoint", None)
if checkpoint is not None:
checkpoint = os.path.realpath(checkpoint)
find_unused_parameters = configs.get("find_unused_parameters", False)
# dist configs
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
gpu = int(configs["gpus"][rank])
torch.cuda.set_device(gpu)
dist.init_process_group(backend="nccl")
# Log rotation
model_dir = os.path.join(configs["exp_dir"], "models")
logger = setup_logger(rank, configs["exp_dir"], gpu, MAX_NUM_log_files)
print("-------------------", dist.get_rank(), world_size)
if world_size > 1:
logger.info("training on multiple gpus, this gpu {}".format(gpu))
if rank == 0:
logger.info("exp_dir is: {}".format(configs["exp_dir"]))
logger.info("<== Passed Arguments ==>")
# Print arguments into logs
for line in pformat(configs).split("\n"):
logger.info(line)
# seed
set_seed(configs["seed"] + rank)
# loss
criterion = configs.get("loss", None)
if criterion:
criterion = parse_loss(criterion)
else:
criterion = [
parse_loss("SISDR"),
]
loss_posi = configs["loss_args"].get(
"loss_posi",
[[
0,
]],
)
loss_weight = configs["loss_args"].get(
"loss_weight",
[[
1.0,
]],
)
loss_args = (loss_posi, loss_weight)
# embeds
tr_spk_embeds = configs.get("train_spk_embeds", None)
tr_single_utt2spk = configs["train_utt2spk"]
joint_training = configs["model_args"]["tse_model"].get(
"joint_training", False)
multi_task = configs["model_args"]["tse_model"].get("multi_task", False)
dict_spk = {}
if not joint_training and tr_spk_embeds:
tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds,
tr_single_utt2spk)
multi_task = None
else:
with open(configs["train_spk2utt"], "r") as f:
tr_spk2embed_dict = json.load(f)
if multi_task:
for i, j in enumerate(tr_spk2embed_dict.keys(
)): # Generate the dictionary for speakers in training set
dict_spk[j] = i
with open(tr_single_utt2spk, "r") as f:
tr_lines = f.readlines()
val_spk_embeds = configs.get("val_spk_embeds", None)
val_spk1_enroll = configs["val_spk1_enroll"]
val_spk2_enroll = configs["val_spk2_enroll"]
if not joint_training and val_spk_embeds:
val_spk2embed_dict = read_vec_scp_file(val_spk_embeds)
else:
val_spk2embed_dict = read_label_file(configs["val_spk2utt"])
val_lines = len(val_spk2embed_dict)
val_spk1_embed = read_label_file(val_spk1_enroll)
val_spk2_embed = read_label_file(val_spk2_enroll)
# dataset and dataloader
train_dataset = Dataset(
configs["data_type"],
configs["train_data"],
configs["dataset_args"],
tr_spk2embed_dict,
None,
None,
state="train",
joint_training=joint_training,
dict_spk=dict_spk,
whole_utt=configs.get("whole_utt", False),
repeat_dataset=configs.get("repeat_dataset", True),
noise_prob=configs["dataset_args"].get("noise_prob", 0),
reverb_prob=configs["dataset_args"].get("reverb_prob", 0),
noise_enroll_prob=configs["dataset_args"].get("noise_enroll_prob", 0),
reverb_enroll_prob=configs["dataset_args"].get("reverb_enroll_prob",
0),
specaug_enroll_prob=configs["dataset_args"].get(
"specaug_enroll_prob", 0),
online_mix=configs["dataset_args"].get("online_mix", False),
noise_lmdb_file=configs["dataset_args"].get("noise_lmdb_file", None),
)
val_dataset = Dataset(configs["data_type"],
configs["val_data"],
configs["dataset_args"],
val_spk2embed_dict,
val_spk1_embed,
val_spk2_embed,
state="val",
joint_training=joint_training,
whole_utt=configs.get("whole_utt", False),
repeat_dataset=True,
online_mix=False,
noise_prob=0,
reverb_prob=0,
noise_enroll_prob=0,
reverb_enroll_prob=0,
specaug_enroll_prob=0)
train_dataloader = DataLoader(train_dataset,
**configs["dataloader_args"],
collate_fn=tse_collate_fn)
val_dataloader = DataLoader(
val_dataset,
**configs["dataloader_args"],
collate_fn=tse_collate_fn_2spk,
)
batch_size = configs["dataloader_args"]["batch_size"]
if configs["dataset_args"].get("sample_num_per_epoch", 0) > 0:
sample_num_per_epoch = configs["dataset_args"]["sample_num_per_epoch"]
else:
sample_num_per_epoch = len(tr_lines) // 2
epoch_iter = sample_num_per_epoch // world_size // batch_size
val_iter = val_lines // 2 // world_size // batch_size
if rank == 0:
logger.info("<== Dataloaders ==>")
logger.info("train dataloaders created")
logger.info("epoch iteration number: {}".format(epoch_iter))
logger.info("val iteration number: {}".format(val_iter))
# model
model_list = []
scheduler_list = []
optimizer_list = []
logger.info("<== Model ==>")
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
num_params = sum(param.numel() for param in model.parameters())
if rank == 0:
logger.info("tse_model size: {:.2f} M".format(num_params / 1e6))
# print model
for line in pformat(model).split("\n"):
logger.info(line)
# ddp_model
model.cuda()
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=find_unused_parameters)
device = torch.device("cuda")
if rank == 0:
logger.info("<== TSE Model Loss ==>")
logger.info("loss criterion is: " + str(configs["loss"]))
configs["optimizer_args"]["tse_model"]["lr"] = configs["scheduler_args"][
"tse_model"]["initial_lr"]
optimizer = getattr(torch.optim, configs["optimizer"]["tse_model"])(
ddp_model.parameters(), **configs["optimizer_args"]["tse_model"])
if rank == 0:
logger.info("<== TSE Model Optimizer ==>")
logger.info("optimizer is: " + configs["optimizer"]["tse_model"])
# scheduler
configs["scheduler_args"]["tse_model"]["num_epochs"] = configs[
"num_epochs"]
configs["scheduler_args"]["tse_model"]["epoch_iter"] = epoch_iter
configs["scheduler_args"]["scale_ratio"] = 1.0
scheduler = getattr(schedulers, configs["scheduler"]["tse_model"])(
optimizer, **configs["scheduler_args"]["tse_model"])
if rank == 0:
logger.info("<== TSE Model Scheduler ==>")
logger.info("scheduler is: " + configs["scheduler"]["tse_model"])
if configs["model_init"]["tse_model"] is not None:
logger.info("Load initial model from {}".format(
configs["model_init"]["tse_model"]))
load_pretrained_model(ddp_model, configs["model_init"]["tse_model"])
elif checkpoint is None:
logger.info("Train model from scratch ...")
for c in criterion:
c = c.to(device)
# append to list
model_list.append(ddp_model)
optimizer_list.append(optimizer)
scheduler_list.append(scheduler)
scaler = torch.cuda.amp.GradScaler(enabled=configs["enable_amp"])
# If specify checkpoint, load some info from checkpoint.
if checkpoint is not None:
load_checkpoint(model_list, optimizer_list, scheduler_list, scaler,
checkpoint)
start_epoch = (
int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", checkpoint)[0]) + 1)
logger.info("Load checkpoint: {}".format(checkpoint))
else:
start_epoch = 1
logger.info("start_epoch: {}".format(start_epoch))
# save config.yaml
if rank == 0:
saved_config_path = os.path.join(configs["exp_dir"], "config.yaml")
with open(saved_config_path, "w") as fout:
data = yaml.dump(configs)
fout.write(data)
# training
dist.barrier(device_ids=[gpu]) # synchronize here
if rank == 0:
logger.info("<========== Training process ==========>")
header = ["Train/Val", "Epoch", "iter", "Loss", "LR"]
for line in tp.header(header, width=10, style="grid").split("\n"):
logger.info(line)
dist.barrier(device_ids=[gpu]) # synchronize here
executor = Executor()
executor.step = 0
train_losses = []
val_losses = []
for epoch in range(start_epoch, configs["num_epochs"] + 1):
train_dataset.set_epoch(epoch)
# train_loss_com
train_loss, _ = executor.train(
train_dataloader,
model_list,
epoch_iter,
optimizer_list,
criterion,
scheduler_list,
scaler=scaler,
epoch=epoch,
logger=logger,
enable_amp=configs["enable_amp"],
clip_grad=configs["clip_grad"],
log_batch_interval=configs["log_batch_interval"],
device=device,
se_loss_weight=loss_args,
multi_task=multi_task,
SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", 0),
fbank_args=configs["dataset_args"].get('fbank_args', None),
sample_rate=configs["dataset_args"]['resample_rate'],
speaker_feat=configs["dataset_args"].get('speaker_feat', True)
)
val_loss, _ = executor.cv(
val_dataloader,
model_list,
val_iter,
criterion,
epoch=epoch,
logger=logger,
enable_amp=configs["enable_amp"],
log_batch_interval=configs["log_batch_interval"],
device=device,
)
if rank == 0:
logger.info("Epoch {} Train info train_loss {}".format(
epoch, train_loss))
logger.info("Epoch {} Val info val_loss {}".format(
epoch, val_loss))
train_losses.append(train_loss)
val_losses.append(val_loss)
best_loss = val_loss
scheduler.best = best_loss
# plot
plt.figure()
plt.title("Loss of Train and Validation")
x = list(range(start_epoch, epoch + 1))
plt.plot(x, train_losses, "b-", label="Train Loss", linewidth=0.8)
plt.plot(x,
val_losses,
"c-",
label="Validation Loss",
linewidth=0.8)
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.xticks(range(start_epoch, epoch + 1, 1))
plt.savefig(
f"{configs['exp_dir']}/{configs['model']['tse_model']}.png")
plt.close()
if rank == 0:
if (epoch % configs["save_epoch_interval"] == 0
or epoch >= configs["num_epochs"] - configs["num_avg"]):
save_checkpoint(
model_list,
optimizer_list,
scheduler_list,
scaler,
os.path.join(model_dir, "checkpoint_{}.pt".format(epoch)),
)
try:
os.symlink(
"checkpoint_{}.pt".format(epoch),
os.path.join(model_dir, "latest_checkpoint.pt"),
)
except FileExistsError:
os.remove(os.path.join(model_dir, "latest_checkpoint.pt"))
os.symlink(
"checkpoint_{}.pt".format(epoch),
os.path.join(model_dir, "latest_checkpoint.pt"),
)
if rank == 0:
os.symlink(
"checkpoint_{}.pt".format(configs["num_epochs"]),
os.path.join(model_dir, "final_checkpoint.pt"),
)
logger.info(tp.bottom(len(header), width=10, style="grid"))
if __name__ == "__main__":
fire.Fire(train)
================================================
FILE: wesep/bin/train_gan.py
================================================
# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)
# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import re
from pprint import pformat
import fire
import matplotlib.pyplot as plt
import tableprint as tp
import torch
import torch.distributed as dist
import yaml
from torch.utils.data import DataLoader
import wesep.utils.schedulers as schedulers
from wesep.dataset.dataset import Dataset, tse_collate_fn, tse_collate_fn_2spk
from wesep.models import get_model
from wesep.utils.checkpoint import (
load_checkpoint,
load_pretrained_model,
save_checkpoint,
)
from wesep.utils.executor_gan import ExecutorGAN
from wesep.utils.file_utils import (
load_speaker_embeddings,
read_label_file,
read_vec_scp_file,
)
from wesep.utils.losses import parse_loss
from wesep.utils.utils import parse_config_or_kwargs, set_seed, setup_logger
MAX_NUM_log_files = 100 # The maximum number of log-files to be kept
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)
def train(config="conf/config.yaml", **kwargs):
"""Trains a model on the given features and spk labels.
:config: A training configuration. Note that all parameters in the
config can also be manually adjusted with --ARG VALUE
:returns: None
"""
configs = parse_config_or_kwargs(config, **kwargs)
checkpoint = configs.get("checkpoint", None)
if checkpoint is not None:
checkpoint = os.path.realpath(checkpoint)
find_unused_parameters = configs.get("find_unused_parameters", False)
gan_loss_weight = configs.get("gan_loss_weight", 0.05)
# dist configs
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
gpu = int(configs["gpus"][rank])
torch.cuda.set_device(gpu)
dist.init_process_group(backend="nccl")
# Log rotation
model_dir = os.path.join(configs["exp_dir"], "models")
logger = setup_logger(rank, configs["exp_dir"], gpu, MAX_NUM_log_files)
print("-------------------", dist.get_rank(), world_size)
if world_size > 1:
logger.info("training on multiple gpus, this gpu {}".format(gpu))
if rank == 0:
logger.info("exp_dir is: {}".format(configs["exp_dir"]))
logger.info("<== Passed Arguments ==>")
# Print arguments into logs
for line in pformat(configs).split("\n"):
logger.info(line)
# seed
set_seed(configs["seed"] + rank)
# support multiple losses, e.g., criterion = [SISNR, CE]
criterion = configs.get("loss", None)
if criterion:
criterion = parse_loss(criterion)
else:
criterion = [
parse_loss("SISNR"),
]
# loss_posi is used to store the indices when the model has multiple outputs
# loss_posi[i][j] stores the index of the output used for i-th criterion,
# that is, output[loss_posi[i][j]] is used for the i-th criterion.
loss_posi = configs["loss_args"].get(
"loss_posi",
[[
0,
]],
)
# loss_weight[i][j] stores the loss weight of output[loss_posi[i][j]] for the i-th criterion. # noqa
loss_weight = configs["loss_args"].get(
"loss_weight",
[[
1.0,
]],
)
loss_args = (loss_posi, loss_weight)
# embeds
tr_spk_embeds = configs["train_spk_embeds"]
tr_single_utt2spk = configs["train_utt2spk"]
joint_training = configs["model_args"]["tse_model"].get(
"joint_training", False)
multi_task = configs["model_args"]["tse_model"].get("multi_task", False)
# dict_spk: {spk_id: int_label}
dict_spk = {}
if not joint_training:
tr_spk2embed_dict = load_speaker_embeddings(tr_spk_embeds,
tr_single_utt2spk)
multi_task = False
else:
with open(configs["train_spk2utt"], "r") as f:
tr_spk2embed_dict = json.load(f)
# tr_spk2embed_dict: {spk_id: [[spk_id, wav_path], ...]}
if multi_task:
for i, j in enumerate(tr_spk2embed_dict.keys(
)): # Generate the dictionary for speakers in training set
dict_spk[j] = i
with open(tr_single_utt2spk, "r") as f:
tr_lines = f.readlines()
val_spk_embeds = configs["val_spk_embeds"]
val_spk1_enroll = configs["val_spk1_enroll"]
val_spk2_enroll = configs["val_spk2_enroll"]
if not joint_training:
val_spk2embed_dict = read_vec_scp_file(val_spk_embeds)
else:
val_spk2embed_dict = read_label_file(configs["val_spk2utt"])
val_spk1_embed = read_label_file(val_spk1_enroll)
val_spk2_embed = read_label_file(val_spk2_enroll)
with open(val_spk_embeds, "r") as f:
val_lines = f.readlines()
# dataset and dataloader
train_dataset = Dataset(
configs["data_type"],
configs["train_data"],
configs["dataset_args"],
tr_spk2embed_dict,
None,
None,
state="train",
joint_training=joint_training,
dict_spk=dict_spk,
whole_utt=configs.get("whole_utt", False),
repeat_dataset=configs.get("repeat_dataset", True),
reverb=configs["dataset_args"].get("reverb", False),
noise=configs["dataset_args"].get("noise", False),
noise_lmdb_file=configs["dataset_args"].get("noise_lmdb_file", None),
online_mix=configs["dataset_args"].get("online_mix", False),
)
val_dataset = Dataset(
configs["data_type"],
configs["val_data"],
configs["dataset_args"],
val_spk2embed_dict,
val_spk1_embed,
val_spk2_embed,
state="val",
joint_training=joint_training,
whole_utt=configs.get("whole_utt", False),
repeat_dataset=True,
reverb=False,
online_mix=False,
)
train_dataloader = DataLoader(train_dataset,
**configs["dataloader_args"],
collate_fn=tse_collate_fn)
val_dataloader = DataLoader(
val_dataset,
**configs["dataloader_args"],
collate_fn=tse_collate_fn_2spk,
)
batch_size = configs["dataloader_args"]["batch_size"]
if configs["dataset_args"].get("sample_num_per_epoch", 0) > 0:
sample_num_per_epoch = configs["dataset_args"]["sample_num_per_epoch"]
else:
sample_num_per_epoch = len(tr_lines) // 2
epoch_iter = sample_num_per_epoch // world_size // batch_size
val_iter = len(val_lines) // 2 // world_size // batch_size
if rank == 0:
logger.info("<== Dataloaders ==>")
logger.info("train dataloaders created")
logger.info("epoch iteration number: {}".format(epoch_iter))
logger.info("val iteration number: {}".format(val_iter))
# model
model_list = []
scheduler_list = []
optimizer_list = []
logger.info("<== Model ==>")
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
num_params = sum(param.numel() for param in model.parameters())
if rank == 0:
logger.info("tse_model size: {}".format(num_params))
# print model
for line in pformat(model).split("\n"):
logger.info(line)
# ddp_model
model.cuda()
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=find_unused_parameters)
device = torch.device("cuda")
if rank == 0:
logger.info("<== TSE Model Loss ==>")
logger.info("loss criterion is: " + str(configs["loss"]))
configs["optimizer_args"]["tse_model"]["lr"] = configs["scheduler_args"][
"tse_model"]["initial_lr"]
optimizer = getattr(torch.optim, configs["optimizer"]["tse_model"])(
ddp_model.parameters(), **configs["optimizer_args"]["tse_model"])
if rank == 0:
logger.info("<== TSE Model Optimizer ==>")
logger.info("optimizer is: " + configs["optimizer"]["tse_model"])
# scheduler
configs["scheduler_args"]["tse_model"]["num_epochs"] = configs[
"num_epochs"]
configs["scheduler_args"]["tse_model"]["epoch_iter"] = epoch_iter
configs["scheduler_args"]["scale_ratio"] = 1.0
scheduler = getattr(schedulers, configs["scheduler"]["tse_model"])(
optimizer, **configs["scheduler_args"]["tse_model"])
if rank == 0:
logger.info("<== TSE Model Scheduler ==>")
logger.info("scheduler is: " + configs["scheduler"]["tse_model"])
if configs["model_init"]["tse_model"] is not None:
logger.info("Load initial model from {}".format(
configs["model_init"]["tse_model"]))
load_pretrained_model(ddp_model, configs["model_init"]["tse_model"])
elif checkpoint is None:
logger.info("Train model from scratch ...")
for c in criterion:
c = c.to(device)
# append to list
model_list.append(ddp_model)
optimizer_list.append(optimizer)
scheduler_list.append(scheduler)
scaler = torch.cuda.amp.GradScaler(enabled=configs["enable_amp"])
# discriminator
discriminator = get_model(configs["model"]["discriminator"])(
**configs["model_args"]["discriminator"])
num_params = sum(param.numel() for param in discriminator.parameters())
# optimizer
configs["optimizer_args"]["discriminator"]["lr"] = configs[
"scheduler_args"]["discriminator"]["initial_lr"]
# scheduler
configs["scheduler_args"]["discriminator"]["num_epochs"] = configs[
"num_epochs"]
configs["scheduler_args"]["discriminator"]["epoch_iter"] = epoch_iter
configs["scheduler_args"]["discriminator"]["scale_ratio"] = 1.0
# ddp model
discriminator.cuda()
ddp_discriminator = torch.nn.parallel.DistributedDataParallel(
discriminator, find_unused_parameters=find_unused_parameters)
optimizer_d = getattr(torch.optim, configs["optimizer"]["discriminator"])(
ddp_discriminator.parameters(),
**configs["optimizer_args"]["discriminator"],
)
scheduler_d = getattr(schedulers, configs["scheduler"]["discriminator"])(
optimizer_d, **configs["scheduler_args"]["discriminator"])
# initialize discriminator
if configs["model_init"]["discriminator"] is not None:
logger.info("Load initial discriminator from {}".format(
configs["model_init"]["discriminator"]))
load_pretrained_model(
ddp_discriminator,
configs["model_init"]["discriminator"],
type="discriminator",
)
elif checkpoint is None:
logger.info("Train discriminator from scratch ...")
# If specify checkpoint, load some info from checkpoint.
if checkpoint is not None:
load_checkpoint(model_list, optimizer_list, scheduler_list, scaler,
checkpoint)
start_epoch = (
int(re.findall(r"(?<=checkpoint_)\d*(?=.pt)", checkpoint)[0]) + 1)
logger.info("Load checkpoint: {}".format(checkpoint))
else:
start_epoch = 1
model_list.append(ddp_discriminator)
optimizer_list.append(optimizer_d)
scheduler_list.append(scheduler_d)
if rank == 0:
logger.info("<== Discriminator Model ==>")
logger.info("discriminator size: {}".format(num_params))
for line in pformat(discriminator).split("\n"):
logger.info(line)
logger.info("<== Discriminator Optimizer ==>")
logger.info("optimizer is: " + configs["optimizer"]["discriminator"])
logger.info("<== Discriminator Scheduler ==>")
logger.info("scheduler is: " + configs["scheduler"]["discriminator"])
# save config.yaml
saved_config_path = os.path.join(configs["exp_dir"], "config.yaml")
with open(saved_config_path, "w") as fout:
data = yaml.dump(configs)
fout.write(data)
logger.info("start_epoch: {}".format(start_epoch))
# training
dist.barrier(device_ids=[gpu]) # synchronize here
if rank == 0:
logger.info("<========== Training process ==========>")
header = [
"Train/Val",
"Epoch",
"iter",
"SE_Loss",
"G_Loss",
"D_Loss",
"LR",
]
for line in tp.header(header, width=10, style="grid").split("\n"):
logger.info(line)
dist.barrier(device_ids=[gpu]) # synchronize here
executor = ExecutorGAN()
executor.step = 0
train_losses = []
val_losses = []
train_d_losses = []
val_d_losses = []
for epoch in range(start_epoch, configs["num_epochs"] + 1):
train_dataset.set_epoch(epoch)
train_loss, train_d_loss = executor.train(
train_dataloader,
model_list,
epoch_iter,
optimizer_list,
criterion,
scheduler_list,
scaler=scaler,
epoch=epoch,
logger=logger,
enable_amp=configs["enable_amp"],
clip_grad=configs["clip_grad"],
log_batch_interval=configs["log_batch_interval"],
device=device,
se_loss_weight=loss_args,
gan_loss_weight=gan_loss_weight,
multi_task=multi_task,
)
val_loss, val_d_loss = executor.cv(
val_dataloader,
model_list,
val_iter,
criterion,
epoch=epoch,
logger=logger,
enable_amp=configs["enable_amp"],
log_batch_interval=configs["log_batch_interval"],
device=device,
)
if rank == 0:
logger.info(
"Epoch {} Train info train_loss {}, train_d_loss {}".format(
epoch, train_loss, train_d_loss))
logger.info("Epoch {} Val info val_loss {}, val_d_loss {}".format(
epoch, val_loss, val_d_loss))
train_losses.append(train_loss)
train_d_losses.append(train_d_loss)
val_losses.append(val_loss)
val_d_losses.append(val_d_loss)
best_loss = val_loss
scheduler.best = best_loss
# plot
plt.figure()
plt.title("Loss of Train and Validation")
x = list(range(start_epoch, epoch + 1))
plt.plot(x,
train_losses,
"b-",
label="train_G_loss",
linewidth=0.8)
plt.plot(x,
train_d_losses,
"r-",
label="train_D_loss",
linewidth=0.8)
plt.plot(x, val_losses, "c-", label="val_G_loss", linewidth=0.8)
plt.plot(x, val_d_losses, "m-", label="val_D_loss", linewidth=0.8)
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.xticks(range(start_epoch, epoch + 1, 1))
plt.savefig(
f"{configs['exp_dir']}/{configs['model']['tse_model']}.png")
plt.close()
if rank == 0:
if (epoch % configs["save_epoch_interval"] == 0
or epoch >= configs["num_epochs"] - configs["num_avg"]):
save_checkpoint(
model_list,
optimizer_list,
scheduler_list,
scaler,
os.path.join(model_dir, "checkpoint_{}.pt".format(epoch)),
)
try:
os.symlink(
"checkpoint_{}.pt".format(epoch),
os.path.join(model_dir, "latest_checkpoint.pt"),
)
except FileExistsError:
os.remove(os.path.join(model_dir, "latest_checkpoint.pt"))
os.symlink(
"checkpoint_{}.pt".format(epoch),
os.path.join(model_dir, "latest_checkpoint.pt"),
)
if rank == 0:
os.symlink(
"checkpoint_{}.pt".format(configs["num_epochs"]),
os.path.join(model_dir, "final_checkpoint.pt"),
)
logger.info(tp.bottom(len(header), width=10, style="grid"))
if __name__ == "__main__":
fire.Fire(train)
================================================
FILE: wesep/cli/__init__.py
================================================
================================================
FILE: wesep/cli/extractor.py
================================================
import os
import sys
from silero_vad import load_silero_vad, get_speech_timestamps
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
import soundfile
from wesep.cli.hub import Hub
from wesep.cli.utils import get_args
from wesep.models import get_model
from wesep.utils.checkpoint import load_pretrained_model
from wesep.utils.utils import set_seed
class Extractor:
def __init__(self, model_dir: str):
set_seed()
config_path = os.path.join(model_dir, "config.yaml")
model_path = os.path.join(model_dir, "avg_model.pt")
with open(config_path, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
self.model = get_model(configs["model"]["tse_model"])(
**configs["model_args"]["tse_model"]
)
load_pretrained_model(self.model, model_path)
self.model.eval()
self.vad = load_silero_vad()
self.table = {}
self.resample_rate = configs["dataset_args"].get("resample_rate", 16000)
self.apply_vad = False
self.device = torch.device("cpu")
self.wavform_norm = True
self.output_norm = True
self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False)
self.joint_training = configs["model_args"]["tse_model"].get(
"joint_training", False
)
def set_wavform_norm(self, wavform_norm: bool):
self.wavform_norm = wavform_norm
def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate
def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad
def set_device(self, device: str):
self.device = torch.device(device)
self.model = self.model.to(self.device)
def set_output_norm(self, output_norm: bool):
self.output_norm = output_norm
def compute_fbank(
self,
wavform,
sample_rate=16000,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
cmn=True,
):
feat = kaldi.fbank(
wavform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
sample_frequency=sample_rate,
)
if cmn:
feat = feat - torch.mean(feat, 0)
return feat
def extract_speech(self, audio_path: str, audio_path_2: str):
pcm_mix, sample_rate_mix = torchaudio.load(
audio_path, normalize=self.wavform_norm
)
pcm_enroll, sample_rate_enroll = torchaudio.load(
audio_path_2, normalize=self.wavform_norm
)
return self.extract_speech_from_pcm(pcm_mix,
sample_rate_mix,
pcm_enroll,
sample_rate_enroll)
def extract_speech_from_pcm(self,
pcm_mix: torch.Tensor,
sample_rate_mix: int,
pcm_enroll: torch.Tensor,
sample_rate_enroll: int):
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
# Only do vad on the enrollment
vad_sample_rate = 16000
wav = pcm_enroll
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sample_rate_enroll != vad_sample_rate:
transform = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=vad_sample_rate
)
wav = transform(wav)
segments = get_speech_timestamps(wav, self.vad, return_seconds=True)
pcmTotal = torch.Tensor()
if len(segments) > 0: # remove all the silence
for segment in segments:
start = int(segment["start"] * sample_rate_enroll)
end = int(segment["end"] * sample_rate_enroll)
pcmTemp = pcm_enroll[0, start:end]
pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)
pcm_enroll = pcmTotal.unsqueeze(0)
else: # all silence, nospeech
return None
pcm_mix = pcm_mix.to(torch.float)
if sample_rate_mix != self.resample_rate:
pcm_mix = torchaudio.transforms.Resample(
orig_freq=sample_rate_mix, new_freq=self.resample_rate
)(pcm_mix)
pcm_enroll = pcm_enroll.to(torch.float)
if sample_rate_enroll != self.resample_rate:
pcm_enroll = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=self.resample_rate
)(pcm_enroll)
if self.joint_training:
if self.speaker_feat:
feats = self.compute_fbank(
pcm_enroll, sample_rate=self.resample_rate, cmn=True
)
feats = feats.unsqueeze(0)
else:
feats = pcm_enroll
feats = feats.to(self.device)
pcm_mix = pcm_mix.to(self.device)
with torch.no_grad():
outputs = self.model(pcm_mix, feats)
outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
target_speech = outputs.to(torch.device("cpu"))
if self.output_norm:
target_speech = (
target_speech
/ abs(target_speech).max(dim=1, keepdim=True).values * 0.9
)
return target_speech
else:
return None
def load_model(language: str) -> Extractor:
model_path = Hub.get_model(language)
return Extractor(model_path)
def load_model_local(model_dir: str) -> Extractor:
return Extractor(model_dir)
def main():
args = get_args()
if args.pretrain == "":
if args.bsrnn:
model = load_model("bsrnn")
else:
model = load_model(args.language)
else:
model = load_model_local(args.pretrain)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
model.set_device(args.device)
model.set_output_norm(args.output_norm)
if args.task == "extraction":
speech = model.extract_speech(args.audio_file, args.audio_file2)
if speech is not None:
if args.normalize_output:
speech = speech / abs(speech).max(dim=1, keepdim=True).values * 0.9
soundfile.write(args.output_file, speech[0], args.resample_rate)
print("Succeed, see {}".format(args.output_file))
else:
print("Fails to extract the target speech")
else:
print("Unsupported task {}".format(args.task))
sys.exit(-1)
if __name__ == "__main__":
main()
================================================
FILE: wesep/cli/hub.py
================================================
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
# 2023 Binbin Zhang(binbzha@qq.com)
# 2024 Shuai Wang(wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
import tarfile
import zipfile
from urllib.request import urlretrieve
import tqdm
def download(url: str, dest: str, only_child=True):
"""download from url to dest"""
assert os.path.exists(dest)
print("Downloading {} to {}".format(url, dest))
def progress_hook(t):
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed
return update_to
# *.tar.gz
name = url.split("?")[0].split("/")[-1]
file_path = os.path.join(dest, name)
with tqdm.tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name)
) as t:
urlretrieve(
url, filename=file_path, reporthook=progress_hook(t), data=None
)
t.total = t.n
if name.endswith((".tar.gz", ".tar")):
with tarfile.open(file_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())
elif name.endswith(".zip"):
with zipfile.ZipFile(file_path, "r") as zip_ref:
if not only_child:
zip_ref.extractall(dest)
else:
for member in zip_ref.namelist():
member_path = os.path.relpath(
member, start=os.path.commonpath(zip_ref.namelist())
)
print(member_path)
if "/" not in member_path:
continue
name = os.path.basename(member_path)
with zip_ref.open(member_path) as source, open(
os.path.join(dest, name), "wb"
) as target:
target.write(source.read())
class Hub(object):
Assets = {
"english": "bsrnn_ecapa_vox1.tar.gz",
}
# Hard coding of the URL
ModelURLs = {
"bsrnn_ecapa_vox1.tar.gz": (
"https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/"
"resolve/master/bsrnn_ecapa_vox1.tar.gz"
),
}
def __init__(self) -> None:
pass
@staticmethod
def get_model(lang: str) -> str:
if lang not in Hub.Assets.keys():
print("ERROR: Unsupported lang {} !!!".format(lang))
sys.exit(1)
# model = Hub.Assets[lang]
model_name = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wesep", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if set(["avg_model.pt", "config.yaml"]).issubset(
set(os.listdir(model_dir))
):
return model_dir
else:
if model_name in Hub.ModelURLs:
model_url = Hub.ModelURLs[model_name]
download(model_url, model_dir)
return model_dir
else:
print(f"ERROR: No URL found for model {model_name}")
return None
================================================
FILE: wesep/cli/utils.py
================================================
import argparse
def get_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-t",
"--task",
choices=[
"extraction",
],
default="extraction",
help="task type",
)
parser.add_argument(
"-l",
"--language",
choices=[
# "chinese",
"english",
],
default="english",
help="language type",
)
parser.add_argument(
"--bsrnn",
action="store_true",
help="whether to use the bsrnn model",
)
parser.add_argument(
"-p", "--pretrain", type=str, default="", help="model directory"
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="device type (most commonly cpu or cuda,"
"but also potentially mps, xpu, xla or meta)"
"and optional device ordinal for the device type.",
)
parser.add_argument("--audio_file", help="mixture's audio file")
parser.add_argument("--audio_file2", help="enroll's audio file")
parser.add_argument(
"--resample_rate", type=int, default=16000, help="resampling rate"
)
parser.add_argument(
"--vad", action="store_true", help="whether to do VAD or not"
)
parser.add_argument(
"--output_file",
default='./extracted_speech.wav',
help="extracted speech saved in .wav"
)
parser.add_argument(
"--output_norm",
default=True,
help="Control if normalize the output audio in .wav"
)
args = parser.parse_args()
return args
================================================
FILE: wesep/dataset/FRAM_RIR.py
================================================
# Author: Rongzhi Gu, Yi Luo
# Copyright: Tencent AI Lab
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torchaudio.functional import highpass_biquad
from torchaudio.transforms import Resample
# set random seed
seed = 20231
np.random.seed(seed)
torch.manual_seed(seed)
def calc_cos(orientation_rad):
"""
cos_theta: tensor, [azimuth, elevation] with shape [..., 2]
return: [..., 3]
"""
return torch.stack(
[
torch.cos(
orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])),
torch.sin(
orientation_rad[..., 0] * torch.sin(orientation_rad[..., 1])),
torch.cos(orientation_rad[..., 1]),
],
-1,
)
def freq_invariant_decay_func(cos_theta, pattern="cardioid"):
"""
cos_theta: tensor
Return:
amplitude: tensor with same shape as cos_theta
"""
if pattern == "cardioid":
return 0.5 + 0.5 * cos_theta
elif pattern == "omni":
return torch.ones_like(cos_theta)
elif pattern == "bidirectional":
return cos_theta
elif pattern == "hyper_cardioid":
return 0.25 + 0.75 * cos_theta
elif pattern == "sub_cardioid":
return 0.75 + 0.25 * cos_theta
elif pattern == "half_omni":
c = torch.clamp(cos_theta, 0)
c[c > 0] = 1.0
return c
else:
raise NotImplementedError
def freq_invariant_src_decay_func(mic_pos,
src_pos,
src_orientation_rad,
pattern="cardioid"):
"""
mic_pos: [n_mic, 3] (tensor)
src_pos: [n_src, 3] (tensor)
src_orientation_rad: [n_src, 2] (tensor), elevation, azimuth
Return:
amplitude: [n_mic, n_src, n_image]
"""
# Steering vector of source(s)
orV_src = calc_cos(src_orientation_rad).unsqueeze(0) # [nsrc, 3]
# receiver to src vector
rcv_to_src_vec = mic_pos.unsqueeze(1) - src_pos.unsqueeze(
0) # [n_mic, n_src, 3]
cos_theta = (rcv_to_src_vec * orV_src).sum(-1) # [n_mic, n_src]
cos_theta /= torch.sqrt(rcv_to_src_vec.pow(2).sum(-1))
cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1))
return freq_invariant_decay_func(cos_theta, pattern)
def freq_invariant_mic_decay_func(mic_pos,
img_pos,
mic_orientation_rad,
pattern="cardioid"):
"""
mic_pos: [n_mic, 3] (tensor)
img_pos: [n_src, n_image, 3] (tensor)
mic_orientation_rad: [n_mic, 2] (tensor), azimuth, elevation
Return:
amplitude: [n_mic, n_src, n_image]
"""
# Steering vector of source(s)
orV_src = calc_cos(mic_orientation_rad) # [nmic, 3]
orV_src = orV_src.view(-1, 1, 1, 3) # [n_mic, 1, 1, 3]
# image to receiver vector
# [1, n_src, n_image, 3] - [n_mic, 1, 1, 3] => [n_mic, n_src, n_image, 3]
img_to_rcv_vec = img_pos.unsqueeze(0) - mic_pos.unsqueeze(1).unsqueeze(1)
cos_theta = (img_to_rcv_vec * orV_src).sum(-1) # [n_mic, n_src, n_image]
cos_theta /= torch.sqrt(img_to_rcv_vec.pow(2).sum(-1))
cos_theta /= torch.sqrt(orV_src.pow(2).sum(-1))
return freq_invariant_decay_func(cos_theta, pattern)
def FRAM_RIR(
mic_pos,
sr,
T60,
room_dim,
src_pos,
num_src=1,
direct_range=(-6, 50),
n_image=(1024, 4097),
a=-2.0,
b=2.0,
tau=0.25,
src_pattern="omni",
src_orientation_rad=None,
mic_pattern="omni",
mic_orientation_rad=None,
):
"""Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR) # noqa
Args:
mic_pos: The microphone(s) position with respect to the room coordinates, # noqa
with shape [num_mic, 3] (in meters). Room coordinate system must be defined in advance, # noqa
with the constraint that the origin of the coordinate is on the floor(so positive z axis points up). # noqa
sr: RIR sampling rate (Hz).
T60: RT60 (second).
room_dim: Room size with shape [3] (meters).
src_pos: The source(s) position with respect to the room coordinate system, with shape [num_src, 3] (meters). # noqa
num_src: Number of sources. Defaults to 1.
direct_range: 2-element tuple, range of early reflection time (milliseconds, # noqa
defined as the context around the direct path signal) of RIRs. # noqa
Defaults to (-6, 50).
n_image: 2-element tuple, minimum and maximum number of images to sample from. # noqa
Defaults to (1024, 4097).
a: controlling the random perturbation added to each virtual sound source. Defaults to -2.0. # noqa
b: controlling the random perturbation added to each virtual sound source. Defaults to 2.0. # noqa
tau: controlling the relationship between the distance and the number of reflections of each # noqa
virtual sound source. Defaults to 0.25.
src_pattern: Polar pattern for all of the sources. Defaults to "omni".
src_orientation_rad: Array-like with shape [num_src, 2]. Orientation (rad) of all # noqa
the sources, where the first column indicate azimuth and the # noqa
second column indicate elevation. Defaults to None. # noqa
mic_pattern: Polar pattern for all of the receivers. Defaults to "omni".
mic_orientation_rad: Array-like with shape [num_mic, 2]. Orientation (rad) of all # noqa
the microphones, where the first column indicate azimuth and # noqa
the second column indicate elevation. Defaults to None. # noqa
Returns:
rir: RIR filters for all mic-source pairs, with shape [num_mic, num_src, rir_length]. # noqa
early_rir: Early reflection (direct path) RIR filters for all mic-source pairs, # noqa
with shape [num_mic, num_src, rir_length].
"""
# sample image
image = np.random.choice(range(n_image[0], n_image[1]))
R = torch.tensor(
1.0 / (2 *
(1.0 / room_dim[0] + 1.0 / room_dim[1] + 1.0 / room_dim[2])))
eps = np.finfo(np.float16).eps
mic_position = torch.from_numpy(mic_pos)
src_position = torch.from_numpy(src_pos) # [nsource, 3]
n_mic = mic_position.shape[0]
num_src = src_position.shape[0]
# [nmic, nsource]
direct_dist = ((mic_position.unsqueeze(1) -
src_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt()
# [nsource]
nearest_dist, nearest_mic_idx = direct_dist.min(0)
# [nsource, 3]
nearest_mic_position = mic_position[nearest_mic_idx]
ns = n_mic * num_src
ratio = 64
sample_sr = sr * ratio
velocity = 340.0
T60 = torch.tensor(T60)
direct_idx = (torch.ceil(direct_dist * sample_sr / velocity).long().view(
ns, ))
rir_length = int(np.ceil(sample_sr * T60))
resample1 = Resample(sample_sr, sample_sr // int(np.sqrt(ratio)))
resample2 = Resample(sample_sr // int(np.sqrt(ratio)), sr)
reflect_coef = (1 - (1 - torch.exp(-0.16 * R / T60)).pow(2)).sqrt()
dist_range = [
torch.linspace(1.0, velocity * T60 / nearest_dist[i] - 1, rir_length)
for i in range(num_src)
]
dist_prob = torch.linspace(0.0, 1.0, rir_length)
dist_prob /= dist_prob.sum()
dist_select_idx = dist_prob.multinomial(num_samples=int(image * num_src),
replacement=True).view(
num_src, image)
dist_nearest_ratio = torch.stack(
[dist_range[i][dist_select_idx[i]] for i in range(num_src)], 0)
# apply different dist ratios to mirophones
azm = torch.FloatTensor(num_src, image).uniform_(-np.pi, np.pi)
ele = torch.FloatTensor(num_src, image).uniform_(-np.pi / 2, np.pi / 2)
# [nsource, nimage, 3]
unit_3d = torch.stack(
[
torch.sin(ele) * torch.cos(azm),
torch.sin(ele) * torch.sin(azm),
torch.cos(ele),
],
-1,
)
# [nsource] x [nsource, T] x [nsource, nimage, 3] => [nsource, nimage, 3]
image2nearest_dist = nearest_dist.view(
-1, 1, 1) * dist_nearest_ratio.unsqueeze(-1)
image_position = (nearest_mic_position.unsqueeze(1) +
image2nearest_dist * unit_3d)
# [nmic, nsource, nimage]
dist = ((mic_position.view(-1, 1, 1, 3) -
image_position.unsqueeze(0)).pow(2).sum(-1) + 1e-3).sqrt()
# reflection perturbation
reflect_max = (torch.log10(velocity * T60) - 3) / torch.log10(reflect_coef)
reflect_ratio = (dist /
(velocity * T60)) * (reflect_max.view(1, -1, 1) - 1) + 1
reflect_pertub = torch.FloatTensor(num_src, image).uniform_(
a, b) * dist_nearest_ratio.pow(tau)
reflect_ratio = torch.maximum(reflect_ratio + reflect_pertub.unsqueeze(0),
torch.ones(1))
# [nmic, nsource, 1 + nimage]
dist = torch.cat([direct_dist.unsqueeze(2), dist], 2)
reflect_ratio = torch.cat([torch.zeros(n_mic, num_src, 1), reflect_ratio],
2)
delta_idx = (torch.minimum(
torch.ceil(dist * sample_sr / velocity),
torch.ones(1) * rir_length - 1,
).long().view(ns, -1))
delta_decay = reflect_coef.pow(reflect_ratio) / dist
#################################
# source orientation simulation #
#################################
if src_pattern != "omni":
# randomly sample each image's relative orientation with respect to the original source # noqa
# equivalent to a random decay corresponds to the source's orientation pattern decay # noqa
img_orientation_rad = torch.FloatTensor(num_src, image,
2).uniform_(-np.pi, np.pi)
img_cos_theta = torch.cos(img_orientation_rad[..., 0]) * torch.cos(
img_orientation_rad[..., 1]) # [nsource, nimage]
img_orientation_decay = freq_invariant_decay_func(
img_cos_theta, pattern=src_pattern) # [nsource, nimage]
# direct path orientation should use the provided parameter
if src_orientation_rad is None:
# assume random orientation if not given
src_orientation_azi = torch.FloatTensor(num_src).uniform_(
-np.pi, np.pi)
src_orientation_ele = torch.FloatTensor(num_src).uniform_(
-np.pi, np.pi)
src_orientation_rad = torch.stack(
[src_orientation_azi, src_orientation_ele], -1)
else:
src_orientation_rad = torch.from_numpy(
src_orientation_rad) # [nsource, 2]
src_orientation_decay = freq_invariant_src_decay_func(
mic_position,
src_position,
src_orientation_rad,
pattern=src_pattern,
) # [nmic, nsource]
# apply decay
delta_decay[:, :, 0] *= src_orientation_decay
delta_decay[:, :, 1:] *= img_orientation_decay.unsqueeze(0)
if mic_pattern != "omni":
# mic orientation simulation #
# when not given, assume that all mics facing up (positive z axis)
if mic_orientation_rad is None:
mic_orientation_rad = torch.stack(
[torch.zeros(n_mic), torch.zeros(n_mic)], -1) # [nmic, 2]
else:
mic_orientation_rad = torch.from_numpy(mic_orientation_rad)
all_src_img_pos = torch.cat(
(src_position.unsqueeze(1), image_position),
1) # [nsource, nimage+1, 3]
mic_orientation_decay = freq_invariant_mic_decay_func(
mic_position,
all_src_img_pos,
mic_orientation_rad,
pattern=mic_pattern,
) # [nmic, nsource, nimage+1]
# apply decay
delta_decay *= mic_orientation_decay
rir = torch.zeros(ns, rir_length)
delta_decay = delta_decay.view(ns, -1)
for i in range(ns):
remainder_idx = delta_idx[i]
valid_mask = np.ones(len(remainder_idx))
while np.sum(valid_mask) > 0:
valid_remainder_idx, unique_remainder_idx = np.unique(
remainder_idx, return_index=True)
rir[i][valid_remainder_idx] += (
delta_decay[i][unique_remainder_idx] *
valid_mask[unique_remainder_idx])
valid_mask[unique_remainder_idx] = 0
remainder_idx[unique_remainder_idx] = 0
direct_mask = torch.zeros(ns, rir_length).float()
for i in range(ns):
direct_mask[
i,
max(direct_idx[i] + sample_sr * direct_range[0] // 1000, 0
):min(direct_idx[i] +
sample_sr * direct_range[1] // 1000, rir_length), ] = 1.0
rir_direct = rir * direct_mask
all_rir = torch.stack([rir, rir_direct], 1).view(ns * 2, -1)
rir_downsample = resample1(all_rir)
rir_hp = highpass_biquad(rir_downsample, sample_sr // int(np.sqrt(ratio)),
80.0)
rir = resample2(rir_hp).float().view(n_mic, num_src, 2, -1)
return rir[:, :, 0].data.numpy(), rir[:, :, 1].data.numpy()
def sample_mic_arch(n_mic, mic_spacing=None, bounding_box=None):
if mic_spacing is None:
mic_spacing = [0.02, 0.10]
if bounding_box is None:
bounding_box = [0.08, 0.12, 0]
sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1)
if sample_n_mic == 1:
mic_arch = np.array([[0, 0, 0]])
else:
mic_arch = []
while len(mic_arch) < sample_n_mic:
this_mic_pos = np.random.uniform(np.array([0, 0, 0]),
np.array(bounding_box))
if len(mic_arch) != 0:
ok = True
for other_mic_pos in mic_arch:
this_mic_spacing = np.linalg.norm(this_mic_pos -
other_mic_pos)
if (this_mic_spacing < mic_spacing[0]
or this_mic_spacing > mic_spacing[1]):
ok = False
break
if ok:
mic_arch.append(this_mic_pos)
else:
mic_arch.append(this_mic_pos)
mic_arch = np.stack(mic_arch, 0) # [nmic, 3]
return mic_arch
def sample_src_pos(
room_dim,
num_src,
array_pos,
min_mic_dis=0.5,
max_mic_dis=5,
min_dis_wall=None,
):
if min_dis_wall is None:
min_dis_wall = [0.5, 0.5, 0.5]
# random sample the source positon
src_pos = []
while len(src_pos) < num_src:
pos = np.random.uniform(np.array(min_dis_wall),
np.array(room_dim) - np.array(min_dis_wall))
dis = np.linalg.norm(pos - np.array(array_pos))
if dis >= min_mic_dis and dis <= max_mic_dis:
src_pos.append(pos)
return np.stack(src_pos, 0)
def sample_mic_array_pos(mic_arch, room_dim, min_dis_wall=None):
"""
Generate the microphone array position according to the given microphone architecture (geometry) # noqa
:param mic_arch: np.array with shape [n_mic, 3]
the relative 3D coordinate to the array_pos in (m)
e.g., 2-mic linear array [[-0.1, 0, 0], [0.1, 0, 0]];
e.g., 4-mic circular array [[0, 0.035, 0], [0.035, 0, 0], [0, -0.035, 0], [-0.035, 0, 0]] # noqa
:param min_dis_wall: minimum distance from the wall in (m)
:return
mic_pos: microphone array position in (m) with shape [n_mic, 3]
array_pos: array CENTER / REFERENCE position in (m) with shape [1, 3]
"""
def rotate(angle, valuex, valuey):
rotate_x = valuex * np.cos(angle) + valuey * np.sin(angle) # [nmic]
rotate_y = valuey * np.cos(angle) - valuex * np.sin(angle)
return np.stack(
[rotate_x, rotate_y, np.zeros_like(rotate_x)], -1) # [nmic, 3]
if min_dis_wall is None:
min_dis_wall = [0.5, 0.5, 0.5]
if isinstance(mic_arch, dict): # ADHOC ARRAY
n_mic, mic_spacing, bounding_box = (
mic_arch["n_mic"],
mic_arch["spacing"],
mic_arch["bounding_box"],
)
sample_n_mic = np.random.randint(n_mic[0], n_mic[1] + 1)
if sample_n_mic == 1:
mic_arch = np.array([[0, 0, 0]])
else:
mic_arch = [
np.random.uniform(np.array([0, 0, 0]), np.array(bounding_box))
]
while len(mic_arch) < sample_n_mic:
this_mic_pos = np.random.uniform(np.array([0, 0, 0]),
np.array(bounding_box))
ok = True
for other_mic_pos in mic_arch:
this_mic_spacing = np.linalg.norm(this_mic_pos -
other_mic_pos)
if (this_mic_spacing < mic_spacing[0]
or this_mic_spacing > mic_spacing[1]):
ok = False
break
if ok:
mic_arch.append(this_mic_pos)
mic_arch = np.stack(mic_arch, 0) # [nmic, 3]
else:
mic_arch = np.array(mic_arch)
mic_array_center = np.mean(mic_arch, 0, keepdims=True) # [1, 3]
max_radius = max(np.linalg.norm(mic_arch - mic_array_center, axis=-1))
array_pos = np.random.uniform(
np.array(min_dis_wall) + max_radius,
np.array(room_dim) - np.array(min_dis_wall) - max_radius,
).reshape(1, 3)
mic_pos = array_pos + mic_arch
# assume the array is always horizontal
rotate_azm = np.random.uniform(-np.pi, np.pi)
mic_pos = array_pos + rotate(rotate_azm, mic_arch[:, 0],
mic_arch[:, 1]) # [n_mic, 3]
return mic_pos, array_pos
def sample_a_config(simu_config):
room_config = simu_config["min_max_room"]
rt60_config = simu_config["rt60"]
mic_dist_config = simu_config["mic_dist"]
num_src = simu_config["num_src"]
room_dim = np.random.uniform(np.array(room_config[0]),
np.array(room_config[1]))
rt60 = np.random.uniform(rt60_config[0], rt60_config[1])
sr = simu_config["sr"]
if ("array_pos"
not in simu_config.keys()): # mic_arch must be given in this case
mic_arch = simu_config["mic_arch"]
mic_pos, array_pos = sample_mic_array_pos(mic_arch, room_dim)
else:
array_pos = simu_config["array_pos"]
if "src_pos" not in simu_config.keys():
src_pos = sample_src_pos(
room_dim,
num_src,
array_pos,
min_mic_dis=mic_dist_config[0],
max_mic_dis=mic_dist_config[1],
)
else:
src_pos = np.array(simu_config["src_pos"])
return mic_pos, sr, rt60, room_dim, src_pos, array_pos
# === single-channel FRA-RIR ===
def single_channel(simu_config):
mic_arch = {"n_mic": [1, 1], "spacing": None, "bounding_box": None}
simu_config["mic_arch"] = mic_arch
mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(
simu_config)
rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos, array_pos)
# with shape [1, n_src, rir_len]
return rir, rir_direct
# === multi-channel (fixed) ===
def multi_channel_array(simu_config):
mic_arch = [[-0.05, 0, 0], [0.05, 0, 0]]
simu_config["mic_arch"] = mic_arch
mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(
simu_config)
rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos)
# with shape [n_mic, n_src, rir_len]
return rir, rir_direct
# === multi-channel (adhoc) ===
def multi_channel_adhoc(simu_config):
mic_arch = {
"n_mic": [1, 3],
"spacing": [0.02, 0.05],
"bounding_box": [0.5, 1.0, 0], # x, y, z
}
simu_config["mic_arch"] = mic_arch
mic_pos, sr, rt60, room_dim, src_pos, array_pos = sample_a_config(
simu_config)
rir, rir_direct = FRAM_RIR(mic_pos, sr, rt60, room_dim, src_pos)
# with shape [sample_n_mic, n_src, rir_len]
return rir, rir_direct
def multi_channel_src_orientation():
"""
========================= → y axis
| |
| *1 *2 |
| |
| ↑ |
| |
| *3 *4 |
| |
=========================
↓
x axis
"""
sr = 16000
rt60 = 0.6
room_dim = [8, 8, 3]
src_pos = np.array([[4, 4, 1.5]]) # middle of the room
mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5],
[6, 6, 1.5]] # mic 1, 2
) # mic 3, 4
src_pattern = "sub_cardioid"
src_orientation_rad = (np.array([180, 90]) / 180.0 * np.pi
) # facing *front* (negative x axis)
rir, rir_direct = FRAM_RIR(
mic_pos,
sr,
rt60,
room_dim=room_dim,
src_pos=src_pos,
src_pattern=src_pattern,
src_orientation_rad=src_orientation_rad,
)
return rir, rir_direct
def multi_channel_mic_orientation():
"""
========================= → y axis
| |
| ↑1 ↓2 |
| |
| o |
| |
| ↑3 ↓4 |
| |
=========================
↓
x axis
"""
sr = 16000
rt60 = 0.6
room_dim = [8, 8, 3]
src_pos = np.array([[4, 4, 1.5]]) # middle of the room
mic_pos = np.array([[2, 2, 1.5], [2, 6, 1.5], [6, 2, 1.5],
[6, 6, 1.5]] # mic 1, 2
) # mic 3, 4
mic_pattern = "sub_cardioid"
mic_orientation_rad = (
np.array([
[180, 90],
[0, 90], # mic 1 (negative x axis), 2 (positive x axis)
[180, 90],
[0, 90],
]) / 180.0 * np.pi) # mic 3 (negative x axis), 4 (positive x axis)
rir, rir_direct = FRAM_RIR(
mic_pos,
sr,
rt60,
room_dim=room_dim,
src_pos=src_pos,
mic_pattern=mic_pattern,
mic_orientation_rad=mic_orientation_rad,
)
return rir, rir_direct
================================================
FILE: wesep/dataset/dataset.py
================================================
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Shuai Wang (wsstriving@gmail.com)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.distributed as dist
import torch.nn.functional as tf
from torch.utils.data import IterableDataset
import wesep.dataset.processor as processor
from wesep.utils.file_utils import read_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
"""Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(
rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers,
)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
"""Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
if len(data) <= self.num_workers:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
else:
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self,
lists,
shuffle=True,
partition=True,
repeat_dataset=False):
self.lists = lists
self.repeat_dataset = repeat_dataset
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
if not self.repeat_dataset:
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
else:
indexes_len = len(indexes)
counter = 0
while True:
index = indexes[counter % indexes_len]
counter += 1
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def tse_collate_fn_2spk(batch, mode="min"):
# Warning: hard-coded for 2 speakers, will be deprecated in the future,
# use tse_collate_fn instead
new_batch = {}
wav_mix = []
wav_targets = []
spk_embeds = []
spk = []
key = []
spk_label = []
length_spk_embeds = []
for s in batch:
wav_mix.append(s["wav_mix"])
wav_targets.append(s["wav_spk1"])
spk.append(s["spk1"])
key.append(s["key"])
spk_embeds.append(torch.from_numpy(s["embed_spk1"].copy()))
length_spk_embeds.append(spk_embeds[-1].shape[1])
if "spk1_label" in s.keys():
spk_label.append(s["spk1_label"])
wav_mix.append(s["wav_mix"])
wav_targets.append(s["wav_spk2"])
spk.append(s["spk2"])
key.append(s["key"])
spk_embeds.append(torch.from_numpy(s["embed_spk2"].copy()))
length_spk_embeds.append(spk_embeds[-1].shape[1])
if "spk2_label" in s.keys():
spk_label.append(s["spk2_label"])
if not (len(set(length_spk_embeds)) == 1):
if mode == "max":
max_len = max(length_spk_embeds)
for i in range(len(length_spk_embeds)):
if len(spk_embeds[i].shape) == 2:
spk_embeds[i] = tf.pad(
spk_embeds[i],
(0, max_len - length_spk_embeds[i]),
"constant",
0,
)
elif len(spk_embeds[i].shape) == 3:
spk_embeds[i] = tf.pad(
spk_embeds[i],
(0, 0, 0, max_len - length_spk_embeds[i]),
"constant",
0,
)
if mode == "min":
min_len = min(length_spk_embeds)
for i in range(len(length_spk_embeds)):
if len(spk_embeds[i].shape) == 2:
spk_embeds[i] = spk_embeds[i][:, :min_len]
elif len(spk_embeds[i].shape) == 3:
spk_embeds[i] = spk_embeds[i][:, :min_len, :]
new_batch["wav_mix"] = torch.concat(wav_mix)
new_batch["wav_targets"] = torch.concat(wav_targets)
new_batch["spk_embeds"] = torch.concat(spk_embeds)
new_batch["length_spk_embeds"] = length_spk_embeds
new_batch["spk"] = spk
new_batch["key"] = key
new_batch["spk_label"] = torch.as_tensor(spk_label)
return new_batch
def tse_collate_fn(batch, mode="min"):
# This is a more generalizable implementation for target speaker extraction
# Support arbitrary number of speakers
new_batch = {}
wav_mix = []
wav_targets = []
spk_embeds = []
spk = []
key = []
spk_label = []
length_spk_embeds = []
for s in batch:
for i in range(s["num_speaker"]):
wav_mix.append(s["wav_mix"])
wav_targets.append(s["wav_spk{}".format(i + 1)])
spk.append(s["spk{}".format(i + 1)])
key.append(s["key"])
spk_embeds.append(
torch.from_numpy(s["embed_spk{}".format(i + 1)].copy()))
length_spk_embeds.append(spk_embeds[-1].shape[1])
if "spk{}_label".format(i + 1) in s.keys():
spk_label.append(s["spk{}_label".format(i + 1)])
if not (len(set(length_spk_embeds)) == 1):
if mode == "max":
max_len = max(length_spk_embeds)
for i in range(len(length_spk_embeds)):
if len(spk_embeds[i].shape) == 2:
spk_embeds[i] = tf.pad(
spk_embeds[i],
(0, max_len - length_spk_embeds[i]),
"constant",
0,
)
elif len(spk_embeds[i].shape) == 3:
spk_embeds[i] = tf.pad(
spk_embeds[i],
(0, 0, 0, max_len - length_spk_embeds[i]),
"constant",
0,
)
if mode == "min":
min_len = min(length_spk_embeds)
for i in range(len(length_spk_embeds)):
if len(spk_embeds[i].shape) == 2:
spk_embeds[i] = spk_embeds[i][:, :min_len]
elif len(spk_embeds[i].shape) == 3:
spk_embeds[i] = spk_embeds[i][:, :min_len, :]
new_batch["wav_mix"] = torch.concat(wav_mix)
new_batch["wav_targets"] = torch.concat(wav_targets)
new_batch["spk_embeds"] = torch.concat(spk_embeds)
new_batch["length_spk_embeds"] = (
length_spk_embeds # Not used, but maybe needed when using the enrollment utterance # noqa
)
new_batch["spk"] = spk
new_batch["key"] = key
new_batch["spk_label"] = torch.as_tensor(spk_label)
return new_batch
def Dataset(
data_type,
data_list_file,
configs,
spk2embed_dict=None,
spk1_embed=None,
spk2_embed=None,
state="train",
joint_training=False,
dict_spk=None,
whole_utt=False,
repeat_dataset=False,
noise_prob=0,
reverb_prob=0,
noise_enroll_prob=0,
reverb_enroll_prob=0,
specaug_enroll_prob=0,
noise_lmdb_file=None,
online_mix=False,
):
"""Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw/feat file level. The second is local shuffle
at training samples level.
Args:
:param spk2_embed:
:param online_mix:
:param spk1_embed:
:param data_type(str): shard/raw/feat
:param data_list_file: data list file
:param configs: dataset configs
:param noise_prob:probility to add noise on mixture
:param reverb_prob:probility to add reverb on mixture
:param noise_enroll_prob:probility to add noise on enrollment speech
:param reverb_enroll_prob:probility to add reverb on enrollment speech
:param specaug_enroll_prob: probility to apply SpecAug on fbank of enrollment speech # noqa
:param noise_lmdb_file: noise data source lmdb file
:param whole_utt: use whole utt or random chunk
:param repeat_dataset:
"""
assert data_type in ["shard", "raw"]
lists = read_lists(data_list_file)
shuffle = configs.get("shuffle", False)
# Global shuffle
dataset = DataList(lists, shuffle=shuffle, repeat_dataset=repeat_dataset)
if data_type == "shard":
dataset = Processor(dataset, processor.url_opener)
if not online_mix:
dataset = Processor(dataset, processor.tar_file_and_group)
else:
dataset = Processor(dataset,
processor.tar_file_and_group_single_spk)
else:
dataset = Processor(dataset, processor.parse_raw)
if configs.get("filter_len", False) and state == "train":
# Filter the data with unwanted length
filter_conf = configs.get("filter_args", {})
dataset = Processor(dataset, processor.filter_len, **filter_conf)
# Local shuffle
if shuffle and not online_mix:
dataset = Processor(dataset, processor.shuffle,
**configs["shuffle_args"])
# resample
resample_rate = configs.get("resample_rate", 16000)
dataset = Processor(dataset, processor.resample, resample_rate)
if not whole_utt:
# random chunk
chunk_len = configs.get("chunk_len", resample_rate * 3)
dataset = Processor(dataset, processor.random_chunk, chunk_len)
if online_mix:
dataset = Processor(
dataset,
processor.mix_speakers,
configs.get("num_speakers", 2),
configs.get("online_buffer_size", 1000),
)
if reverb_prob > 0:
dataset = Processor(dataset, processor.add_reverb, reverb_prob)
dataset = Processor(
dataset,
processor.snr_mixer,
configs.get("use_random_snr", False),
)
if noise_prob > 0:
assert noise_lmdb_file is not None
dataset = Processor(dataset, processor.add_noise, noise_lmdb_file,
noise_prob)
speaker_feat = configs.get("speaker_feat", False)
if state == "train":
if not joint_training:
dataset = Processor(dataset, processor.sample_spk_embedding,
spk2embed_dict)
else:
dataset = Processor(dataset, processor.sample_enrollment,
spk2embed_dict, dict_spk)
if reverb_enroll_prob > 0:
dataset = Processor(dataset, processor.add_reverb_on_enroll,
reverb_enroll_prob)
if noise_enroll_prob > 0:
assert noise_lmdb_file is not None
dataset = Processor(
dataset,
processor.add_noise_on_enroll,
noise_lmdb_file,
noise_enroll_prob,
)
if speaker_feat:
dataset = Processor(dataset, processor.compute_fbank,
**configs["fbank_args"])
dataset = Processor(dataset, processor.apply_cmvn)
if specaug_enroll_prob > 0:
dataset = Processor(dataset,
processor.spec_aug,
prob=specaug_enroll_prob)
else:
if not joint_training:
dataset = Processor(
dataset,
processor.sample_fix_spk_embedding,
spk2embed_dict,
spk1_embed,
spk2_embed,
)
else:
dataset = Processor(
dataset,
processor.sample_fix_spk_enrollment,
spk2embed_dict,
spk1_embed,
spk2_embed,
dict_spk,
)
if speaker_feat:
dataset = Processor(dataset, processor.compute_fbank,
**configs["fbank_args"])
dataset = Processor(dataset, processor.apply_cmvn)
return dataset
================================================
FILE: wesep/dataset/lmdb_data.py
================================================
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import random
import lmdb
class LmdbData:
def __init__(self, lmdb_file):
self.db = lmdb.open(lmdb_file,
readonly=True,
lock=False,
readahead=False)
with self.db.begin(write=False) as txn:
obj = txn.get(b"__keys__")
assert obj is not None
self.keys = pickle.loads(obj)
assert isinstance(self.keys, list)
def random_one(self):
assert len(self.keys) > 0
index = random.randint(0, len(self.keys) - 1)
key = self.keys[index]
with self.db.begin(write=False) as txn:
value = txn.get(key.encode())
assert value is not None
return key, value
def __del__(self):
self.db.close()
if __name__ == "__main__":
import sys
db = LmdbData(sys.argv[1])
key, _ = db.random_one()
print(key)
key, _ = db.random_one()
print(key)
================================================
FILE: wesep/dataset/processor.py
================================================
import io
import json
import logging
import random
import tarfile
from subprocess import PIPE, Popen
from urllib.parse import urlparse
import librosa
import numpy as np
import soundfile as sf
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from scipy import signal
from wesep.dataset.FRAM_RIR import single_channel as RIR_sim
from wesep.dataset.lmdb_data import LmdbData
AUDIO_FORMAT_SETS = {"flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"}
# set the simulation configuration
simu_config = {
"min_max_room": [[3, 3, 2.5], [10, 6, 4]],
"rt60": [0.1, 0.7],
"sr": 16000,
"mic_dist": [0.2, 5.0],
"num_src": 1,
}
def url_opener(data):
"""Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert "src" in sample
# TODO(Binbin Zhang): support HTTP
url = sample["src"]
try:
pr = urlparse(url)
# local file
if pr.scheme == "" or pr.scheme == "file":
stream = open(url, "rb")
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
else:
cmd = f"wget -q -O - {url}"
process = Popen(cmd, shell=True, stdout=PIPE)
sample.update(process=process)
stream = process.stdout
sample.update(stream=stream)
yield sample
except Exception as ex:
logging.warning("Failed to open {}".format(url))
def tar_file_and_group(data):
"""Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, mix_wav, spk1_wav, spk2_wav, ..., sample_rate}]
"""
for sample in data:
assert "stream" in sample
stream = tarfile.open(fileobj=sample["stream"], mode="r:*")
# TODO: The mode need to be validated
# In order to be compatible with the torch 2.x version,
# the file reading method here does not use streaming.
prev_prefix = None
example = {}
num_speakers = 0
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind(".")
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prev_prefix not in prefix:
example["key"] = prev_prefix
if valid:
example["num_speaker"] = num_speakers
num_speakers = 0
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if "spk" in postfix:
example[postfix] = (
file_obj.read().decode("utf8").strip())
num_speakers += 1
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
if prefix[-5:-1] == "_spk":
example["wav" + prefix[-5:]] = waveform
prefix = prefix[:-5]
else:
example["wav_mix"] = waveform
example["sample_rate"] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning("error to parse {}".format(name))
prev_prefix = prefix
if prev_prefix is not None:
example["key"] = prev_prefix
example["num_speaker"] = num_speakers
num_speakers = 0
yield example
stream.close()
if "process" in sample:
sample["process"].communicate()
sample["stream"].close()
def tar_file_and_group_single_spk(data):
"""Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, spk, sample_rate}]
"""
for sample in data:
assert "stream" in sample
stream = tarfile.open(fileobj=sample["stream"],
mode="r|*") # Only support pytorch version <2.0
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind(".")
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example["key"] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix in ["spk"]:
example[postfix] = (
file_obj.read().decode("utf8").strip())
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
example["wav"] = waveform
example["sample_rate"] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning("error to parse {}".format(name))
prev_prefix = prefix
if prev_prefix is not None:
example["key"] = prev_prefix
yield example
stream.close()
if "process" in sample:
sample["process"].communicate()
sample["stream"].close()
def parse_raw_single_spk(data):
"""Parse key/wav/spk from json line
Args:
data: Iterable[str], str is a json line has key/wav/spk
Returns:
Iterable[{key, wav, spk, sample_rate}]
"""
for sample in data:
assert "src" in sample
json_line = sample["src"]
obj = json.loads(json_line)
assert "key" in obj
assert "wav" in obj
assert "spk" in obj
key = obj["key"]
wav_file = obj["wav"]
spk = obj["spk"]
try:
waveform, sample_rate = torchaudio.load(wav_file)
example = dict(key=key,
spk=spk,
wav=waveform,
sample_rate=sample_rate)
yield example
except Exception as ex:
logging.warning("Failed to read {}".format(wav_file))
def mix_speakers(data, num_speaker=2, shuffle_size=1000):
"""Dynamic mixing speakers when loading data,
shuffle is not needed if this function is used
Args:
:param data: Iterable[{key, wavs, spks}]
:param num_speaker:
:param use_random_snr:
:param shuffle_size:
Returns:
Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
cur_spk = x["spk"]
example = {
"key": x["key"],
"wav_spk1": x["wav"],
"spk1": x["spk"],
"sample_rate": x["sample_rate"],
}
key = "mix_" + x["key"]
interference_idx = 1
while interference_idx < num_speaker:
interference = random.choice(buf)
while interference["spk"] == cur_spk:
interference = random.choice(buf)
key = key + "_" + interference["key"]
interference_idx += 1
example["wav_spk" +
str(interference_idx)] = interference["wav"]
example["spk" +
str(interference_idx)] = interference["spk"]
example["key"] = key
example["num_speaker"] = num_speaker
yield example
buf = []
# The samples left over
random.shuffle(buf)
for x in buf:
cur_spk = x["spk"]
example = {
"key": x["key"],
"wav_spk1": x["wav"],
"spk1": x["spk"],
"sample_rate": x["sample_rate"],
}
key = "mix_" + x["key"]
interference_idx = 1
while interference_idx < num_speaker:
interference = random.choice(buf)
while interference["spk"] == cur_spk:
interference = random.choice(buf)
key = key + "_" + interference["key"]
interference_idx += 1
example["wav_spk" + str(interference_idx)] = interference["wav"]
example["spk" + str(interference_idx)] = interference["spk"]
example["key"] = key
example["num_speaker"] = num_speaker
yield example
def snr_mixer(data, use_random_snr: bool = False):
"""Dynamic mixing speakers when loading data, shuffle is not needed if this function is used. # noqa
Args:
data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
use_random_snr (bool, optional): Whether use random SNR to mix speeches. Defaults to False. # noqa
Returns:
Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
"""
for sample in data:
assert "num_speaker" in sample.keys()
if "wav_spk1_reverb" in sample.keys():
suffix = "_reverb"
else:
suffix = ""
num_speaker = sample["num_speaker"]
wavs_to_mix = [sample["wav_spk1" + suffix]]
target_energy = torch.sum(wavs_to_mix[0]**2, dim=-1, keepdim=True)
for i in range(1, num_speaker):
interference = sample[f"wav_spk{i + 1}" + suffix]
if use_random_snr:
snr = random.uniform(-10, 10)
else:
snr = 0
energy = torch.sum(interference**2, dim=-1, keepdim=True)
interference *= torch.sqrt(target_energy / energy) * 10**(snr / 20)
wavs_to_mix.append(interference)
wavs_to_mix = torch.stack(wavs_to_mix)
sample["wav_mix"] = torch.sum(wavs_to_mix, 0)
max_amp = max(
torch.abs(sample["wav_mix"]).max().item(),
*[x.item() for x in torch.abs(wavs_to_mix).max(dim=-1)[0]],
)
if max_amp != 0:
mix_scaling = 1 / max_amp
else:
mix_scaling = 1
sample["wav_mix"] = sample["wav_mix"] * mix_scaling
for i in range(0, num_speaker):
sample[f"wav_spk{i + 1}" + suffix] *= mix_scaling
yield sample
def shuffle(data, shuffle_size=2500):
"""Local shuffle the data
Args:
data: Iterable[{key, wavs, spks}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, wavs, spks}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def spk_to_id(data, spk2id):
"""Parse spk id
Args:
data: Iterable[{key, wav/feat, spk}]
spk2id: Dict[str, int]
Returns:
Iterable[{key, wav/feat, label}]
"""
for sample in data:
assert "spk" in sample
if sample["spk"] in spk2id:
label = spk2id[sample["spk"]]
else:
label = -1
sample["label"] = label
yield sample
def resample(data, resample_rate=16000):
"""Resample data.
Inplace operation.
Args:
data: Iterable[{key, wavs, spks, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wavs, spks, sample_rate}]
"""
for sample in data:
assert "sample_rate" in sample
sample_rate = sample["sample_rate"]
if sample_rate != resample_rate:
all_keys = list(sample.keys())
sample["sample_rate"] = resample_rate
for key in all_keys:
if "wav" in key:
waveform = sample[key]
sample[key] = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
yield sample
def sample_spk_embedding(data, spk_embeds):
"""sample reference speaker embeddings for the target speaker
Args:
data: Iterable[{key, wav, label, sample_rate}]
spk_embeds: dict which stores all potential embeddings for the speaker
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("spk"):
sample["embed_" + key] = random.choice(spk_embeds[sample[key]])
yield sample
def sample_fix_spk_embedding(data, spk2embed_dict, spk1_embed, spk2_embed):
"""sample reference speaker embeddings for the target speaker
Args:
data: Iterable[{key, wav, label, sample_rate}]
spk_embeds: dict which stores all potential embeddings for the speaker
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("spk"):
if key == "spk1":
sample["embed_" +
key] = spk2embed_dict[spk1_embed[sample["key"]]]
else:
sample["embed_" +
key] = spk2embed_dict[spk2_embed[sample["key"]]]
yield sample
def sample_enrollment(data, spk_embeds, dict_spk):
"""sample reference speech for the target speaker
Args:
data: Iterable[{key, wav, label, sample_rate}]
spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker # noqa
dict_spk: dict of speakers in the enrollment sets [Order: spkID]
Returns:
Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment), # noqa
spk_lable(when multi-task training)}]
"""
for sample in data:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("spk"):
enrollment, _ = sf.read(
random.choice(spk_embeds[sample[key]])[1])
sample["embed_" + key] = np.expand_dims(enrollment, axis=0)
if dict_spk:
sample[key + "_label"] = dict_spk[sample[key]]
yield sample
def sample_fix_spk_enrollment(data,
spk2embed_dict,
spk1_embed,
spk2_embed,
dict_spk=None):
"""sample reference speaker embeddings for the target speaker
Args:
data: Iterable[{key, wav, label, sample_rate}]
spk_embeds: dict which stores all potential enrollment utterance files(/.wav) for the speaker # noqa
dict_spk: dict of speakers in the enrollment sets [Order: spkID]
Returns:
Iterable[{key, wav, label, sample_rate, spk_embed(raw waveform of enrollment), # noqa
spk_lable(when multi-task training)}]
"""
for sample in data:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("spk"):
if key == "spk1":
enrollment, _ = sf.read(
spk2embed_dict[spk1_embed[sample["key"]]])
else:
enrollment, _ = sf.read(
spk2embed_dict[spk2_embed[sample["key"]]])
sample["embed_" + key] = np.expand_dims(enrollment, axis=0)
if dict_spk:
sample[key + "_label"] = dict_spk[sample[key]]
yield sample
def compute_fbank(data,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=1.0):
"""Extract fbank
Args:
data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa
Returns:
Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa
"""
for sample in data:
assert "sample_rate" in sample
sample_rate = sample["sample_rate"]
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("embed"):
waveform = torch.from_numpy(sample[key])
waveform = waveform * (1 << 15)
mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
sample_frequency=sample_rate,
window_type="hamming",
use_energy=False,
)
sample[key] = mat
yield sample
def apply_cmvn(data, norm_mean=True, norm_var=False):
"""Apply CMVN
Args:
data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa
Returns:
Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1', 'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2'] # noqa
"""
for sample in data:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("embed"):
mat = sample[key]
if norm_mean:
mat = mat - torch.mean(mat, dim=0)
if norm_var:
mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8)
mat = mat.unsqueeze(0)
sample[key] = mat.detach().numpy()
yield sample
def get_random_chunk(data_list, chunk_len):
"""Get random chunk
Args:
data_list: [torch.Tensor: 1XT] (random len)
chunk_len: chunk length
Returns:
[torch.Tensor] (exactly chunk_len)
"""
# Assert all entries in the list share the same length
assert False not in [len(i) == len(data_list[0]) for i in data_list]
data_list = [data[0] for data in data_list]
data_len = len(data_list[0])
# random chunk
if data_len >= chunk_len:
chunk_start = random.randint(0, data_len - chunk_len)
for i in range(len(data_list)):
temp_data = data_list[i][chunk_start:chunk_start + chunk_len]
while torch.equal(temp_data, torch.zeros_like(temp_data)):
chunk_start = random.randint(0, data_len - chunk_len)
temp_data = data_list[i][chunk_start:chunk_start + chunk_len]
data_list[i] = temp_data
# re-clone the data to avoid memory leakage
if type(data_list[i]) == torch.Tensor:
data_list[i] = data_list[i].clone()
else: # np.array
data_list[i] = data_list[i].copy()
else:
# padding
repeat_factor = chunk_len // data_len + 1
for i in range(len(data_list)):
if type(data_list[i]) == torch.Tensor:
data_list[i] = data_list[i].repeat(repeat_factor)
else: # np.array
data_list[i] = np.tile(data_list[i], repeat_factor)
data_list[i] = data_list[i][:chunk_len]
data_list = [data.unsqueeze(0) for data in data_list]
return data_list
def filter_len(
data,
min_num_seconds=1,
max_num_seconds=1000,
):
"""Filter the utterance with very short duration and random chunk the
utterance with very long duration.
Args:
data: Iterable[{key, wav, label, sample_rate}]
min_num_seconds: minimum number of seconds of wav file
max_num_seconds: maximum number of seconds of wav file
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert "key" in sample
assert "sample_rate" in sample
assert "wav" in sample
sample_rate = sample["sample_rate"]
wav = sample["wav"]
min_len = min_num_seconds * sample_rate
max_len = max_num_seconds * sample_rate
if wav.size(1) < min_len:
continue
elif wav.size(1) > max_len:
wav = get_random_chunk([wav], max_len)[0]
sample["wav"] = wav
yield sample
def random_chunk(data, chunk_len):
"""Random chunk the data into chunk_len
Args:
data: Iterable[{key, wav/feat, label}]
chunk_len: chunk length for each sample
Returns:
Iterable[{key, wav/feat, label}]
"""
for sample in data:
assert "key" in sample
wav_keys = [key for key in list(sample.keys()) if "wav" in key]
wav_data_list = [sample[key] for key in wav_keys]
wav_data_list = get_random_chunk(wav_data_list, chunk_len)
sample.update(zip(wav_keys, wav_data_list))
yield sample
def fix_chunk(data, chunk_len):
"""Random chunk the data into chunk_len
Args:
data: Iterable[{key, wav/feat, label}]
chunk_len: chunk length for each sample
Returns:
Iterable[{key, wav/feat, label}]
"""
for sample in data:
assert "key" in sample
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("wav"):
sample[key] = sample[key][:, :chunk_len]
yield sample
def add_noise(
data,
noise_lmdb_file,
noise_prob: float = 0.0,
noise_db_low: int = -5,
noise_db_high: int = 25,
single_channel: bool = True,
):
"""Add noise to mixture
Args:
data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
noise_lmdb_file: noise LMDB data source.
noise_db_low (int, optional): SNR lower bound. Defaults to -5.
noise_db_high (int, optional): SNR upper bound. Defaults to 25.
single_channel (bool, optional): Whether to force the noise file to be single channel. # noqa
Defaults to True.
Returns:
Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}] # noqa
"""
noise_source = LmdbData(noise_lmdb_file)
for sample in data:
if noise_prob > random.random():
assert "sample_rate" in sample.keys()
tgt_fs = sample["sample_rate"]
speech = sample["wav_mix"].numpy() # [1, nsamples]
nsamples = speech.shape[1]
power = (speech**2).mean()
noise_key, noise_data = noise_source.random_one()
if noise_key.startswith(
"speech"): # using interference speech as additive noise
snr_range = [10, 30]
else:
snr_range = [noise_db_low, noise_db_high]
noise_db = np.random.uniform(snr_range[0], snr_range[1])
with sf.SoundFile(io.BytesIO(noise_data)) as f:
fs = f.samplerate
if tgt_fs and fs != tgt_fs:
nsamples_ = int(nsamples / tgt_fs * fs) + 1
else:
nsamples_ = nsamples
if f.frames == nsamples_:
noise = f.read(dtype=np.float64, always_2d=True)
elif f.frames < nsamples_:
offset = np.random.randint(0, nsamples_ - f.frames)
# noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True)
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples_ - f.frames - offset), (0, 0)],
mode="wrap",
)
else:
offset = np.random.randint(0, f.frames - nsamples_)
f.seek(offset)
# noise: (Time, Nmic)
noise = f.read(nsamples_, dtype=np.float64, always_2d=True)
if len(noise) != nsamples_:
raise RuntimeError(
f"Something wrong: {noise_lmdb_file}")
if single_channel:
num_ch = noise.shape[1]
chs = [np.random.randint(num_ch)]
noise = noise[:, chs]
# noise: (Nmic, Time)
noise = noise.T
if tgt_fs and fs != tgt_fs:
logging.warning(
f"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)" # noqa
)
noise = librosa.resample(noise,
orig_sr=fs,
target_sr=tgt_fs,
res_type="kaiser_fast")
if noise.shape[1] < nsamples:
noise = np.pad(
noise,
[(0, 0), (0, nsamples - noise.shape[1])],
mode="wrap",
)
else:
noise = noise[:, :nsamples]
noise_power = (noise**2).mean()
scale = (10**(-noise_db / 20) * np.sqrt(power) /
np.sqrt(max(noise_power, 1e-10)))
scaled_noise = scale * noise
speech = speech + scaled_noise
sample["wav_mix"] = torch.from_numpy(speech)
sample["noise"] = torch.from_numpy(scaled_noise)
sample["snr"] = noise_db
yield sample
def add_reverb(data, reverb_prob=0):
"""
Args:
data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
Returns:
Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
Note: This function is implemented with reference to
Fast Random Appoximation of Multi-channel Room Impulse Response (FRAM-RIR)
https://arxiv.org/pdf/2304.08052
This function is only used when online_mixing.
"""
for sample in data:
assert "num_speaker" in sample.keys()
assert "sample_rate" in sample.keys()
simu_config["num_src"] = sample["num_speaker"]
simu_config["sr"] = sample["sample_rate"]
rirs, _ = RIR_sim(simu_config) # [n_mic, nsource, nsamples]
rirs = rirs[0] # [nsource, nsamples]
for i in range(sample["num_speaker"]):
if reverb_prob > random.random():
# [1, audio_len], currently only support single-channel audio
audio = sample[f"wav_spk{i + 1}"].numpy()
rir = rirs[i:i + 1, :] # [1, nsamples]
rir_audio = signal.convolve(
audio, rir,
mode="full")[:, :audio.shape[1]] # [1, audio_len]
max_scale = np.max(np.abs(rir_audio))
out_audio = rir_audio / max_scale * 0.9
# Note: Here, we do not replace the dry audio with the reverberant audio, # noqa
# which means we hope the model to perform dereverberation and
# TSE simultaneously.
sample[f"wav_spk{i + 1}"] = torch.from_numpy(out_audio)
yield sample
def add_noise_on_enroll(
data,
noise_lmdb_file,
noise_enroll_prob: float = 0.0,
noise_db_low: int = 0,
noise_db_high: int = 25,
single_channel: bool = True,
):
"""Add noise to mixture
Args:
data: Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
noise_lmdb_file: noise LMDB data source.
noise_db_low (int, optional): SNR lower bound. Defaults to 0.
noise_db_high (int, optional): SNR upper bound. Defaults to 25.
single_channel (bool, optional): Whether to force the noise file to be single channel. # noqa
Defaults to True.
Returns:
Iterable[{key, wav_mix, wav_spk1, wav_spk2, ..., spk1, spk2, ..., noise, snr}] # noqa
"""
noise_source = LmdbData(noise_lmdb_file)
for sample in data:
assert "sample_rate" in sample.keys()
tgt_fs = sample["sample_rate"]
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("spk") and "label" not in key:
if noise_enroll_prob > random.random():
speech = sample["embed_" + key]
nsamples = speech.shape[1]
power = (speech**2).mean()
noise_key, noise_data = noise_source.random_one()
if noise_key.startswith(
"speech"
): # using interference speech as additive noise
snr_range = [10, 30]
else:
snr_range = [noise_db_low, noise_db_high]
noise_db = np.random.uniform(snr_range[0], snr_range[1])
_, noise_data = noise_source.random_one()
with sf.SoundFile(io.BytesIO(noise_data)) as f:
fs = f.samplerate
if tgt_fs and fs != tgt_fs:
nsamples_ = int(nsamples / tgt_fs * fs) + 1
else:
nsamples_ = nsamples
if f.frames == nsamples_:
noise = f.read(dtype=np.float64, always_2d=True)
elif f.frames < nsamples_:
offset = np.random.randint(0, nsamples_ - f.frames)
# noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True)
# Repeat noise
noise = np.pad(
noise,
[
(offset, nsamples_ - f.frames - offset),
(0, 0),
],
mode="wrap",
)
else:
offset = np.random.randint(0, f.frames - nsamples_)
f.seek(offset)
# noise: (Time, Nmic)
noise = f.read(nsamples_,
dtype=np.float64,
always_2d=True)
if len(noise) != nsamples_:
raise RuntimeError(
f"Something wrong: {noise_lmdb_file}")
if single_channel:
num_ch = noise.shape[1]
chs = [np.random.randint(num_ch)]
noise = noise[:, chs]
# noise: (Nmic, Time)
noise = noise.T
if tgt_fs and fs != tgt_fs:
logging.warning(
f"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)" # noqa
)
noise = librosa.resample(
noise,
orig_sr=fs,
target_sr=tgt_fs,
res_type="kaiser_fast",
)
if noise.shape[1] < nsamples:
noise = np.pad(
noise,
[(0, 0), (0, nsamples - noise.shape[1])],
mode="wrap",
)
else:
noise = noise[:, :nsamples]
noise_power = (noise**2).mean()
scale = (10**(-noise_db / 20) * np.sqrt(power) /
np.sqrt(max(noise_power, 1e-10)))
scaled_noise = scale * noise
speech = speech + scaled_noise
sample["embed_" + key] = speech
yield sample
def add_reverb_on_enroll(data, reverb_enroll_prob=0):
"""
Args:
data: Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
Returns:
Iterable[{key, wav_spk1, wav_spk2, ..., spk1, spk2, ...}]
"""
for sample in data:
assert "num_speaker" in sample.keys()
assert "sample_rate" in sample.keys()
for i in range(sample["num_speaker"]):
simu_config["sr"] = sample["sample_rate"]
simu_config["num_src"] = 1
rirs, _ = RIR_sim(simu_config) # [n_mic, nsource, nsamples]
rirs = rirs[0] # [nsource, nsamples]
if reverb_enroll_prob > random.random():
# [1, audio_len], currently only support single-channel audio
audio = sample[f"embed_spk{i+1}"]
# rir = rirs[i : i + 1, :] # [1, nsamples]
rir = rirs
rir_audio = signal.convolve(
audio, rir,
mode="full")[:, :audio.shape[1]] # [1, audio_len]
max_scale = np.max(np.abs(rir_audio))
out_audio = rir_audio / max_scale * 0.9
# Note: Here, we do not replace the dry audio with the reverberant audio, # noqa
# which means we hope the model to perform dereverberation and
# TSE simultaneously.
sample[f"embed_spk{i+1}"] = out_audio
yield sample
def spec_aug(data, num_t_mask=1, num_f_mask=1, max_t=10, max_f=8, prob=0):
"""Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
prob: prob of spec_aug
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
if random.random() < prob:
all_keys = list(sample.keys())
for key in all_keys:
if key.startswith("embed"):
y = sample[key]
max_frames = y.shape[1]
max_freq = y.shape[2]
# time mask
for i in range(num_t_mask):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
y[:, start:end, :] = 0
# freq mask
for i in range(num_f_mask):
start = random.randint(0, max_freq - 1)
length = random.randint(1, max_f)
end = min(max_freq, start + length)
y[:, :, start:end] = 0
sample[key] = y
yield sample
================================================
FILE: wesep/dataset/vad.py
================================================
import numpy as np
import soundfile as sf
class VoiceActivityDetection:
def __init__(self, wave):
self.wave = wave
def segmentation(self, overlap, slice_len):
frequency = 16000
signal = self.wave
self.seg_len = len(signal) / frequency
self.slice_len = slice_len
overlap = 2
slices = np.arange(0, self.seg_len, slice_len - overlap, dtype=np.intc)
# print(slices)
audio_slices = []
for start, end in zip(slices[:-1], slices[1:]):
start_audio = start * frequency
end_audio = (end + overlap) * frequency
audio_slice = signal[int(start_audio):int(end_audio)]
# print(len(audio_slice))
audio_slices.append(audio_slice)
# wavfile.write('slices{}.wav'.format(start), 16000, audio_slice)
# print(len(audio_slices))
return audio_slices
def calc_energy(self, audio):
# for a in enumerate(audio):
# if (a == 0.0):
# a = 0.00001
# print(np.sum(np.sum(audio**2)))
energy = audio / np.sum(np.sum(audio**2) + 1e-8) * 1e2
# print(len(audio))
return energy
def select(self):
audio_slices = self.segmentation(overlap=1, slice_len=4)
energies = []
for audio in audio_slices:
chunk_len = len(audio) / 10
chunk_slice = np.arange(0,
len(audio) + chunk_len,
chunk_len,
dtype=np.intc)
for start, end in zip(chunk_slice[:-1], chunk_slice[1:]):
energy = self.calc_energy(audio[start:end])
# print(energy)
for i, _ in enumerate(energy):
if (energy[i]) == 0:
energy[i] = 0.00001
# print(energy[i])
energies.append(sum(energy))
# print(energies)
threshold = np.quantile(energies, 0.25)
print(threshold)
if threshold < 0.0001:
threshold = 0.0001
fin_audios = []
i = 0
for audio in audio_slices:
chunk_len = len(audio) / 10
chunk_slice = np.arange(0,
len(audio) + chunk_len,
chunk_len,
dtype=np.intc)
count = 0
for start, end in zip(chunk_slice[:-1], chunk_slice[1:]):
energy = self.calc_energy(audio[start:end])
# if 50% enenrgy > threshold
# print(energy)
print(sum(i >= threshold for i in energy))
if sum(i >= threshold for i in energy) >= chunk_len // 2:
count += 1
# save seg
# print(count)
if count >= 5:
sf.write("output{}.wav".format(i), audio, 16000)
if len(audio) < self.slice_len * 16000:
# print(self.slice_len*16000-len(audio))
audio = np.concatenate(
[audio,
np.zeros(self.slice_len * 16000 - len(audio))])
fin_audios.append(audio)
i += 1
if len(fin_audios) == 0:
fin_audios.append(np.zeros(self.slice_len * 16000))
return fin_audios
================================================
FILE: wesep/models/__init__.py
================================================
import wesep.models.bsrnn as bsrnn
import wesep.models.convtasnet as convtasnet
import wesep.models.dpccn as dpccn
import wesep.models.tfgridnet as tfgridnet
import wesep.modules.metric_gan.discriminator as discriminator
import wesep.models.bsrnn_multi_optim as bsrnn_multi
import wesep.models.bsrnn_feats as bsrnn_feats
def get_model(model_name: str):
if model_name.startswith("ConvTasNet"):
return getattr(convtasnet, model_name)
elif model_name.startswith("BSRNN_Multi"):
return getattr(bsrnn_multi, model_name)
elif model_name.startswith("BSRNN_Feats"):
return getattr(bsrnn_feats, model_name)
elif model_name.startswith("BSRNN"):
return getattr(bsrnn, model_name)
elif model_name.startswith("DPCCN"):
return getattr(dpccn, model_name)
elif model_name.startswith("TFGridNet"):
return getattr(tfgridnet, model_name)
elif model_name.startswith("CMGAN"):
return getattr(discriminator, model_name)
else: # model_name error !!!
print(model_name + " not found !!!")
exit(1)
================================================
FILE: wesep/models/bsrnn.py
================================================
from __future__ import print_function
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from wespeaker.models.speaker_model import get_speaker_model
from wesep.modules.common.speaker import PreEmphasis
from wesep.modules.common.speaker import SpeakerFuseLayer
from wesep.modules.common.speaker import SpeakerTransform
class ResRNN(nn.Module):
def __init__(self, input_size, hidden_size, bidirectional=True):
super(ResRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.eps = torch.finfo(torch.float32).eps
self.norm = nn.GroupNorm(1, input_size, self.eps)
self.rnn = nn.LSTM(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
# linear projection layer
self.proj = nn.Linear(hidden_size * 2,
input_size) # hidden_size = feature_dim * 2
def forward(self, input):
# input shape: batch, dim, seq
rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())
rnn_output = self.proj(rnn_output.contiguous().view(
-1, rnn_output.shape[2])).view(input.shape[0], input.shape[2],
input.shape[1])
return input + rnn_output.transpose(1, 2).contiguous()
"""
TODO : attach the speaker embedding to each input
Input shape:(B,feature_dim + spk_emb_dim , T)
"""
class BSNet(nn.Module):
def __init__(self, in_channel, nband=7, bidirectional=True):
super(BSNet, self).__init__()
self.nband = nband
self.feature_dim = in_channel // nband
self.band_rnn = ResRNN(self.feature_dim,
self.feature_dim * 2,
bidirectional=bidirectional)
self.band_comm = ResRNN(self.feature_dim,
self.feature_dim * 2,
bidirectional=bidirectional)
def forward(self, input, dummy: Optional[torch.Tensor] = None):
# input shape: B, nband*N, T
B, N, T = input.shape
band_output = self.band_rnn(
input.view(B * self.nband, self.feature_dim,
-1)).view(B, self.nband, -1, T)
# band comm
band_output = (band_output.permute(0, 3, 2, 1).contiguous().view(
B * T, -1, self.nband))
output = (self.band_comm(band_output).view(
B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous())
return output.view(B, N, T)
class FuseSeparation(nn.Module):
def __init__(
self,
nband=7,
num_repeat=6,
feature_dim=128,
spk_emb_dim=256,
spk_fuse_type="concat",
multi_fuse=True,
):
"""
:param nband : len(self.band_width)
"""
super(FuseSeparation, self).__init__()
self.multi_fuse = multi_fuse
self.nband = nband
self.feature_dim = feature_dim
self.separation = nn.ModuleList([])
if self.multi_fuse:
for _ in range(num_repeat):
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type,
))
self.separation.append(BSNet(nband * feature_dim, nband))
else:
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type,
))
for _ in range(num_repeat):
self.separation.append(BSNet(nband * feature_dim, nband))
def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
"""
x: [B, nband, feature_dim, T]
out: [B, nband, feature_dim, T]
"""
batch_size = x.shape[0]
if self.multi_fuse:
for i, sep_func in enumerate(self.separation):
x = sep_func(x, spk_embedding)
if i % 2 == 0:
x = x.view(batch_size * nch, self.nband * self.feature_dim,
-1)
else:
x = x.view(batch_size * nch, self.nband, self.feature_dim,
-1)
else:
x = self.separation[0](x, spk_embedding)
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
for idx, sep in enumerate(self.separation):
if idx > 0:
x = sep(x, spk_embedding)
x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)
return x
class BSRNN(nn.Module):
# self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,
# use_bidirectional=True
def __init__(
self,
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
use_spk_transform=True,
use_bidirectional=True,
spk_fuse_type="concat",
multi_fuse=True,
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
):
super(BSRNN, self).__init__()
self.sr = sr
self.win = win
self.stride = stride
self.group = self.win // 2
self.enc_dim = self.win // 2 + 1
self.feature_dim = feature_dim
self.eps = torch.finfo(torch.float32).eps
self.spk_emb_dim = spk_emb_dim
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
# 0-1k (100 hop), 1k-4k (250 hop),
# 4k-8k (500 hop), 8k-16k (1k hop),
# 16k-20k (2k hop), 20k-inf
# 0-8k (1k hop), 8k-16k (2k hop), 16k
bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))
bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))
bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))
bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))
# add up to 8k
self.band_width = [bandwidth_100] * 15
self.band_width += [bandwidth_200] * 10
self.band_width += [bandwidth_500] * 5
self.band_width += [bandwidth_2k] * 1
self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))
self.nband = len(self.band_width)
if use_spk_transform:
self.spk_transform = SpeakerTransform()
else:
self.spk_transform = nn.Identity()
if joint_training:
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if not spk_feat:
if feat_type == "consistent":
self.preEmphasis = PreEmphasis()
self.spk_encoder = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=win,
win_length=win,
hop_length=stride,
f_min=20,
window_fn=torch.hamming_window,
n_mels=spk_args["feat_dim"],
)
else:
self.preEmphasis = nn.Identity()
self.spk_encoder = nn.Identity()
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
else:
self.pred_linear = nn.Identity()
self.BN = nn.ModuleList([])
for i in range(self.nband):
self.BN.append(
nn.Sequential(
nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
))
self.separator = FuseSeparation(
nband=self.nband,
num_repeat=num_repeat,
feature_dim=feature_dim,
spk_emb_dim=spk_emb_dim,
spk_fuse_type=spk_fuse_type,
multi_fuse=multi_fuse,
)
# self.proj = nn.Linear(hidden_size*2, input_size)
self.mask = nn.ModuleList([])
for i in range(self.nband):
self.mask.append(
nn.Sequential(
nn.GroupNorm(1, self.feature_dim,
torch.finfo(torch.float32).eps),
nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),
))
def pad_input(self, input, window, stride):
"""
Zero-padding input according to window/stride size.
"""
batch_size, nsample = input.shape
# pad the signals at the end for matching the window/stride size
rest = window - (stride + nsample % window) % window
if rest > 0:
pad = torch.zeros(batch_size, rest).type(input.type())
input = torch.cat([input, pad], 1)
pad_aux = torch.zeros(batch_size, stride).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 1)
return input, rest
def forward(self, input, embeddings):
# input shape: (B, C, T)
wav_input = input
spk_emb_input = embeddings
batch_size, nsample = wav_input.shape
nch = 1
# frequency-domain separation
spec = torch.stft(
wav_input,
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(wav_input.device).type(
wav_input.type()),
return_complex=True,
)
# concat real and imag, split to subbands
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
subband_spec = []
subband_mix_spec = []
band_idx = 0
for i in range(len(self.band_width)):
subband_spec.append(spec_RI[:, :, band_idx:band_idx +
self.band_width[i]].contiguous())
subband_mix_spec.append(spec[:, band_idx:band_idx +
self.band_width[i]]) # B*nch, BW, T
band_idx += self.band_width[i]
# normalization and bottleneck
subband_feature = []
for i, bn_func in enumerate(self.BN):
subband_feature.append(
bn_func(subband_spec[i].view(batch_size * nch,
self.band_width[i] * 2, -1)))
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
# print(subband_feature.size(), spk_emb_input.size())
predict_speaker_lable = torch.tensor(0.0).to(
spk_emb_input.device) # dummy
if self.joint_training:
if not self.spk_feat:
if self.feat_type == "consistent":
with torch.no_grad():
spk_emb_input = self.preEmphasis(spk_emb_input)
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
spk_emb_input = spk_emb_input.log()
spk_emb_input = spk_emb_input - torch.mean(
spk_emb_input, dim=-1, keepdim=True)
spk_emb_input = spk_emb_input.permute(0, 2, 1)
tmp_spk_emb_input = self.spk_model(spk_emb_input)
if isinstance(tmp_spk_emb_input, tuple):
spk_emb_input = tmp_spk_emb_input[-1]
else:
spk_emb_input = tmp_spk_emb_input
predict_speaker_lable = self.pred_linear(spk_emb_input)
spk_embedding = self.spk_transform(spk_emb_input)
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
sep_output = self.separator(subband_feature, spk_embedding,
torch.tensor(nch))
sep_subband_spec = []
for i, mask_func in enumerate(self.mask):
this_output = mask_func(sep_output[:, i]).view(
batch_size * nch, 2, 2, self.band_width[i], -1)
this_mask = this_output[:, 0] * torch.sigmoid(
this_output[:, 1]) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
est_spec_real = (subband_mix_spec[i].real * this_mask_real -
subband_mix_spec[i].imag * this_mask_imag
) # B*nch, BW, T
est_spec_imag = (subband_mix_spec[i].real * this_mask_imag +
subband_mix_spec[i].imag * this_mask_real
) # B*nch, BW, T
sep_subband_spec.append(torch.complex(est_spec_real,
est_spec_imag))
est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T
output = torch.istft(
est_spec.view(batch_size * nch, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(wav_input.device).type(
wav_input.type()),
length=nsample,
)
output = output.view(batch_size, nch, -1)
s = torch.squeeze(output, dim=1)
return s, predict_speaker_lable
if __name__ == "__main__":
from thop import profile, clever_format
model = BSRNN(
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
spk_fuse_type="additive",
)
s = 0
for param in model.parameters():
s += np.product(param.size())
print("# of parameters: " + str(s / 1024.0 / 1024.0))
x = torch.randn(4, 32000)
spk_embeddings = torch.randn(4, 256)
output = model(x, spk_embeddings)
print(output.shape)
macs, params = profile(model, inputs=(x, spk_embeddings))
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)
================================================
FILE: wesep/models/bsrnn_feats.py
================================================
from __future__ import print_function
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from wespeaker.models.speaker_model import get_speaker_model
from wesep.modules.common.speaker import PreEmphasis
from wesep.modules.common.speaker import SpeakerFuseLayer
from wesep.modules.common.speaker import SpeakerTransform
from wesep.utils.funcs import compute_fbank, apply_cmvn
class ResRNN(nn.Module):
def __init__(self, input_size, hidden_size, bidirectional=True):
super(ResRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.eps = torch.finfo(torch.float32).eps
self.norm = nn.GroupNorm(1, input_size, self.eps)
self.rnn = nn.LSTM(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
# linear projection layer
self.proj = nn.Linear(hidden_size * 2,
input_size) # hidden_size = feature_dim * 2
def forward(self, input):
# input shape: batch, dim, seq
rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())
rnn_output = self.proj(rnn_output.contiguous().view(
-1, rnn_output.shape[2])).view(input.shape[0], input.shape[2],
input.shape[1])
return input + rnn_output.transpose(1, 2).contiguous()
"""
TODO : attach the speaker embedding to each input
Input shape:(B,feature_dim + spk_emb_dim , T)
"""
class BSNet(nn.Module):
def __init__(self, in_channel, nband=7, bidirectional=True):
super(BSNet, self).__init__()
self.nband = nband
self.feature_dim = in_channel // nband
self.band_rnn = ResRNN(self.feature_dim,
self.feature_dim * 2,
bidirectional=bidirectional)
self.band_comm = ResRNN(self.feature_dim,
self.feature_dim * 2,
bidirectional=bidirectional)
def forward(self, input, dummy: Optional[torch.Tensor] = None):
# input shape: B, nband*N, T
B, N, T = input.shape
band_output = self.band_rnn(
input.view(B * self.nband, self.feature_dim,
-1)).view(B, self.nband, -1, T)
# band comm
band_output = (band_output.permute(0, 3, 2, 1).contiguous().view(
B * T, -1, self.nband))
output = (self.band_comm(band_output).view(
B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous())
return output.view(B, N, T)
class CrossAtt(nn.Module):
def __init__(self, embed_dim, num_heads, *args, **kwargs):
super(CrossAtt, self).__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads,
*args, **kwargs)
def forward(self, query, key, value):
if query.dim() == 4:
spk_embeddings = []
for i in range(query.shape[1]):
x = query[:, i, :, :].squeeze(dim=1) # (batch, feature, time)
x, _ = self.multihead_attn(x.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2))
spk_embeddings.append(x.transpose(1, 2))
spk_embeddings = torch.stack(spk_embeddings, dim=1)
elif query.dim() == 3:
x, _ = self.multihead_attn(query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2))
spk_embeddings = x.transpose(1, 2)
return spk_embeddings
class FuseSeparation(nn.Module):
def __init__(
self,
nband=7,
num_repeat=6,
feature_dim=128,
spk_emb_dim=256,
spk_fuse_type="concat",
multi_fuse=True,
):
"""
:param nband : len(self.band_width)
"""
super(FuseSeparation, self).__init__()
self.spk_fuse_type = spk_fuse_type
self.multi_fuse = multi_fuse
self.nband = nband
self.feature_dim = feature_dim
self.attenFuse = nn.ModuleList([])
if spk_fuse_type and spk_fuse_type.startswith("cross_"):
spk_emb_frame_dim = 512 # Ecapa_TDNN
spk_emb_dim = feature_dim
self.attenFuse.append(nn.Linear(spk_emb_frame_dim, feature_dim))
self.attenFuse.append(CrossAtt(embed_dim=feature_dim, num_heads=2,
batch_first=True))
self.separation = nn.ModuleList([])
if self.multi_fuse and self.spk_fuse_type:
for _ in range(num_repeat):
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type.removeprefix("cross_"),
))
self.separation.append(BSNet(nband * feature_dim, nband))
else:
if self.spk_fuse_type:
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type.removeprefix("cross_"),
))
for _ in range(num_repeat):
self.separation.append(BSNet(nband * feature_dim, nband))
def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
"""
x: [B, nband, feature_dim, T]
out: [B, nband, feature_dim, T]
"""
batch_size = x.shape[0]
if self.spk_fuse_type and self.spk_fuse_type.startswith('cross_'):
spk_embedding = spk_embedding.transpose(1, 2)
spk_embedding = self.attenFuse[0](spk_embedding)
spk_embedding = spk_embedding.transpose(1, 2)
spk_embedding = self.attenFuse[1](x, spk_embedding, spk_embedding)
if self.multi_fuse and self.spk_fuse_type:
for i, sep_func in enumerate(self.separation):
x = sep_func(x, spk_embedding)
if i % 2 == 0:
x = x.view(batch_size * nch, self.nband * self.feature_dim,
-1)
else:
x = x.view(batch_size * nch, self.nband, self.feature_dim,
-1)
if self.spk_fuse_type.startswith('cross_'):
spk_embedding = spk_embedding.transpose(1, 2)
spk_embedding = self.attenFuse[0](spk_embedding)
spk_embedding = spk_embedding.transpose(1, 2)
spk_embedding = self.attenFuse[1](x, spk_embedding,
spk_embedding)
else:
idx_start = -1
if self.spk_fuse_type:
x = self.separation[0](x, spk_embedding)
idx_start += 1
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
for idx, sep in enumerate(self.separation):
if idx > idx_start:
x = sep(x, spk_embedding)
x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)
return x
class BSRNN_Feats(nn.Module):
# self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,
# use_bidirectional=True
def __init__(
self,
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
use_spk_transform=False,
use_bidirectional=True,
spectral_feat=False,
spk_fuse_type="concat",
multi_fuse=False,
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
):
super(BSRNN_Feats, self).__init__()
self.sr = sr
self.win = win
self.stride = stride
self.group = self.win // 2
self.enc_dim = self.win // 2 + 1
self.feature_dim = feature_dim
self.eps = torch.finfo(torch.float32).eps
self.spk_emb_dim = spk_emb_dim
self.spk_fuse_type = spk_fuse_type
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
# 0-1k (100 hop), 1k-4k (250 hop),
# 4k-8k (500 hop), 8k-16k (1k hop),
# 16k-20k (2k hop), 20k-inf
# 0-8k (1k hop), 8k-16k (2k hop), 16k
bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))
bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))
bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))
bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))
# add up to 8k
self.band_width = [bandwidth_100] * 15
self.band_width += [bandwidth_200] * 10
self.band_width += [bandwidth_500] * 5
self.band_width += [bandwidth_2k] * 1
self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))
self.nband = len(self.band_width)
if use_spk_transform:
self.spk_transform = SpeakerTransform()
else:
self.spk_transform = nn.Identity()
if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb'):
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if not spk_feat:
if feat_type == "consistent":
self.preEmphasis = PreEmphasis()
self.spk_encoder = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=win,
win_length=win,
hop_length=stride,
f_min=20,
window_fn=torch.hamming_window,
n_mels=spk_args["feat_dim"],
)
else:
self.preEmphasis = nn.Identity()
self.spk_encoder = nn.Identity()
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
else:
self.pred_linear = nn.Identity()
spec_map = 2
if spectral_feat:
spec_map += 1
self.spectral_feat = spectral_feat
self.spec_map = spec_map
self.BN = nn.ModuleList([])
for i in range(self.nband):
self.BN.append(
nn.Sequential(
nn.GroupNorm(1, self.band_width[i] * spec_map, self.eps),
nn.Conv1d(self.band_width[i] * spec_map, self.feature_dim, 1),
))
self.separator = FuseSeparation(
nband=self.nband,
num_repeat=num_repeat,
feature_dim=feature_dim,
spk_emb_dim=spk_emb_dim,
spk_fuse_type=spk_fuse_type,
multi_fuse=multi_fuse,
)
self.mask = nn.ModuleList([])
for i in range(self.nband):
self.mask.append(
nn.Sequential(
nn.GroupNorm(1, self.feature_dim,
torch.finfo(torch.float32).eps),
nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),
))
def pad_input(self, input, window, stride):
"""
Zero-padding input according to window/stride size.
"""
batch_size, nsample = input.shape
# pad the signals at the end for matching the window/stride size
rest = window - (stride + nsample % window) % window
if rest > 0:
pad = torch.zeros(batch_size, rest).type(input.type())
input = torch.cat([input, pad], 1)
pad_aux = torch.zeros(batch_size, stride).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 1)
return input, rest
def forward(self, input, embeddings):
# input shape: (B, C, T)
wav_input = input
spk_emb_input = embeddings
batch_size, nsample = wav_input.shape
nch = 1
# frequency-domain separation
spec = torch.stft(
wav_input,
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(wav_input.device).type(
wav_input.type()),
return_complex=True,
)
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
# Calculate the spectral level feature
if self.spectral_feat:
aux_c = torch.stft(
spk_emb_input,
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(spk_emb_input.device).type(
spk_emb_input.type()),
return_complex=True,
)
if self.spectral_feat == 'tfmap_spec':
mix_mag_ori = torch.abs(spec)
enroll_mag = torch.abs(aux_c)
mix_mag = F.normalize(mix_mag_ori, p=2, dim=1)
enroll_mag = F.normalize(enroll_mag, p=2, dim=1)
mix_mag = mix_mag.permute(0, 2, 1).contiguous()
att_scores = torch.matmul(mix_mag, enroll_mag)
att_weights = F.softmax(att_scores, dim=-1)
enroll_mag = enroll_mag.permute(0, 2, 1).contiguous()
tf_map = torch.matmul(att_weights, enroll_mag)
tf_map = tf_map.permute(0, 2, 1).contiguous()
tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)
# Recover the energy of estimated tfmap feature
tf_map = (
torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True)
* tf_map
)
# Another kind of nomalization for tf_map feature
# tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)
spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1)
if self.spectral_feat == 'tfmap_emb': # Only Ecapa-TDNN.
with torch.no_grad():
signal_dim = wav_input.dim()
extended_shape = (
[1] * (3 - signal_dim)
+ list(wav_input.size())
)
pad = int(self.win // 2)
mix_emb = F.pad(
wav_input.view(extended_shape),
[pad, pad],
mode="reflect"
)
mix_emb = mix_emb.view(mix_emb.shape[-signal_dim:])
signal_dim = spk_emb_input.dim()
extended_shape = (
[1] * (3 - signal_dim)
+ list(spk_emb_input.size())
)
pad = int(self.win // 2)
spk_emb = F.pad(
spk_emb_input.view(extended_shape),
[pad, pad],
mode="reflect"
)
spk_emb = spk_emb.view(spk_emb.shape[-signal_dim:])
spk_emb = compute_fbank(
spk_emb,
frame_length=self.win * 1e3 / self.sr,
frame_shift=self.stride * 1e3 / self.sr,
dither=0.0,
sample_rate=self.sr
)
mix_emb = compute_fbank(
mix_emb,
frame_length=self.win * 1e3 / self.sr,
frame_shift=self.stride * 1e3 / self.sr,
dither=0.0,
sample_rate=self.sr
)
mix_emb = apply_cmvn(mix_emb)
spk_emb = apply_cmvn(spk_emb)
spk_emb = self.spk_model(spk_emb)
if isinstance(spk_emb, tuple):
spk_emb_frame = spk_emb[0]
else:
spk_emb_frame = spk_emb
mix_emb = self.spk_model(mix_emb)
if isinstance(mix_emb, tuple):
mix_emb_frame = mix_emb[0]
else:
mix_emb_frame = mix_emb
mix_emb_frame_ = F.normalize(mix_emb_frame, p=2, dim=1)
spk_emb_frame_ = F.normalize(spk_emb_frame, p=2, dim=1)
mix_emb_frame_ = mix_emb_frame_.transpose(1, 2)
att_scores = torch.matmul(mix_emb_frame_, spk_emb_frame_)
att_weights = F.softmax(att_scores, dim=-1)
mix_mag_ori = torch.abs(spec)
enroll_mag = torch.abs(aux_c)
enroll_mag = enroll_mag.transpose(1, 2)
# enroll_mag = F.normalize(enroll_mag, p=2, dim=1)
tf_map = torch.matmul(att_weights, enroll_mag)
tf_map = tf_map.transpose(1, 2)
tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)
# Recover the energy of estimated tfmap feature
tf_map = (
torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True)
* tf_map
)
# Another kind of nomalization for tf_map feature
# tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)
spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1)
# concat real and imag, split to subbands
subband_spec = []
subband_mix_spec = []
band_idx = 0
for i in range(len(self.band_width)):
subband_spec.append(spec_RI[:, :, band_idx:band_idx +
self.band_width[i]].contiguous())
subband_mix_spec.append(spec[:, band_idx:band_idx +
self.band_width[i]]) # B*nch, BW, T
band_idx += self.band_width[i]
# normalization and bottleneck
subband_feature = []
for i, bn_func in enumerate(self.BN):
subband_feature.append(
bn_func(subband_spec[i].view(batch_size * nch,
self.band_width[i] * self.spec_map,
-1)))
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
# print(subband_feature.size(), spk_emb_input.size())
predict_speaker_lable = torch.tensor(0.0).to(
spk_emb_input.device) # dummy
if (
(self.spectral_feat and self.spectral_feat == "tfmap_emb")
and (self.spk_fuse_type and self.spk_fuse_type.startswith("cross_"))
):
spk_emb_input = spk_emb_frame
elif self.joint_training and self.spk_fuse_type:
if not self.spk_feat:
if self.feat_type == "consistent":
with torch.no_grad():
spk_emb_input = self.preEmphasis(spk_emb_input)
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
spk_emb_input = spk_emb_input.log()
spk_emb_input = spk_emb_input - torch.mean(
spk_emb_input, dim=-1, keepdim=True)
spk_emb_input = spk_emb_input.permute(0, 2, 1)
if self.spk_fuse_type and self.spk_fuse_type.startswith("cross_"):
tmp_spk_emb_input = self.spk_model._get_frame_level_feat(
spk_emb_input)
else:
tmp_spk_emb_input = self.spk_model(spk_emb_input)
if isinstance(tmp_spk_emb_input, tuple):
spk_emb_input = tmp_spk_emb_input[-1]
else:
spk_emb_input = tmp_spk_emb_input
predict_speaker_lable = self.pred_linear(spk_emb_input)
spk_embedding = self.spk_transform(spk_emb_input)
if self.spk_fuse_type and not self.spk_fuse_type.startswith("cross_"):
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
sep_output = self.separator(subband_feature, spk_embedding,
torch.tensor(nch))
sep_subband_spec = []
for i, mask_func in enumerate(self.mask):
this_output = mask_func(sep_output[:, i]).view(
batch_size * nch, 2, 2, self.band_width[i], -1)
this_mask = this_output[:, 0] * torch.sigmoid(
this_output[:, 1]) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
est_spec_real = (subband_mix_spec[i].real * this_mask_real -
subband_mix_spec[i].imag * this_mask_imag
) # B*nch, BW, T
est_spec_imag = (subband_mix_spec[i].real * this_mask_imag +
subband_mix_spec[i].imag * this_mask_real
) # B*nch, BW, T
sep_subband_spec.append(torch.complex(est_spec_real,
est_spec_imag))
est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T
output = torch.istft(
est_spec.view(batch_size * nch, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(wav_input.device).type(
wav_input.type()),
length=nsample,
)
output = output.view(batch_size, nch, -1)
s = torch.squeeze(output, dim=1)
return s, predict_speaker_lable
if __name__ == "__main__":
from thop import profile, clever_format
model = BSRNN_Feats(
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
spectral_feat='tfmap_emb',
spk_fuse_type='cross_multiply',
spk_model="ECAPA_TDNN_GLOB_c512",
spk_args={
"embed_dim": 192,
"feat_dim": 80,
"pooling_func": "ASTP",
}
)
s = 0
for param in model.parameters():
s += np.product(param.size())
print("# of parameters: " + str(s / 1024.0 / 1024.0))
x = torch.randn(4, 32000)
spk_embeddings = torch.randn(4, 16000)
output = model(x, spk_embeddings)
print(output[0].shape)
macs, params = profile(model, inputs=(x, spk_embeddings))
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)
================================================
FILE: wesep/models/bsrnn_multi_optim.py
================================================
from __future__ import print_function
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from wespeaker.models.speaker_model import get_speaker_model
from wesep.modules.common.speaker import PreEmphasis
from wesep.modules.common.speaker import SpeakerFuseLayer
from wesep.modules.common.speaker import SpeakerTransform
class ResRNN(nn.Module):
def __init__(self, input_size, hidden_size, bidirectional=True):
super(ResRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.eps = torch.finfo(torch.float32).eps
self.norm = nn.GroupNorm(1, input_size, self.eps)
self.rnn = nn.LSTM(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
# linear projection layer
self.proj = nn.Linear(
hidden_size * 2, input_size
) # hidden_size = feature_dim * 2
def forward(self, input):
# input shape: batch, dim, seq
rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())
rnn_output = self.proj(
rnn_output.contiguous().view(-1, rnn_output.shape[2])
).view(input.shape[0], input.shape[2], input.shape[1])
return input + rnn_output.transpose(1, 2).contiguous()
"""
TODO : attach the speaker embedding to each input
Input shape:(B,feature_dim + spk_emb_dim , T)
"""
class BSNet(nn.Module):
def __init__(self, in_channel, nband=7, bidirectional=True):
super(BSNet, self).__init__()
self.nband = nband
self.feature_dim = in_channel // nband
self.band_rnn = ResRNN(
self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional
)
self.band_comm = ResRNN(
self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional
)
def forward(self, input, dummy: Optional[torch.Tensor] = None):
# input shape: B, nband*N, T
B, N, T = input.shape
band_output = self.band_rnn(
input.view(B * self.nband, self.feature_dim, -1)
).view(B, self.nband, -1, T)
# band comm
band_output = (
band_output.permute(0, 3, 2, 1).contiguous().view(B * T, -1, self.nband)
)
output = (
self.band_comm(band_output)
.view(B, T, -1, self.nband)
.permute(0, 3, 2, 1)
.contiguous()
)
return output.view(B, N, T)
class FuseSeparation(nn.Module):
def __init__(
self,
nband=7,
num_repeat=6,
feature_dim=128,
spk_emb_dim=256,
spk_fuse_type="concat",
multi_fuse=True,
):
"""
:param nband : len(self.band_width)
"""
super(FuseSeparation, self).__init__()
self.multi_fuse = multi_fuse
self.nband = nband
self.feature_dim = feature_dim
self.separation = nn.ModuleList([])
if self.multi_fuse:
for _ in range(num_repeat):
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type,
)
)
self.separation.append(BSNet(nband * feature_dim, nband))
else:
self.separation.append(
SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type,
)
)
for _ in range(num_repeat):
self.separation.append(BSNet(nband * feature_dim, nband))
def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
"""
x: [B, nband, feature_dim, T]
out: [B, nband, feature_dim, T]
"""
batch_size = x.shape[0]
if self.multi_fuse:
for i, sep_func in enumerate(self.separation):
x = sep_func(x, spk_embedding)
if i % 2 == 0:
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
else:
x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)
else:
x = self.separation[0](x, spk_embedding)
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
for idx, sep in enumerate(self.separation):
if idx > 0:
x = sep(x, spk_embedding)
x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)
return x
class BSRNN_Multi(nn.Module):
# self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6,
# use_bidirectional=True
def __init__(
self,
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
use_spk_transform=True,
use_bidirectional=True,
spk_fuse_type="concat",
multi_fuse=True,
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
):
super(BSRNN_Multi, self).__init__()
self.sr = sr
self.win = win
self.stride = stride
self.group = self.win // 2
self.enc_dim = self.win // 2 + 1
self.feature_dim = feature_dim
self.eps = torch.finfo(torch.float32).eps
self.spk_emb_dim = spk_emb_dim
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
# 0-1k (100 hop), 1k-4k (250 hop),
# 4k-8k (500 hop), 8k-16k (1k hop),
# 16k-20k (2k hop), 20k-inf
# 0-8k (1k hop), 8k-16k (2k hop), 16k
bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))
bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim))
bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))
bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))
# add up to 8k
self.band_width = [bandwidth_100] * 15
self.band_width += [bandwidth_200] * 10
self.band_width += [bandwidth_500] * 5
self.band_width += [bandwidth_2k] * 1
self.band_width.append(self.enc_dim - int(np.sum(self.band_width)))
self.nband = len(self.band_width)
if use_spk_transform:
self.spk_transform = SpeakerTransform()
else:
self.spk_transform = nn.Identity()
if joint_training:
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if not spk_feat:
if feat_type == "consistent":
self.preEmphasis = PreEmphasis()
self.spk_encoder = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=win,
win_length=win,
hop_length=stride,
f_min=20,
window_fn=torch.hamming_window,
n_mels=spk_args["feat_dim"],
)
else:
self.preEmphasis = nn.Identity()
self.spk_encoder = nn.Identity()
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
else:
self.pred_linear = nn.Identity()
self.BN = nn.ModuleList([])
for i in range(self.nband):
self.BN.append(
nn.Sequential(
nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
)
)
self.separator = FuseSeparation(
nband=self.nband,
num_repeat=num_repeat,
feature_dim=feature_dim,
spk_emb_dim=spk_emb_dim,
spk_fuse_type=spk_fuse_type,
multi_fuse=multi_fuse,
)
# self.proj = nn.Linear(hidden_size*2, input_size)
self.mask = nn.ModuleList([])
for i in range(self.nband):
self.mask.append(
nn.Sequential(
nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),
nn.Tanh(),
nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),
)
)
def pad_input(self, input, window, stride):
"""
Zero-padding input according to window/stride size.
"""
batch_size, nsample = input.shape
# pad the signals at the end for matching the window/stride size
rest = window - (stride + nsample % window) % window
if rest > 0:
pad = torch.zeros(batch_size, rest).type(input.type())
input = torch.cat([input, pad], 1)
pad_aux = torch.zeros(batch_size, stride).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 1)
return input, rest
def forward(self, input, embeddings):
# input shape: (B, C, T)
wav_input = input
spk_emb_input = embeddings
batch_size, nsample = wav_input.shape
nch = 1
# frequency-domain separation
spec = torch.stft(
wav_input,
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win)
.to(wav_input.device)
.type(wav_input.type()),
return_complex=True,
)
# concat real and imag, split to subbands
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
subband_spec = []
subband_mix_spec = []
band_idx = 0
for i in range(len(self.band_width)):
subband_spec.append(
spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous()
)
subband_mix_spec.append(
spec[:, band_idx : band_idx + self.band_width[i]]
) # B*nch, BW, T
band_idx += self.band_width[i]
# normalization and bottleneck
subband_feature = []
for i, bn_func in enumerate(self.BN):
subband_feature.append(
bn_func(
subband_spec[i].view(batch_size * nch, self.band_width[i] * 2, -1)
)
)
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
# print(subband_feature.size(), spk_emb_input.size())
predict_speaker_lable = torch.tensor(0.0).to(spk_emb_input.device) # dummy
if self.joint_training:
if not self.spk_feat:
if self.feat_type == "consistent":
with torch.no_grad():
spk_emb_input = self.preEmphasis(spk_emb_input)
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
spk_emb_input = spk_emb_input.log()
spk_emb_input = spk_emb_input - torch.mean(
spk_emb_input, dim=-1, keepdim=True
)
spk_emb_input = spk_emb_input.permute(0, 2, 1)
tmp_spk_emb_input = self.spk_model(spk_emb_input)
if isinstance(tmp_spk_emb_input, tuple):
spk_emb_input = tmp_spk_emb_input[-1]
else:
spk_emb_input = tmp_spk_emb_input
predict_speaker_lable = self.pred_linear(spk_emb_input)
spk_embedding = self.spk_transform(spk_emb_input)
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch))
sep_subband_spec = []
for i, mask_func in enumerate(self.mask):
this_output = mask_func(sep_output[:, i]).view(
batch_size * nch, 2, 2, self.band_width[i], -1
)
this_mask = this_output[:, 0] * torch.sigmoid(
this_output[:, 1]
) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
est_spec_real = (
subband_mix_spec[i].real * this_mask_real
- subband_mix_spec[i].imag * this_mask_imag
) # B*nch, BW, T
est_spec_imag = (
subband_mix_spec[i].real * this_mask_imag
+ subband_mix_spec[i].imag * this_mask_real
) # B*nch, BW, T
sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T
output = torch.istft(
est_spec.view(batch_size * nch, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win)
.to(wav_input.device)
.type(wav_input.type()),
length=nsample,
)
output = output.view(batch_size, nch, -1)
s = torch.squeeze(output, dim=1)
if torch.is_grad_enabled():
self_embedding = s.detach()
self_predict_speaker_lable = torch.tensor(0.0).to(
self_embedding.device
) # dummy
if self.joint_training:
if self.feat_type == "consistent":
with torch.no_grad():
self_embedding = self.preEmphasis(self_embedding)
self_embedding = self.spk_encoder(self_embedding) + 1e-8
self_embedding = self_embedding.log()
self_embedding = self_embedding - torch.mean(
self_embedding, dim=-1, keepdim=True
)
self_embedding = self_embedding.permute(0, 2, 1)
self_tmp_spk_emb_input = self.spk_model(self_embedding)
if isinstance(self_tmp_spk_emb_input, tuple):
self_spk_emb_input = self_tmp_spk_emb_input[-1]
else:
self_spk_emb_input = self_tmp_spk_emb_input
self_predict_speaker_lable = self.pred_linear(self_spk_emb_input)
self_spk_embedding = self.spk_transform(self_spk_emb_input)
self_spk_embedding = self_spk_embedding.unsqueeze(1).unsqueeze(3)
self_sep_output = self.separator(
subband_feature, self_spk_embedding, torch.tensor(nch)
)
self_sep_subband_spec = []
for i, mask_func in enumerate(self.mask):
this_output = mask_func(self_sep_output[:, i]).view(
batch_size * nch, 2, 2, self.band_width[i], -1
)
this_mask = this_output[:, 0] * torch.sigmoid(
this_output[:, 1]
) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
est_spec_real = (
subband_mix_spec[i].real * this_mask_real
- subband_mix_spec[i].imag * this_mask_imag
) # B*nch, BW, T
est_spec_imag = (
subband_mix_spec[i].real * this_mask_imag
+ subband_mix_spec[i].imag * this_mask_real
) # B*nch, BW, T
self_sep_subband_spec.append(
torch.complex(est_spec_real, est_spec_imag)
)
self_est_spec = torch.cat(self_sep_subband_spec, 1) # B*nch, F, T
self_output = torch.istft(
self_est_spec.view(batch_size * nch, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win)
.to(wav_input.device)
.type(wav_input.type()),
length=nsample,
)
self_output = self_output.view(batch_size, nch, -1)
self_s = torch.squeeze(self_output, dim=1)
return s, self_s, predict_speaker_lable, self_predict_speaker_lable
return s, predict_speaker_lable
if __name__ == "__main__":
from thop import profile, clever_format
model = BSRNN_Multi(
spk_emb_dim=256,
sr=16000,
win=512,
stride=128,
feature_dim=128,
num_repeat=6,
spk_fuse_type="additive",
)
s = 0
for param in model.parameters():
s += np.product(param.size())
print("# of parameters: " + str(s / 1024.0 / 1024.0))
x = torch.randn(4, 32000)
spk_embeddings = torch.randn(4, 256)
output = model(x, spk_embeddings)
print(output.shape)
macs, params = profile(model, inputs=(x, spk_embeddings))
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)
================================================
FILE: wesep/models/convtasnet.py
================================================
import torch
import torch.nn as nn
from wesep.modules.common import select_norm
from wesep.modules.common.speaker import SpeakerTransform
from wesep.modules.tasnet import DeepEncoder, DeepDecoder
from wesep.modules.tasnet import MultiEncoder, MultiDecoder
from wesep.modules.tasnet import FuseSeparation
from wesep.modules.tasnet.convs import Conv1D, ConvTrans1D
from wesep.modules.tasnet.speaker import ResNet4SpExplus
from wespeaker.models.speaker_model import get_speaker_model
class ConvTasNet(nn.Module):
def __init__(
self,
N=512,
L=16,
B=128,
H=512,
P=3,
X=8,
R=3,
spk_emb_dim=256,
norm="gLN",
activate="relu",
causal=False,
skip_con=False,
spk_fuse_type="concatConv",
# "concat", "additive", "multiply", "FiLM", "None",
# ("concatConv" only for convtasnet)
multi_fuse=True,
use_spk_transform=True,
encoder_type="Multi", # 'Multi', 'Deep', None
decoder_type="Multi",
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
):
"""
:param N: Number of filters in autoencoder
:param L: Length of the filters (in samples)
:param B: Number of channels in bottleneck and the residual paths
:param H: Number of channels in convolutional blocks
:param P: Kernel size in convolutional blocks
:param X: Number of convolutional blocks in each repeat
:param R: Number of repeats
:param norm:
:param activate:
:param causal:
:param skip_con:
:param spk_fuse_type: concat/addition/FiLM
:param use_spk_transform:
:param use_deep_enc:
:param use_deep_dec:
"""
super(ConvTasNet, self).__init__()
self.encoder_type = encoder_type
self.decoder_type = decoder_type
# n x 1 x T => n x N x T
if encoder_type == "Multi":
self.encoder = MultiEncoder(
in_channels=1,
middle_channels=N,
out_channels=B,
kernel_size=L,
stride=L // 2,
)
elif encoder_type == "Deep":
self.encoder = DeepEncoder(1, N, L, stride=L // 2)
self.LayerN_S = select_norm(norm, N)
self.BottleN_S = Conv1D(N, B, 1)
else:
self.encoder = nn.Sequential(
Conv1D(1, N, L, stride=L // 2, padding=0), nn.ReLU())
self.LayerN_S = select_norm(norm, N)
self.BottleN_S = Conv1D(N, B, 1)
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
if joint_training:
if not self.spk_feat:
if self.feat_type == "consistent":
self.spk_model = ResNet4SpExplus(
in_channel=N, C_embedding=spk_emb_dim
) # The speaker model is fixed for SpEx+ currently
else:
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if self.spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
if not use_spk_transform:
self.spk_transform = nn.Identity()
else:
self.spk_transform = SpeakerTransform()
# Separation block
# n x B x T => n x B x T
self.separation = FuseSeparation(
R,
X,
B,
H,
P,
norm=norm,
causal=causal,
skip_con=skip_con,
C_embedding=spk_emb_dim,
spk_fuse_type=spk_fuse_type,
multi_fuse=multi_fuse,
)
# n x N x T => n x 1 x L
if decoder_type == "Multi":
self.decoder = MultiDecoder(
in_channels=B,
middle_channels=N,
out_channels=1,
kernel_size=L,
stride=L // 2,
)
elif decoder_type == "Deep":
self.decoder = DeepDecoder(N, L, stride=L // 2)
self.gen_masks = Conv1D(B, N, 1)
else:
self.decoder = ConvTrans1D(N, 1, L, stride=L // 2)
self.gen_masks = Conv1D(B, N, 1)
# activation function
active_f = {
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"softmax": nn.Softmax(dim=0),
}
# self.activation_type = activate
self.activation = active_f[activate]
def forward(self, x, embeddings):
if x.dim() >= 3:
raise RuntimeError(
"{} accept 1/2D tensor as input, but got {:d}".format(
self.__name__, x.dim()))
if x.dim() == 1:
x = torch.unsqueeze(x, 0)
# x: n x 1 x L => n x N x T
if self.encoder_type == "Multi":
e, w1, w2, w3 = self.encoder(x)
x = e # replace x with e, for asymmetric encoder-decoder
else:
x = self.encoder(x)
e = self.LayerN_S(x)
e = self.BottleN_S(
e) # Embedding fuse after dimension changed fro N to B
if (self.joint_training):
# Only support sharing Encoder and ResNet in SpEx+ currently
# Speaker Encoder
if not self.spk_feat and self.feat_type == "consistent":
if self.encoder_type == "Multi":
_, aux_w1, aux_w2, aux_w3 = self.encoder(embeddings)
embeddings = torch.cat([aux_w1, aux_w2, aux_w3], 1)
else:
aux_x = self.encoder(embeddings)
aux_e = self.LayerN_S(aux_x)
embeddings = self.BottleN_S(aux_e)
embeddings = self.spk_model(embeddings)
if isinstance(embeddings, tuple):
embeddings = embeddings[-1]
if self.multi_task:
predict_speaker_lable = self.pred_linear(embeddings)
spk_embeds = self.spk_transform(embeddings.unsqueeze(-1))
e = self.separation(e, spk_embeds)
# decoder part n x L
if self.decoder_type == "Multi":
s = self.decoder(
e, w1, w2, w3,
actLayer=self.activation) # s is a tuple by using multiDecoder
else:
# n x B x L => n x N x L
m = self.gen_masks(e)
# n x N x L
m = self.activation(m)
x = x * m
s = self.decoder(x)
if self.joint_training and self.multi_task:
if not isinstance(s, list):
s = [
s,
]
s.append(predict_speaker_lable)
return s # s: N x Len Or List(N x Len,x3/x4)
def check_parameters(net):
"""
Returns module parameters. Mb
"""
parameters = sum(param.numel() for param in net.parameters())
return parameters / 10**6
def test_convtasnet():
x = torch.randn(4, 32000)
spk_embeddings = torch.randn(4, 256)
net = ConvTasNet(use_spk_transform=False, spk_fuse_type="FiLM")
s = net(x, spk_embeddings)
print(str(check_parameters(net)) + " Mb")
print(s[1].shape)
if __name__ == "__main__":
test_convtasnet()
================================================
FILE: wesep/models/dpccn.py
================================================
import torch
import torch.nn as nn
import torchaudio
from wespeaker.models.speaker_model import get_speaker_model
from wesep.modules.common.speaker import PreEmphasis
from wesep.modules.common.speaker import SpeakerFuseLayer
from wesep.modules.common.speaker import SpeakerTransform
from wesep.modules.dpccn.convs import Conv2dBlock
from wesep.modules.dpccn.convs import ConvTrans2dBlock
from wesep.modules.dpccn.convs import DenseBlock
from wesep.modules.dpccn.convs import TCNBlock
class DPCCN(nn.Module):
def __init__(
self,
win=512,
stride=128,
spk_emb_dim=256,
sr=16000,
use_spk_transform=False,
spk_fuse_type="multiply",
feature_dim=257,
kernel_size=(3, 3),
stride1=(1, 1),
stride2=(1, 2),
paddings=(1, 1),
output_padding=(0, 0),
tcn_dims=384,
tcn_blocks=10,
tcn_layers=2,
causal=False,
pool_size=(4, 8, 16, 32),
multi_fuse=False,
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
) -> None:
super(DPCCN, self).__init__()
self.win_len = win
self.hop_size = stride
self.spk_emb_dim = spk_emb_dim
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
self.conv2d = nn.Conv2d(2, 16, kernel_size, stride1, paddings)
self.encoder = self._build_encoder(kernel_size=kernel_size,
stride=stride2,
padding=paddings)
if use_spk_transform:
self.spk_transform = SpeakerTransform()
else:
self.spk_transform = nn.Identity()
if joint_training:
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if not spk_feat:
if feat_type == "consistent":
self.preEmphasis = PreEmphasis()
self.spk_encoder = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=win,
win_length=win,
hop_length=stride,
f_min=20,
window_fn=torch.hamming_window,
n_mels=spk_args["feat_dim"],
)
else:
self.preEmphasis = nn.Identity()
self.spk_encoder = nn.Identity()
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
else:
self.pred_linear = nn.Identity()
self.spk_fuse = SpeakerFuseLayer(
embed_dim=self.spk_emb_dim,
feat_dim=feature_dim,
fuse_type=spk_fuse_type,
)
self.tcn_layers = self._build_tcn_layers(
tcn_layers,
tcn_blocks,
in_dims=tcn_dims,
out_dims=tcn_dims,
causal=causal,
)
self.decoder = self._build_decoder(
kernel_size=kernel_size,
stride=stride2,
padding=paddings,
output_padding=output_padding,
)
self.avg_pool = self._build_avg_pool(pool_size)
self.avg_proj = nn.Conv2d(64, 32, 1, 1)
self.deconv2d = nn.ConvTranspose2d(32, 2, kernel_size, stride1,
paddings)
def _build_encoder(self, **enc_kargs):
"""
Build encoder layers
"""
encoder = nn.ModuleList()
encoder.append(DenseBlock(16, 16, "enc"))
for i in range(4):
encoder.append(
nn.Sequential(
Conv2dBlock(in_dims=16 if i == 0 else 32,
out_dims=32,
**enc_kargs),
DenseBlock(32, 32, "enc"),
))
encoder.append(Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs))
encoder.append(Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs))
encoder.append(Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs))
return encoder
def _build_decoder(self, **dec_kargs):
"""
Build decoder layers
"""
decoder = nn.ModuleList()
decoder.append(
ConvTrans2dBlock(in_dims=384 * 2, out_dims=128, **dec_kargs))
decoder.append(
ConvTrans2dBlock(in_dims=128 * 2, out_dims=64, **dec_kargs))
decoder.append(
ConvTrans2dBlock(in_dims=64 * 2, out_dims=32, **dec_kargs))
for i in range(4):
decoder.append(
nn.Sequential(
DenseBlock(32, 64, "dec"),
ConvTrans2dBlock(in_dims=64,
out_dims=32 if i != 3 else 16,
**dec_kargs),
))
decoder.append(DenseBlock(16, 32, "dec"))
return decoder
def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs):
"""
Build TCN blocks in each repeat (layer)
"""
blocks = [
TCNBlock(**tcn_kargs, dilation=(2**b)) for b in range(tcn_blocks)
]
return nn.Sequential(*blocks)
def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs):
"""
Build TCN layers
"""
layers = [
self._build_tcn_blocks(tcn_blocks, **tcn_kargs)
for _ in range(tcn_layers)
]
return nn.Sequential(*layers)
def _build_avg_pool(self, pool_size):
"""
Build avg pooling layers
"""
avg_pool = nn.ModuleList()
for sz in pool_size:
avg_pool.append(
nn.Sequential(nn.AvgPool2d(sz), nn.Conv2d(32, 8, 1, 1)))
return avg_pool
def forward(self, input, aux):
wav_input = input
spk_emb_input = aux
batch_size, nsample = wav_input.shape
# frequency-domain separation
spec = torch.stft(
wav_input,
n_fft=self.win_len,
hop_length=self.hop_size,
window=torch.hann_window(self.win_len).to(wav_input.device).type(
wav_input.type()),
return_complex=True,
)
# concat real and imag, split to subbands
spec_RI = torch.stack([spec.real, spec.imag], 1)
# spec = torch.einsum("hijk->hikj", spec_RI) # batchsize, 2, T, F
spec = torch.transpose(spec_RI, 2, 3) # batchsize, 2, T, F
out = self.conv2d(spec)
out_list = []
out = self.encoder[0](out)
predict_speaker_lable = torch.tensor(0.0).to(
spk_emb_input.device) # dummy
if self.joint_training:
if not self.spk_feat:
if self.feat_type == "consistent":
with torch.no_grad():
spk_emb_input = self.preEmphasis(spk_emb_input)
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
spk_emb_input = spk_emb_input.log()
spk_emb_input = spk_emb_input - torch.mean(
spk_emb_input, dim=-1, keepdim=True)
spk_emb_input = spk_emb_input.permute(0, 2, 1)
tmp_spk_emb_input = self.spk_model(spk_emb_input)
if isinstance(tmp_spk_emb_input, tuple):
spk_emb_input = tmp_spk_emb_input[-1]
else:
spk_emb_input = tmp_spk_emb_input
predict_speaker_lable = self.pred_linear(spk_emb_input)
spk_embedding = self.spk_transform(spk_emb_input)
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
out = self.spk_fuse(out.transpose(2, 3), spk_embedding).transpose(2, 3)
out_list.append(out)
for _, enc in enumerate(self.encoder[1:]):
out = enc(out)
out_list.append(out)
B, N, T, F = out.shape
out = out.reshape(B, N, T * F)
out = self.tcn_layers(out)
out = out.reshape(B, N, T, F)
out_list = out_list[::-1]
for idx, dec in enumerate(self.decoder):
out = dec(torch.cat([out_list[idx], out], 1))
# Pyramidal pooling
B, N, T, F = out.shape
upsample = nn.Upsample(size=(T, F), mode="bilinear")
pool_list = []
for avg in self.avg_pool:
pool_list.append(upsample(avg(out)))
out = torch.cat([out, *pool_list], 1)
out = self.avg_proj(out)
out = self.deconv2d(out)
est_spec = torch.transpose(out, 2, 3) # (batchsize, 2, F, T)
B, N, F, T = est_spec.shape
est_spec = torch.chunk(est_spec, 2, 1) # [(B, 1, F, T), (B, 1, F, T)])
est_spec = torch.complex(est_spec[0], est_spec[1])
output = torch.istft(
est_spec.reshape(B, -1, T),
n_fft=self.win_len,
hop_length=self.hop_size,
window=torch.hann_window(self.win_len).to(wav_input.device).type(
wav_input.type()),
length=nsample,
)
return output, predict_speaker_lable
if __name__ == "__main__":
import numpy as np
model = DPCCN()
s = 0
for param in model.parameters():
s += np.product(param.size())
print("# of parameters: " + str(s / 1024.0 / 1024.0))
mix = torch.randn(4, 32000)
aux = torch.randn(4, 256)
est = model(mix, aux)
print(est.size())
================================================
FILE: wesep/models/sep_model.py
================================================
import wesep.models.bsrnn as bsrnn
import wesep.models.convtasnet as convtasnet
import wesep.models.dpccn as dpccn
import wesep.models.tfgridnet as tfgridnet
def get_model(model_name: str):
if model_name.startswith("ConvTasNet"):
return getattr(convtasnet, model_name)
elif model_name.startswith("BSRNN"):
return getattr(bsrnn, model_name)
elif model_name.startswith("DPCNN"):
return getattr(dpccn, model_name)
elif model_name.startswith("TFGridNet"):
return getattr(tfgridnet, model_name)
else: # model_name error !!!
print(model_name + " not found !!!")
exit(1)
if __name__ == "__main__":
print(get_model("ConvTasNet"))
================================================
FILE: wesep/models/tfgridnet.py
================================================
# The implementation is based on:
# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torchaudio
from packaging.version import parse as V
from wespeaker.models.speaker_model import get_speaker_model
from wesep.modules.common.speaker import PreEmphasis
from wesep.modules.common.speaker import SpeakerFuseLayer, SpeakerTransform
from wesep.modules.tfgridnet.gridnet_block import GridNetBlock
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
class TFGridNet(nn.Module):
"""Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up
the code by vectorizing multiple heads in self-attention,
and better dealing with Deconv1D in each intra- and inter-block
when emb_ks == emb_hs.
Reference:
[1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
"TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation",
in TASLP, 2023.
[2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
"TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural
Speaker Separation", in ICASSP, 2023.
NOTES:
As outlined in the Reference, this model works best when trained with
variance normalized mixture input and target, e.g., with mixture of
shape [batch, samples, microphones], you normalize it by dividing
with torch.std(mixture, (1, 2)). You must do the same for the target
signals. It is encouraged to do so when not using
scale-invariant loss functions such as SI-SDR.
Specifically, use:
std_ = std(mix)
mix = mix / std_
tgt = tgt / std_
Args:
n_srcs: number of output sources/speakers.
n_fft: stft window size.
stride: stft stride.
window: stft window type choose between 'hamming', 'hanning' or None.
n_imics: num of channels (only fixed-array geometry supported).
n_layers: number of TFGridNetV2 blocks.
lstm_hidden_units: number of hidden units in LSTM.
attn_n_head: number of heads in self-attention
attn_approx_qk_dim: approximate dim of frame-level key/value tensors
emb_dim: embedding dimension
emb_ks: kernel size for unfolding and deconv1D
emb_hs: hop size for unfolding and deconv1D
activation: activation function to use in the whole TFGridNetV2 model,
you can use any torch supported activation e.g. 'relu' or 'elu'.
eps: small epsilon for normalization layers.
spk_emb_dim: the dimension of target speaker embeddings.
use_spk_transform: whether use networks to transfer the speaker embeds.
spk_fuse_type: the fusion method of speaker embeddings.
"""
def __init__(
self,
n_srcs=1,
sr=16000,
n_fft=128,
stride=64,
window="hann",
n_imics=1,
n_layers=6,
lstm_hidden_units=192,
attn_n_head=4,
attn_approx_qk_dim=512,
emb_dim=48,
emb_ks=4,
emb_hs=1,
activation="prelu",
eps=1.0e-5,
spk_emb_dim=256,
use_spk_transform=False,
spk_fuse_type="multiply",
joint_training=True,
multi_task=False,
spksInTrain=251,
spk_model=None,
spk_model_init=None,
spk_model_freeze=False,
spk_args=None,
spk_feat=False,
feat_type="consistent",
):
super().__init__()
self.n_srcs = n_srcs
self.n_fft = n_fft
self.stride = stride
self.window = window
self.n_imics = n_imics
self.n_layers = n_layers
self.spk_emb_dim = spk_emb_dim
self.joint_training = joint_training
self.spk_feat = spk_feat
self.feat_type = feat_type
self.spk_model_freeze = spk_model_freeze
self.multi_task = multi_task
assert n_fft % 2 == 0
n_freqs = n_fft // 2 + 1
if use_spk_transform:
self.spk_transform = SpeakerTransform()
else:
self.spk_transform = nn.Identity()
if joint_training:
self.spk_model = get_speaker_model(spk_model)(**spk_args)
if spk_model_init:
pretrained_model = torch.load(spk_model_init)
state = self.spk_model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
state[key] = pretrained_model[key]
# print(key)
else:
print("not %s loaded" % key)
self.spk_model.load_state_dict(state)
if spk_model_freeze:
for param in self.spk_model.parameters():
param.requires_grad = False
if not spk_feat:
if feat_type == "consistent":
self.preEmphasis = PreEmphasis()
self.spk_encoder = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=n_fft,
win_length=n_fft,
hop_length=stride,
f_min=20,
window_fn=torch.hamming_window,
n_mels=spk_args["feat_dim"],
)
else:
self.preEmphasis = nn.Identity()
self.spk_encoder = nn.Identity()
if multi_task:
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
else:
self.pred_linear = nn.Identity()
self.spk_fuse = SpeakerFuseLayer(
embed_dim=spk_emb_dim,
feat_dim=n_freqs,
fuse_type=spk_fuse_type,
)
t_ksize = 3
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
self.conv = nn.Sequential(
nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding),
nn.GroupNorm(1, emb_dim, eps=eps),
)
self.blocks = nn.ModuleList([])
for _ in range(n_layers):
self.blocks.append(
GridNetBlock(
emb_dim,
emb_ks,
emb_hs,
n_freqs,
lstm_hidden_units,
n_head=attn_n_head,
approx_qk_dim=attn_approx_qk_dim,
activation=activation,
eps=eps,
))
self.deconv = nn.ConvTranspose2d(emb_dim,
n_srcs * 2,
ks,
padding=padding)
def forward(
self,
input: torch.Tensor,
embeddings: torch.Tensor,
) -> torch.Tensor:
"""Forward.
Args:
input (torch.Tensor): batched multi-channel audio tensor with
M audio channels and N samples [B, N, M]
embeddings (torch.Tensor): batched target speaker embeddings [B, D]
Returns:
enhanced (List[Union(torch.Tensor)]):
[(B, T), ...] list of len n_srcs
of mono audio tensors with T samples.
"""
batch_size, n_samples = input.shape[0], input.shape[1]
spk_emb_input = embeddings
if self.n_imics == 1:
assert len(input.shape) == 2
input = input[..., None] # [B, N, M]
mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
input = input / mix_std_ # RMS normalization
input = input.transpose(1, 2).reshape(
-1, input.size(1)) # [B, N, M] -> [B*M, N]
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.n_fft,
dtype=input.dtype,
device=input.device)
batch = torch.stft(
input,
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.stride,
window=window,
return_complex=True,
onesided=True,
) # [B, F, T]
batch = batch.transpose(1, 2) # [B, T, F]
batch0 = batch.view(batch_size, -1, batch.size(1),
batch.size(2)) # [B, M, T, F]
# ilens = torch.full((batch_size,), n_samples, dtype=torch.long)
batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F]
n_batch, _, n_frames, n_freqs = batch.shape
batch = self.conv(batch) # [B, -1, T, F]
predict_speaker_label = torch.tensor(0.0).to(
spk_emb_input.device) # dummy
if self.joint_training:
if not self.spk_feat:
if self.feat_type == "consistent":
with torch.no_grad():
spk_emb_input = self.preEmphasis(spk_emb_input)
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
spk_emb_input = spk_emb_input.log()
spk_emb_input = spk_emb_input - torch.mean(
spk_emb_input, dim=-1, keepdim=True)
spk_emb_input = spk_emb_input.permute(0, 2, 1)
tmp_spk_emb_input = self.spk_model(spk_emb_input)
if isinstance(tmp_spk_emb_input, tuple):
spk_emb_input = tmp_spk_emb_input[-1]
else:
spk_emb_input = tmp_spk_emb_input
predict_speaker_label = self.pred_linear(spk_emb_input)
spk_embedding = self.spk_transform(spk_emb_input) # [B, D]
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) # [B, 1, D, 1]
for ii in range(self.n_layers):
batch = torch.transpose(
self.spk_fuse(batch.transpose(2, 3), spk_embedding), 2,
3) # [B, -1, T, F]
batch = self.blocks[ii](batch) # [B, -1, T, F]
batch = self.deconv(batch) # [B, n_srcs*2, T, F]
batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs])
assert is_torch_1_9_plus, "Require torch 1.9.0+."
batch = torch.complex(batch[:, :, 0], batch[:, :, 1])
batch = torch.istft(
torch.transpose(batch.view(-1, n_frames, n_freqs), 1, 2),
n_fft=self.n_fft,
hop_length=self.stride,
win_length=self.n_fft,
window=window,
onesided=True,
length=n_samples,
return_complex=False,
) # [B, n_srcs]
batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples)
batch = batch * mix_std_ # reverse the RMS normalization
# batch = [batch[:, src] for src in range(self.num_spk)]
batch = batch.squeeze(1)
return batch, predict_speaker_label
@property
def num_spk(self):
return self.n_srcs
@staticmethod
def pad2(input_tensor, target_len):
input_tensor = torch.nn.functional.pad(
input_tensor, (0, target_len - input_tensor.shape[-1]))
return input_tensor
================================================
FILE: wesep/modules/__init__.py
================================================
================================================
FILE: wesep/modules/common/__init__.py
================================================
from wesep.modules.common.norm import ChannelWiseLayerNorm # noqa
from wesep.modules.common.norm import FiLM # noqa
from wesep.modules.common.norm import GlobalChannelLayerNorm # noqa
from wesep.modules.common.norm import select_norm # noqa
================================================
FILE: wesep/modules/common/norm.py
================================================
import numbers
import torch
import torch.nn as nn
class GlobalChannelLayerNorm(nn.Module):
"""
Calculate Global Layer Normalization
dim: (int or list or torch.Size) –
input shape from an expected input of size
eps: a value added to the denominator for numerical stability.
elementwise_affine: a boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
"""
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
super(GlobalChannelLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x):
# x = N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x L
# gln: mean,var N x 1 x 1
if x.dim() != 3:
raise RuntimeError("{} accept 3D tensor as input".format(
self.__name__))
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
# N x C x L
if self.elementwise_affine:
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps) +
self.bias)
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class ChannelWiseLayerNorm(nn.LayerNorm):
"""
Channel wise layer normalization
"""
def __init__(self, *args, **kwargs):
super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)
def forward(self, x):
"""
x: N x C x T
"""
x = torch.transpose(x, 1, 2)
x = super().forward(x)
x = torch.transpose(x, 1, 2)
return x
def select_norm(norm, dim):
"""
Build normalize layer
LN cost more memory than BN
"""
if norm not in ["cLN", "gLN", "BN"]:
raise RuntimeError("Unsupported normalize layer: {}".format(norm))
if norm == "cLN":
return ChannelWiseLayerNorm(dim, elementwise_affine=True)
elif norm == "BN":
return nn.BatchNorm1d(dim)
else:
return GlobalChannelLayerNorm(dim, elementwise_affine=True)
class FiLM(nn.Module):
"""Feature-wise Linear Modulation (FiLM) layer
https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160 # noqa
"""
def __init__(self,
feat_size,
embed_size,
num_film_layers=1,
layer_norm=False):
super(FiLM, self).__init__()
self.feat_size = feat_size
self.embed_size = embed_size
self.num_film_layers = num_film_layers
self.layer_norm = nn.LayerNorm(embed_size) if layer_norm else None
gamma_fcs, beta_fcs = [], []
for i in range(num_film_layers):
if i == 0:
gamma_fcs.append(nn.Linear(embed_size, feat_size))
beta_fcs.append(nn.Linear(embed_size, feat_size))
else:
gamma_fcs.append(nn.Linear(feat_size, feat_size))
beta_fcs.append(nn.Linear(feat_size, feat_size))
self.gamma_fcs = nn.ModuleList(gamma_fcs)
self.beta_fcs = nn.ModuleList(beta_fcs)
self.init_weights()
def init_weights(self):
for i in range(self.num_film_layers):
nn.init.zeros_(self.gamma_fcs[i].weight)
nn.init.zeros_(self.gamma_fcs[i].bias)
nn.init.zeros_(self.beta_fcs[i].weight)
nn.init.zeros_(self.beta_fcs[i].bias)
def forward(self, embed, x):
gamma, beta = None, None
for i in range(len(self.gamma_fcs)):
if i == 0:
gamma = self.gamma_fcs[i](embed)
beta = self.beta_fcs[i](embed)
else:
gamma = self.gamma_fcs[i](gamma)
beta = self.beta_fcs[i](beta)
if len(gamma.shape) < len(x.shape):
gamma = gamma.unsqueeze(-1).expand_as(x)
beta = beta.unsqueeze(-1).expand_as(x)
else:
gamma = gamma.expand_as(x)
beta = beta.expand_as(x)
# print(gamma.size(), beta.size())
x = (1 + gamma) * x + beta
if self.layer_norm is not None:
x = self.layer_norm(x)
return x
class ConditionalLayerNorm(nn.Module):
"""
https://github.com/HuangZiliAndy/fairseq/blob/multispk/fairseq/models/wavlm/WavLM.py#L1160
"""
def __init__(self,
normalized_shape,
embed_dim,
modulate_bias=False,
eps=1e-5):
super(ConditionalLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape, )
self.normalized_shape = tuple(normalized_shape)
self.embed_dim = embed_dim
self.eps = eps
self.weight = nn.Parameter(torch.empty(*normalized_shape))
self.bias = nn.Parameter(torch.empty(*normalized_shape))
assert len(normalized_shape) == 1
self.ln_weight_modulation = FiLM(normalized_shape[0], embed_dim)
self.modulate_bias = modulate_bias
if self.modulate_bias:
self.ln_bias_modulation = FiLM(normalized_shape[0], embed_dim)
else:
self.ln_bias_modulation = None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input, embed):
mean = torch.mean(input, -1, keepdim=True)
var = torch.var(input, -1, unbiased=False, keepdim=True)
weight = self.ln_weight_modulation(
embed, self.weight.expand(embed.size(0), -1))
if self.ln_bias_modulation is None:
bias = self.bias
else:
bias = self.ln_bias_modulation(embed,
self.bias.expand(embed.size(0), -1))
res = (input - mean) / torch.sqrt(var + self.eps) * weight + bias
return res
def extra_repr(self):
return "{normalized_shape}, {embed_dim}, \
modulate_bias={modulate_bias}, eps={eps}".format(**self.__dict__)
================================================
FILE: wesep/modules/common/speaker.py
================================================
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from wesep.modules.common import FiLM
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
self.register_buffer(
"flipped_filter",
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
)
def forward(self, input: torch.tensor) -> torch.tensor:
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), "reflect")
return F.conv1d(input, self.flipped_filter).squeeze(1)
class SpeakerTransform(nn.Module):
def __init__(self, embed_dim=256, num_layers=3, hid_dim=128):
"""
Transform the pretrained speaker embeddings, keep the dimension
:param embed_dim:
:param num_layers:
:param hid_dim:
:return:
"""
super(SpeakerTransform, self).__init__()
self.transforms = []
self.transforms.append(nn.Conv1d(embed_dim, hid_dim, 1))
for _ in range(num_layers - 2):
self.transforms.append(nn.Conv1d(hid_dim, hid_dim, 1))
self.transforms.append(nn.Tanh())
self.transforms.append(nn.Conv1d(hid_dim, embed_dim, 1))
self.transforms = nn.Sequential(*self.transforms)
def forward(self, x):
if len(x.size()) == 2:
return self.transforms(x.unsqueeze(-1)).squeeze(-1)
else:
return self.transforms(x)
class LinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(LinearLayer, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def forward(self, x, dummy: Optional[torch.Tensor] = None):
return self.linear(x)
class SpeakerFuseLayer(nn.Module):
def __init__(self, embed_dim=256, feat_dim=512, fuse_type="concat"):
super(SpeakerFuseLayer, self).__init__()
assert fuse_type in ["concat", "additive", "multiply", "FiLM", "None"]
self.fuse_type = fuse_type
if fuse_type == "concat":
self.fc = LinearLayer(embed_dim + feat_dim, feat_dim)
elif fuse_type == "additive":
self.fc = LinearLayer(embed_dim, feat_dim)
elif fuse_type == "multiply":
self.fc = LinearLayer(embed_dim, feat_dim)
elif fuse_type == "FiLM":
self.fc = FiLM(feat_dim, embed_dim)
else:
raise ValueError("Fuse type not defined.")
def forward(self, x, embed):
"""
:param x: batch x dimension x length
:param embed: batch x dimension x 1
:return:
"""
if self.fuse_type == "concat":
# For Conv
if len(x.size()) == 3:
embed_t = embed.expand(-1, -1, x.size(2))
y = torch.cat([x, embed_t], 1)
y = torch.transpose(y, 1, 2)
x = torch.transpose(self.fc(y), 1, 2)
else:
# len(x.size() == 4
embed_t = embed.expand(-1, x.size(1), -1, x.size(3))
y = torch.cat([x, embed_t], 2)
y = torch.transpose(y, 2, 3)
x = torch.transpose(self.fc(y), 2, 3).contiguous()
# print(x.size())
elif self.fuse_type == "additive":
if len(x.size()) == 3:
embed_t = embed.expand(-1, -1, x.size(2))
embed_t = torch.transpose(embed_t, 1, 2)
x = x + torch.transpose(self.fc(embed_t), 1, 2)
else:
# len(x.size() == 4
embed_t = embed.expand(-1, x.size(1), -1, x.size(3))
embed_t = torch.transpose(embed_t, 2, 3)
x = x + torch.transpose(self.fc(embed_t), 2, 3)
elif self.fuse_type == "multiply":
if len(x.size()) == 3:
embed_t = embed.expand(-1, -1, x.size(2))
embed_t = torch.transpose(embed_t, 1, 2)
x = x * torch.transpose(self.fc(embed_t), 1, 2)
else:
# len(x.size() == 4
embed_t = embed.expand(-1, x.size(1), -1, x.size(3))
embed_t = torch.transpose(embed_t, 2, 3)
x = x * torch.transpose(self.fc(embed_t), 2, 3)
else:
embed = embed.squeeze(-1)
x = self.fc(embed, x)
return x
def test_speaker_fuse():
st = SpeakerTransform(embed_dim=256, num_layers=3, hid_dim=128)
sfl = SpeakerFuseLayer(fuse_type="multiply")
embeds = torch.rand(4, 256)
encoder_output = torch.rand(4, 512, 1000)
print(embeds.size())
embeds = st(embeds)
print(embeds.size())
output = sfl(encoder_output, embeds)
print(output.size())
if __name__ == "__main__":
test_speaker_fuse()
================================================
FILE: wesep/modules/dpccn/__init__.py
================================================
================================================
FILE: wesep/modules/dpccn/convs.py
================================================
from typing import Tuple
import torch
import torch.nn as nn
class Conv1D(nn.Conv1d):
"""
1D conv in ConvTasNet
"""
def __init__(self, *args, **kwargs):
super(Conv1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
"""
x: N x L or N x C x L
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 2/3D tensor as input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if squeeze:
x = torch.squeeze(x)
return x
class Conv2dBlock(nn.Module):
def __init__(
self,
in_dims: int = 16,
out_dims: int = 32,
kernel_size: Tuple[int] = (3, 3),
stride: Tuple[int] = (1, 1),
padding: Tuple[int] = (1, 1),
) -> None:
super(Conv2dBlock, self).__init__()
self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride,
padding)
self.elu = nn.ELU()
self.norm = nn.InstanceNorm2d(out_dims)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv2d(x)
x = self.elu(x)
return self.norm(x)
class ConvTrans2dBlock(nn.Module):
def __init__(
self,
in_dims: int = 32,
out_dims: int = 16,
kernel_size: Tuple[int] = (3, 3),
stride: Tuple[int] = (1, 2),
padding: Tuple[int] = (1, 0),
output_padding: Tuple[int] = (0, 0),
) -> None:
super(ConvTrans2dBlock, self).__init__()
self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size,
stride, padding, output_padding)
self.elu = nn.ELU()
self.norm = nn.InstanceNorm2d(out_dims)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.convtrans2d(x)
x = self.elu(x)
return self.norm(x)
class DenseBlock(nn.Module):
def __init__(self, in_dims, out_dims, mode="enc", **kargs):
super(DenseBlock, self).__init__()
if mode not in ["enc", "dec"]:
raise RuntimeError("The mode option must be 'enc' or 'dec'!")
n = 1 if mode == "enc" else 2
self.conv1 = Conv2dBlock(in_dims=in_dims * n,
out_dims=in_dims,
**kargs)
self.conv2 = Conv2dBlock(in_dims=in_dims * (n + 1),
out_dims=in_dims,
**kargs)
self.conv3 = Conv2dBlock(in_dims=in_dims * (n + 2),
out_dims=in_dims,
**kargs)
self.conv4 = Conv2dBlock(in_dims=in_dims * (n + 3),
out_dims=in_dims,
**kargs)
self.conv5 = Conv2dBlock(in_dims=in_dims * (n + 4),
out_dims=out_dims,
**kargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y1 = self.conv1(x)
y2 = self.conv2(torch.cat([x, y1], 1))
y3 = self.conv3(torch.cat([x, y1, y2], 1))
y4 = self.conv4(torch.cat([x, y1, y2, y3], 1))
y5 = self.conv5(torch.cat([x, y1, y2, y3, y4], 1))
return y5
class TCNBlock(nn.Module):
"""
TCN block:
IN - ELU - Conv1D - IN - ELU - Conv1D
"""
def __init__(
self,
in_dims: int = 384,
out_dims: int = 384,
kernel_size: int = 3,
dilation: int = 1,
causal: bool = False,
) -> None:
super(TCNBlock, self).__init__()
self.norm1 = nn.InstanceNorm1d(in_dims)
self.elu1 = nn.ELU()
dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else
(dilation * (kernel_size - 1)))
# dilated conv
self.dconv1 = nn.Conv1d(
in_dims,
out_dims,
kernel_size,
padding=dconv_pad,
dilation=dilation,
groups=in_dims,
bias=True,
)
self.norm2 = nn.InstanceNorm1d(in_dims)
self.elu2 = nn.ELU()
self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True)
# different padding way
self.causal = causal
self.dconv_pad = dconv_pad
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.elu1(self.norm1(x))
y = self.dconv1(y)
if self.causal:
y = y[:, :, :-self.dconv_pad]
y = self.elu2(self.norm2(y))
y = self.dconv2(y)
x = x + y
return x
================================================
FILE: wesep/modules/metric_gan/__init__.py
================================================
================================================
FILE: wesep/modules/metric_gan/discriminator.py
================================================
import torch
import torch.nn as nn
# utility functions/classes used in the implementation of discriminators.
class LearnableSigmoid(nn.Module):
def __init__(self, in_features, beta=1):
super().__init__()
self.beta = beta
self.slope = nn.Parameter(torch.ones(in_features))
self.slope.requiresGrad = True
def forward(self, x):
return self.beta * torch.sigmoid(self.slope * x)
# discriminators
class CMGAN_Discriminator(nn.Module):
def __init__(
self,
n_fft=400,
hop=100,
in_channels=2,
hid_chans=16,
ksz=(4, 4),
stride=(2, 2),
padding=(1, 1),
bias=False,
num_conv_blocks=4,
num_linear_layers=2,
):
"""discriminator used in CMGAN (Interspeech 2022)
paper: https://arxiv.org/pdf/2203.15149.pdf
code: https://github.com/ruizhecao96/CMGAN
Args:
n_fft (int, optional): the windows length of stft. Defaults to 400.
hop (int, optional): the hop length of stft. Defaults to 100.
in_channels (int, optional): num of input channels. Defaults to 2.
hid_chans (int, optional): num of hidden channels. Defaults to 16.
ksz (tuple, optional): kernel size. Defaults to (4, 4).
stride (tuple, optional): stride. Defaults to (2, 2).
padding (tuple, optional): padding. Defaults to (1, 1).
bias (bool, optional): bias. Defaults to False.
num_conv_blocks (int, optional): num of conv blocks. Defaults to 4.
num_linear_layers (int, optional): num of linear layers. Defaults to 2.
"""
super(CMGAN_Discriminator, self).__init__()
assert num_conv_blocks >= num_linear_layers
self.n_fft = n_fft
self.hop = hop
self.num_conv_blocks = num_conv_blocks
self.num_linear_layers = num_linear_layers
self.conv = nn.ModuleList([])
in_chans = in_channels
out_chans = hid_chans
for i in range(num_conv_blocks):
self.conv.append(
nn.Sequential(
nn.utils.spectral_norm(
nn.Conv2d(
in_chans,
out_chans,
ksz,
stride,
padding,
bias=bias,
)),
nn.InstanceNorm2d(out_chans, affine=True),
nn.PReLU(out_chans),
))
in_chans = out_chans
out_chans = hid_chans * (2**(i + 1))
self.pooling = nn.Sequential(
nn.AdaptiveMaxPool2d(1),
nn.Flatten(),
)
self.fc = nn.ModuleList([])
for i in range(num_linear_layers - 1):
self.fc.append(
nn.Sequential(
nn.utils.spectral_norm(
nn.Linear(
hid_chans * (2**(num_conv_blocks - 1 - i)),
hid_chans * (2**(num_conv_blocks - 2 - i)),
)),
nn.Dropout(0.3),
nn.PReLU(hid_chans * (2**(num_conv_blocks - 2 - i))),
))
self.fc.append(
nn.Sequential(
nn.utils.spectral_norm(
nn.Linear(
hid_chans * (2**(num_conv_blocks - num_linear_layers)),
1,
)),
LearnableSigmoid(1),
))
def forward(self, ref_wav, est_wav):
"""
Args:
ref_wav (torch.Tensor): the reference signal. [B, T]
est_wav (torch.Tensor): the estimated signal. [B, T]
Return:
estimated_scores (torch.Tensor): estimated scores, [B]
"""
ref_spec = torch.stft(
ref_wav,
self.n_fft,
self.hop,
window=torch.hann_window(self.n_fft).to(ref_wav.device).type(
ref_wav.type()),
return_complex=True,
).transpose(-1, -2)
est_spec = torch.stft(
est_wav,
self.n_fft,
self.hop,
window=torch.hann_window(self.n_fft).to(est_wav.device).type(
est_wav.type()),
return_complex=True,
).transpose(-1, -2)
# input shape: (B, 2, T, F)
input = torch.stack((abs(ref_spec), abs(est_spec)), dim=1)
for i in range(self.num_conv_blocks):
input = self.conv[i](input)
input = self.pooling(input)
for i in range(self.num_linear_layers):
input = self.fc[i](input)
return input
if __name__ == "__main__":
# functions used to test discriminators
def test_CMGAN_Discriminator():
B, T = 2, 16000
ref_spec = torch.randn(B, T)
est_spec = torch.randn(B, T)
D = CMGAN_Discriminator()
metric = D(ref_spec, est_spec).detach()
print(f"estimated metric score is {metric}")
test_CMGAN_Discriminator()
================================================
FILE: wesep/modules/tasnet/__init__.py
================================================
from wesep.modules.tasnet.decoder import DeepDecoder # noqa
from wesep.modules.tasnet.decoder import MultiDecoder # noqa
from wesep.modules.tasnet.encoder import DeepEncoder # noqa
from wesep.modules.tasnet.encoder import MultiEncoder # noqa
from wesep.modules.tasnet.separation import Separation, FuseSeparation # noqa
from wesep.modules.tasnet.speaker import ResNet4SpExplus # noqa
================================================
FILE: wesep/modules/tasnet/convs.py
================================================
import torch
import torch.nn as nn
from wesep.modules.common import select_norm
# from wesep.modules.common.spkadapt import SpeakerFuseLayer
class Conv1D(nn.Conv1d):
def __init__(self, *args, **kwargs):
super(Conv1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
# x: N x C x L
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 2/3D tensor as input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if squeeze:
x = torch.squeeze(x)
return x
class ConvTrans1D(nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
super(ConvTrans1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
"""
x: N x L or N x C x L
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 2/3D tensor as input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if squeeze:
x = torch.squeeze(x)
return x
class Conv1DBlock(nn.Module):
"""
Consider only residual links
"""
def __init__(
self,
in_channels=256,
out_channels=512,
kernel_size=3,
dilation=1,
norm="gln",
causal=False,
skip_con=True,
):
super(Conv1DBlock, self).__init__()
# conv 1 x 1
self.conv1x1 = Conv1D(in_channels, out_channels, 1)
self.PReLU_1 = nn.PReLU()
self.norm_1 = select_norm(norm, out_channels)
# not causal don't need to padding, causal need to pad+1 = kernel_size
self.pad = ((dilation * (kernel_size - 1)) // 2 if not causal else
(dilation * (kernel_size - 1)))
# depthwise convolution
# TODO: This is not depthwise seperable convolution
self.dwconv = Conv1D(
out_channels,
out_channels,
kernel_size,
groups=out_channels,
padding=self.pad,
dilation=dilation,
)
self.PReLU_2 = nn.PReLU()
self.norm_2 = select_norm(norm, out_channels)
if skip_con:
self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True)
self.Output = nn.Conv1d(out_channels, in_channels, 1, bias=True)
self.causal = causal
self.skip_con = skip_con
def forward(self, x):
# x: N x C x L
# N x O_C x L
c = self.conv1x1(x)
# N x O_C x L
c = self.PReLU_1(c)
c = self.norm_1(c)
# causal: N x O_C x (L+pad)
# noncausal: N x O_C x L
c = self.dwconv(c)
if self.causal:
c = c[:, :, :-self.pad]
c = self.PReLU_2(c)
c = self.norm_2(c)
# N x O_C x L
if self.skip_con:
Sc = self.Sc_conv(c)
c = self.Output(c)
return Sc, c + x
c = self.Output(c)
return x + c
class Conv1DBlock4Fuse(nn.Module):
"""
1D convolutional block:
Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv
"""
def __init__(
self,
in_channels=256,
spk_embed_dim=100,
conv_channels=512,
kernel_size=3,
dilation=1,
norm="cLN",
causal=False,
):
super(Conv1DBlock4Fuse, self).__init__()
# 1x1 conv
self.conv1x1 = Conv1D(in_channels + spk_embed_dim, conv_channels, 1)
self.prelu1 = nn.PReLU()
self.lnorm1 = select_norm(norm, conv_channels)
dconv_pad = ((dilation * (kernel_size - 1)) // 2 if not causal else
(dilation * (kernel_size - 1)))
# depthwise conv
self.dconv = nn.Conv1d(
conv_channels,
conv_channels,
kernel_size,
groups=conv_channels,
padding=dconv_pad,
dilation=dilation,
bias=True,
)
self.prelu2 = nn.PReLU()
self.lnorm2 = select_norm(norm, conv_channels)
# 1x1 conv cross channel
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
# different padding way
self.causal = causal
self.dconv_pad = dconv_pad
def forward(self, x, aux):
T = x.shape[-1]
aux = aux.repeat(1, 1, T)
y = torch.cat([x, aux], 1)
y = self.conv1x1(y)
y = self.lnorm1(self.prelu1(y))
y = self.dconv(y)
if self.causal:
y = y[:, :, :-self.dconv_pad]
y = self.lnorm2(self.prelu2(y))
y = self.sconv(y)
x = x + y
return x
================================================
FILE: wesep/modules/tasnet/decoder.py
================================================
import torch
import torch.nn as nn
from wesep.modules.tasnet.convs import Conv1D, ConvTrans1D
class DeepDecoder(nn.Module):
def __init__(self, N, kernel_size=16, stride=16 // 2):
super(DeepDecoder, self).__init__()
self.sequential = nn.Sequential(
nn.ConvTranspose1d(N,
N,
kernel_size=3,
stride=1,
dilation=8,
padding=8),
nn.PReLU(),
nn.ConvTranspose1d(N,
N,
kernel_size=3,
stride=1,
dilation=4,
padding=4),
nn.PReLU(),
nn.ConvTranspose1d(N,
N,
kernel_size=3,
stride=1,
dilation=2,
padding=2),
nn.PReLU(),
nn.ConvTranspose1d(N,
N,
kernel_size=3,
stride=1,
dilation=1,
padding=1),
nn.PReLU(),
nn.ConvTranspose1d(N,
1,
kernel_size=kernel_size,
stride=stride,
bias=True),
)
def forward(self, x):
"""
x: N x L or N x C x L
"""
x = self.sequential(x)
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x
class MultiDecoder(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, kernel_size,
stride):
super(MultiDecoder, self).__init__()
B = in_channels
N = middle_channels
L = kernel_size
# n x B x T => n x 2N x T
self.mask1 = Conv1D(B, N, 1)
self.mask2 = Conv1D(B, N, 1)
self.mask3 = Conv1D(B, N, 1)
# using ConvTrans1D: n x N x T => n x 1 x To
# To = (T - 1) * L // 2 + L
self.decoder_1d_1 = ConvTrans1D(N,
out_channels,
kernel_size=L,
stride=stride,
bias=True)
self.decoder_1d_2 = ConvTrans1D(N,
out_channels,
kernel_size=80,
stride=stride,
bias=True)
self.decoder_1d_3 = ConvTrans1D(N,
out_channels,
kernel_size=160,
stride=stride,
bias=True)
def forward(self, x, w1, w2, w3, actLayer):
"""
x: N x L or N x C x L
"""
m1 = actLayer(self.mask1(x))
m2 = actLayer(self.mask2(x))
m3 = actLayer(self.mask3(x))
s1 = w1 * m1
s2 = w2 * m2
s3 = w3 * m3
est1 = self.decoder_1d_1(s1, squeeze=True)
xlen = est1.shape[-1]
if est1.dim() > 1:
est2 = self.decoder_1d_2(s2, squeeze=True)[:, :xlen]
est3 = self.decoder_1d_3(s3, squeeze=True)[:, :xlen]
else:
est1 = est1.unsqueeze(0)
est2 = self.decoder_1d_2(s2, squeeze=True).unsqueeze(0)[:, :xlen]
est3 = self.decoder_1d_3(s3, squeeze=True).unsqueeze(0)[:, :xlen]
s = [est1, est2, est3]
return s
================================================
FILE: wesep/modules/tasnet/encoder.py
================================================
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from wesep.modules.common import select_norm
from wesep.modules.tasnet.convs import Conv1D
class DeepEncoder(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(DeepEncoder, self).__init__()
self.sequential = nn.Sequential(
Conv1D(in_channels, out_channels, kernel_size, stride=stride),
Conv1D(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
padding=1,
),
nn.PReLU(),
Conv1D(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=2,
padding=2,
),
nn.PReLU(),
Conv1D(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=4,
padding=4,
),
nn.PReLU(),
Conv1D(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=8,
padding=8,
),
nn.PReLU(),
)
def forward(self, x):
"""
:param x: [B, T]
:return: out: [B, N, T]
"""
x = self.sequential(x)
return x
class MultiEncoder(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, kernel_size,
stride):
super(MultiEncoder, self).__init__()
self.L1 = kernel_size
self.L2 = 80
self.L3 = 160
self.encoder_1d_short = Conv1D(in_channels,
middle_channels,
self.L1,
stride=stride,
padding=0)
self.encoder_1d_middle = Conv1D(in_channels,
middle_channels,
self.L2,
stride=stride,
padding=0)
self.encoder_1d_long = Conv1D(in_channels,
middle_channels,
self.L3,
stride=stride,
padding=0)
# keep T not change
# T = int((xlen - L) / (L // 2)) + 1
# before repeat blocks, always cLN
self.ln = select_norm(
"cLN",
3 * middle_channels) # ChannelWiseLayerNorm(3 * middle_channels)
# n x N x T => n x B x T
self.proj = Conv1D(3 * middle_channels, out_channels, 1)
def forward(self, x):
"""
:param x: [B, T]
:return: out: [B, N, T]
"""
w1 = F.relu(self.encoder_1d_short(x))
T = w1.shape[-1]
xlen1 = x.shape[-1]
xlen2 = (T - 1) * (self.L1 // 2) + self.L2
xlen3 = (T - 1) * (self.L1 // 2) + self.L3
w2 = F.relu(
self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant",
0)))
w3 = F.relu(
self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
# n x 3N x T
x = self.ln(th.cat([w1, w2, w3], 1))
# n x B x T
x = self.proj(x)
return x, w1, w2, w3
================================================
FILE: wesep/modules/tasnet/separation.py
================================================
import torch.nn as nn
from wesep.modules.common import select_norm
from wesep.modules.common.speaker import SpeakerFuseLayer
from wesep.modules.tasnet.convs import Conv1DBlock, Conv1DBlock4Fuse
class Separation(nn.Module):
def __init__(
self,
R,
X,
B,
H,
P,
norm="gLN",
causal=False,
skip_con=True,
start_dilation=0,
):
"""
Args
:param R: Number of repeats
:param X: Number of convolutional blocks in each repeat
:param B: Number of channels in bottleneck and the residual paths
:param H: Number of channels in convolutional blocks
:param P: Kernel size in convolutional blocks
:param norm: The type of normalization(gln, cln, bn)
:param causal: Two choice(causal or noncausal)
:param skip_con: Whether to use skip connection
"""
super(Separation, self).__init__()
self.separation = nn.ModuleList([])
for _ in range(R):
for x in range(start_dilation, X):
self.separation.append(
Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con))
self.skip_con = skip_con
def forward(self, x):
"""
x: [B, N, L]
out: [B, N, L]
"""
if self.skip_con:
skip_connection = 0
for i in range(len(self.separation)):
skip, out = self.separation[i](x)
skip_connection = skip_connection + skip
x = out
return skip_connection
else:
for i in range(len(self.separation)):
out = self.separation[i](x)
x = out
return x
class FuseSeparation(nn.Module):
def __init__(
self,
R,
X,
B,
H,
P,
norm="gLN",
causal=False,
skip_con=False,
C_embedding=256,
spk_fuse_type="concatConv",
multi_fuse=True,
):
"""
:param R: Number of repeats
:param X: Number of convolutional blocks in each repeat
:param B: Number of channels in bottleneck and the residual paths
:param H: Number of channels in convolutional blocks
:param P: Kernel size in convolutional blocks
:param norm: The type of normalization(gln, cln, bn)
:param causal: Two choice(causal or noncausal)
:param skip_con: Whether to use skip connection
"""
super(FuseSeparation, self).__init__()
self.multi_fuse = multi_fuse
self.spk_fuse_type = spk_fuse_type
self.separation = nn.ModuleList([])
if self.multi_fuse:
for _ in range(R):
if spk_fuse_type == "concatConv":
self.separation.append(
Conv1DBlock4Fuse(
spk_embed_dim=C_embedding,
in_channels=B,
conv_channels=H,
kernel_size=P,
norm=norm,
causal=causal,
dilation=1,
))
self.separation.append(
Separation(
1,
X,
B,
H,
P,
norm=norm,
causal=causal,
skip_con=skip_con,
start_dilation=1,
))
else:
self.separation.append(
SpeakerFuseLayer(
embed_dim=C_embedding,
feat_dim=B,
fuse_type=spk_fuse_type,
))
self.separation.append(nn.PReLU())
self.separation.append(select_norm(norm, B))
self.separation.append(
Separation(
1,
X,
B,
H,
P,
norm=norm,
causal=causal,
skip_con=skip_con,
))
else:
if spk_fuse_type == "concatConv":
self.separation.append(
Conv1DBlock4Fuse(
spk_embed_dim=C_embedding,
in_channels=B,
conv_channels=H,
kernel_size=P,
norm=norm,
causal=causal,
dilation=1,
))
else:
self.separation.append(
SpeakerFuseLayer(
embed_dim=C_embedding,
feat_dim=B,
fuse_type=spk_fuse_type,
))
self.separation.append(nn.PReLU())
self.separation.append(select_norm(norm, B))
self.separation = Separation(R,
X,
B,
H,
P,
norm=norm,
causal=causal,
skip_con=skip_con)
def forward(self, x, spk_embedding):
"""
x: [B, N, L]
out: [B, N, L]
"""
if self.multi_fuse:
if self.spk_fuse_type == "concatConv":
round_num = 2
else:
round_num = 4
for i in range(len(self.separation)):
if i % round_num == 0:
x = self.separation[i](x, spk_embedding)
else:
x = self.separation[i](x)
else:
x = self.separation[0](x, spk_embedding)
for i in range(1, len(self.separation)):
x = self.separation[i](x)
return x
================================================
FILE: wesep/modules/tasnet/separator.py
================================================
import torch.nn as nn
from wesep.modules.tasnet.convs import Conv1DBlock
class Separation(nn.Module):
"""
R Number of repeats
X Number of convolutional blocks in each repeat
B Number of channels in bottleneck and the residual paths
H Number of channels in convolutional blocks
P Kernel size in convolutional blocks
norm The type of normalization(gln, cl, bn)
causal Two choice(causal or noncausal)
skip_con Whether to use skip connection
"""
def __init__(self, R, X, B, H, P, norm="gln", causal=False, skip_con=True):
super(Separation, self).__init__()
self.separation = nn.ModuleList([])
for _ in range(R):
for x in range(X):
self.separation.append(
Conv1DBlock(B, H, P, 2**x, norm, causal, skip_con))
self.skip_con = skip_con
def forward(self, x):
"""
x: [B, N, L]
out: [B, N, L]
"""
if self.skip_con:
skip_connection = 0
for i in range(len(self.separation)):
skip, out = self.separation[i](x)
skip_connection = skip_connection + skip
x = out
return skip_connection
else:
for i in range(len(self.separation)):
out = self.separation[i](x)
x = out
return x
================================================
FILE: wesep/modules/tasnet/speaker.py
================================================
import torch.nn as nn
from wesep.modules.common.norm import ChannelWiseLayerNorm
from wesep.modules.tasnet.convs import Conv1D
class ResBlock(nn.Module):
"""
ref to
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
and
https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
"""
def __init__(self, in_dims, out_dims):
super().__init__()
self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
self.batch_norm1 = nn.BatchNorm1d(out_dims)
self.batch_norm2 = nn.BatchNorm1d(out_dims)
self.prelu1 = nn.PReLU()
self.prelu2 = nn.PReLU()
self.mp = nn.MaxPool1d(3)
if in_dims != out_dims:
self.downsample = True
self.conv_downsample = nn.Conv1d(in_dims,
out_dims,
kernel_size=1,
bias=False)
else:
self.downsample = False
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batch_norm1(x)
x = self.prelu1(x)
x = self.conv2(x)
x = self.batch_norm2(x)
if self.downsample:
residual = self.conv_downsample(residual)
x = x + residual
x = self.prelu2(x)
return self.mp(x)
class ResNet4SpExplus(nn.Module):
def __init__(self, in_channel=256, C_embedding=256):
super().__init__()
self.aux_enc3 = nn.Sequential(
ChannelWiseLayerNorm(3 * in_channel),
Conv1D(3 * 256, 256, 1),
ResBlock(256, 256),
ResBlock(256, 512),
ResBlock(512, 512),
Conv1D(512, C_embedding, 1),
)
def forward(self, x):
aux = self.aux_enc3(x)
aux = aux.mean(dim=-1)
return aux
================================================
FILE: wesep/modules/tfgridnet/__init__.py
================================================
================================================
FILE: wesep/modules/tfgridnet/gridnet_block.py
================================================
# The implementation is based on:
# https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from wesep.utils.utils import get_layer
class GridNetBlock(nn.Module):
def __getitem__(self, key):
return getattr(self, key)
def __init__(
self,
emb_dim,
emb_ks,
emb_hs,
n_freqs,
hidden_channels,
n_head=4,
approx_qk_dim=512,
activation="prelu",
eps=1e-5,
):
super().__init__()
assert activation == "prelu"
in_channels = emb_dim * emb_ks
self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
self.intra_rnn = nn.LSTM(
in_channels,
hidden_channels,
1,
batch_first=True,
bidirectional=True,
)
if emb_ks == emb_hs:
self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2,
emb_dim,
emb_ks,
stride=emb_hs)
self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
self.inter_rnn = nn.LSTM(
in_channels,
hidden_channels,
1,
batch_first=True,
bidirectional=True,
)
if emb_ks == emb_hs:
self.inter_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.inter_linear = nn.ConvTranspose1d(hidden_channels * 2,
emb_dim,
emb_ks,
stride=emb_hs)
E = math.ceil(approx_qk_dim * 1.0 /
n_freqs) # approx_qk_dim is only approximate
assert emb_dim % n_head == 0
self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_Q",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_K",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module("attn_conv_V",
nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1))
self.add_module(
"attn_norm_V",
AllHeadPReLULayerNormalization4DCF(
(n_head, emb_dim // n_head, n_freqs), eps=eps),
)
self.add_module(
"attn_concat_proj",
nn.Sequential(
nn.Conv2d(emb_dim, emb_dim, 1),
get_layer(activation)(),
LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
),
)
self.emb_dim = emb_dim
self.emb_ks = emb_ks
self.emb_hs = emb_hs
self.n_head = n_head
def forward(self, x):
"""GridNetBlock Forward.
Args:
x: [B, C, T, Q]
out: [B, C, T, Q]
"""
B, C, old_T, old_Q = x.shape
olp = self.emb_ks - self.emb_hs
T = math.ceil((old_T + 2 * olp - self.emb_ks) /
self.emb_hs) * self.emb_hs + self.emb_ks
Q = math.ceil((old_Q + 2 * olp - self.emb_ks) /
self.emb_hs) * self.emb_hs + self.emb_ks
x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C]
x = F.pad(
x,
(0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C]
# intra RNN
input_ = x
intra_rnn = self.intra_norm(input_) # [B, T, Q, C]
if self.emb_ks == self.emb_hs:
intra_rnn = intra_rnn.view([B * T, -1,
self.emb_ks * C]) # [BT, Q//I, I*C]
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, Q//I, H]
intra_rnn = self.intra_linear(intra_rnn) # [BT, Q//I, I*C]
intra_rnn = intra_rnn.view([B, T, Q, C])
else:
intra_rnn = intra_rnn.view([B * T, Q, C]) # [BT, Q, C]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, C, Q]
intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1),
stride=(self.emb_hs, 1)) # [BT, C*I, -1]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I]
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
intra_rnn = intra_rnn.view([B, T, C, Q])
intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C]
intra_rnn = intra_rnn + input_ # [B, T, Q, C]
intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C]
# inter RNN
input_ = intra_rnn
inter_rnn = self.inter_norm(input_) # [B, Q, T, C]
if self.emb_ks == self.emb_hs:
inter_rnn = inter_rnn.view([B * Q, -1,
self.emb_ks * C]) # [BQ, T//I, I*C]
inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, T//I, H]
inter_rnn = self.inter_linear(inter_rnn) # [BQ, T//I, I*C]
inter_rnn = inter_rnn.view([B, Q, T, C])
else:
inter_rnn = inter_rnn.view(B * Q, T, C) # [BQ, T, C]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, C, T]
inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1),
stride=(self.emb_hs, 1)) # [BQ, C*I, -1]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I]
inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, -1, H]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1]
inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T]
inter_rnn = inter_rnn.view([B, Q, C, T])
inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C]
inter_rnn = inter_rnn + input_ # [B, Q, T, C]
inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q]
inter_rnn = inter_rnn[..., olp:olp + old_T, olp:olp + old_Q]
batch = inter_rnn
Q = self["attn_norm_Q"](
self["attn_conv_Q"](batch)) # [B, n_head, C, T, Q]
K = self["attn_norm_K"](
self["attn_conv_K"](batch)) # [B, n_head, C, T, Q]
V = self["attn_norm_V"](
self["attn_conv_V"](batch)) # [B, n_head, C, T, Q]
Q = Q.view(-1, *Q.shape[2:]) # [B*n_head, C, T, Q]
K = K.view(-1, *K.shape[2:]) # [B*n_head, C, T, Q]
V = V.view(-1, *V.shape[2:]) # [B*n_head, C, T, Q]
Q = Q.transpose(1, 2)
Q = Q.flatten(start_dim=2) # [B', T, C*Q]
K = K.transpose(2, 3)
K = K.contiguous().view([B * self.n_head, -1, old_T]) # [B', C*Q, T]
V = V.transpose(1, 2) # [B', T, C, Q]
old_shape = V.shape
V = V.flatten(start_dim=2) # [B', T, C*Q]
emb_dim = Q.shape[-1]
attn_mat = torch.matmul(Q, K) / (emb_dim**0.5) # [B', T, T]
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
V = torch.matmul(attn_mat, V) # [B', T, C*Q]
V = V.reshape(old_shape) # [B', T, C, Q]
V = V.transpose(1, 2) # [B', C, T, Q]
emb_dim = V.shape[1]
batch = V.contiguous().view([B, self.n_head * emb_dim, old_T,
old_Q]) # [B, C, T, Q])
batch = self["attn_concat_proj"](batch) # [B, C, T, Q])
out = batch + inter_rnn
return out
class LayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 2
param_size = [1, input_dimension[0], 1, input_dimension[1]]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.eps = eps
def forward(self, x):
if x.ndim == 4:
stat_dim = (1, 3)
else:
raise ValueError(
"Expect x to have 4 dimensions, but got {}".format(x.ndim))
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) +
self.eps) # [B,1,T,F]
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat
class AllHeadPReLULayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 3
H, E, n_freqs = input_dimension
param_size = [1, H, E, 1, n_freqs]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.act = nn.PReLU(num_parameters=H, init=0.25)
self.eps = eps
self.H = H
self.E = E
self.n_freqs = n_freqs
def forward(self, x):
assert x.ndim == 4
B, _, T, _ = x.shape
x = x.view([B, self.H, self.E, T, self.n_freqs])
x = self.act(x) # [B,H,E,T,F]
stat_dim = (2, 4)
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) +
self.eps) # [B,H,1,T,1]
x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F]
return x
================================================
FILE: wesep/utils/abs_loss.py
================================================
from abc import ABC, abstractmethod
import torch
EPS = torch.finfo(torch.get_default_dtype()).eps
class AbsEnhLoss(torch.nn.Module, ABC):
"""Base class for all Enhancement loss modules."""
# the name will be the key that appears in the reporter
@property
def name(self) -> str:
return NotImplementedError
# This property specifies whether the criterion will only
# be evaluated during the inference stage
@property
def only_for_test(self) -> bool:
return False
@abstractmethod
def forward(
self,
ref,
inf,
) -> torch.Tensor:
# the return tensor should be shape of (batch)
raise NotImplementedError
================================================
FILE: wesep/utils/checkpoint.py
================================================
from typing import List, Optional
import torch
from wesep.utils.schedulers import BaseClass
def load_pretrained_model(model: torch.nn.Module,
path: str,
type: str = "generator"):
assert type in ["generator", "discriminator"]
states = torch.load(
path,
map_location="cpu",
)
if type == "generator":
state = states["models"][0]
else:
assert len(states["models"]) == 2
state = states["models"][1]
if isinstance(model, torch.nn.DataParallel):
model.module.load_state_dict(state)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.load_state_dict(state)
else:
model.load_state_dict(state)
def load_checkpoint(
models: List[torch.nn.Module],
optimizers: List[torch.optim.Optimizer],
schedulers: List[BaseClass],
scaler: Optional[torch.cuda.amp.GradScaler],
path: str,
only_model: bool = False,
mode: str = "all",
):
assert mode in ["all", "generator", "discriminator"]
states = torch.load(
path,
map_location="cpu",
)
if mode == "generator":
model_state, optimizer_state, scheduler_state = (
[states["models"][0]],
[states["optimizers"][0]],
[states["schedulers"][0]],
)
elif mode == "discriminator":
model_state, optimizer_state, scheduler_state = (
[states["models"][1]],
[states["optimizers"][1]],
[states["schedulers"][1]],
)
else:
model_state, optimizer_state, scheduler_state = (
states["models"],
states["optimizers"],
states["schedulers"],
)
for model, state in zip(models, model_state):
if isinstance(model, torch.nn.DataParallel):
model.module.load_state_dict(state, strict=False)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.load_state_dict(state, strict=False)
else:
model.load_state_dict(state, strict=False)
if not only_model:
for optimizer, state in zip(optimizers, optimizer_state):
optimizer.load_state_dict(state)
for scheduler, state in zip(schedulers, scheduler_state):
if scheduler is not None:
scheduler.load_state_dict(state)
if scaler is not None:
if states["scaler"] is not None:
scaler.load_state_dict(states["scaler"])
def save_checkpoint(
models: List[torch.nn.Module],
optimizers: List[torch.optim.Optimizer],
schedulers: List[BaseClass],
scaler: Optional[torch.cuda.amp.GradScaler],
path: str,
):
if isinstance(models[0], torch.nn.DataParallel):
state_dict = [model.module.state_dict() for model in models]
elif isinstance(models[0], torch.nn.parallel.DistributedDataParallel):
state_dict = [model.module.state_dict() for model in models]
else:
state_dict = [model.state_dict() for model in models]
torch.save(
{
"models":
state_dict,
"optimizers": [o.state_dict() for o in optimizers],
"schedulers":
[s.state_dict() if s is not None else None for s in schedulers],
"scaler":
scaler.state_dict() if scaler is not None else None,
},
path,
)
================================================
FILE: wesep/utils/datadir_writer.py
================================================
import warnings
from pathlib import Path
from typing import Union
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/fileio/datadir_writer.py
class DatadirWriter:
"""Writer class to create kaldi like data directory.
Examples:
>>> with DatadirWriter("output") as writer:
... # output/sub.txt is created here
... subwriter = writer["sub.txt"]
... # Write "uttidA some/where/a.wav"
... subwriter["uttidA"] = "some/where/a.wav"
... subwriter["uttidB"] = "some/where/b.wav"
"""
def __init__(self, p: Union[Path, str]):
self.path = Path(p)
self.chilidren = {}
self.fd = None
self.has_children = False
self.keys = set()
def __enter__(self):
return self
def __getitem__(self, key: str) -> "DatadirWriter":
if self.fd is not None:
raise RuntimeError("This writer points out a file")
if key not in self.chilidren:
w = DatadirWriter((self.path / key))
self.chilidren[key] = w
self.has_children = True
retval = self.chilidren[key]
return retval
def __setitem__(self, key: str, value: str):
if self.has_children:
raise RuntimeError("This writer points out a directory")
if key in self.keys:
warnings.warn(f"Duplicated: {key}", stacklevel=1)
if self.fd is None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.fd = self.path.open("w", encoding="utf-8")
self.keys.add(key)
self.fd.write(f"{key} {value}\n")
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
if self.has_children:
prev_child = None
for child in self.chilidren.values():
child.close()
if prev_child is not None and prev_child.keys != child.keys:
warnings.warn(
f"Ids are mismatching between "
f"{prev_child.path} and {child.path}",
stacklevel=1)
prev_child = child
elif self.fd is not None:
self.fd.close()
================================================
FILE: wesep/utils/dnsmos.py
================================================
import json
import math
import librosa
import numpy as np
import requests
import torch
import torchaudio
SAMPLING_RATE = 16000
INPUT_LENGTH = 9.01
# URL for the web service
SCORING_URI_DNSMOS = "https://dnsmos.azurewebsites.net/score"
SCORING_URI_DNSMOS_P835 = (
"https://dnsmos.azurewebsites.net/v1/dnsmosp835/score")
def poly1d(coefficients, use_numpy=False):
if use_numpy:
return np.poly1d(coefficients)
coefficients = tuple(reversed(coefficients))
def func(p):
return sum(coef * p**i for i, coef in enumerate(coefficients))
return func
class DNSMOS_web:
# ported from
# https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos.py
def __init__(self, auth_key):
self.auth_key = auth_key
def __call__(self, aud, input_fs, fname="", method="p808"):
if input_fs != SAMPLING_RATE:
audio = librosa.resample(aud,
orig_sr=input_fs,
target_sr=SAMPLING_RATE)
else:
audio = aud
# Set the content type
headers = {"Content-Type": "application/json"}
# If authentication is enabled, set the authorization header
headers["Authorization"] = f"Basic {self.auth_key}"
fname = fname + ".wav" if fname else "audio.wav"
data = {"data": audio.tolist(), "filename": fname}
input_data = json.dumps(data)
# Make the request and display the response
if method == "p808":
u = SCORING_URI_DNSMOS
else:
u = SCORING_URI_DNSMOS_P835
resp = requests.post(u, data=input_data, headers=headers)
score_dict = resp.json()
return score_dict
class DNSMOS_local:
# ported from
# https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/dnsmos_local.py
def __init__(
self,
primary_model_path,
p808_model_path,
use_gpu=False,
convert_to_torch=False,
gpu_device=None,
):
self.convert_to_torch = convert_to_torch
self.use_gpu = use_gpu
self.gpu_device = gpu_device
if convert_to_torch:
try:
from onnx2torch import convert
except ModuleNotFoundError:
raise RuntimeError(
"Please install onnx2torch manually and retry!") from None
if primary_model_path is not None:
self.primary_model = convert(primary_model_path).eval()
self.p808_model = convert(p808_model_path).eval()
self.spectrogram = torchaudio.transforms.Spectrogram(
n_fft=321, hop_length=160, pad_mode="constant")
self.to_db = torchaudio.transforms.AmplitudeToDB("power",
top_db=80.0)
if use_gpu:
if gpu_device is not None:
torch.cuda.set_device(gpu_device)
if primary_model_path is not None:
self.primary_model = self.primary_model.cuda()
self.p808_model = self.p808_model.cuda()
self.spectrogram = self.spectrogram.cuda()
else:
try:
import onnxruntime as ort
except ModuleNotFoundError:
raise RuntimeError(
"Please install onnxruntime manually and retry!") from None
prvd = ("CUDAExecutionProvider"
if use_gpu else "CPUExecutionProvider")
if primary_model_path is not None:
self.onnx_sess = ort.InferenceSession(primary_model_path,
providers=[prvd])
self.p808_onnx_sess = ort.InferenceSession(p808_model_path,
providers=[prvd])
if self.gpu_device is not None:
self.onnx_sess.set_providers([prvd],
[{
"device_id": gpu_device
}])
self.p808_onnx_sess.set_providers(
[prvd], [{
"device_id": gpu_device
}])
def audio_melspec(
self,
audio,
n_mels=120,
frame_size=320,
hop_length=160,
sr=16000,
to_db=True,
):
if self.convert_to_torch:
specgram = self.spectrogram(audio)
fb = torch.as_tensor(
librosa.filters.mel(sr=sr, n_fft=frame_size + 1,
n_mels=n_mels).T,
dtype=audio.dtype,
device=audio.device,
)
mel_spec = torch.matmul(specgram.transpose(-1, -2),
fb).transpose(-1, -2)
if to_db:
self.to_db.db_multiplier = math.log10(
max(self.to_db.amin, torch.max(mel_spec)))
mel_spec = (self.to_db(mel_spec) + 40) / 40
else:
mel_spec = librosa.feature.melspectrogram(
y=audio,
sr=sr,
n_fft=frame_size + 1,
hop_length=hop_length,
n_mels=n_mels,
)
if to_db:
mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) +
40) / 40
return mel_spec.T
def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
flag = not self.convert_to_torch
if is_personalized_MOS:
p_ovr = poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046],
flag)
p_sig = poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726],
flag)
p_bak = poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132],
flag)
else:
p_ovr = poly1d([-0.06766283, 1.11546468, 0.04602535], flag)
p_sig = poly1d([-0.08397278, 1.22083953, 0.0052439], flag)
p_bak = poly1d([-0.13166888, 1.60915514, -0.39604546], flag)
sig_poly = p_sig(sig)
bak_poly = p_bak(bak)
ovr_poly = p_ovr(ovr)
return sig_poly, bak_poly, ovr_poly
def __call__(self, aud, input_fs, is_personalized_MOS=False):
if self.convert_to_torch:
if self.use_gpu:
if self.gpu_device is not None:
device = f"cuda:{self.gpu_device}"
else:
device = "cuda"
else:
device = "cpu"
if isinstance(aud, torch.Tensor):
aud = aud.to(device=device)
else:
aud = torch.as_tensor(aud, dtype=torch.float32, device=device)
else:
aud = (aud.cpu().detach().numpy()
if isinstance(aud, torch.Tensor) else aud)
if input_fs != SAMPLING_RATE:
if self.convert_to_torch:
audio = torch.as_tensor(
librosa.resample(
aud.detach().cpu().numpy(),
orig_sr=input_fs,
target_sr=SAMPLING_RATE,
),
dtype=aud.dtype,
device=aud.device,
)
else:
audio = librosa.resample(aud,
orig_sr=input_fs,
target_sr=SAMPLING_RATE)
else:
audio = aud
len_samples = int(INPUT_LENGTH * SAMPLING_RATE)
while len(audio) < len_samples:
if self.convert_to_torch:
audio = torch.cat((audio, audio))
else:
audio = np.append(audio, audio)
num_hops = int(np.floor(len(audio) / SAMPLING_RATE) - INPUT_LENGTH) + 1
hop_len_samples = SAMPLING_RATE
predicted_mos_sig_seg_raw = []
predicted_mos_bak_seg_raw = []
predicted_mos_ovr_seg_raw = []
predicted_mos_sig_seg = []
predicted_mos_bak_seg = []
predicted_mos_ovr_seg = []
predicted_p808_mos = []
for idx in range(num_hops):
audio_seg = audio[int(idx *
hop_len_samples):int((idx + INPUT_LENGTH) *
hop_len_samples)]
if len(audio_seg) < len_samples:
continue
if self.convert_to_torch:
input_features = audio_seg.float()[None, :]
p808_input_features = self.audio_melspec(
audio=audio_seg[:-160]).float()[None, :, :]
p808_mos = self.p808_model(p808_input_features)
mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.primary_model(
input_features)[0]
else:
input_features = np.array(audio_seg).astype("float32")[
np.newaxis, :]
p808_input_features = np.array(
self.audio_melspec(audio=audio_seg[:-160])).astype(
"float32")[np.newaxis, :, :]
p808_mos = self.p808_onnx_sess.run(
None, {"input_1": p808_input_features})[0][0][0]
mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(
None, {"input_1": input_features})[0][0]
mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS)
predicted_mos_sig_seg_raw.append(mos_sig_raw)
predicted_mos_bak_seg_raw.append(mos_bak_raw)
predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
predicted_mos_sig_seg.append(mos_sig)
predicted_mos_bak_seg.append(mos_bak)
predicted_mos_ovr_seg.append(mos_ovr)
predicted_p808_mos.append(p808_mos)
to_array = torch.stack if self.convert_to_torch else np.array
return {
"OVRL_raw": to_array(predicted_mos_ovr_seg_raw).mean(),
"SIG_raw": to_array(predicted_mos_sig_seg_raw).mean(),
"BAK_raw": to_array(predicted_mos_bak_seg_raw).mean(),
"OVRL": to_array(predicted_mos_ovr_seg).mean(),
"SIG": to_array(predicted_mos_sig_seg).mean(),
"BAK": to_array(predicted_mos_bak_seg).mean(),
"P808_MOS": to_array(predicted_p808_mos).mean(),
}
================================================
FILE: wesep/utils/executor.py
================================================
# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)
# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import tableprint as tp
# if your python version < 3.7 use the below one
import torch
from wesep.utils.funcs import clip_gradients, compute_fbank, apply_cmvn
import random
class Executor:
def __init__(self):
self.step = 0
def train(
self,
dataloader,
models,
epoch_iter,
optimizers,
criterion,
schedulers,
scaler,
epoch,
enable_amp,
logger,
clip_grad=5.0,
log_batch_interval=100,
device=torch.device("cuda"),
se_loss_weight=1.0,
multi_task=False,
SSA_enroll_prob=0,
fbank_args=None,
sample_rate=16000,
speaker_feat=True
):
"""Train one epoch"""
model = models[0]
optimizer = optimizers[0]
scheduler = schedulers[0]
model.train()
log_interval = log_batch_interval
accum_grad = 1
losses = []
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_context = model.join
else:
model_context = nullcontext
with model_context():
for i, batch in enumerate(dataloader):
features = batch["wav_mix"]
targets = batch["wav_targets"]
# embeddings when not joint training, enrollment wavforms
# when joint training
enroll = batch["spk_embeds"]
# spk_lable is an empty list when not joint training
# and multi-task
spk_label = batch["spk_label"]
cur_iter = (epoch - 1) * epoch_iter + i
scheduler.step(cur_iter)
features = features.float().to(device) # (B,T,F)
targets = targets.float().to(device)
enroll = enroll.float().to(device)
spk_label = spk_label.to(device)
with torch.cuda.amp.autocast(enabled=enable_amp):
if SSA_enroll_prob > 0:
if SSA_enroll_prob > random.random():
with torch.no_grad():
outputs = model(features, enroll)
est_speech = outputs[0]
self_fbank = est_speech
if fbank_args is not None and speaker_feat:
self_fbank = compute_fbank(
est_speech, **fbank_args,
sample_rate=sample_rate)
self_fbank = apply_cmvn(self_fbank)
outputs = model(features, self_fbank)
else:
outputs = model(features, enroll)
else:
outputs = model(features, enroll)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
loss = 0
for ii in range(len(criterion)):
# se_loss_weight: ([position in outputs[0], [1]],
# [weights:[1.0], [0.5]])
for ji in range(len(se_loss_weight[0][ii])):
if (multi_task and criterion[ii].__class__.__name__
== "CrossEntropyLoss"):
loss += se_loss_weight[1][ii][ji] * (
criterion[ii](
outputs[se_loss_weight[0][ii][ji]],
spk_label,
).mean() / accum_grad)
continue
loss += se_loss_weight[1][ii][ji] * (criterion[ii](
outputs[se_loss_weight[0][ii][ji]],
targets).mean() / accum_grad)
losses.append(loss.item())
total_loss_avg = sum(losses) / len(losses)
# updata the model
optimizer.zero_grad()
# scaler does nothing here if enable_amp=False
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_gradients(model, clip_grad)
scaler.step(optimizer)
scaler.update()
if (i + 1) % log_interval == 0:
logger.info(
tp.row(
(
"TRAIN",
epoch,
i + 1,
total_loss_avg * accum_grad,
optimizer.param_groups[0]["lr"],
),
width=10,
style="grid",
))
if (i + 1) == epoch_iter:
break
total_loss_avg = sum(losses) / len(losses)
return total_loss_avg, 0
def cv(
self,
dataloader,
models,
val_iter,
criterion,
epoch,
enable_amp,
logger,
log_batch_interval=100,
device=torch.device("cuda"),
):
"""Cross validation on"""
model = models[0]
model.eval()
log_interval = log_batch_interval
losses = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
features = batch["wav_mix"]
targets = batch["wav_targets"]
enroll = batch["spk_embeds"]
features = features.float().to(device) # (B,T,F)
targets = targets.float().to(device)
enroll = enroll.float().to(device)
with torch.cuda.amp.autocast(enabled=enable_amp):
outputs = model(features, enroll)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# By default, the first loss is used as the indicator
# of the validation set.
loss = criterion[0](outputs[0], targets).mean()
losses.append(loss.item())
total_loss_avg = sum(losses) / len(losses)
if (i + 1) % log_interval == 0:
logger.info(
tp.row(
("VAL", epoch, i + 1, total_loss_avg, "-"),
width=10,
style="grid",
))
if (i + 1) == val_iter:
break
return total_loss_avg, 0
================================================
FILE: wesep/utils/executor_gan.py
================================================
# Copyright (c) 2021 Hongji Wang (jijijiang77@gmail.com)
# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import tableprint as tp
# if your python version < 3.7 use the below one
import torch
import torch.nn.functional as F
from wesep.utils.funcs import clip_gradients
from wesep.utils.score import batch_evaluation, cal_PESQ_norm
class ExecutorGAN:
def __init__(self):
self.step = 0
def train(
self,
dataloader,
models,
epoch_iter,
optimizers,
criterion,
schedulers,
scaler,
epoch,
enable_amp,
logger,
clip_grad=5.0,
log_batch_interval=100,
device=torch.device("cuda"),
se_loss_weight=0.95,
gan_loss_weight=0.05,
multi_task=False,
):
"""Train one epoch"""
assert (len(models) == len(optimizers) == len(schedulers) ==
2), "Currently only support one discriminator"
model, discriminator = models
optimizer, optimizer_dis = optimizers
scheduler, scheduler_dis = schedulers
model.train()
discriminator.train()
log_interval = log_batch_interval
accum_grad = 1
losses = []
se_losses = []
dis_losses = []
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_context = model.join
else:
model_context = nullcontext
with model_context():
for i, batch in enumerate(dataloader):
features = batch["wav_mix"]
targets = batch["wav_targets"]
# embeddings when when not joint training, enrollment
# wavforms when joint training
enroll = batch["spk_embeds"]
# spk_lable is an empty list when not joint training
# and multi-task
spk_label = batch["spk_label"]
one_labels = torch.ones(features.size(0))
cur_iter = (epoch - 1) * epoch_iter + i
scheduler.step(cur_iter)
scheduler_dis.step(cur_iter)
features = features.float().to(device)
targets = targets.float().to(device)
enroll = enroll.float().to(device)
spk_label = spk_label.to(device)
one_labels = one_labels.float().to(device)
# calculate discriminator loss
with torch.cuda.amp.autocast(enabled=enable_amp):
outputs = model(features, enroll)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# outputs is a list of tensors, each tensor has shape
# (Batch, samples)
if multi_task:
# remove the predicted spk_label from the outputs list
enhanced_wavs = torch.stack(outputs[:-1], dim=0)
else:
# enhanced_wavs: [N, Batch, samples], N is the number
# of output of the model
enhanced_wavs = torch.stack(outputs, dim=0)
d_loss = self._calculate_discriminator_loss(
discriminator,
targets,
enhanced_wavs.detach(),
features.detach(),
)
dis_losses.append(d_loss.item())
total_dis_loss_avg = sum(dis_losses) / len(dis_losses)
# updata discriminator
optimizer_dis.zero_grad()
# scaler does nothing here if enable_amp=False
scaler.scale(d_loss).backward()
scaler.unscale_(optimizer_dis)
clip_gradients(discriminator, clip_grad)
scaler.step(optimizer_dis)
scaler.update()
# calculate generator loss
with torch.cuda.amp.autocast(enabled=enable_amp):
se_loss = 0
for ii in range(len(criterion)):
# se_loss_weight[0]: 2-D array,loss_posi;
# se_loss_weight[1]: 2-D array,loss_weight.
for ji in range(len(se_loss_weight[0][ii])):
if multi_task and ii == (len(criterion) - 1):
se_loss += se_loss_weight[1][ii][ji] * (
criterion[ii](
outputs[se_loss_weight[0][ii][ji]],
spk_label,
).mean() / accum_grad)
continue
se_loss += se_loss_weight[1][ii][ji] * (
criterion[ii]
(outputs[se_loss_weight[0][ii][ji]],
targets).mean() / accum_grad)
gan_loss = 0
len_output = (len(outputs) -
1 if multi_task else len(outputs))
for j in range(len_output):
enhanced_fake_metric = discriminator(
targets, outputs[j])
gan_loss += F.mse_loss(
enhanced_fake_metric.flatten(),
one_labels,
)
g_loss = se_loss + gan_loss_weight * gan_loss
losses.append(g_loss.item())
se_losses.append(se_loss.item())
total_loss_avg = sum(losses) / len(losses)
total_se_loss_avg = sum(se_losses) / len(se_losses)
# updata the generator
optimizer.zero_grad()
# scaler does nothing here if enable_amp=False
scaler.scale(g_loss).backward()
scaler.unscale_(optimizer)
clip_gradients(model, clip_grad)
scaler.step(optimizer)
scaler.update()
if (i + 1) % log_interval == 0:
logger.info(
tp.row(
(
"TRAIN",
epoch,
i + 1,
total_se_loss_avg,
total_loss_avg * accum_grad,
total_dis_loss_avg * accum_grad,
optimizer.param_groups[0]["lr"],
),
width=10,
style="grid",
))
if (i + 1) == epoch_iter:
break
total_loss_avg = sum(losses) / len(losses)
total_dis_loss_avg = sum(dis_losses) / len(dis_losses)
return total_loss_avg, total_dis_loss_avg
def cv(
self,
dataloader,
models,
val_iter,
criterion,
epoch,
enable_amp,
logger,
log_batch_interval=100,
device=torch.device("cuda"),
):
"""Cross validation on"""
assert len(models) == 2, "Currently only support one discriminator"
model, discriminator = models
model.eval()
discriminator.eval()
log_interval = log_batch_interval
losses = []
se_losses = []
dis_losses = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
features = batch["wav_mix"]
targets = batch["wav_targets"]
enroll = batch["spk_embeds"]
one_labels = torch.ones(features.size(0))
features = features.float().to(device) # (B,T,F)
targets = targets.float().to(device)
enroll = enroll.float().to(device)
one_labels = one_labels.float().to(device)
with torch.cuda.amp.autocast(enabled=enable_amp):
outputs = model(features, enroll)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# calculate discriminator loss
d_loss = self._calculate_discriminator_loss(
discriminator,
targets,
outputs[0].unsqueeze(0),
features,
)
dis_losses.append(d_loss.item())
total_dis_loss_avg = sum(dis_losses) / len(dis_losses)
# calculate generator loss
with torch.cuda.amp.autocast(enabled=enable_amp):
se_loss = criterion[0](outputs[0], targets).mean()
enhanced_fake_metric = discriminator(targets, outputs[0])
gan_loss = F.mse_loss(
enhanced_fake_metric.flatten(),
one_labels,
)
g_loss = se_loss + gan_loss
losses.append(g_loss.item())
se_losses.append(se_loss.item())
total_loss_avg = sum(losses) / len(losses)
total_se_loss_avg = sum(se_losses) / len(se_losses)
if (i + 1) % log_interval == 0:
logger.info(
tp.row(
(
"VAL",
epoch,
i + 1,
total_se_loss_avg,
total_loss_avg,
total_dis_loss_avg,
"-",
),
width=10,
style="grid",
))
if (i + 1) == val_iter:
break
return total_loss_avg, total_dis_loss_avg
def mse_loss(self, output, target):
return F.mse_loss(output.flatten(), target)
def _calculate_discriminator_loss(
self,
discriminator,
clean_wavs,
enhanced_wavs,
noisy_wavs,
):
"""Calculate the discriminator loss
Args:
discriminator (torch.nn.Module): the discriminator model
clean_wavs (torch.Tensor): the clean waveforms, [Batch, samples]
enhanced_wavs (torch.Tensor): the predicted waveforms,
[N, Batch, samples]
noisy_wavs (torch.Tensor): the noisy waveforms, [Batch, samples]
Returns:
torch.Tensor: the discriminator loss
"""
def calculate_mse_loss(output, target):
if target is not None:
target = torch.FloatTensor(target).to(device)
return self.mse_loss(output, target)
return 0
device = clean_wavs.device
one_labels = torch.ones(clean_wavs.size(0)).float().to(device)
noisy_fake_metric = discriminator(clean_wavs, noisy_wavs)
clean_fake_metric = discriminator(clean_wavs, clean_wavs)
audio_ref = clean_wavs.detach().cpu().numpy()
audio_noisy = noisy_wavs.detach().cpu().numpy()
noisy_real_metric = batch_evaluation(cal_PESQ_norm,
audio_noisy,
audio_ref,
parallel=False)
loss_d_clean = self.mse_loss(clean_fake_metric, one_labels)
loss_d_noisy = calculate_mse_loss(noisy_fake_metric, noisy_real_metric)
d_loss = loss_d_clean + loss_d_noisy
# unbind enhanced_wavs to get a list of tensors,
# each tensor has shape (Batch, samples)
enhanced_wavs = torch.unbind(enhanced_wavs, dim=0)
for enhanced_wav in enhanced_wavs:
enhanced_fake_metric = discriminator(clean_wavs, enhanced_wav)
audio_est = enhanced_wav.detach().cpu().numpy()
enhanced_real_metric = batch_evaluation(cal_PESQ_norm,
audio_est,
audio_ref,
parallel=False)
loss_d_enhanced = calculate_mse_loss(enhanced_fake_metric,
enhanced_real_metric)
d_loss += loss_d_enhanced
return d_loss
================================================
FILE: wesep/utils/file_utils.py
================================================
import collections
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import kaldiio
import numpy as np
import soundfile
def read_lists(list_file):
"""list_file: only 1 column"""
lists = []
with open(list_file, "r", encoding="utf8") as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_vec_scp_file(scp_file):
"""
Read the pre-extracted kaldi-format speaker embeddings.
:param scp_file: path to xvector.scp
:return: dict {wav_name: embedding}
"""
samples_dict = {}
for key, vec in kaldiio.load_scp_sequential(scp_file):
if len(vec.shape) == 1:
vec = np.expand_dims(vec, 0)
samples_dict[key] = vec
return samples_dict
def norm_embeddings(embeddings, kaldi_style=True):
"""
Norm embeddings to unit length
:param embeddings: input embeddings
:param kaldi_style: if true, the norm should be embedding dimension
:return:
"""
scale = math.sqrt(embeddings.shape[-1]) if kaldi_style else 1.0
if len(embeddings.shape) == 2:
return (scale * embeddings.transpose() /
np.linalg.norm(embeddings, axis=1)).transpose()
elif len(embeddings.shape) == 1:
return scale * embeddings / np.linalg.norm(embeddings)
def read_label_file(label_file):
"""
Read the utt2spk file
:param label_file: the path to utt2spk
:return: dict {wav_name: spk_id}
"""
labels_dict = {}
with open(label_file, "r") as fin:
for line in fin:
tokens = line.strip().split()
labels_dict[tokens[0]] = tokens[1]
return labels_dict
def load_speaker_embeddings(scp_file, utt2spk_file):
"""
:param scp_file:
:param utt2spk_file:
:return: {spk1: [emb1, emb2 ...], spk2: [emb1, emb2...]}
"""
samples_dict = read_vec_scp_file(scp_file)
labels_dict = read_label_file(utt2spk_file)
spk2embeds = {}
for key, vec in samples_dict.items():
if len(vec.shape) == 1:
vec = np.expand_dims(vec, 0)
label = labels_dict[key]
if label in spk2embeds.keys():
spk2embeds[label].append(vec)
else:
spk2embeds[label] = [vec]
return spk2embeds
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py
def read_2columns_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 columns as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2columns_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data = {}
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v
return data
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/fileio/read_text.py
def read_multi_columns_text(
path: Union[Path, str],
return_unsplit: bool = False
) -> Tuple[Dict[str, List[str]], Optional[Dict[str, str]]]:
"""Read a text file having 2 or more columns as dict object.
Examples:
wav.scp:
key1 /some/path/a1.wav /some/path/a2.wav
key2 /some/path/b1.wav /some/path/b2.wav /some/path/b3.wav
key3 /some/path/c1.wav
...
>>> read_multi_columns_text('wav.scp')
{'key1': ['/some/path/a1.wav', '/some/path/a2.wav'],
'key2': ['/some/path/b1.wav', '/some/path/b2.wav',
'/some/path/b3.wav'],
'key3': ['/some/path/c1.wav']}
"""
data = {}
if return_unsplit:
unsplit_data = {}
else:
unsplit_data = None
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v.split() if v != "" else [""]
if return_unsplit:
unsplit_data[k] = v
return data, unsplit_data
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py
def soundfile_read(
wavs: Union[str, List[str]],
dtype=None,
always_2d: bool = False,
concat_axis: int = 1,
start: int = 0,
end: int = None,
return_subtype: bool = False,
) -> Tuple[np.array, int]:
if isinstance(wavs, str):
wavs = [wavs]
arrays = []
subtypes = []
prev_rate = None
prev_wav = None
for wav in wavs:
with soundfile.SoundFile(wav) as f:
f.seek(start)
if end is not None:
frames = end - start
else:
frames = -1
if dtype == "float16":
array = f.read(
frames,
dtype="float32",
always_2d=always_2d,
).astype(dtype)
else:
array = f.read(frames, dtype=dtype, always_2d=always_2d)
rate = f.samplerate
subtype = f.subtype
subtypes.append(subtype)
if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
# array: (Time, Channel)
array = array[:, None]
if prev_wav is not None:
if prev_rate != rate:
raise RuntimeError(
f"{prev_wav} and {wav} have mismatched sampling rate: "
f"{prev_rate} != {rate}")
dim1 = arrays[0].shape[1 - concat_axis]
dim2 = array.shape[1 - concat_axis]
if dim1 != dim2:
raise RuntimeError(
"Shapes must match with "
f"{1 - concat_axis} axis, but gut {dim1} and {dim2}")
prev_rate = rate
prev_wav = wav
arrays.append(array)
if len(arrays) == 1:
array = arrays[0]
else:
array = np.concatenate(arrays, axis=concat_axis)
if return_subtype:
return array, rate, subtypes
else:
return array, rate
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/fileio/sound_scp.py
class SoundScpReader(collections.abc.Mapping):
"""Reader class for 'wav.scp'.
Examples:
wav.scp is a text file that looks like the following:
key1 /some/path/a.wav
key2 /some/path/b.wav
key3 /some/path/c.wav
key4 /some/path/d.wav
...
>>> reader = SoundScpReader('wav.scp')
>>> rate, array = reader['key1']
If multi_columns=True is given and
multiple files are given in one line
with space delimiter, and the output array are concatenated
along channel direction
key1 /some/path/a.wav /some/path/a2.wav
key2 /some/path/b.wav /some/path/b2.wav
...
>>> reader = SoundScpReader('wav.scp', multi_columns=True)
>>> rate, array = reader['key1']
In the above case, a.wav and a2.wav are concatenated.
Note that even if multi_columns=True is given,
SoundScpReader still supports a normal wav.scp,
i.e., a wav file is given per line,
but this option is disable by default
because dict[str, list[str]] object is needed to be kept,
but it increases the required amount of memory.
"""
def __init__(
self,
fname,
dtype=None,
always_2d: bool = False,
multi_columns: bool = False,
concat_axis=1,
):
self.fname = fname
self.dtype = dtype
self.always_2d = always_2d
if multi_columns:
self.data, _ = read_multi_columns_text(fname)
else:
self.data = read_2columns_text(fname)
self.multi_columns = multi_columns
self.concat_axis = concat_axis
def __getitem__(self, key) -> Tuple[int, np.ndarray]:
wavs = self.data[key]
array, rate = soundfile_read(
wavs,
dtype=self.dtype,
always_2d=self.always_2d,
concat_axis=self.concat_axis,
)
# Returned as scipy.io.wavread's order
return rate, array
def get_path(self, key):
return self.data[key]
def __contains__(self, item):
return item
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data)
def keys(self):
return self.data.keys()
================================================
FILE: wesep/utils/funcs.py
================================================
# Created on 2018/12
# Author: Kaituo XU
import math
import torch
import torchaudio.compliance.kaldi as kaldi
def overlap_and_add(signal, frame_step):
"""Reconstructs a signal from a framed representation.
Adds potentially overlapping frames of a signal with shape
`[..., frames, frame_length]`, offsetting subsequent frames
by `frame_step`.
The resulting tensor has shape `[..., output_size]` where
output_size = (frames - 1) * frame_step + frame_length
Args:
signal: A [..., frames, frame_length] Tensor. All dimensions
may be unknown, and rank must be at least 2.
frame_step: An integer denoting overlap offsets. Must be
less than or equal to frame_length.
Returns:
A Tensor with shape [..., output_size] containing the overlap-added
frames of signal's inner-most two dimensions.
output_size = (frames - 1) * frame_step + frame_length
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/
contrib/signal/python/ops/reconstruction_ops.py
"""
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
subframe_length = math.gcd(frame_length,
frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame,
subframe_step)
frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)
result = signal.new_zeros(*outer_dimensions, output_subframes,
subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result
def remove_pad(inputs, inputs_lengths):
"""
Args:
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
inputs_lengths: torch.Tensor, [B]
Returns:
results: a list containing B items, each item is [C, T], T varies
"""
results = []
dim = inputs.dim()
if dim == 3:
C = inputs.size(1)
for input, length in zip(inputs, inputs_lengths):
if dim == 3: # [B, C, T]
results.append(input[:, :length].view(C, -1).cpu().numpy())
elif dim == 2: # [B, T]
results.append(input[:length].view(-1).cpu().numpy())
return results
def clip_gradients(model, clip):
norms = []
for _, p in model.named_parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
norms.append(param_norm.item())
clip_coef = clip / (param_norm + 1e-6)
if clip_coef < 1:
p.grad.data.mul_(clip_coef)
return norms
def compute_fbank(
data,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=1.0,
sample_rate=16000,
):
"""Extract fbank"""
fbank_list = []
for index_ in range(data.shape[0]):
waveform = data[index_, :].unsqueeze(0)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
sample_frequency=sample_rate,
window_type="hamming",
use_energy=False,
)
fbank_list.append(mat.unsqueeze(0))
np_fbank = torch.cat(fbank_list, 0)
return np_fbank
def apply_cmvn(data, norm_mean=True, norm_var=False):
"""Apply CMVN
Args:
data: Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1',
'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']
Returns:
Iterable['spk1', 'spk2', 'wav_mix', 'sample_rate', 'wav_spk1',
'wav_spk2', 'key', 'num_speaker', 'embed_spk1', 'embed_spk2']
"""
mat_list = []
for index_ in range(data.shape[0]):
mat = data[index_, :, :]
if norm_mean:
mat = mat - torch.mean(mat, dim=0)
if norm_var:
mat = mat / torch.sqrt(torch.var(mat, dim=0) + 1e-8)
mat = mat.unsqueeze(0)
mat_list.append(mat)
np_mat = torch.cat(mat_list, 0)
return np_mat
if __name__ == "__main__":
torch.manual_seed(123)
M, C, K, N = 2, 2, 3, 4
frame_step = 2
signal = torch.randint(5, (M, C, K, N))
result = overlap_and_add(signal, frame_step)
print(signal)
print(result)
================================================
FILE: wesep/utils/losses.py
================================================
import auraloss
import torch.nn as nn
import torchmetrics.audio as audio_metrics
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
"""Get a loss function with its name from the configuration file."""
valid_losses = {}
torch_losses = {
"L1": nn.L1Loss(),
"L2": nn.MSELoss(),
"CE": nn.CrossEntropyLoss(),
}
torchmetrics_losses = {
# Not tested
"PIT":
audio_metrics.PermutationInvariantTraining(
scale_invariant_signal_noise_ratio),
}
auraloss_losses = {
"STFT": auraloss.freq.STFTLoss(),
"MultiResolutionSTFT": auraloss.freq.MultiResolutionSTFTLoss(),
"SISDR": auraloss.time.SISDRLoss(),
"SISNR": auraloss.time.SISDRLoss(),
"SNR": auraloss.time.SNRLoss(),
}
valid_losses.update(torch_losses)
valid_losses.update(auraloss_losses)
valid_losses.update(torchmetrics_losses)
def parse_loss(loss):
loss_functions = []
if not isinstance(loss, list):
loss = [loss]
for i in range(len(loss)):
loss_name = loss[i]
loss_functions.append(valid_losses.get(loss_name))
return loss_functions
================================================
FILE: wesep/utils/schedulers.py
================================================
# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
# 2021 Zhengyang Chen (chenzhengyang117@gmail.com)
# 2022 Hongji Wang (jijijiang77@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
class MarginScheduler:
def __init__(
self,
model,
epoch_iter,
increase_start_epoch,
fix_start_epoch,
initial_margin,
final_margin,
update_margin,
increase_type="exp",
):
"""
The margin is fixed as initial_margin before increase_start_epoch,
between increase_start_epoch and fix_start_epoch, the margin is
exponentially increasing from initial_margin to final_margin
after fix_start_epoch, the margin is fixed as final_margin.
"""
self.model = model
self.increase_start_iter = (increase_start_epoch - 1) * epoch_iter
self.fix_start_iter = (fix_start_epoch - 1) * epoch_iter
self.initial_margin = initial_margin
self.final_margin = final_margin
self.increase_type = increase_type
self.fix_already = False
self.current_iter = 0
self.update_margin = update_margin and hasattr(self.model.projection,
"update")
self.increase_iter = self.fix_start_iter - self.increase_start_iter
self.init_margin()
def init_margin(self):
if hasattr(self.model.projection, "update"):
self.model.projection.update(margin=self.initial_margin)
def get_increase_margin(self):
initial_val = 1.0
final_val = 1e-3
current_iter = self.current_iter - self.increase_start_iter
if self.increase_type == "exp": # exponentially increase the margin
ratio = (1.0 - math.exp(
(current_iter / self.increase_iter) *
math.log(final_val / (initial_val + 1e-6))) * initial_val)
else: # linearly increase the margin
ratio = 1.0 * current_iter / self.increase_iter
return (self.initial_margin +
(self.final_margin - self.initial_margin) * ratio)
def step(self, current_iter=None):
if not self.update_margin or self.fix_already:
return
if current_iter is not None:
self.current_iter = current_iter
if self.current_iter >= self.fix_start_iter:
self.fix_already = True
if hasattr(self.model.projection, "update"):
self.model.projection.update(margin=self.final_margin)
elif self.current_iter >= self.increase_start_iter:
if hasattr(self.model.projection, "update"):
self.model.projection.update(margin=self.get_increase_margin())
self.current_iter += 1
def get_margin(self):
try:
margin = self.model.projection.margin
except Exception:
margin = 0.0
return margin
class BaseClass:
"""
Base Class for learning rate scheduler
"""
def __init__(
self,
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch=6,
scale_ratio=1.0,
warm_from_zero=False,
):
"""
warm_up_epoch: the first warm_up_epoch is the multiprocess
warm-up stage
scale_ratio: multiplied to the current lr in the multiprocess
training process
"""
self.optimizer = optimizer
self.max_iter = num_epochs * epoch_iter
self.initial_lr = initial_lr
self.final_lr = final_lr
self.scale_ratio = scale_ratio
self.current_iter = 0
self.warm_up_iter = warm_up_epoch * epoch_iter
self.warm_from_zero = warm_from_zero
def get_multi_process_coeff(self):
lr_coeff = 1.0 * self.scale_ratio
if self.current_iter < self.warm_up_iter:
if self.warm_from_zero:
lr_coeff = (self.scale_ratio * self.current_iter /
self.warm_up_iter)
elif self.scale_ratio > 1:
lr_coeff = (self.scale_ratio -
1) * self.current_iter / self.warm_up_iter + 1.0
return lr_coeff
def get_current_lr(self):
"""
This function should be implemented in the child class
"""
return 0.0
def get_lr(self):
return self.optimizer.param_groups[0]["lr"]
def set_lr(self):
current_lr = self.get_current_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = current_lr
def step(self, current_iter=None):
if current_iter is not None:
self.current_iter = current_iter
self.set_lr()
self.current_iter += 1
def step_return_lr(self, current_iter=None):
if current_iter is not None:
self.current_iter = current_iter
current_lr = self.get_current_lr()
self.current_iter += 1
return current_lr
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value
for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
class ExponentialDecrease(BaseClass):
def __init__(
self,
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch=6,
scale_ratio=1.0,
warm_from_zero=False,
):
super().__init__(
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch,
scale_ratio,
warm_from_zero,
)
def get_current_lr(self):
lr_coeff = self.get_multi_process_coeff()
current_lr = (lr_coeff * self.initial_lr * math.exp(
(self.current_iter / self.max_iter) *
math.log(self.final_lr / self.initial_lr)))
return current_lr
class TriAngular2(BaseClass):
"""
The implementation of https://arxiv.org/pdf/1506.01186.pdf
"""
def __init__(
self,
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch=6,
scale_ratio=1.0,
cycle_step=2,
reduce_lr_diff_ratio=0.5,
):
super().__init__(
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch,
scale_ratio,
)
self.reduce_lr_diff_ratio = reduce_lr_diff_ratio
self.cycle_iter = cycle_step * epoch_iter
self.step_size = self.cycle_iter // 2
self.max_lr = initial_lr
self.min_lr = final_lr
self.gap = self.max_lr - self.min_lr
def get_current_lr(self):
lr_coeff = self.get_multi_process_coeff()
point = self.current_iter % self.cycle_iter
cycle_index = self.current_iter // self.cycle_iter
self.max_lr = (self.min_lr +
self.gap * self.reduce_lr_diff_ratio**cycle_index)
if point <= self.step_size:
current_lr = (self.min_lr +
(self.max_lr - self.min_lr) * point / self.step_size)
else:
current_lr = (self.max_lr - (self.max_lr - self.min_lr) *
(point - self.step_size) / self.step_size)
current_lr = lr_coeff * current_lr
return current_lr
def show_lr_curve(scheduler):
import matplotlib.pyplot as plt
lr_list = []
for current_lr in range(0, scheduler.max_iter):
lr_list.append(scheduler.step_return_lr(current_lr))
data_index = list(range(1, len(lr_list) + 1))
plt.plot(data_index, lr_list, "-o", markersize=1)
plt.legend(loc="best")
plt.xlabel("Iteration")
plt.ylabel("LR")
plt.show()
if __name__ == "__main__":
optimizer = None
num_epochs = 6
epoch_iter = 500
initial_lr = 0.6
final_lr = 0.1
warm_up_epoch = 2
scale_ratio = 4
scheduler = ExponentialDecrease(
optimizer,
num_epochs,
epoch_iter,
initial_lr,
final_lr,
warm_up_epoch,
scale_ratio,
)
# scheduler = TriAngular2(optimizer,
# num_epochs,
# epoch_iter,
# initial_lr,
# final_lr,
# warm_up_epoch,
# scale_ratio,
# cycle_step=2,
# reduce_lr_diff_ratio=0.5)
show_lr_curve(scheduler)
================================================
FILE: wesep/utils/score.py
================================================
import numpy as np
from joblib import Parallel, delayed
from pesq import pesq
from pystoi.stoi import stoi
def cal_SISNR(est, ref, eps=1e-8):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
est: separated signal, numpy.ndarray, [T]
ref: reference signal, numpy.ndarray, [T]
Returns:
SISNR
"""
assert len(est) == len(ref)
est_zm = est - np.mean(est)
ref_zm = ref - np.mean(ref)
t = np.sum(est_zm * ref_zm) * ref_zm / (np.linalg.norm(ref_zm)**2 + eps)
return 20 * np.log10(eps + np.linalg.norm(t) /
(np.linalg.norm(est_zm - t) + eps))
def cal_SISNRi(est, ref, mix, eps=1e-8):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
est: separated signal, numpy.ndarray, [T]
ref: reference signal, numpy.ndarray, [T]
Returns:
SISNR
"""
assert len(est) == len(ref) == len(mix)
sisnr1 = cal_SISNR(est, ref)
sisnr2 = cal_SISNR(mix, ref)
return sisnr1, sisnr1 - sisnr2
def cal_PESQ(est, ref):
assert len(est) == len(ref)
mode = "wb"
p = pesq(16000, ref, est, mode)
return p
def cal_PESQ_norm(est, ref):
assert len(est) == len(ref)
mode = "wb"
try:
# normalize PESQ to (0, 1)
p = (pesq(16000, ref, est, mode) + 0.5) / 5
except Exception:
# error can happen due to silent estimated signal
p = None
return p
def cal_PESQi(est, ref, mix):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
est: separated signal, numpy.ndarray, [T]
ref: reference signal, numpy.ndarray, [T]
Returns:
SISNR
"""
assert len(est) == len(ref) == len(mix)
pesq1 = cal_PESQ(est, ref)
pesq2 = cal_PESQ(mix, ref)
return pesq1, pesq1 - pesq2
def cal_STOI(est, ref):
assert len(est) == len(ref)
p = stoi(ref, est, 16000)
return p
def cal_STOIi(est, ref, mix):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
est: separated signal, numpy.ndarray, [T]
ref: reference signal, numpy.ndarray, [T]
Returns:
SISNR
"""
assert len(est) == len(ref) == len(mix)
stoi1 = cal_STOI(est, ref)
stoi2 = cal_STOI(mix, ref)
return stoi1, stoi1 - stoi2
def batch_evaluation(metric, est, ref, lengths=None, parallel=False, n_jobs=8):
"""Calculate specified evaluation metrics in batches
Args:
metric (Callable): the function to calculate metric
est (np.ndarray): separated signal, numpy.ndarray, [B, T]
ref (np.ndarray): reference signal, numpy.ndarray, [B, T]
lengths (np.ndarray, optional): specify the length of each signal.
Defaults to None.
parallel (bool, optional): whether to calculate metric in parallel.
Default to False.
n_jobs (int, optional): number of jobs, used when `parallel` is True.
Defaults to 8.
Returns:
scores (np.ndarray): batched metrics, [B]
"""
assert callable(metric)
if lengths is not None:
assert ((0 < lengths) & (lengths <= 1)).all()
lengths = (lengths * est.size(1)).round().int().cpu()
est = [p[:length].cpu() for p, length in zip(est, lengths)]
ref = [t[:length].cpu() for t, length in zip(ref, lengths)]
if parallel:
while True:
try:
scores = Parallel(n_jobs=n_jobs,
timeout=30)(delayed(metric)(p, t)
for p, t in zip(est, ref))
break
except Exception as e:
print(e)
print("Evaluation timeout...... (will try again)")
else:
scores = []
for p, t in zip(est, ref):
score = metric(p, t)
scores.append(score)
if None in scores:
return None
return np.array(scores)
================================================
FILE: wesep/utils/signal.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
"""
Return window coefficient
"""
def sqrthann(win_len):
return get_window("hann", win_len, fftbins=True)**0.5
if win_type == "None" or win_type is None:
window = np.ones(win_len)
elif win_type == "sqrthann":
window = sqrthann(win_len)
else:
window = get_window(win_type, win_len, fftbins=True) # **0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
window[None, :, None].astype(np.float32))
class ConvSTFT(nn.Module):
def __init__(
self,
win_len,
win_inc,
fft_len=None,
win_type="hamming",
feature_type="real",
):
super(ConvSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
self.register_buffer("weight", kernel)
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
inputs = F.pad(
inputs, [self.win_len - self.stride, self.win_len - self.stride])
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == "complex":
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return mags, phase
class ConviSTFT(nn.Module):
def __init__(
self,
win_len,
win_inc,
fft_len=None,
win_type="hamming",
feature_type="real",
):
super(ConviSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, window = init_kernels(win_len,
win_inc,
self.fft_len,
win_type,
invers=True)
self.register_buffer("weight", kernel)
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.stride = win_inc
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer("window", window)
self.register_buffer("enframe", torch.eye(win_len)[:, None, :])
def forward(self, inputs, phase=None):
"""
inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
phase: [B, N//2+1, T] (if not none)
"""
if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
# this is from torch-stft: https://github.com/pseeth/torch-stft
t = self.window.repeat(1, 1, inputs.size(-1))**2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
# outputs = torch.where(coff == 0, outputs, outputs/coff)
outputs = outputs[..., self.win_len -
self.stride:-(self.win_len - self.stride)]
return outputs
================================================
FILE: wesep/utils/utils.py
================================================
# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import difflib
import logging
import os
import random
import shutil
import sys
from distutils.util import strtobool
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import yaml
def str2bool(value: str) -> bool:
return bool(strtobool(value))
def get_logger(outdir, fname):
formatter = logging.Formatter(
"[ %(levelname)s : %(asctime)s ] - %(message)s")
logging.basicConfig(
level=logging.DEBUG,
format="[ %(levelname)s : %(asctime)s ] - %(message)s",
)
logger = logging.getLogger("Pyobj, f")
# Dump log to file
fh = logging.FileHandler(os.path.join(outdir, fname))
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def setup_logger(rank, exp_dir, device_ids, MAX_NUM_LOG_FILES: int = 100):
model_dir = os.path.join(exp_dir, "models")
file_name = "train.log"
if rank == 0:
os.makedirs(model_dir, exist_ok=True)
for i in range(MAX_NUM_LOG_FILES - 1, -1, -1):
if i == 0:
p = Path(os.path.join(exp_dir, file_name))
pn = p.parent / (p.stem + ".1" + p.suffix)
else:
_p = Path(os.path.join(exp_dir, file_name))
p = _p.parent / (_p.stem + f".{i}" + _p.suffix)
pn = _p.parent / (_p.stem + f".{i + 1}" + _p.suffix)
if p.exists():
if i == MAX_NUM_LOG_FILES - 1:
p.unlink()
else:
shutil.move(p, pn)
dist.barrier(device_ids=[device_ids]) # let the rank 0 mkdir first
return get_logger(exp_dir, file_name)
def parse_config_or_kwargs(config_file, **kwargs):
"""parse_config_or_kwargs
:param config_file: Config file that has parameters, yaml format
:param **kwargs: Other alternative parameters or overwrites for conf
"""
with open(config_file) as con_read:
yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
# values from conf file are all possible params
help_str = "Valid Parameters are:\n"
help_str += "\n".join(list(yaml_config.keys()))
# passed kwargs will override yaml conf
# for key in kwargs.keys():
# assert key in yaml_config, "Parameter {} invalid!\n".format(key)
# add the path of config file to dict
if "config" not in kwargs:
kwargs["config"] = config_file
return dict(yaml_config, **kwargs)
def validate_path(dir_name):
"""Create the directory if it doesn't exist
:param dir_name
:return: None
"""
dir_name = os.path.dirname(dir_name) # get the path
if not os.path.exists(dir_name) and (dir_name != ""):
os.makedirs(dir_name)
def set_seed(seed=42):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def generate_enahnced_scp(directory: str, extension: str = "wav"):
source_dir = Path(directory)
spk_scp = source_dir.joinpath("spk1.scp")
audio_list = []
for file_path in source_dir.rglob(f"*.{extension}"):
audio_list.append(file_path)
with open(spk_scp, "w") as f:
for audio in audio_list:
path = str(audio.resolve())
ori_filename = audio.stem
spk1_id = ori_filename.split("-")[1]
# spk2_id = ori_filename.split("_")[1].split("-")[0]
curr_spk = ori_filename.split("T")[1]
prefix = "s1" if curr_spk == spk1_id else "s2"
f_dash_index = ori_filename.find("-")
l_dash_index = ori_filename.rfind("-")
filename = ori_filename[f_dash_index + 1:l_dash_index]
final_filename = prefix + "/" + filename + ".wav"
line = final_filename + " " + path
f.write(line + "\n")
def get_commandline_args():
# ported from
# https://github.com/espnet/espnet/blob/master/espnet/utils/cli_utils.py
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [(arg.replace("'", "'\\''") if all(
char not in arg
for char in extra_chars) else "'" + arg.replace("'", "'\\''") + "'")
for arg in sys.argv]
return sys.executable + " " + " ".join(argv)
# ported from
# https://github.com/espnet/espnet/blob/master/espnet2/utils/config_argparse.py
class ArgumentParser(argparse.ArgumentParser):
"""Simple implementation of ArgumentParser supporting config file
This class is originated from https://github.com/bw2/ConfigArgParse,
but this class is lack of some features that it has.
- Not supporting multiple config files
- Automatically adding "--config" as an option.
- Not supporting any formats other than yaml
- Not checking argument type
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_argument("--config", help="Give config file in yaml format")
def parse_known_args(self, args=None, namespace=None):
# Once parsing for setting from "--config"
_args, _ = super().parse_known_args(args, namespace)
if _args.config is not None:
if not Path(_args.config).exists():
self.error(f"No such file: {_args.config}")
with open(_args.config, "r", encoding="utf-8") as f:
d = yaml.safe_load(f)
if not isinstance(d, dict):
self.error("Config file has non dict value: {_args.config}")
for key in d:
for action in self._actions:
if key == action.dest:
break
else:
self.error(
f"unrecognized arguments: {key} (from {_args.config})")
# NOTE(kamo): Ignore "--config" from a config file
# NOTE(kamo): Unlike "configargparse", this module doesn't
# check type. i.e. We can set any type value
# regardless of argument type.
self.set_defaults(**d)
return super().parse_known_args(args, namespace)
def get_layer(l_name, library=torch.nn):
"""Return layer object handler from library e.g. from torch.nn
E.g. if l_name=="elu", returns torch.nn.ELU.
Args:
l_name (string): Case insensitive name for layer in library
(e.g. .'elu').
library (module): Name of library/module where to search for
object handler with l_name e.g. "torch.nn".
Returns:
layer_handler (object): handler for the requested layer
e.g. (torch.nn.ELU)
"""
all_torch_layers = list(dir(torch.nn))
match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
if len(match) == 0:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers])
raise NotImplementedError(
"Layer with name {} not found in {}.\n Closest matches: {}".format(
l_name, str(library), close_matches))
elif len(match) > 1:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers])
raise NotImplementedError(
"Multiple matchs for layer with name {} not found in {}.\n "
"All matches: {}".format(l_name, str(library), close_matches))
else:
# valid
layer_handler = getattr(library, match[0])
return layer_handler
# def spk2id(utt_spk_list):
# _, spk_list = zip(*utt_spk_list)
# spk_list = sorted(list(set(spk_list))) # remove overlap and sort
# spk2id_dict = {}
# spk_list.sort()
# for i, spk in enumerate(spk_list):
# spk2id_dict[spk] = i
# return spk2id_dict