Showing preview only (8,357K chars total). Download the full file or copy to clipboard to get everything.
Repository: jctian98/e2e_lfmmi
Branch: master
Commit: 34b805690663
Files: 1103
Total size: 7.8 MB
Directory structure:
gitextract_ni0k430x/
├── .gitignore
├── .run.sh.swp
├── README.md
├── __init__.py
├── asr/
│ ├── __init__.py
│ ├── asr_mix_utils.py
│ ├── asr_utils.py
│ ├── chainer_backend/
│ │ ├── __init__.py
│ │ └── asr.py
│ └── pytorch_backend/
│ ├── __init__.py
│ ├── asr.py
│ ├── asr_init.py
│ ├── asr_mix.py
│ └── recog.py
├── bin/
│ ├── __init__.py
│ ├── asr_align.py
│ ├── asr_enhance.py
│ ├── asr_recog.py
│ ├── asr_train.py
│ ├── lm_train.py
│ ├── mt_train.py
│ ├── mt_trans.py
│ ├── st_train.py
│ ├── st_trans.py
│ ├── tts_decode.py
│ ├── tts_train.py
│ ├── vc_decode.py
│ └── vc_train.py
├── egs/
│ ├── .gitignore
│ ├── aishell1/
│ │ ├── .gitignore
│ │ ├── aed.sh
│ │ ├── cmd.sh
│ │ ├── conf/
│ │ │ ├── fbank.conf
│ │ │ ├── gpu.conf
│ │ │ ├── lm.yaml
│ │ │ ├── lm_rnn.yaml
│ │ │ ├── lm_transformer.yaml
│ │ │ ├── pitch.conf
│ │ │ ├── queue.conf
│ │ │ ├── slurm.conf
│ │ │ ├── specaug.yaml
│ │ │ ├── specaug_test.yaml
│ │ │ └── tuning/
│ │ │ ├── decode_pytorch_transformer.yaml
│ │ │ ├── decode_rnn.yaml
│ │ │ ├── train_pytorch_conformer_kernel15.yaml
│ │ │ ├── train_pytorch_conformer_kernel31.yaml
│ │ │ ├── train_pytorch_conformer_kernel31_large.yaml
│ │ │ ├── train_pytorch_conformer_kernel31_small.yaml
│ │ │ ├── train_pytorch_transformer.yaml
│ │ │ ├── train_rnn.yaml
│ │ │ └── transducer/
│ │ │ ├── decode_default.yaml
│ │ │ ├── train_conformer-rnn_transducer.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_att.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_small.yaml
│ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml
│ │ │ ├── train_conformer-rnn_transducer_ngpu4_large.yaml
│ │ │ ├── train_transducer.yaml
│ │ │ └── train_transducer_aux.yaml
│ │ ├── local/
│ │ │ ├── add_lex_disambig.pl
│ │ │ ├── aishell_data_prep.sh
│ │ │ ├── aishell_train_lms.sh
│ │ │ ├── apply_map.pl
│ │ │ ├── build_sp_text.py
│ │ │ ├── build_word_mapping.py
│ │ │ ├── compile_bigram.sh
│ │ │ ├── download_and_untar.sh
│ │ │ ├── fstaddselfloops.pl
│ │ │ ├── k2_aishell_prepare_dict.sh
│ │ │ ├── k2_aishell_prepare_dict_char.sh
│ │ │ ├── k2_prepare_lang.sh
│ │ │ ├── make_lexicon_fst.py
│ │ │ ├── max_rescore.py
│ │ │ ├── parse_options.sh
│ │ │ ├── parse_text_jieba.py
│ │ │ ├── prepare_word_lex.py
│ │ │ └── sym2int.pl
│ │ ├── nt.sh
│ │ ├── path.sh
│ │ └── prepare.sh
│ ├── aishell2/
│ │ ├── .gitignore
│ │ ├── aed.sh
│ │ ├── conf/
│ │ │ ├── .fbank.conf.swp
│ │ │ ├── fbank.conf
│ │ │ ├── gpu.conf
│ │ │ ├── lm.yaml
│ │ │ ├── lm_rnn.yaml
│ │ │ ├── lm_transformer.yaml
│ │ │ ├── pitch.conf
│ │ │ ├── queue.conf
│ │ │ ├── slurm.conf
│ │ │ ├── specaug.yaml
│ │ │ ├── specaug_test.yaml
│ │ │ └── tuning/
│ │ │ ├── decode_pytorch_transformer.yaml
│ │ │ ├── decode_rnn.yaml
│ │ │ ├── train_pytorch_conformer_kernel15.yaml
│ │ │ ├── train_pytorch_conformer_kernel31.yaml
│ │ │ ├── train_pytorch_transformer.yaml
│ │ │ ├── train_rnn.yaml
│ │ │ └── transducer/
│ │ │ ├── decode_default.yaml
│ │ │ ├── train_conformer-rnn_transducer.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml
│ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml
│ │ │ ├── train_transducer.yaml
│ │ │ └── train_transducer_aux.yaml
│ │ ├── local/
│ │ │ ├── add_lex_disambig.pl
│ │ │ ├── apply_map.pl
│ │ │ ├── fstaddselfloops.pl
│ │ │ ├── jieba_split_text.py
│ │ │ ├── k2_prepare_lang.sh
│ │ │ ├── make_lexicon_fst.py
│ │ │ ├── max_rescore.py
│ │ │ ├── mmi_rescore.sh
│ │ │ ├── parse_options.sh
│ │ │ ├── prepare_data.sh
│ │ │ ├── prepare_dict.sh
│ │ │ ├── rerank.py
│ │ │ ├── sym2int.pl
│ │ │ ├── train_lms.sh
│ │ │ └── word_segmentation.py
│ │ ├── nt.sh
│ │ └── prepare.sh
│ ├── asrucs/
│ │ ├── .gitignore
│ │ ├── cmd.sh
│ │ ├── conf/
│ │ │ ├── decode.yaml
│ │ │ ├── fbank.conf
│ │ │ ├── gpu.conf
│ │ │ ├── lm.yaml
│ │ │ ├── lm_rnn.yaml
│ │ │ ├── lm_transformer.yaml
│ │ │ ├── pitch.conf
│ │ │ ├── pure_ctc.yaml
│ │ │ ├── queue.conf
│ │ │ ├── slurm.conf
│ │ │ ├── specaug.yaml
│ │ │ ├── specaug_test.yaml
│ │ │ ├── train.yaml
│ │ │ ├── train_conformer-rnn_transducer_cs.yaml
│ │ │ └── tuning/
│ │ │ ├── decode_pytorch_transformer.yaml
│ │ │ ├── decode_rnn.yaml
│ │ │ ├── train_pytorch_conformer_kernel15.yaml
│ │ │ ├── train_pytorch_conformer_kernel31.yaml
│ │ │ ├── train_pytorch_conformer_kernel31_large.yaml
│ │ │ ├── train_pytorch_conformer_kernel31_small.yaml
│ │ │ ├── train_pytorch_transformer.yaml
│ │ │ ├── train_rnn.yaml
│ │ │ └── transducer/
│ │ │ ├── decode_default.yaml
│ │ │ ├── train_conformer-rnn_transducer.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_att.yaml
│ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_small.yaml
│ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml
│ │ │ ├── train_conformer-rnn_transducer_ngpu4_large.yaml
│ │ │ ├── train_transducer.yaml
│ │ │ └── train_transducer_aux.yaml
│ │ ├── espnet
│ │ ├── espnet_utils
│ │ ├── local/
│ │ │ ├── add_seperator.py
│ │ │ ├── generate_fake_cs.py
│ │ │ └── prepare_fake_cs.sh
│ │ ├── nt.sh
│ │ ├── path.sh
│ │ ├── prepare.sh
│ │ ├── steps
│ │ ├── text
│ │ └── utils
│ ├── espnet_utils/
│ │ ├── add_uttcls_json.py
│ │ ├── addjson.py
│ │ ├── apply-cmvn.py
│ │ ├── asr_align_wav.sh
│ │ ├── average_checkpoints.py
│ │ ├── build_fake_lexicon.py
│ │ ├── build_sp_text.py
│ │ ├── calculate_rtf.py
│ │ ├── change_root.py
│ │ ├── change_yaml.py
│ │ ├── clean_corpus.sh
│ │ ├── compute-cmvn-stats.py
│ │ ├── compute-fbank-feats.py
│ │ ├── compute-stft-feats.py
│ │ ├── concat_json_multiref.py
│ │ ├── concatjson.py
│ │ ├── convert_fbank.sh
│ │ ├── convert_fbank_to_wav.py
│ │ ├── copy-feats.py
│ │ ├── data2json.sh
│ │ ├── divide_lang.sh
│ │ ├── double_precious_cer.py
│ │ ├── download_from_google_drive.sh
│ │ ├── dump-pcm.py
│ │ ├── dump.sh
│ │ ├── dump_pcm.sh
│ │ ├── eval-source-separation.py
│ │ ├── eval_perm_free_error.py
│ │ ├── eval_source_separation.sh
│ │ ├── feat-to-shape.py
│ │ ├── feat_to_shape.sh
│ │ ├── feats2npy.py
│ │ ├── filt.py
│ │ ├── filter_all_eng_utts.py
│ │ ├── filter_scp.py
│ │ ├── filter_trn.py
│ │ ├── free-gpu.sh
│ │ ├── gdown.pl
│ │ ├── generate_wav.sh
│ │ ├── generate_wav_from_fbank.py
│ │ ├── get_yaml.py
│ │ ├── jieba_build_dict.py
│ │ ├── json2sctm.py
│ │ ├── json2text.py
│ │ ├── json2trn.py
│ │ ├── json2trn_mt.py
│ │ ├── json2trn_wo_dict.py
│ │ ├── k2/
│ │ │ ├── add_lex_disambig.pl
│ │ │ ├── apply_map.pl
│ │ │ ├── fstaddselfloops.pl
│ │ │ ├── k2_prepare_lang.sh
│ │ │ ├── parse_options.sh
│ │ │ └── sym2int.pl
│ │ ├── make_fbank.sh
│ │ ├── make_pair_json.py
│ │ ├── make_stft.sh
│ │ ├── mbr_analysis.py
│ │ ├── mcd_calculate.py
│ │ ├── merge_scp2json.py
│ │ ├── mergejson.py
│ │ ├── mix-mono-wav-scp.py
│ │ ├── mmi_rescore.sh
│ │ ├── pack_model.sh
│ │ ├── prepare_block_load.sh
│ │ ├── prepare_mer.py
│ │ ├── queue-freegpu.pl
│ │ ├── recog_wav.sh
│ │ ├── reduce_data_dir.sh
│ │ ├── remove_longshortdata.sh
│ │ ├── remove_punctuation.pl
│ │ ├── rerank_mmi.py
│ │ ├── result2json.py
│ │ ├── score_bleu.sh
│ │ ├── score_lang_id.py
│ │ ├── score_sclite.sh
│ │ ├── score_sclite_case.sh
│ │ ├── score_sclite_wo_dict.sh
│ │ ├── scp2json.py
│ │ ├── show_result.sh
│ │ ├── significant_test.sh
│ │ ├── sort_scp_by_length.py
│ │ ├── speed_perturb.sh
│ │ ├── split_scp.py
│ │ ├── split_scp_fix_length.py
│ │ ├── splitjson.py
│ │ ├── spm_decode
│ │ ├── spm_encode
│ │ ├── spm_train
│ │ ├── stdout.pl
│ │ ├── synth_wav.sh
│ │ ├── text2token.py
│ │ ├── text2vocabulary.py
│ │ ├── text_norm.py
│ │ ├── trace_rnnt.py
│ │ ├── train_lms_srilm.sh
│ │ ├── translate_wav.sh
│ │ ├── trim_silence.py
│ │ ├── trim_silence.sh
│ │ ├── trn2ctm.py
│ │ ├── trn2stm.py
│ │ ├── update_json.sh
│ │ ├── word_ngram_rescore.py
│ │ └── word_ngram_rescore.sh
│ ├── steps/
│ │ ├── align_basis_fmllr.sh
│ │ ├── align_basis_fmllr_lats.sh
│ │ ├── align_fmllr.sh
│ │ ├── align_fmllr_lats.sh
│ │ ├── align_lvtln.sh
│ │ ├── align_raw_fmllr.sh
│ │ ├── align_sgmm2.sh
│ │ ├── align_si.sh
│ │ ├── best_path_weights.sh
│ │ ├── cleanup/
│ │ │ ├── clean_and_segment_data.sh
│ │ │ ├── clean_and_segment_data_nnet3.sh
│ │ │ ├── combine_short_segments.py
│ │ │ ├── create_segments_from_ctm.pl
│ │ │ ├── debug_lexicon.sh
│ │ │ ├── decode_fmllr_segmentation.sh
│ │ │ ├── decode_segmentation.sh
│ │ │ ├── decode_segmentation_nnet3.sh
│ │ │ ├── find_bad_utts.sh
│ │ │ ├── find_bad_utts_nnet.sh
│ │ │ ├── internal/
│ │ │ │ ├── align_ctm_ref.py
│ │ │ │ ├── compute_tf_idf.py
│ │ │ │ ├── ctm_to_text.pl
│ │ │ │ ├── get_ctm_edits.py
│ │ │ │ ├── get_non_scored_words.py
│ │ │ │ ├── get_pron_stats.py
│ │ │ │ ├── make_one_biased_lm.py
│ │ │ │ ├── modify_ctm_edits.py
│ │ │ │ ├── resolve_ctm_edits_overlaps.py
│ │ │ │ ├── retrieve_similar_docs.py
│ │ │ │ ├── segment_ctm_edits.py
│ │ │ │ ├── segment_ctm_edits_mild.py
│ │ │ │ ├── split_text_into_docs.pl
│ │ │ │ ├── stitch_documents.py
│ │ │ │ ├── taint_ctm_edits.py
│ │ │ │ └── tf_idf.py
│ │ │ ├── lattice_oracle_align.sh
│ │ │ ├── make_biased_lm_graphs.sh
│ │ │ ├── make_biased_lms.py
│ │ │ ├── make_segmentation_data_dir.sh
│ │ │ ├── make_segmentation_graph.sh
│ │ │ ├── make_utterance_fsts.pl
│ │ │ ├── make_utterance_graph.sh
│ │ │ ├── segment_long_utterances.sh
│ │ │ ├── segment_long_utterances_nnet3.sh
│ │ │ └── split_long_utterance.sh
│ │ ├── combine_ali_dirs.sh
│ │ ├── combine_trans_dirs.sh
│ │ ├── compare_alignments.sh
│ │ ├── compute_cmvn_stats.sh
│ │ ├── compute_vad_decision.sh
│ │ ├── conf/
│ │ │ ├── append_eval_to_ctm.py
│ │ │ ├── append_prf_to_ctm.py
│ │ │ ├── apply_calibration.sh
│ │ │ ├── convert_ctm_to_tra.py
│ │ │ ├── get_ctm_conf.sh
│ │ │ ├── lattice_depth_per_frame.sh
│ │ │ ├── parse_arpa_unigrams.py
│ │ │ ├── prepare_calibration_data.py
│ │ │ ├── prepare_word_categories.py
│ │ │ └── train_calibration.sh
│ │ ├── copy_ali_dir.sh
│ │ ├── copy_lat_dir.sh
│ │ ├── copy_trans_dir.sh
│ │ ├── data/
│ │ │ ├── augment_data_dir.py
│ │ │ ├── data_dir_manipulation_lib.py
│ │ │ ├── make_musan.py
│ │ │ ├── make_musan.sh
│ │ │ └── reverberate_data_dir.py
│ │ ├── decode.sh
│ │ ├── decode_basis_fmllr.sh
│ │ ├── decode_biglm.sh
│ │ ├── decode_combine.sh
│ │ ├── decode_fmllr.sh
│ │ ├── decode_fmllr_extra.sh
│ │ ├── decode_fmmi.sh
│ │ ├── decode_fromlats.sh
│ │ ├── decode_lvtln.sh
│ │ ├── decode_nolats.sh
│ │ ├── decode_raw_fmllr.sh
│ │ ├── decode_sgmm2.sh
│ │ ├── decode_sgmm2_fromlats.sh
│ │ ├── decode_sgmm2_rescore.sh
│ │ ├── decode_sgmm2_rescore_project.sh
│ │ ├── decode_with_map.sh
│ │ ├── diagnostic/
│ │ │ ├── analyze_alignments.sh
│ │ │ ├── analyze_lats.sh
│ │ │ ├── analyze_lattice_depth_stats.py
│ │ │ └── analyze_phone_length_stats.py
│ │ ├── dict/
│ │ │ ├── apply_g2p.sh
│ │ │ ├── apply_g2p_phonetisaurus.sh
│ │ │ ├── apply_lexicon_edits.py
│ │ │ ├── get_pron_stats.py
│ │ │ ├── internal/
│ │ │ │ ├── get_subsegments.py
│ │ │ │ ├── prune_pron_candidates.py
│ │ │ │ └── sum_arc_info.py
│ │ │ ├── learn_lexicon_bayesian.sh
│ │ │ ├── learn_lexicon_greedy.sh
│ │ │ ├── merge_learned_lexicons.py
│ │ │ ├── prons_to_lexicon.py
│ │ │ ├── prune_pron_candidates.py
│ │ │ ├── select_prons_bayesian.py
│ │ │ ├── select_prons_greedy.py
│ │ │ ├── train_g2p.sh
│ │ │ └── train_g2p_phonetisaurus.sh
│ │ ├── get_ctm.sh
│ │ ├── get_ctm_conf_fast.sh
│ │ ├── get_ctm_fast.sh
│ │ ├── get_fmllr_basis.sh
│ │ ├── get_lexicon_probs.sh
│ │ ├── get_prons.sh
│ │ ├── get_train_ctm.sh
│ │ ├── info/
│ │ │ ├── chain_dir_info.pl
│ │ │ ├── gmm_dir_info.pl
│ │ │ ├── nnet2_dir_info.pl
│ │ │ ├── nnet3_dir_info.pl
│ │ │ └── nnet3_disc_dir_info.pl
│ │ ├── libs/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ └── nnet3/
│ │ │ ├── __init__.py
│ │ │ ├── report/
│ │ │ │ ├── __init__.py
│ │ │ │ └── log_parse.py
│ │ │ ├── train/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── chain_objf/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── acoustic_model.py
│ │ │ │ ├── common.py
│ │ │ │ ├── dropout_schedule.py
│ │ │ │ └── frame_level_objf/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── acoustic_model.py
│ │ │ │ ├── common.py
│ │ │ │ └── raw_model.py
│ │ │ └── xconfig/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── basic_layers.py
│ │ │ ├── composite_layers.py
│ │ │ ├── convolution.py
│ │ │ ├── gru.py
│ │ │ ├── layers.py
│ │ │ ├── lstm.py
│ │ │ ├── parser.py
│ │ │ ├── stats_layer.py
│ │ │ ├── trivial_layers.py
│ │ │ └── utils.py
│ │ ├── lmrescore.sh
│ │ ├── lmrescore_const_arpa.sh
│ │ ├── lmrescore_const_arpa_undeterminized.sh
│ │ ├── lmrescore_rnnlm_lat.sh
│ │ ├── make_denlats.sh
│ │ ├── make_denlats_sgmm2.sh
│ │ ├── make_fbank.sh
│ │ ├── make_fbank_pitch.sh
│ │ ├── make_index.sh
│ │ ├── make_mfcc.sh
│ │ ├── make_mfcc_pitch.sh
│ │ ├── make_mfcc_pitch_online.sh
│ │ ├── make_phone_graph.sh
│ │ ├── make_plp.sh
│ │ ├── make_plp_pitch.sh
│ │ ├── nnet/
│ │ │ ├── align.sh
│ │ │ ├── decode.sh
│ │ │ ├── ivector/
│ │ │ │ ├── extract_ivectors.sh
│ │ │ │ ├── train_diag_ubm.sh
│ │ │ │ └── train_ivector_extractor.sh
│ │ │ ├── make_bn_feats.sh
│ │ │ ├── make_denlats.sh
│ │ │ ├── make_fmllr_feats.sh
│ │ │ ├── make_fmmi_feats.sh
│ │ │ ├── make_priors.sh
│ │ │ ├── pretrain_dbn.sh
│ │ │ ├── train.sh
│ │ │ ├── train_mmi.sh
│ │ │ ├── train_mpe.sh
│ │ │ └── train_scheduler.sh
│ │ ├── nnet2/
│ │ │ ├── adjust_priors.sh
│ │ │ ├── align.sh
│ │ │ ├── check_ivectors_compatible.sh
│ │ │ ├── convert_lda_to_raw.sh
│ │ │ ├── convert_nnet1_to_nnet2.sh
│ │ │ ├── create_appended_model.sh
│ │ │ ├── decode.sh
│ │ │ ├── dump_bottleneck_features.sh
│ │ │ ├── get_egs.sh
│ │ │ ├── get_egs2.sh
│ │ │ ├── get_egs_discriminative2.sh
│ │ │ ├── get_ivector_id.sh
│ │ │ ├── get_lda.sh
│ │ │ ├── get_lda_block.sh
│ │ │ ├── get_perturbed_feats.sh
│ │ │ ├── make_denlats.sh
│ │ │ ├── make_multisplice_configs.py
│ │ │ ├── relabel_egs.sh
│ │ │ ├── relabel_egs2.sh
│ │ │ ├── remove_egs.sh
│ │ │ ├── retrain_fast.sh
│ │ │ ├── retrain_simple2.sh
│ │ │ ├── retrain_tanh.sh
│ │ │ ├── train_block.sh
│ │ │ ├── train_convnet_accel2.sh
│ │ │ ├── train_discriminative.sh
│ │ │ ├── train_discriminative2.sh
│ │ │ ├── train_discriminative_multilang2.sh
│ │ │ ├── train_more.sh
│ │ │ ├── train_more2.sh
│ │ │ ├── train_multilang2.sh
│ │ │ ├── train_multisplice_accel2.sh
│ │ │ ├── train_multisplice_ensemble.sh
│ │ │ ├── train_pnorm.sh
│ │ │ ├── train_pnorm_accel2.sh
│ │ │ ├── train_pnorm_bottleneck_fast.sh
│ │ │ ├── train_pnorm_ensemble.sh
│ │ │ ├── train_pnorm_fast.sh
│ │ │ ├── train_pnorm_multisplice.sh
│ │ │ ├── train_pnorm_multisplice2.sh
│ │ │ ├── train_pnorm_simple.sh
│ │ │ ├── train_pnorm_simple2.sh
│ │ │ ├── train_tanh.sh
│ │ │ ├── train_tanh_bottleneck.sh
│ │ │ ├── train_tanh_fast.sh
│ │ │ └── update_nnet.sh
│ │ ├── nnet3/
│ │ │ ├── adjust_priors.sh
│ │ │ ├── align.sh
│ │ │ ├── align_lats.sh
│ │ │ ├── chain/
│ │ │ │ ├── align_lats.sh
│ │ │ │ ├── build_tree.sh
│ │ │ │ ├── build_tree_multiple_sources.sh
│ │ │ │ ├── e2e/
│ │ │ │ │ ├── README.txt
│ │ │ │ │ ├── compute_biphone_stats.py
│ │ │ │ │ ├── get_egs_e2e.sh
│ │ │ │ │ ├── prepare_e2e.sh
│ │ │ │ │ ├── text_to_phones.py
│ │ │ │ │ └── train_e2e.py
│ │ │ │ ├── gen_topo.pl
│ │ │ │ ├── gen_topo.py
│ │ │ │ ├── gen_topo2.py
│ │ │ │ ├── gen_topo3.py
│ │ │ │ ├── gen_topo4.py
│ │ │ │ ├── gen_topo5.py
│ │ │ │ ├── gen_topo_orig.py
│ │ │ │ ├── get_egs.sh
│ │ │ │ ├── get_model_context.sh
│ │ │ │ ├── get_phone_post.sh
│ │ │ │ ├── make_weighted_den_fst.sh
│ │ │ │ ├── multilingual/
│ │ │ │ │ └── combine_egs.sh
│ │ │ │ ├── train.py
│ │ │ │ └── train_tdnn.sh
│ │ │ ├── chain2/
│ │ │ │ ├── combine_egs.sh
│ │ │ │ ├── compute_preconditioning_matrix.sh
│ │ │ │ ├── get_raw_egs.sh
│ │ │ │ ├── internal/
│ │ │ │ │ ├── get_best_model.sh
│ │ │ │ │ └── get_train_schedule.py
│ │ │ │ ├── process_egs.sh
│ │ │ │ ├── randomize_egs.sh
│ │ │ │ ├── train.sh
│ │ │ │ ├── validate_processed_egs.sh
│ │ │ │ ├── validate_randomized_egs.sh
│ │ │ │ └── validate_raw_egs.sh
│ │ │ ├── components.py
│ │ │ ├── compute_output.sh
│ │ │ ├── convert_nnet2_to_nnet3.py
│ │ │ ├── decode.sh
│ │ │ ├── decode_grammar.sh
│ │ │ ├── decode_lookahead.sh
│ │ │ ├── decode_looped.sh
│ │ │ ├── decode_score_fusion.sh
│ │ │ ├── decode_semisup.sh
│ │ │ ├── dot/
│ │ │ │ ├── descriptor_parser.py
│ │ │ │ └── nnet3_to_dot.py
│ │ │ ├── get_degs.sh
│ │ │ ├── get_egs.sh
│ │ │ ├── get_egs_discriminative.sh
│ │ │ ├── get_egs_targets.sh
│ │ │ ├── get_saturation.pl
│ │ │ ├── get_successful_models.py
│ │ │ ├── lstm/
│ │ │ │ ├── make_configs.py
│ │ │ │ └── train.sh
│ │ │ ├── make_bottleneck_features.sh
│ │ │ ├── make_denlats.sh
│ │ │ ├── make_tdnn_configs.py
│ │ │ ├── multilingual/
│ │ │ │ ├── allocate_multilingual_examples.py
│ │ │ │ └── combine_egs.sh
│ │ │ ├── nnet3_to_dot.sh
│ │ │ ├── report/
│ │ │ │ ├── convert_model.py
│ │ │ │ ├── generate_plots.py
│ │ │ │ └── summarize_compute_debug_timing.py
│ │ │ ├── tdnn/
│ │ │ │ ├── make_configs.py
│ │ │ │ ├── train.sh
│ │ │ │ └── train_raw_nnet.sh
│ │ │ ├── train_discriminative.sh
│ │ │ ├── train_dnn.py
│ │ │ ├── train_raw_dnn.py
│ │ │ ├── train_raw_rnn.py
│ │ │ ├── train_rnn.py
│ │ │ ├── train_tdnn.sh
│ │ │ ├── xconfig_to_config.py
│ │ │ └── xconfig_to_configs.py
│ │ ├── online/
│ │ │ ├── decode.sh
│ │ │ ├── nnet2/
│ │ │ │ ├── align.sh
│ │ │ │ ├── copy_data_dir.sh
│ │ │ │ ├── copy_ivector_dir.sh
│ │ │ │ ├── decode.sh
│ │ │ │ ├── dump_nnet_activations.sh
│ │ │ │ ├── extract_ivectors.sh
│ │ │ │ ├── extract_ivectors_online.sh
│ │ │ │ ├── get_egs.sh
│ │ │ │ ├── get_egs2.sh
│ │ │ │ ├── get_egs_discriminative2.sh
│ │ │ │ ├── get_pca_transform.sh
│ │ │ │ ├── make_denlats.sh
│ │ │ │ ├── prepare_online_decoding.sh
│ │ │ │ ├── prepare_online_decoding_retrain.sh
│ │ │ │ ├── prepare_online_decoding_transfer.sh
│ │ │ │ ├── train_diag_ubm.sh
│ │ │ │ └── train_ivector_extractor.sh
│ │ │ ├── nnet3/
│ │ │ │ ├── decode.sh
│ │ │ │ ├── decode_wake_word.sh
│ │ │ │ └── prepare_online_decoding.sh
│ │ │ └── prepare_online_decoding.sh
│ │ ├── oracle_wer.sh
│ │ ├── overlap/
│ │ │ ├── get_overlap_segments.py
│ │ │ ├── get_overlap_targets.py
│ │ │ ├── output_to_rttm.py
│ │ │ ├── post_process_output.sh
│ │ │ └── prepare_overlap_graph.py
│ │ ├── paste_feats.sh
│ │ ├── pytorchnn/
│ │ │ ├── check_py.py
│ │ │ ├── compute_sentence_scores.py
│ │ │ ├── data.py
│ │ │ ├── lmrescore_nbest_pytorchnn.sh
│ │ │ ├── model.py
│ │ │ └── train.py
│ │ ├── resegment_data.sh
│ │ ├── resegment_text.sh
│ │ ├── rnnlmrescore.sh
│ │ ├── scoring/
│ │ │ ├── score_kaldi_cer.sh
│ │ │ ├── score_kaldi_compare.sh
│ │ │ └── score_kaldi_wer.sh
│ │ ├── search_index.sh
│ │ ├── segmentation/
│ │ │ ├── ali_to_targets.sh
│ │ │ ├── combine_targets_dirs.sh
│ │ │ ├── convert_targets_dir_to_whole_recording.sh
│ │ │ ├── convert_utt2spk_and_segments_to_rttm.py
│ │ │ ├── copy_targets_dir.sh
│ │ │ ├── decode_sad.sh
│ │ │ ├── detect_speech_activity.sh
│ │ │ ├── evaluate_segmentation.pl
│ │ │ ├── get_targets_for_out_of_segments.sh
│ │ │ ├── internal/
│ │ │ │ ├── arc_info_to_targets.py
│ │ │ │ ├── find_oov_phone.py
│ │ │ │ ├── get_default_targets_for_out_of_segments.py
│ │ │ │ ├── get_transform_probs_mat.py
│ │ │ │ ├── merge_segment_targets_to_recording.py
│ │ │ │ ├── merge_targets.py
│ │ │ │ ├── prepare_sad_graph.py
│ │ │ │ ├── resample_targets.py
│ │ │ │ ├── sad_to_segments.py
│ │ │ │ └── verify_phones_list.py
│ │ │ ├── lats_to_targets.sh
│ │ │ ├── merge_targets_dirs.sh
│ │ │ ├── post_process_sad_to_segments.sh
│ │ │ ├── prepare_targets_gmm.sh
│ │ │ ├── resample_targets_dir.sh
│ │ │ └── validate_targets_dir.sh
│ │ ├── select_feats.sh
│ │ ├── shift_feats.sh
│ │ ├── subset_ali_dir.sh
│ │ ├── tandem/
│ │ │ ├── align_fmllr.sh
│ │ │ ├── align_sgmm2.sh
│ │ │ ├── align_si.sh
│ │ │ ├── decode.sh
│ │ │ ├── decode_fmllr.sh
│ │ │ ├── decode_sgmm2.sh
│ │ │ ├── make_denlats.sh
│ │ │ ├── make_denlats_sgmm2.sh
│ │ │ ├── mk_aslf_lda_mllt.sh
│ │ │ ├── mk_aslf_sgmm2.sh
│ │ │ ├── train_deltas.sh
│ │ │ ├── train_lda_mllt.sh
│ │ │ ├── train_mllt.sh
│ │ │ ├── train_mmi.sh
│ │ │ ├── train_mmi_sgmm2.sh
│ │ │ ├── train_mono.sh
│ │ │ ├── train_sat.sh
│ │ │ ├── train_sgmm2.sh
│ │ │ └── train_ubm.sh
│ │ ├── tfrnnlm/
│ │ │ ├── check_py.py
│ │ │ ├── check_tensorflow_installed.sh
│ │ │ ├── lmrescore_rnnlm_lat.sh
│ │ │ ├── lmrescore_rnnlm_lat_pruned.sh
│ │ │ ├── lstm.py
│ │ │ ├── lstm_fast.py
│ │ │ ├── reader.py
│ │ │ └── vanilla_rnnlm.py
│ │ ├── train_deltas.sh
│ │ ├── train_diag_ubm.sh
│ │ ├── train_lda_mllt.sh
│ │ ├── train_lvtln.sh
│ │ ├── train_map.sh
│ │ ├── train_mmi.sh
│ │ ├── train_mmi_fmmi.sh
│ │ ├── train_mmi_fmmi_indirect.sh
│ │ ├── train_mmi_sgmm2.sh
│ │ ├── train_mono.sh
│ │ ├── train_mpe.sh
│ │ ├── train_quick.sh
│ │ ├── train_raw_sat.sh
│ │ ├── train_sat.sh
│ │ ├── train_sat_basis.sh
│ │ ├── train_segmenter.sh
│ │ ├── train_sgmm2.sh
│ │ ├── train_sgmm2_group.sh
│ │ ├── train_smbr.sh
│ │ ├── train_ubm.sh
│ │ └── word_align_lattices.sh
│ └── utils/
│ ├── add_disambig.pl
│ ├── add_lex_disambig.pl
│ ├── analyze_segments.pl
│ ├── apply_map.pl
│ ├── best_wer.sh
│ ├── build_const_arpa_lm.sh
│ ├── combine_data.sh
│ ├── convert_slf.pl
│ ├── convert_slf_parallel.sh
│ ├── copy_data_dir.sh
│ ├── create_data_link.pl
│ ├── create_split_dir.pl
│ ├── ctm/
│ │ ├── convert_ctm.pl
│ │ ├── fix_ctm.sh
│ │ └── resolve_ctm_overlaps.py
│ ├── data/
│ │ ├── combine_short_segments.sh
│ │ ├── convert_data_dir_to_whole.sh
│ │ ├── extend_segment_times.py
│ │ ├── extract_wav_segments_data_dir.sh
│ │ ├── fix_subsegment_feats.pl
│ │ ├── get_allowed_durations.py
│ │ ├── get_frame_shift.sh
│ │ ├── get_num_frames.sh
│ │ ├── get_reco2dur.sh
│ │ ├── get_reco2utt_for_data.sh
│ │ ├── get_segments_for_data.sh
│ │ ├── get_uniform_subsegments.py
│ │ ├── get_utt2dur.sh
│ │ ├── get_utt2num_frames.sh
│ │ ├── internal/
│ │ │ ├── choose_utts_to_combine.py
│ │ │ ├── combine_segments_to_recording.py
│ │ │ ├── modify_speaker_info.py
│ │ │ └── perturb_volume.py
│ │ ├── limit_feature_dim.sh
│ │ ├── modify_speaker_info.sh
│ │ ├── modify_speaker_info_to_recording.sh
│ │ ├── normalize_data_range.pl
│ │ ├── perturb_data_dir_speed_3way.sh
│ │ ├── perturb_data_dir_volume.sh
│ │ ├── perturb_speed_to_allowed_lengths.py
│ │ ├── remove_dup_utts.sh
│ │ ├── resample_data_dir.sh
│ │ ├── shift_and_combine_feats.sh
│ │ ├── shift_feats.sh
│ │ └── subsegment_data_dir.sh
│ ├── dict_dir_add_pronprobs.sh
│ ├── eps2disambig.pl
│ ├── filt.py
│ ├── filter_scp.pl
│ ├── filter_scps.pl
│ ├── find_arpa_oovs.pl
│ ├── fix_data_dir.sh
│ ├── format_lm.sh
│ ├── format_lm_sri.sh
│ ├── gen_topo.pl
│ ├── int2sym.pl
│ ├── kwslist_post_process.pl
│ ├── lang/
│ │ ├── add_unigrams_arpa.pl
│ │ ├── adjust_unk_arpa.pl
│ │ ├── adjust_unk_graph.sh
│ │ ├── bpe/
│ │ │ ├── add_final_optional_silence.sh
│ │ │ ├── apply_bpe.py
│ │ │ ├── bidi.py
│ │ │ ├── learn_bpe.py
│ │ │ ├── prepend_words.py
│ │ │ └── reverse.py
│ │ ├── check_g_properties.pl
│ │ ├── check_phones_compatible.sh
│ │ ├── compute_sentence_probs_arpa.py
│ │ ├── extend_lang.sh
│ │ ├── get_word_position_phone_map.pl
│ │ ├── grammar/
│ │ │ ├── augment_phones_txt.py
│ │ │ └── augment_words_txt.py
│ │ ├── internal/
│ │ │ ├── apply_unk_lm.sh
│ │ │ ├── arpa2fst_constrained.py
│ │ │ └── modify_unk_pron.py
│ │ ├── limit_arpa_unk_history.py
│ │ ├── make_kn_lm.py
│ │ ├── make_lexicon_fst.py
│ │ ├── make_lexicon_fst_silprob.py
│ │ ├── make_phone_bigram_lang.sh
│ │ ├── make_phone_lm.py
│ │ ├── make_position_dependent_subword_lexicon.py
│ │ ├── make_subword_lexicon_fst.py
│ │ ├── make_unk_lm.sh
│ │ └── validate_disambig_sym_file.pl
│ ├── ln.pl
│ ├── make_absolute.sh
│ ├── make_lexicon_fst.pl
│ ├── make_lexicon_fst_silprob.pl
│ ├── make_unigram_grammar.pl
│ ├── map_arpa_lm.pl
│ ├── mkgraph.sh
│ ├── mkgraph_lookahead.sh
│ ├── nnet/
│ │ ├── gen_dct_mat.py
│ │ ├── gen_hamm_mat.py
│ │ ├── gen_splice.py
│ │ ├── make_blstm_proto.py
│ │ ├── make_cnn_proto.py
│ │ ├── make_lstm_proto.py
│ │ ├── make_nnet_proto.py
│ │ └── subset_data_tr_cv.sh
│ ├── nnet-cpu/
│ │ ├── make_nnet_config.pl
│ │ ├── make_nnet_config_block.pl
│ │ ├── make_nnet_config_preconditioned.pl
│ │ └── update_learning_rates.pl
│ ├── nnet3/
│ │ └── convert_config_tdnn_to_affine.py
│ ├── parallel/
│ │ ├── limit_num_gpus.sh
│ │ ├── pbs.pl
│ │ ├── queue.pl
│ │ ├── retry.pl
│ │ ├── run.pl
│ │ └── slurm.pl
│ ├── parse_options.sh
│ ├── perturb_data_dir_speed.sh
│ ├── pinyin_map.pl
│ ├── prepare_extended_lang.sh
│ ├── prepare_lang.sh
│ ├── prepare_online_nnet_dist_build.sh
│ ├── remove_data_links.sh
│ ├── remove_oovs.pl
│ ├── reverse_arpa.py
│ ├── rnnlm_compute_scores.sh
│ ├── s2eps.pl
│ ├── scoring/
│ │ ├── wer_ops_details.pl
│ │ ├── wer_per_spk_details.pl
│ │ ├── wer_per_utt_details.pl
│ │ └── wer_report.pl
│ ├── segmentation.pl
│ ├── show_lattice.sh
│ ├── shuffle_list.pl
│ ├── spk2utt_to_utt2spk.pl
│ ├── split_data.sh
│ ├── split_scp.pl
│ ├── ssh.pl
│ ├── subset_data_dir.sh
│ ├── subset_scp.pl
│ ├── subword/
│ │ ├── prepare_lang_subword.sh
│ │ └── prepare_subword_text.sh
│ ├── summarize_logs.pl
│ ├── summarize_warnings.pl
│ ├── sym2int.pl
│ ├── utt2spk_to_spk2utt.pl
│ ├── validate_data_dir.sh
│ ├── validate_dict_dir.pl
│ ├── validate_lang.pl
│ ├── validate_text.pl
│ └── write_kwslist.pl
├── env/
│ └── build_env.sh
├── kaldi
├── lm/
│ ├── __init__.py
│ ├── chainer_backend/
│ │ ├── __init__.py
│ │ ├── extlm.py
│ │ └── lm.py
│ ├── lm_utils.py
│ └── pytorch_backend/
│ ├── __init__.py
│ ├── extlm.py
│ └── lm.py
├── mt/
│ ├── __init__.py
│ ├── mt_utils.py
│ └── pytorch_backend/
│ ├── __init__.py
│ └── mt.py
├── nets/
│ ├── __init__.py
│ ├── asr_interface.py
│ ├── batch_beam_search.py
│ ├── batch_beam_search_online_sim.py
│ ├── beam_search.py
│ ├── beam_search_transducer.py
│ ├── chainer_backend/
│ │ ├── __init__.py
│ │ ├── asr_interface.py
│ │ ├── ctc.py
│ │ ├── deterministic_embed_id.py
│ │ ├── e2e_asr.py
│ │ ├── e2e_asr_transformer.py
│ │ ├── nets_utils.py
│ │ ├── rnn/
│ │ │ ├── __init__.py
│ │ │ ├── attentions.py
│ │ │ ├── decoders.py
│ │ │ ├── encoders.py
│ │ │ └── training.py
│ │ └── transformer/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── ctc.py
│ │ ├── decoder.py
│ │ ├── decoder_layer.py
│ │ ├── embedding.py
│ │ ├── encoder.py
│ │ ├── encoder_layer.py
│ │ ├── label_smoothing_loss.py
│ │ ├── layer_norm.py
│ │ ├── mask.py
│ │ ├── positionwise_feed_forward.py
│ │ ├── subsampling.py
│ │ └── training.py
│ ├── ctc_prefix_score.py
│ ├── e2e_asr_common.py
│ ├── e2e_mt_common.py
│ ├── lm_interface.py
│ ├── mt_interface.py
│ ├── pytorch_backend/
│ │ ├── __init__.py
│ │ ├── conformer/
│ │ │ ├── __init__.py
│ │ │ ├── argument.py
│ │ │ ├── convolution.py
│ │ │ ├── encoder.py
│ │ │ ├── encoder_layer.py
│ │ │ └── swish.py
│ │ ├── ctc.py
│ │ ├── e2e_asr.py
│ │ ├── e2e_asr_conformer.py
│ │ ├── e2e_asr_maskctc.py
│ │ ├── e2e_asr_mix.py
│ │ ├── e2e_asr_mix_transformer.py
│ │ ├── e2e_asr_mulenc.py
│ │ ├── e2e_asr_transducer.py
│ │ ├── e2e_asr_transducer_cs.py
│ │ ├── e2e_asr_transformer.py
│ │ ├── e2e_mt.py
│ │ ├── e2e_mt_transformer.py
│ │ ├── e2e_st.py
│ │ ├── e2e_st_conformer.py
│ │ ├── e2e_st_transformer.py
│ │ ├── e2e_tts_fastspeech.py
│ │ ├── e2e_tts_tacotron2.py
│ │ ├── e2e_tts_transformer.py
│ │ ├── e2e_vc_tacotron2.py
│ │ ├── e2e_vc_transformer.py
│ │ ├── fastspeech/
│ │ │ ├── __init__.py
│ │ │ ├── duration_calculator.py
│ │ │ ├── duration_predictor.py
│ │ │ └── length_regulator.py
│ │ ├── frontends/
│ │ │ ├── __init__.py
│ │ │ ├── beamformer.py
│ │ │ ├── dnn_beamformer.py
│ │ │ ├── dnn_wpe.py
│ │ │ ├── feature_transform.py
│ │ │ ├── frontend.py
│ │ │ └── mask_estimator.py
│ │ ├── gtn_ctc.py
│ │ ├── initialization.py
│ │ ├── lm/
│ │ │ ├── __init__.py
│ │ │ ├── default.py
│ │ │ ├── seq_rnn.py
│ │ │ └── transformer.py
│ │ ├── maskctc/
│ │ │ ├── __init__.py
│ │ │ ├── add_mask_token.py
│ │ │ └── mask.py
│ │ ├── nets_utils.py
│ │ ├── rnn/
│ │ │ ├── __init__.py
│ │ │ ├── argument.py
│ │ │ ├── attentions.py
│ │ │ ├── decoders.py
│ │ │ └── encoders.py
│ │ ├── streaming/
│ │ │ ├── __init__.py
│ │ │ ├── segment.py
│ │ │ └── window.py
│ │ ├── tacotron2/
│ │ │ ├── __init__.py
│ │ │ ├── cbhg.py
│ │ │ ├── decoder.py
│ │ │ └── encoder.py
│ │ ├── transducer/
│ │ │ ├── __init__.py
│ │ │ ├── arguments.py
│ │ │ ├── auxiliary_task.py
│ │ │ ├── blocks.py
│ │ │ ├── causal_conv1d.py
│ │ │ ├── custom_decoder.py
│ │ │ ├── custom_encoder.py
│ │ │ ├── error_calculator.py
│ │ │ ├── initializer.py
│ │ │ ├── joint_network.py
│ │ │ ├── loss.py
│ │ │ ├── rnn_decoder.py
│ │ │ ├── rnn_encoder.py
│ │ │ ├── tdnn.py
│ │ │ ├── transformer_decoder_layer.py
│ │ │ ├── utils.py
│ │ │ └── vgg2l.py
│ │ ├── transformer/
│ │ │ ├── __init__.py
│ │ │ ├── add_sos_eos.py
│ │ │ ├── argument.py
│ │ │ ├── attention.py
│ │ │ ├── contextual_block_encoder_layer.py
│ │ │ ├── decoder.py
│ │ │ ├── decoder_layer.py
│ │ │ ├── dynamic_conv.py
│ │ │ ├── dynamic_conv2d.py
│ │ │ ├── embedding.py
│ │ │ ├── encoder.py
│ │ │ ├── encoder_layer.py
│ │ │ ├── encoder_mix.py
│ │ │ ├── initializer.py
│ │ │ ├── label_smoothing_loss.py
│ │ │ ├── layer_norm.py
│ │ │ ├── lightconv.py
│ │ │ ├── lightconv2d.py
│ │ │ ├── mask.py
│ │ │ ├── multi_layer_conv.py
│ │ │ ├── optimizer.py
│ │ │ ├── plot.py
│ │ │ ├── positionwise_feed_forward.py
│ │ │ ├── repeat.py
│ │ │ ├── sgd_optimizer.py
│ │ │ ├── subsampling.py
│ │ │ └── subsampling_without_posenc.py
│ │ └── wavenet.py
│ ├── scorer_interface.py
│ ├── scorers/
│ │ ├── .mmi_rnnt_scorer.py.swp
│ │ ├── __init__.py
│ │ ├── _mmi_utils.py
│ │ ├── ctc.py
│ │ ├── ctc_rnnt_scorer.py
│ │ ├── length_bonus.py
│ │ ├── lookahead.py
│ │ ├── mmi.py
│ │ ├── mmi_alignment_score.py
│ │ ├── mmi_frame_prefix_scorer.py
│ │ ├── mmi_frame_scorer.py
│ │ ├── mmi_frame_scorer_trace.py
│ │ ├── mmi_lookahead.py
│ │ ├── mmi_lookahead_bak.py
│ │ ├── mmi_lookahead_split.py
│ │ ├── mmi_prefix_score.py
│ │ ├── mmi_rescorer.py
│ │ ├── mmi_rnnt_lookahead_scorer.py
│ │ ├── mmi_rnnt_scorer.py
│ │ ├── mmi_utils.py
│ │ ├── new_mmi_frame_scorer.py
│ │ ├── ngram.py
│ │ ├── sorted_matcher.py
│ │ ├── test.py
│ │ ├── tlg_scorer.py
│ │ ├── trace_frame.py
│ │ └── word_ngram.py
│ ├── st_interface.py
│ ├── transducer_decoder_interface.py
│ └── tts_interface.py
├── optimizer/
│ ├── __init__.py
│ ├── chainer.py
│ ├── factory.py
│ ├── parser.py
│ └── pytorch.py
├── scheduler/
│ ├── __init__.py
│ ├── chainer.py
│ ├── pytorch.py
│ └── scheduler.py
├── snowfall/
│ ├── __init__.py
│ ├── common.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── aishell.py
│ │ ├── asr_datamodule.py
│ │ ├── datamodule.py
│ │ └── librispeech.py
│ ├── decoding/
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ └── lm_rescore.py
│ ├── dist.py
│ ├── lexicon.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── conformer.py
│ │ ├── contextnet.py
│ │ ├── interface.py
│ │ ├── tdnn.py
│ │ ├── tdnn_lstm.py
│ │ ├── tdnnf.py
│ │ └── transformer.py
│ ├── objectives/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── ctc.py
│ │ └── mmi.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── ctc_graph.py
│ │ ├── diagnostics.py
│ │ ├── mmi_graph.py
│ │ └── mmi_mbr_graph.py
│ └── warpper/
│ ├── k2_decode.py
│ ├── mmi_test.py
│ ├── mmi_utils.py
│ ├── prefix_scorer.py
│ ├── warpper_ctc.py
│ └── warpper_mmi.py
├── st/
│ ├── __init__.py
│ └── pytorch_backend/
│ ├── __init__.py
│ └── st.py
├── transform/
│ ├── __init__.py
│ ├── add_deltas.py
│ ├── channel_selector.py
│ ├── cmvn.py
│ ├── functional.py
│ ├── perturb.py
│ ├── spec_augment.py
│ ├── spectrogram.py
│ ├── transform_interface.py
│ ├── transformation.py
│ └── wpe.py
├── tts/
│ ├── __init__.py
│ └── pytorch_backend/
│ ├── __init__.py
│ └── tts.py
├── utils/
│ ├── __init__.py
│ ├── bmuf.py
│ ├── check_kwargs.py
│ ├── cli_readers.py
│ ├── cli_utils.py
│ ├── cli_writers.py
│ ├── dataset.py
│ ├── deterministic_utils.py
│ ├── draw_num_fst.py
│ ├── dynamic_import.py
│ ├── fill_missing_args.py
│ ├── io_utils.py
│ ├── parse_decoding_process.py
│ ├── parse_npy.py
│ ├── print.py
│ ├── rtf_calculator.py
│ ├── sampler.py
│ ├── spec_augment.py
│ └── training/
│ ├── __init__.py
│ ├── batchfy.py
│ ├── evaluator.py
│ ├── iterators.py
│ ├── tensorboard_logger.py
│ └── train_utils.py
├── vc/
│ └── pytorch_backend/
│ └── vc.py
└── version.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
interface
================================================
FILE: README.md
================================================
# End-to-end speech secognition toolkit
This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9).
If this repositry can help you, we will be appreciate if you can star it and cite our papers.
This is the official implementation following papers:
[**Consistent Training and Decoding For End-to-end Speech Recognition Using Lattice-free MMI**](https://ieeexplore.ieee.org/document/9746579/) (Accepted by ICASSP 2022)
[**Improving Mandarin End-to-End Speech Recognition with Word N-gram Language Model**](https://ieeexplore.ieee.org/document/9721084) (Accepted by SPL)
[**Integrate Lattice-Free MMI into End-to-End Speech Recognition**](https://arxiv.org/abs/2203.15614) (Submitted to TASLP)
We achieve state-of-the-art results on two of the most popular results in Aishell-1 and AIshell-2 Mandarin datasets.
Please feel free to change / modify the code as you like. :)
### Update
- 2021/12/29: Release the first version, which contains all MMI-related features, including MMI training criteria, MMI Prefix Score (for attention-based encoder-decoder, AED) and MMI Alignment Score (For neural transducer, NT).
- 2022/1/6: Release the word-level N-gram LM scorer.
- 2022/1/12: We update the instructions to build the environment. We also release the trained NT model for Aishell-1 for quick performance check. We update the guildline to run our code.
- 2022/3/29 We release a new CTC / RNN-T recipe for code-switch problem based on ASRU 2019 Mandarin-English code-switch dataset (see egs/asrucs); Results on Aishell-1 and Aishell-2 are also updated.
### Environment:
The main dependencies of this code can be divided into three part: `kaldi`, `espnet` and `k2`
Please follow the instructions in [build_env.sh](https://github.com/jctian98/e2e_lfmmi/blob/master/env/build_env.sh) to build the environment.
Note the script cannot run automatically and you need to run it line-by-line.
### Results
Currently we have released examples on Aishell-1 and Aishell-2 datasets.
With MMI training & decoding methods and the word-level N-gram LM. We achieve results on Aishell-1 and Aishell-2 as below. All results are in CER%
The model file of Aishell-1 NT system is [here](https://drive.google.com/file/d/1VE2YtLb70UpQkeGWE8WhHJl7sSwNa_zG/view?usp=sharing) for quick performance check.
| Test set | Aishell-1-dev | Aishell-1-test | Aishell-2-ios | Aishell-2-android | Aishell-2-mic |
| :---- | :-: | :--: | :-: | :-----: | :-: |
| AED | 4.60| 5.07 | 5.72| 6.60 | 6.58|
| AED + MMI + Word Ngram | 4.08| 4.45 | 5.15| 5.92 | 5.77|
| NT | 4.41| 4.82 | 5.81| 6.52 | 6.52|
| NT + MMI + Word Ngram | 3.79| 4.10 | 5.02| 5.85 | 5.66|
### Get Start
Take Aishell-1 as an example. Working process for other examples are very similar.
step 1: clone the code and link kaldi
```
conda activate lfmmi
git clone https://github.com/jctian98/e2e_lfmmi E2E-ASR-Framework # clone and RENAME
cd E2E-ASR-Framework
ln -s <path-to-kaldi> kaldi # link kaldi
```
step 2: prepare data, lexicon and LMs. Before you run, please set the datadir in `prepare.sh`
```
cd egs/aishell1
bash prepare.sh
```
step 3: model training. You should split the data before start the training.
You can skip this step and download our trained model [here](https://drive.google.com/file/d/1VE2YtLb70UpQkeGWE8WhHJl7sSwNa_zG/view?usp=sharing)
```
python3 espnet_utils/splitjson.py -p <ngpu> dump/train_sp/deltafalse/data.json
bash nt.sh --stop_stage 1
```
step 4: decode
```
bash nt.sh --stage 2 --mmi-weight 0.2 --word-ngram-weight 0.4
```
Several Hint:
1. Please change the paths in `path.sh` accordingly before you start
2. Please change the `data` to config your data path in `prepare.sh`
3. Our code runs in DDP style and requires some global variables. Before you start, you need to set them manually. We assume Pytorch distributed API works well on your machine.
```
export HOST_GPU_NUM=x # number of GPUs on each host
export HOST_NUM=x # number of hosts
export NODE_NUM=x # number of GPUs in total (on all hosts)
export INDEX=x # index of this host
export CHIEF_IP=xx.xx.xx.xx # IP of the master host
```
4. You may encounter some problem about `k2`. Try to delete `data/lang_phone/Linv.pt` (in training) and `data/word_3gram/G.pt`(in decoding) and re-generate them again.
5. Multiple choices are available during decoding (we take `nt.sh` as an example, but the usage of `aed.sh` is the same).
To use the MMI-related scorers, you need train the model with MMI auxiliary criterion;
To use MMI Prefix Score (in AED) or MMI Alignment score (in NT):
```
bash nt.sh --stage 2 --mmi-weight 0.2
```
To use any external LM, you need to train them in advance (as implemented in `prepare.sh`)
To use word-level N-gram LM:
```
bash nt.sh --stage 2 --word-ngram-weight 0.4
```
To use character-level N-gram LM:
```
bash nt.sh --stage 2 --ngram-weight 1.0
```
To use neural network LM:
```
bash nt.sh --stage 2 --lm-weight 1.0
```
### Reference
kaldi: https://github.com/kaldi-asr/kaldi
Espent: https://github.com/espnet/espnet
k2-fsa: https://github.com/k2-fsa/k2
### Citations
```
@INPROCEEDINGS{9746579,
author={Tian, Jinchuan and Yu, Jianwei and Weng, Chao and Zhang, Shi-Xiong and Su, Dan and Yu, Dong and Zou, Yuexian},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Consistent Training and Decoding for End-to-End Speech Recognition Using Lattice-Free MMI},
year={2022},
volume={},
number={},
pages={7782-7786},
doi={10.1109/ICASSP43922.2022.9746579}}
@ARTICLE{9721084,
author={Tian, Jinchuan and Yu, Jianwei and Weng, Chao and Zou, Yuexian and Yu, Dong},
journal={IEEE Signal Processing Letters},
title={Improving Mandarin End-to-End Speech Recognition with Word N-gram Language Model},
year={2022},
volume={},
number={},
pages={1-1},
doi={10.1109/LSP.2022.3154241}}
@article{tian2022integrate,
title={Integrate Lattice-Free MMI into End-to-End Speech Recognition},
author={Tian, Jinchuan and Yu, Jianwei and Weng, Chao and Zou, Yuexian and Yu, Dong},
journal={arXiv preprint arXiv:2203.15614},
year={2022}
}
```
### Authorship
Jinchuan Tian; tianjinchuan@stu.pku.edu.cn or tyriontian@tencent.com
Jianwei Yu; tomasyu@tencent.com (supervisor)
Chao Weng; cweng@tencent.com
Yuexian Zou; zouyx@pku.edu.cn
================================================
FILE: __init__.py
================================================
"""Initialize espnet package."""
import os
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
================================================
FILE: asr/__init__.py
================================================
"""Initialize sub package."""
================================================
FILE: asr/asr_mix_utils.py
================================================
#!/usr/bin/env python3
"""
This script is used to provide utility functions designed for multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
Most functions can be directly used as in asr_utils.py:
CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load,
torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf.
"""
import copy
import logging
import os
from chainer.training import extension
import matplotlib
from espnet.asr.asr_utils import parse_hypothesis
matplotlib.use("Agg")
# * -------------------- chainer extension related -------------------- *
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
CustomConverter object. Function to convert data.
device (torch.device): The destination device to send tensor.
reverse (bool): If True, input and output length are reversed.
"""
def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False):
"""Initialize PlotAttentionReport."""
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.outdir = outdir
self.converter = converter
self.device = device
self.reverse = reverse
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save imaged matrix of att_ws."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % (
self.outdir,
self.data[idx][0],
ns + 1,
)
att_w = self.get_attention_weight(idx, att_w, ns)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of attention matrix to tensorboard."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
att_w = self.get_attention_weight(idx, att_w, ns)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step)
plot.clf()
def get_attention_weights(self):
"""Return attention weights.
Returns:
arr_ws_sd (numpy.ndarray): attention weights. It's shape would be
differ from bachend.dtype=float
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
other case => (B, Lmax, Tmax).
* chainer-> attention weights (B, Lmax, Tmax).
"""
batch = self.converter([self.converter.transform(self.data)], self.device)
att_ws_sd = self.att_vis_fn(*batch)
return att_ws_sd
def get_attention_weight(self, idx, att_w, spkr_idx):
"""Transform attention weight in regard to self.reverse."""
if self.reverse:
dec_len = int(self.data[idx][1]["input"][0]["shape"][0])
enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
else:
dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
enc_len = int(self.data[idx][1]["input"][0]["shape"][0])
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Visualize attention weights matrix.
Args:
att_w(Tensor): Attention weight matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib.pyplot as plt
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename):
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
def add_results_to_json(js, nbest_hyps_sd, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers (# Utts x # Spkrs).
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
num_spkrs = len(nbest_hyps_sd)
new_js["output"] = []
for ns in range(num_spkrs):
tmp_js = []
nbest_hyps = nbest_hyps_sd[ns]
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
out_dic = dict(js["output"][ns].items())
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
tmp_js.append(out_dic)
# show 1-best result
if n == 1:
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
new_js["output"].append(tmp_js)
return new_js
================================================
FILE: asr/asr_utils.py
================================================
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import copy
import json
import logging
import os
import shutil
import tempfile
import numpy as np
import torch
# * -------------------- training iterator related -------------------- *
class CompareValueTrigger(object):
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
from chainer import training
self._key = key
self._best_value = None
self._interval_trigger = training.util.get_trigger(trigger)
self._init_summary()
self._compare_fn = compare_fn
def __call__(self, trainer):
"""Get value related to the key and compare with current value."""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None:
# initialize best value
self._best_value = value
return False
elif self._compare_fn(self._best_value, value):
return True
else:
self._best_value = value
return False
def _init_summary(self):
import chainer
self._summary = chainer.reporter.DictSummary()
try:
from chainer.training import extension
except ImportError:
PlotAttentionReport = None
else:
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of att_ws matrix."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir,
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True
)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_han_plot(att_w)
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step,
)
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_attention_weights(self):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
else:
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def trim_attention_weight(self, uttid, att_w):
"""Transform attention matrix with regard to self.reverse."""
if self.reverse:
enc_key, enc_axis = self.okey, self.oaxis
dec_key, dec_axis = self.ikey, self.iaxis
else:
enc_key, enc_axis = self.ikey, self.iaxis
dec_key, dec_axis = self.okey, self.oaxis
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
att_w = att_w.astype(np.float32)
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def draw_han_plot(self, att_w):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
legends = []
plt.subplot(1, len(att_w), h)
for i in range(aw.shape[1]):
plt.plot(aw[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, aw.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
else:
legends = []
for i in range(att_w.shape[1]):
plt.plot(att_w[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, att_w.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
if han_mode:
plt = self.draw_han_plot(att_w)
else:
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
try:
from chainer.training import extension
except ImportError:
PlotCTCReport = None
else:
class PlotCTCReport(extension.Extension):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of ctc prob."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
def log_ctc_probs(self, logger, step):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_ctc_probs(self):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
probs = self.ctc_vis_fn(*batch)
else:
probs = self.ctc_vis_fn(**batch)
return probs, uttid_list
def trim_ctc_prob(self, uttid, prob):
"""Trim CTC posteriors accoding to input lengths."""
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
prob = prob[:enc_len]
return prob
def draw_ctc_plot(self, ctc_prob):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ctc_prob = ctc_prob.astype(np.float32)
plt.clf()
topk_ids = np.argsort(ctc_prob, axis=1)
n_frames, vocab = ctc_prob.shape
times_probs = np.arange(n_frames)
plt.figure(figsize=(20, 8))
# NOTE: index 0 is reserved for blank
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
)
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12)
plt.ylabel("Posteriors", fontsize=12)
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
plt.yticks(list(range(0, 2, 1)))
plt.tight_layout()
return plt
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()
def restore_snapshot(model, snapshot, load_fn=None):
"""Extension to restore snapshot.
Returns:
An extension function.
"""
import chainer
from chainer import training
if load_fn is None:
load_fn = chainer.serializers.load_npz
@training.make_extension(trigger=(1, "epoch"))
def restore_snapshot(trainer):
_restore_snapshot(model, snapshot, load_fn)
return restore_snapshot
def _restore_snapshot(model, snapshot, load_fn=None):
if load_fn is None:
import chainer
load_fn = chainer.serializers.load_npz
load_fn(snapshot, model)
logging.info("restored from " + str(snapshot))
def adadelta_eps_decay(eps_decay):
"""Extension to perform adadelta eps decay.
Args:
eps_decay (float): Decay rate of eps.
Returns:
An extension function.
"""
from chainer import training
@training.make_extension(trigger=(1, "epoch"))
def adadelta_eps_decay(trainer):
_adadelta_eps_decay(trainer, eps_decay)
return adadelta_eps_decay
def _adadelta_eps_decay(trainer, eps_decay):
optimizer = trainer.updater.get_optimizer("main")
# for chainer
if hasattr(optimizer, "eps"):
current_eps = optimizer.eps
setattr(optimizer, "eps", current_eps * eps_decay)
logging.info("adadelta eps decayed to " + str(optimizer.eps))
# pytorch
else:
for p in optimizer.param_groups:
p["eps"] *= eps_decay
logging.info("adadelta eps decayed to " + str(p["eps"]))
def adam_lr_decay(eps_decay):
"""Extension to perform adam lr decay.
Args:
eps_decay (float): Decay rate of lr.
Returns:
An extension function.
"""
from chainer import training
@training.make_extension(trigger=(1, "epoch"))
def adam_lr_decay(trainer):
_adam_lr_decay(trainer, eps_decay)
return adam_lr_decay
def _adam_lr_decay(trainer, eps_decay):
optimizer = trainer.updater.get_optimizer("main")
# for chainer
if hasattr(optimizer, "lr"):
current_lr = optimizer.lr
setattr(optimizer, "lr", current_lr * eps_decay)
logging.info("adam lr decayed to " + str(optimizer.lr))
# pytorch
else:
for p in optimizer.param_groups:
p["lr"] *= eps_decay
logging.info("adam lr decayed to " + str(p["lr"]))
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
"""Extension to take snapshot of the trainer for pytorch.
Returns:
An extension function.
"""
from chainer.training import extension
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
def torch_snapshot(trainer):
_torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
return torch_snapshot
def _torch_snapshot_object(trainer, target, filename, savefun):
from chainer.serializers import DictionarySerializer
# make snapshot_dict dictionary
s = DictionarySerializer()
s.save(trainer)
if hasattr(trainer.updater.model, "model"):
# (for TTS)
if hasattr(trainer.updater.model.model, "module"):
model_state_dict = trainer.updater.model.model.module.state_dict()
else:
model_state_dict = trainer.updater.model.model.state_dict()
else:
# (for ASR)
if hasattr(trainer.updater.model, "module"):
model_state_dict = trainer.updater.model.module.state_dict()
else:
model_state_dict = trainer.updater.model.state_dict()
snapshot_dict = {
"trainer": s.target,
"model": model_state_dict,
}
if hasattr(trainer.updater, "ddp_trainer"):
# For ASR
snapshot_dict["optimizer"] = trainer.updater.ddp_trainer.optimizer.state_dict()
else:
# Others like LM
snapshot_dict["optimizer"] = trainer.updater.get_optimizer("main").state_dict()
# save snapshot dictionary
fn = filename.format(trainer)
prefix = "tmp" + fn
tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
try:
savefun(snapshot_dict, tmppath)
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model (torch.nn.model): Model.
iteration (int): Number of iterations.
duration (int) {100, 1000}:
Number of durations to control the interval of the `sigma` change.
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor (float) {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval ** scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise
# * -------------------- general -------------------- *
def get_model_conf(model_path, conf_path=None):
"""Get model config information by reading a model config file (model.json).
Args:
model_path (str): Model path.
conf_path (str): Optional model config path.
Returns:
list[int, int, dict[str, Any]]: Config information loaded from json file.
"""
if conf_path is None:
model_conf = os.path.dirname(model_path) + "/model.json"
else:
model_conf = conf_path
with open(model_conf, "rb") as f:
logging.info("reading a config file from " + model_conf)
confs = json.load(f)
if isinstance(confs, dict):
# for lm
args = confs
return argparse.Namespace(**args)
else:
# for asr, tts, mt
idim, odim, args = confs
return idim, odim, argparse.Namespace(**args)
def chainer_load(path, model):
"""Load chainer model parameters.
Args:
path (str): Model path or snapshot file path to be loaded.
model (chainer.Chain): Chainer model.
"""
import chainer
if "snapshot" in os.path.basename(path):
chainer.serializers.load_npz(path, model, path="updater/model:main/")
else:
chainer.serializers.load_npz(path, model)
def torch_save(path, model):
"""Save torch model states.
Args:
path (str): Model path to be saved.
model (torch.nn.Module): Torch model.
"""
if hasattr(model, "module"):
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def snapshot_object(target, filename):
"""Returns a trainer extension to take snapshots of a given object.
Args:
target (model): Object to serialize.
filename (str): Name of the file into which the object is serialized.It can
be a format string, where the trainer object is passed to
the :meth: `str.format` method. For example,
``'snapshot_{.updater.iteration}'`` is converted to
``'snapshot_10000'`` at the 10,000th iteration.
Returns:
An extension function.
"""
from chainer.training import extension
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
def snapshot_object(trainer):
torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
return snapshot_object
def torch_load(path, model):
"""Load torch model states.
Args:
path (str): Model path or snapshot file path to be loaded.
model (torch.nn.Module): Torch model.
"""
if "snapshot" in os.path.basename(path):
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
"model"
]
else:
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
if hasattr(model, "module"):
model.module.load_state_dict(model_state_dict)
else:
model.load_state_dict(model_state_dict)
del model_state_dict
def torch_resume(snapshot_path, trainer, load_trainer_and_opt=True):
"""Resume from snapshot for pytorch.
Args:
snapshot_path (str): Snapshot file path.
trainer (chainer.training.Trainer): Chainer's trainer instance.
"""
from chainer.serializers import NpzDeserializer
if not load_trainer_and_opt:
print("Only model weights are resumed")
print("trainer and optimizer is ignored")
print("make sure this is the second-stage training")
# load snapshot
snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
# restore trainer states
if load_trainer_and_opt:
d = NpzDeserializer(snapshot_dict["trainer"])
d.load(trainer)
# restore model states
if hasattr(trainer.updater.model, "model"):
# (for TTS model)
if hasattr(trainer.updater.model.model, "module"):
trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
else:
trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
else:
# (for ASR model)
if hasattr(trainer.updater.model, "module"):
trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
else:
trainer.updater.model.load_state_dict(snapshot_dict["model"])
# restore optimizer states
if load_trainer_and_opt and hasattr(trainer.updater.ddp_trainer, "optimizer"):
trainer.updater.ddp_trainer.optimizer.load_state_dict(snapshot_dict["optimizer"])
# delete opened snapshot
del snapshot_dict
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
# for no reference case (e.g., speech translation)
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# RNNT MMI
if "mmi_tot_score" in hyp:
out_dic["mmi_tot_score"] = hyp["mmi_tot_score"]
# LASCTC MMI
if "scores" in hyp:
if "mmi_tot_score" in hyp["scores"]:
out_dic["mmi_tot_score"] = hyp["scores"]["mmi_tot_score"]
if "mmi" in hyp["scores"]:
out_dic["mmi"] = hyp["scores"]["mmi"]
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
return new_js
def plot_spectrogram(
plt,
spec,
mode="db",
fs=None,
frame_shift=None,
bottom=True,
left=True,
right=True,
top=False,
labelbottom=True,
labelleft=True,
labelright=True,
labeltop=False,
cmap="inferno",
):
"""Plot spectrogram using matplotlib.
Args:
plt (matplotlib.pyplot): pyplot object.
spec (numpy.ndarray): Input stft (Freq, Time)
mode (str): db or linear.
fs (int): Sample frequency. To convert y-axis to kHz unit.
frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
bottom (bool):Whether to draw the respective ticks.
left (bool):
right (bool):
top (bool):
labelbottom (bool):Whether to draw the respective tick labels.
labelleft (bool):
labelright (bool):
labeltop (bool):
cmap (str): Colormap defined in matplotlib.
"""
spec = np.abs(spec)
if mode == "db":
x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
elif mode == "linear":
x = spec
else:
raise ValueError(mode)
if fs is not None:
ytop = fs / 2000
ylabel = "kHz"
else:
ytop = x.shape[0]
ylabel = "bin"
if frame_shift is not None and fs is not None:
xtop = x.shape[1] * frame_shift / fs
xlabel = "s"
else:
xtop = x.shape[1]
xlabel = "frame"
extent = (0, xtop, 0, ytop)
plt.imshow(x[::-1], cmap=cmap, extent=extent)
if labelbottom:
plt.xlabel("time [{}]".format(xlabel))
if labelleft:
plt.ylabel("freq [{}]".format(ylabel))
plt.colorbar().set_label("{}".format(mode))
plt.tick_params(
bottom=bottom,
left=left,
right=right,
top=top,
labelbottom=labelbottom,
labelleft=labelleft,
labelright=labelright,
labeltop=labeltop,
)
plt.axis("auto")
# * ------------------ recognition related ------------------ *
def format_mulenc_args(args):
"""Format args for multi-encoder setup.
It deals with following situations: (when args.num_encs=2):
1. args.elayers = None -> args.elayers = [4, 4];
2. args.elayers = 4 -> args.elayers = [4, 4];
3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
"""
# default values when None is assigned.
default_dict = {
"etype": "blstmp",
"elayers": 4,
"eunits": 300,
"subsample": "1",
"dropout_rate": 0.0,
"atype": "dot",
"adim": 320,
"awin": 5,
"aheads": 4,
"aconv_chans": -1,
"aconv_filts": 100,
}
for k in default_dict.keys():
if isinstance(vars(args)[k], list):
if len(vars(args)[k]) != args.num_encs:
logging.warning(
"Length mismatch {}: Convert {} to {}.".format(
k, vars(args)[k], vars(args)[k][: args.num_encs]
)
)
vars(args)[k] = vars(args)[k][: args.num_encs]
else:
if not vars(args)[k]:
# assign default value if it is None
vars(args)[k] = default_dict[k]
logging.warning(
"{} is not specified, use default value {}.".format(
k, default_dict[k]
)
)
# duplicate
logging.warning(
"Type mismatch {}: Convert {} to {}.".format(
k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
)
)
vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
return args
================================================
FILE: asr/chainer_backend/__init__.py
================================================
"""Initialize sub package."""
================================================
FILE: asr/chainer_backend/asr.py
================================================
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the speech recognition task."""
import json
import logging
import os
import six
# chainer related
import chainer
from chainer import training
from chainer.datasets import TransformDataset
from chainer.training import extensions
# espnet related
from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import chainer_load
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import restore_snapshot
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.deterministic_utils import set_deterministic_chainer
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator
from espnet.utils.training.iterators import ToggleableShufflingSerialIterator
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop
# rnnlm
import espnet.lm.chainer_backend.extlm as extlm_chainer
import espnet.lm.chainer_backend.lm as lm_chainer
# numpy related
import matplotlib
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from tensorboardX import SummaryWriter
matplotlib.use("Agg")
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# check cuda and cudnn availability
if not chainer.cuda.available:
logging.warning("cuda is not available")
if not chainer.cuda.cudnn_enabled:
logging.warning("cudnn is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
# specify model architecture
logging.info("import model module: " + args.model_module)
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args, flag_return=False)
assert isinstance(model, ASRInterface)
total_subsampling_factor = model.get_total_subsampling_factor()
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# Set gpu
ngpu = args.ngpu
if ngpu == 1:
gpu_id = 0
# Make a specified GPU current
chainer.cuda.get_device_from_id(gpu_id).use()
model.to_gpu() # Copy the model to the GPU
logging.info("single gpu calculation.")
elif ngpu > 1:
gpu_id = 0
devices = {"main": gpu_id}
for gid in six.moves.xrange(1, ngpu):
devices["sub_%d" % gid] = gid
logging.info("multi gpu calculation (#gpus = %d)." % ngpu)
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
else:
gpu_id = -1
logging.info("cpu calculation")
# Setup an optimizer
if args.opt == "adadelta":
optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
elif args.opt == "adam":
optimizer = chainer.optimizers.Adam()
elif args.opt == "noam":
optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
else:
raise NotImplementedError("args.opt={}".format(args.opt))
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
# Setup a converter
converter = model.custom_converter(subsampling_factor=model.subsample[0])
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
# set up training iterator and updater
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
accum_grad = args.accum_grad
if ngpu <= 1:
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train, load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train, load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
]
# set up updater
updater = model.custom_updater(
train_iters[0],
optimizer,
converter=converter,
device=gpu_id,
accum_grad=accum_grad,
)
else:
if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
raise NotImplementedError(
"--batch-count 'bin' and 'frame' are not implemented "
"in chainer multi gpu"
)
# set up minibatches
train_subsets = []
for gid in six.moves.xrange(ngpu):
# make subset
train_json_subset = {
k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid
}
# make minibatch list (variable length)
train_subsets += [
make_batchset(
train_json_subset,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
)
]
# each subset must have same length for MultiprocessParallelUpdater
maxlen = max([len(train_subset) for train_subset in train_subsets])
for train_subset in train_subsets:
if maxlen != len(train_subset):
for i in six.moves.xrange(maxlen - len(train_subset)):
train_subset += [train_subset[i]]
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
for gid in six.moves.xrange(ngpu)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
for gid in six.moves.xrange(ngpu)
]
# set up updater
updater = model.custom_parallel_updater(
train_iters, optimizer, converter=converter, devices=devices
)
# Set up a trainer
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler(train_iters),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
if args.opt == "noam":
from espnet.nets.chainer_backend.transformer.training import VaswaniRule
trainer.extend(
VaswaniRule(
"alpha",
d=args.adim,
warmup_steps=args.transformer_warmup_steps,
scale=args.transformer_lr,
),
trigger=(1, "iteration"),
)
# Resume from a snapshot
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
# set up validation iterator
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
if args.n_iter_processes > 0:
valid_iter = chainer.iterators.MultiprocessIterator(
TransformDataset(valid, load_cv),
batch_size=1,
repeat=False,
shuffle=False,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
)
else:
valid_iter = chainer.iterators.SerialIterator(
TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False
)
# Evaluate the model with the test dataset for each epoch
trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
logging.info("Using custom PlotAttentionReport")
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=gpu_id,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Take a snapshot for each specified epoch
trainer.extend(
extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"),
trigger=(1, "epoch"),
)
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
# Save best models
trainer.extend(
extensions.snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode != "ctc":
trainer.extend(
extensions.snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.acc.best"),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.loss.best"),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps", lambda trainer: trainer.updater.get_optimizer("main").eps
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
writer = SummaryWriter(args.tensorboard_dir)
trainer.extend(
TensorboardLogger(writer, att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# specify model architecture
logging.info("reading model parameters from " + args.model)
# To be compatible with v.0.3.0 models
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
chainer_load(args.model, model)
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit
)
)
chainer_load(args.rnnlm, rnnlm)
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
)
chainer_load(args.word_rnnlm, word_rnnlm)
if rnnlm is not None:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# decode each utterance
new_js = {}
with chainer.no_backprop_mode():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
================================================
FILE: asr/pytorch_backend/__init__.py
================================================
"""Initialize sub package."""
================================================
FILE: asr/pytorch_backend/asr.py
================================================
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the speech recognition task."""
import copy
import json
import logging
import math
import os
import sys
from chainer import reporter as reporter_module
from chainer import training
from chainer.training import extensions
from chainer.training.updater import StandardUpdater
import numpy as np
import torch
import torch.distributed as dist
import time
from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import format_mulenc_args
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import plot_spectrogram
from espnet.asr.asr_utils import restore_snapshot
from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot
from espnet.asr.pytorch_backend.asr_init import freeze_modules
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.asr.pytorch_backend.asr_init import load_trained_modules
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.beam_search_transducer import BeamSearchTransducer
from espnet.nets.pytorch_backend.e2e_asr import pad_list
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
from espnet.transform.spectrogram import IStft
from espnet.transform.transformation import Transformation
from espnet.utils.cli_writers import file_writer_helper
from espnet.utils.dataset import ChainerDataLoader
from espnet.utils.dataset import TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop
from espnet.snowfall.warpper.k2_decode import k2_decode
import matplotlib
from espnet.utils.parse_decoding_process import plot_decoding_logs
from espnet.utils.bmuf import BlockAdamTrainer
matplotlib.use("Agg")
if sys.version_info[0] == 2:
from itertools import izip_longest as zip_longest
else:
from itertools import zip_longest as zip_longest
from espnet.nets.scorers.mmi_rnnt_scorer import MMIRNNTScorer
# from espnet.nets.scorers.mmi_alignment_score import MMIRNNTScorer
from espnet.utils.print import step_print
from espnet.utils.sampler import BufferSampler
from espnet.utils.rtf_calculator import RTF_calculator
from espnet.nets.lm_interface import dynamic_import_lm
def _recursive_to(xs, device):
if torch.is_tensor(xs):
return xs.to(device)
if isinstance(xs, tuple):
return tuple(_recursive_to(x, device) for x in xs)
return xs
def is_alphabet(char):
if (char >= '\u0041' and char <= '\u005a') or (char >= '\u0061' and char <= '\u007a'):
return True
else:
return False
class CustomEvaluator(BaseEvaluator):
"""Custom Evaluator for Pytorch.
Args:
model (torch.nn.Module): The model to evaluate.
iterator (chainer.dataset.Iterator) : The train iterator.
target (link | dict[str, link]) :Link object or a dictionary of
links to evaluate. If this is just a link object, the link is
registered by the name ``'main'``.
device (torch.device): The device used.
ngpu (int): The number of GPUs.
"""
def __init__(self, model, iterator, target, device, ngpu=None):
super(CustomEvaluator, self).__init__(iterator, target)
self.model = model
self.device = device
if ngpu is not None:
self.ngpu = ngpu
elif device.type == "cpu":
self.ngpu = 0
else:
self.ngpu = 1
# The core part of the update routine can be customized by overriding
def evaluate(self):
"""Main evaluate routine for CustomEvaluator."""
iterator = self._iterators["main"]
if self.eval_hook:
self.eval_hook(self)
if hasattr(iterator, "reset"):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)
summary = reporter_module.DictSummary()
self.model.eval()
with torch.no_grad():
for batch in it:
print("evaluation batch")
x = _recursive_to(batch, self.device)
observation = {}
with reporter_module.report_scope(observation):
# read scp files
# x: original json with loaded features
# will be converted to chainer variable later
if self.ngpu == 0:
self.model(*x)
else:
# apex does not support torch.nn.DataParallel
# data_parallel(self.model, x, range(self.ngpu))
self.model(*x)
summary.add(observation)
self.model.train()
return summary.compute_mean()
class CustomUpdater(StandardUpdater):
"""Custom Updater for Pytorch.
Args:
model (torch.nn.Module): The model to update.
grad_clip_threshold (float): The gradient clipping value to use.
train_iter (chainer.dataset.Iterator): The training iterator.
optimizer (torch.optim.optimizer): The training optimizer.
device (torch.device): The device to use.
ngpu (int): The number of gpus to use.
use_apex (bool): The flag to use Apex in backprop.
"""
def __init__(
self,
model,
grad_clip_threshold,
train_iter,
optimizer,
device,
ngpu,
grad_noise=False,
accum_grad=1,
use_apex=False,
ddp_trainer=None
):
super(CustomUpdater, self).__init__(train_iter, optimizer)
self.model = model
self.grad_clip_threshold = grad_clip_threshold
self.device = device
self.ngpu = ngpu
self.accum_grad = accum_grad
self.forward_count = 0
self.grad_noise = grad_noise
self.iteration = 0
self.use_apex = use_apex
self.ddp_trainer = ddp_trainer
self.optimizer = optimizer
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Main update routine of the CustomUpdater."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
epoch = train_iter.epoch
batch = train_iter.next()
x = _recursive_to(batch, self.device)
is_new_epoch = train_iter.epoch != epoch
if self.ngpu == 0:
loss = self.model(*x).mean() / self.accum_grad
else:
# apex does not support torch.nn.DataParallel
#loss = (
# data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
#)
loss = self.model(*x) / self.accum_grad
if self.use_apex:
from apex import amp
# NOTE: for a compatibility with noam optimizer
opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# step_print(f"| forward_count {self.forward_count} | finish backward")
# gradient noise injection
if self.grad_noise:
from espnet.asr.asr_utils import add_gradient_noise
add_gradient_noise(
self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55
)
# update parameters
self.forward_count += 1
if not is_new_epoch and self.forward_count != self.accum_grad:
return
self.forward_count = 0
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip_threshold
)
logging.info("on device {} grad norm={}".format(self.device, grad_norm))
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
self.ddp_trainer.optimizer.zero_grad()
else:
"""
Optimizer is never used for update.
The real updating process and the DDP communication is in
this `update_and_sync()`
"""
# self.optimizer.step()
self.ddp_trainer.update_and_sync()
if self.iteration % 1 == 0:
step_print(f"| iteration: {self.iteration} | gradient applied")
def update(self):
self.update_core()
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.iteration += 1
class CustomConverter(object):
"""Custom batch converter for Pytorch.
Args:
subsampling_factor (int): The subsampling factor.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=torch.float32):
"""Construct a CustomConverter object."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs, ys, texts, xs_orig = batch[0]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[:: self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list(
[torch.from_numpy(x.real).float() for x in xs], 0
).to(device, dtype=self.dtype)
xs_pad_imag = pad_list(
[torch.from_numpy(x.imag).float() for x in xs], 0
).to(device, dtype=self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
device, dtype=self.dtype
)
xs_pad_orig = pad_list([torch.from_numpy(x).float() for x in xs_orig], 0).to(
device, dtype=self.dtype
)
ilens = torch.from_numpy(ilens).to(device)
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad = pad_list(
[
torch.from_numpy(
np.array(y[0][:]) if isinstance(y, tuple) else y
).long()
for y in ys
],
self.ignore_id,
).to(device)
return xs_pad, ilens, ys_pad, texts, xs_pad_orig
class CustomConverterMulEnc(object):
"""Custom batch converter for Pytorch in multi-encoder case.
Args:
subsampling_factors (list): List of subsampling factors for each encoder.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32):
"""Initialize the converter."""
self.subsamping_factors = subsamping_factors
self.ignore_id = -1
self.dtype = dtype
self.num_encs = len(subsamping_factors)
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs_list = batch[0][: self.num_encs]
ys = batch[0][-1]
# perform subsampling
if np.sum(self.subsamping_factors) > self.num_encs:
xs_list = [
[x[:: self.subsampling_factors[i], :] for x in xs_list[i]]
for i in range(self.num_encs)
]
# get batch of lengths of input sequences
ilens_list = [
np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs)
]
# perform padding and convert to tensor
# currently only support real number
xs_list_pad = [
pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to(
device, dtype=self.dtype
)
for i in range(self.num_encs)
]
ilens_list = [
torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs)
]
# NOTE: this is for multi-task learning (e.g., speech translation)
ys_pad = pad_list(
[
torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long()
for y in ys
],
self.ignore_id,
).to(device)
return xs_list_pad, ilens_list, ys_pad
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
if args.num_encs > 1:
args = format_mulenc_args(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim_list = [
int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
]
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
for i in range(args.num_encs):
logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if "transducer" in args.model_module:
if (
getattr(args, "etype", False) == "custom"
or getattr(args, "dtype", False) == "custom"
):
mtl_mode = "custom_transducer"
else:
mtl_mode = "transducer"
logging.info("Pure transducer mode")
elif args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
model = load_trained_modules(idim_list[0], odim, args)
else:
model_class = dynamic_import(args.model_module)
model = model_class(
idim_list[0] if args.num_encs == 1 else idim_list, odim, args
)
assert isinstance(model, ASRInterface)
total_subsampling_factor = model.get_total_subsampling_factor()
print(model)
logging.info(
" Total parameter of the model = "
+ str(sum(p.numel() for p in model.parameters()))
)
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
)
torch_load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
global_rank = args.node_rank * args.node_size + args.local_rank
args.outdir = args.outdir.replace("RANK", str(global_rank))
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
indent=4,
ensure_ascii=False,
sort_keys=True,
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
reporter = model.reporter
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
args.batch_size *= args.ngpu
if args.num_encs > 1:
# TODO(ruizhili): implement data parallel for multi-encoder setup.
raise NotImplementedError(
"Data parallel is not supported for multi-encoder setup."
)
# set torch device
assert args.ngpu in [1, 0] # this is ddp version
device = torch.device(f"cuda:{args.local_rank}" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
if args.freeze_mods:
model, model_params = freeze_modules(model, args.freeze_mods)
else:
model_params = model.parameters()
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# We build the SGD optimizer but never use it.
# Other code needs this
# The real optimizer is in ddp_trainer
optimizer = torch.optim.SGD(model_params, lr=1.0)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
if args.opt == "noam":
model, optimizer.optimizer = amp.initialize(
model, optimizer.optimizer, opt_level=args.train_dtype
)
else:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.train_dtype
)
use_apex = True
from espnet.nets.pytorch_backend.ctc import CTC
amp.register_float_function(CTC, "loss_fn")
amp.init()
logging.warning("register ctc as float function")
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
if args.num_encs == 1:
converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
else:
converter = CustomConverterMulEnc(
[i[0] for i in model.subsample_list], dtype=dtype
)
# read json data
args.train_json = args.train_json.replace("RANK", str(global_rank + 1))
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
# if use block_load, the utterance must sorted from shortest to longest
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 or args.block_load
# make minibatch list (variable length)
# disable the adaptive batch_size to sync DDP training
# if use frame as the count, we do not set min_batch_size
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.batch_size if args.batch_size > 0 else 1, #args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
no_sort=args.block_load,
)
valid = make_batchset(
valid_json,
args.batch_size * 2,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.batch_size, #args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
if args.block_load:
assert args.n_iter_processes <= 1, "never use more than one worker"
sampler = BufferSampler(
length=len(train),
utts_per_ark=args.utts_per_ark,
batch_size=args.batch_size,
buf_size=args.block_buffer_size,
seed=args.seed,
)
prefetch_factor = sampler.get_prefetch_factor()
shuffle = None
else:
sampler=None
prefetch_factor = 20
shuffle = not use_sortagrad
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
block_load=args.block_load,
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_dataset = TransformDataset(train, lambda data: converter([load_tr(data)]))
valid_dataset = TransformDataset(valid, lambda data: converter([load_cv(data)]))
train_iter = ChainerDataLoader(
dataset=train_dataset,
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=shuffle,
collate_fn=lambda x: x[0],
prefetch_factor=prefetch_factor,
sampler=sampler
)
# prefetch_factor=5,
valid_iter = ChainerDataLoader(
dataset=valid_dataset,
batch_size=1,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=args.n_iter_processes,
)
# Set up a trainer
ddp_trainer = BlockAdamTrainer(args,
master_node=args.master_node,
rank=global_rank,
world_size=args.world_size,
model=model,
)
updater = CustomUpdater(
model,
args.grad_clip,
{"main": train_iter},
optimizer,
device,
args.ngpu,
args.grad_noise,
args.accum_grad,
use_apex=use_apex,
ddp_trainer=ddp_trainer
)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad and args.sortagrad != 0:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
# Resume from a snapshot
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer, args.load_trainer_and_opt)
# Evaluate the model with the test dataset for each epoch
if args.save_interval_iters > 0:
trainer.extend(
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
trigger=(args.save_interval_iters, "iteration"),
)
else:
trainer.extend(
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
)
# Save attention weight each epoch
is_attn_plot = (
"transformer" in args.model_module
or "conformer" in args.model_module
or mtl_mode in ["att", "mtl", "custom_transducer"]
)
if args.num_save_attention > 0 and is_attn_plot:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=device,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Save CTC prob at each epoch
if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0:
# NOTE: sort it by output lengths
data = sorted(
list(valid_json.items())[: args.num_save_ctc],
key=lambda x: int(x[1]["output"][0]["shape"][0]),
reverse=True,
)
if hasattr(model, "module"):
ctc_vis_fn = model.module.calculate_all_ctc_probs
plot_class = model.module.ctc_plot_class
else:
ctc_vis_fn = model.calculate_all_ctc_probs
plot_class = model.ctc_plot_class
ctc_reporter = plot_class(
ctc_vis_fn,
data,
args.outdir + "/ctc_prob",
converter=converter,
transform=load_cv,
device=device,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(ctc_reporter, trigger=(1, "epoch"))
else:
ctc_reporter = None
# Make a plot for training and validation values
if args.num_encs > 1:
report_keys_loss_ctc = [
"main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)]
report_keys_cer_ctc = [
"main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)]
if hasattr(model, "is_rnnt"):
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_trans",
"validation/main/loss_trans",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_lm",
"validation/main/loss_lm",
"main/loss_aux_trans",
"validation/main/loss_aux_trans",
"main/loss_aux_symm_kl",
"validation/main/loss_aux_symm_kl",
"main/loss_mbr",
"validation/main/loss_mbr",
"main/loss_mmi",
"validation/main/loss_mmi",
"main/loss_lang",
"validation/main/loss_lang",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
else:
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
"main/loss_third",
"validation/main/loss_third",
"main/loss_mbr",
"validation/main/loss_mbr",
]
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/cer_ctc", "validation/main/cer_ctc"]
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
"epoch",
file_name="cer.png",
)
)
# save the checkpoint only if this is the master GPU
if global_rank == 0:
# Save best models
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode not in ["ctc", "transducer", "custom_transducer"]:
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# save snapshot which contains model and optimizer states
if args.save_interval_iters > 0:
trainer.extend(
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
trigger=(args.save_interval_iters, "iteration"),
)
# save snapshot at every epoch - for model averaging
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# NOTE: In some cases, it may take more than one epoch for the model's loss
# to escape from a local minimum.
# Thus, restore_snapshot extension is not used here.
# see details in https://github.com/espnet/espnet/pull/2171
elif args.criterion == "loss_eps_decay_only":
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
if hasattr(model, "is_rnnt"):
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_trans",
"main/loss_ctc",
"main/loss_lm",
"main/loss_aux_trans",
"main/loss_aux_symm_kl",
"main/loss_mbr",
"main/loss_mmi",
"main/loss_att",
"main/loss_lang",
"validation/main/loss",
"validation/main/loss_trans",
"validation/main/loss_ctc",
"validation/main/loss_lm",
"validation/main/loss_aux_trans",
"validation/main/loss_aux_symm_kl",
"validation/main/loss_mbr",
"validation/main/loss_mmi",
"validation/main/loss_att",
"validation/main/loss_lang",
"elapsed_time",
]
else:
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"main/loss_third",
"main/loss_mbr",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"validation/main/loss_third",
"validation/main/loss_mbr",
"main/acc",
"validation/main/acc",
"main/cer_ctc",
"validation/main/cer_ctc",
"elapsed_time",
] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc)
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
"eps"
],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
if args.report_cer:
report_keys.append("validation/main/cer")
if args.report_wer:
report_keys.append("validation/main/wer")
logwriter = open(args.outdir + f"/train.{global_rank}.log", 'w')
trainer.extend(
extensions.PrintReport(report_keys, out=logwriter),
trigger=(args.report_interval_iters, "iteration"),
)
# trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
if args.ngpu == 1:
gpu_id = args.local_rank - 1
logging.warning("gpu id: " + str(gpu_id))
device=torch.device("cuda:{}".format(gpu_id))
else:
device=torch.device("cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable GPU
model, train_args = load_trained_model(args.model, training=False)
assert isinstance(model, ASRInterface)
model.recog_args = args
if args.streaming_mode and "transformer" in train_args.model_module:
raise NotImplementedError("streaming mode for transformer is not implemented")
logging.info(
" Total parameter of the model = "
+ str(sum(p.numel() for p in model.parameters()))
)
# read rnnlm
if args.rnnlm and args.lm_weight > 0.0:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") == "default":
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
elif getattr(rnnlm_args, "model_module", "default") == "transformer":
lm_class = dynamic_import_lm("transformer", rnnlm_args.backend)
rnnlm = lm_class(len(train_args.char_list), rnnlm_args)
else:
raise ValueError("Unsupported LM type")
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(word_dict),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
model = model.to(device)
if rnnlm:
rnnlm = rnnlm.to(device)
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=True,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
# load transducer beam search
if hasattr(model, "is_rnnt"):
if hasattr(model, "dec"):
trans_decoder = model.dec
else:
trans_decoder = model.decoder
joint_network = model.joint_network
# We only use the MMIRNNTScorer now
if train_args.aux_mmi and train_args.aux_mmi_type == "mmi":
adim = train_args.enc_block_arch[0]['d_hidden']
weight_path = os.path.dirname(args.result_label) + "/dump"
os.makedirs(weight_path, exist_ok=True)
model.aux_mmi.dump_weight(args.local_rank, weight_path)
mmi_scorer_module = MMIRNNTScorer
mmi_scorer = mmi_scorer_module(lang=model.aux_mmi.lang,
device=device,
idim=adim,
sos_id=model.sos,
rank=args.local_rank,
use_segment=args.use_segment,
char_list=train_args.char_list,
weight_path=weight_path,
lookahead=args.mas_lookahead,
)
else:
mmi_scorer = None
if args.ngram_model and args.ngram_weight > 0.0:
print(f"Using ngram model: {args.ngram_model}", flush=True)
from espnet.nets.scorers.ngram import NgramPartScorer
ngram_scorer = NgramPartScorer(args.ngram_model, train_args.char_list)
else:
ngram_scorer = None
if args.word_ngram is not None and args.word_ngram_weight > 0.0:
from espnet.nets.scorers.word_ngram import WordNgramPartialScorer
word_ngram_scorer = WordNgramPartialScorer
word_ngram_scorer = word_ngram_scorer(
args.word_ngram, device, train_args.char_list,
log_semiring=args.word_ngram_log_semiring,
lower_char=args.word_ngram_lower_char)
else:
word_ngram_scorer = None
if args.tlg_scorer is not None and args.tlg_weight > 0.0:
print(f"Using tlg scorer: {args.tlg_scorer}", flush=True)
from espnet.nets.scorers.tlg_scorer import TlgPartialScorer
tlg_scorer = TlgPartialScorer(lang=args.tlg_scorer,
nonblk_reward=args.tlg_nonblk_reward)
else:
tlg_scorer = None
# for code-switch data
if args.cs_nt_decode_feature in ["chn", "eng"]:
ctc_module = getattr(model, "aux_ctc", None)
else:
ctc_module = getattr(model, "decoder_ctc", None)
if args.eng_vocab is not None and os.path.isfile(args.eng_vocab):
eng_vocab = [s.strip() for s in open(args.eng_vocab, encoding="utf-8").readlines()]
else:
eng_vocab = None
beam_search_transducer = BeamSearchTransducer(
decoder=trans_decoder,
joint_network=joint_network,
beam_size=args.beam_size,
nbest=args.nbest,
lm=rnnlm,
lm_weight=args.lm_weight,
search_type=args.search_type,
char_list=train_args.char_list,
max_sym_exp=args.max_sym_exp,
u_max=args.u_max,
nstep=args.nstep,
prefix_alpha=args.prefix_alpha,
score_norm=args.score_norm,
mmi_scorer=mmi_scorer,
mmi_weight=args.mmi_weight,
ngram_scorer=ngram_scorer,
ngram_weight=args.ngram_weight,
word_ngram_scorer=word_ngram_scorer,
word_ngram_weight=args.word_ngram_weight,
tlg_scorer=tlg_scorer,
tlg_weight=args.tlg_weight,
forbid_eng=args.forbid_eng,
ctc_module=ctc_module,
ctc_weight=args.ctc_weight,
eng_vocab=eng_vocab
)
if args.k2_decode:
k2_decode(model, device, js, load_inputs_and_targets, args.batchsize, args.use_segment)
print("Finish FST decoding. Abort!")
return
nbest_dict = {}
rtf_calculator = RTF_calculator(js)
rtf_calculator.tik()
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feats = load_inputs_and_targets(batch)
feat = (
feats[0][0]
if args.num_encs == 1
else [feats[idx][0] for idx in range(model.num_encs)]
)
# For Oteam ASR Only: skip all transcriptions that have english chars
text_trans = js[name]["output"][0]["text"]
if any([is_alphabet(x) for x in text_trans]) and args.skip_eng:
continue
if args.streaming_mode == "window" and args.num_encs == 1:
logging.info(
"Using streaming recognizer with window size %d frames",
args.streaming_window,
)
se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
for i in range(0, feat.shape[0], args.streaming_window):
logging.info(
"Feeding frames %d - %d", i, i + args.streaming_window
)
se2e.accept_input(feat[i : i + args.streaming_window])
logging.info("Running offline attention decoder")
se2e.decode_with_attention_offline()
logging.info("Offline attention decoder finished")
nbest_hyps = se2e.retrieve_recognition()
elif args.streaming_mode == "segment" and args.num_encs == 1:
logging.info(
"Using streaming recognizer with threshold value %d",
args.streaming_min_blank_dur,
)
nbest_hyps = []
for n in range(args.nbest):
nbest_hyps.append({"yseq": [], "score": 0.0})
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
r = np.prod(model.subsample)
for i in range(0, feat.shape[0], r):
hyps = se2e.accept_input(feat[i : i + r])
if hyps is not None:
text = "".join(
[
train_args.char_list[int(x)]
for x in hyps[0]["yseq"][1:-1]
if int(x) != -1
]
)
text = text.replace(
"\u2581", " "
).strip() # for SentencePiece
text = text.replace(model.space, " ")
text = text.replace(model.blank, "")
logging.info(text)
for n in range(args.nbest):
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
nbest_hyps[n]["score"] += hyps[n]["score"]
elif hasattr(model, "is_rnnt"):
nbest_hyps = model.recognize(feat, beam_search_transducer,
decode_feature=args.cs_nt_decode_feature)
else:
nbest_hyps = model.recognize(
feat, args, train_args.char_list, rnnlm
)
# visualization
# decode_dir = os.path.dirname(args.result_label)
# graph_dir = os.path.join(decode_dir, "graph")
# os.makedirs(graph_dir, exist_ok=True)
# plot_decoding_logs(graph_dir, train_args.char_list,
# args, name, nbest_hyps)
nbest_dict[name] = nbest_hyps
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
# sort data if batchsize > 1
keys = list(js.keys())
if args.batchsize > 1:
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
batch = [(name, js[name]) for name in names]
feats = (
load_inputs_and_targets(batch)[0]
if args.num_encs == 1
else load_inputs_and_targets(batch)
)
if args.streaming_mode == "window" and args.num_encs == 1:
raise NotImplementedError
elif args.streaming_mode == "segment" and args.num_encs == 1:
if args.batchsize > 1:
raise NotImplementedError
feat = feats[0]
nbest_hyps = []
for n in range(args.nbest):
nbest_hyps.append({"yseq": [], "score": 0.0})
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
r = np.prod(model.subsample)
for i in range(0, feat.shape[0], r):
hyps = se2e.accept_input(feat[i : i + r])
if hyps is not None:
text = "".join(
[
train_args.char_list[int(x)]
for x in hyps[0]["yseq"][1:-1]
if int(x) != -1
]
)
text = text.replace(
"\u2581", " "
).strip() # for SentencePiece
text = text.replace(model.space, " ")
text = text.replace(model.blank, "")
logging.info(text)
for n in range(args.nbest):
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
nbest_hyps[n]["score"] += hyps[n]["score"]
nbest_hyps = [nbest_hyps]
else:
nbest_hyps = model.recognize_batch(
feats, args, train_args.char_list, rnnlm=rnnlm
)
for i, nbest_hyp in enumerate(nbest_hyps):
name = names[i]
new_js[name] = add_results_to_json(
js[name], nbest_hyp, train_args.char_list
)
rtf_calculator.tok()
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
def enhance(args):
"""Dumping enhanced speech and mask.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# TODO(ruizhili): implement enhance for multi-encoder model
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(
args.num_encs
)
# load trained model parameters
logging.info("reading model parameters from " + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=None, # Apply pre_process in outer func
)
if args.batchsize == 0:
args.batchsize = 1
# Creates writers for outputs from the network
if args.enh_wspecifier is not None:
enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype)
else:
enh_writer = None
# Creates a Transformation instance
preprocess_conf = (
train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf
)
if preprocess_conf is not None:
logging.info(f"Use preprocessing: {preprocess_conf}")
transform = Transformation(preprocess_conf)
else:
transform = None
# Creates a IStft instance
istft = None
frame_shift = args.istft_n_shift # Used for plot the spectrogram
if args.apply_istft:
if preprocess_conf is not None:
# Read the conffile and find stft setting
with open(preprocess_conf) as f:
# Json format: e.g.
# {"process": [{"type": "stft",
# "win_length": 400,
# "n_fft": 512, "n_shift": 160,
# "window": "han"},
# {"type": "foo", ...}, ...]}
conf = json.load(f)
assert "process" in conf, conf
# Find stft setting
for p in conf["process"]:
if p["type"] == "stft":
istft = IStft(
win_length=p["win_length"],
n_shift=p["n_shift"],
window=p.get("window", "hann"),
)
logging.info(
"stft is found in {}. "
"Setting istft config from it\n{}".format(
preprocess_conf, istft
)
)
frame_shift = p["n_shift"]
break
if istft is None:
# Set from command line arguments
istft = IStft(
win_length=args.istft_win_length,
n_shift=args.istft_n_shift,
window=args.istft_window,
)
logging.info(
"Setting istft config from the command line args\n{}".format(istft)
)
# sort data
keys = list(js.keys())
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
num_images = 0
if not os.path.exists(args.image_dir):
os.makedirs(args.image_dir)
for names in grouper(args.batchsize, keys, None):
batch = [(name, js[name]) for name in names]
# May be in time region: (Batch, [Time, Channel])
org_feats = load_inputs_and_targets(batch)[0]
if transform is not None:
# May be in time-freq region: : (Batch, [Time, Channel, Freq])
feats = transform(org_feats, train=False)
else:
feats = org_feats
with torch.no_grad():
enhanced, mask, ilens = model.enhance(feats)
for idx, name in enumerate(names):
# Assuming mask, feats : [Batch, Time, Channel. Freq]
# enhanced : [Batch, Time, Freq]
enh = enhanced[idx][: ilens[idx]]
mas = mask[idx][: ilens[idx]]
feat = feats[idx]
# Plot spectrogram
if args.image_dir is not None and num_images < args.num_images:
import matplotlib.pyplot as plt
num_images += 1
ref_ch = 0
plt.figure(figsize=(20, 10))
plt.subplot(4, 1, 1)
plt.title("Mask [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
mas[:, ref_ch].T,
fs=args.fs,
mode="linear",
frame_shift=frame_shift,
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 2)
plt.title("Noisy speech [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
feat[:, ref_ch].T,
fs=args.fs,
mode="db",
frame_shift=frame_shift,
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 3)
plt.title("Masked speech [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
(feat[:, ref_ch] * mas[:, ref_ch]).T,
frame_shift=frame_shift,
fs=args.fs,
mode="db",
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 4)
plt.title("Enhanced speech")
plot_spectrogram(
plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift
)
plt.savefig(os.path.join(args.image_dir, name + ".png"))
plt.clf()
# Write enhanced wave files
if enh_writer is not None:
if istft is not None:
enh = istft(enh)
else:
enh = enh
if args.keep_length:
if len(org_feats[idx]) < len(enh):
# Truncate the frames added by stft padding
enh = enh[: len(org_feats[idx])]
elif len(org_feats) > len(enh):
padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [
(0, 0)
] * (enh.ndim - 1)
enh = np.pad(enh, padwidth, mode="constant")
if args.enh_filetype in ("sound", "sound.hdf5"):
enh_writer[name] = (args.fs, enh)
else:
# Hint: To dump stft_signal, mask or etc,
# enh_filetype='hdf5' might be convenient.
enh_writer[name] = enh
if num_images >= args.num_images and enh_writer is None:
logging.info("Breaking the process.")
break
def ctc_align(args):
"""CTC forced alignments with the given args.
Args:
args (namespace): The program arguments.
"""
def add_alignment_to_json(js, alignment, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
alignment (list[int]): List of alignment.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["ctc_alignment"] = []
alignment_tokens = []
for idx, a in enumerate(alignment):
alignment_tokens.append(char_list[a])
alignment_tokens = " ".join(alignment_tokens)
new_js["ctc_alignment"] = alignment_tokens
return new_js
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=True,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = "cuda"
else:
device = "cpu"
dtype = getattr(torch, args.dtype)
logging.info(f"Decoding device={device}, dtype={dtype}")
model.to(device=device, dtype=dtype).eval()
# read json data
with open(args.align_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) aligning " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat, label = load_inputs_and_targets(batch)
feat = feat[0]
label = label[0]
enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
alignment = model.ctc.forced_align(enc, label)
new_js[name] = add_alignment_to_json(
js[name], alignment, train_args.char_list
)
else:
raise NotImplementedError("Align_batch is not implemented.")
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
================================================
FILE: asr/pytorch_backend/asr_init.py
================================================
"""Finetuning methods."""
import logging
import os
import torch
from collections import OrderedDict
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.dynamic_import import dynamic_import
def freeze_modules(model, modules):
"""Freeze model parameters according to modules list.
Args:
model (torch.nn.Module): main model to update
modules (list): specified module list for freezing
Return:
model (torch.nn.Module): updated model
model_params (filter): filtered model parameters
"""
for mod, param in model.named_parameters():
if any(mod.startswith(m) for m in modules):
logging.info(f"freezing {mod}, it will not be updated.")
param.requires_grad = False
model_params = filter(lambda x: x.requires_grad, model.parameters())
return model, model_params
def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (OrderedDict): the initial model state_dict
partial_state_dict (OrderedDict): the trained model state_dict
modules (list): specified module list for transfer
Return:
(boolean): allow transfer
"""
modules_model = []
partial_modules = []
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
modules_model += [(key_m, value_m.shape)]
len_match = len(modules_model) == len(partial_modules)
module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted(
partial_modules, key=lambda x: (x[0], x[1])
)
return len_match and module_match
def get_partial_state_dict(model_state_dict, modules):
"""Create state_dict with specified modules matching input model modules.
Note that get_partial_lm_state_dict is used if a LM specified.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (OrderedDict): the updated state_dict
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
def get_lm_state_dict(lm_state_dict):
"""Create compatible ASR decoder state dict from LM state dict.
Args:
lm_state_dict (OrderedDict): pre-trained LM state_dict
Return:
new_state_dict (OrderedDict): LM state_dict with updated keys
"""
new_state_dict = OrderedDict()
for key, value in list(lm_state_dict.items()):
if key == "predictor.embed.weight":
new_state_dict["dec.embed.weight"] = value
elif key.startswith("predictor.rnn."):
_split = key.split(".")
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
new_state_dict[new_key] = value
return new_state_dict
def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in module_state_dict.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_mods (list): the update module list
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.warning(
"module(s) %s don't match or (partially match) "
"available modules in model.",
incorrect_mods,
)
logging.warning("for information, the existing modules in model are:")
logging.warning("%s", mods_model)
return new_mods
def load_trained_model(model_path, training=True):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), "model.json")
)
logging.warning("reading model parameters from " + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
# CTC Loss is not needed, default to builtin to prevent import errors
# if hasattr(train_args, "ctc_type"):
# train_args.ctc_type = "builtin"
model_class = dynamic_import(model_module)
if "transducer" in model_module:
model = model_class(idim, odim, train_args, training=training)
custom_torch_load(model_path, model, training=training)
else:
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
# when start decoding jobs with very large nj, this function leads
# to reading error. Do this for many times
def _load_trained_model(model_path, training=True, patience=10):
for i in range(patience):
try:
model, train_args = _load_trained_model(model_path, training=training)
print(f"Model Init: Successful initialize model in {i}-th trail", flush=True)
return model, train_args
except:
print(f"Model Init: Fail in {i}-th trail. Try again!", flush=True)
def get_trained_model_state_dict(model_path):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to model.***.best
Return:
model.state_dict() (OrderedDict): the loaded model state_dict
(bool): Boolean defining whether the model is an LM
"""
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
if "rnnlm" in model_path:
logging.warning("reading model parameters from %s", model_path)
return get_lm_state_dict(torch.load(model_path))
idim, odim, args = get_model_conf(model_path, conf_path)
logging.warning("reading model parameters from " + model_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert (
isinstance(model, MTInterface)
or isinstance(model, ASRInterface)
or isinstance(model, TTSInterface)
)
return model.state_dict()
def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
Args:
idim (int): initial input dimension.
odim (int): initial output dimension.
args (Namespace): The initial model arguments.
interface (Interface): ASRInterface or STInterface or TTSInterface.
Return:
model (torch.nn.Module): The model with pretrained modules.
"""
def print_new_keys(state_dict, modules, model_path):
logging.warning("loading %s from model: %s", modules, model_path)
for k in state_dict.keys():
logging.warning("override %s" % k)
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning("model(s) found for pre-initialization")
for model_path, modules in [
(enc_model_path, enc_modules),
(dec_model_path, dec_modules),
]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict = get_trained_model_state_dict(model_path)
modules = filter_modules(model_state_dict, modules)
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(
main_state_dict, partial_state_dict, modules
):
print_new_keys(partial_state_dict, modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.warning(
f"modules {modules} in model {model_path} "
f"don't match your training config",
)
else:
logging.warning("model was not found : %s", model_path)
main_model.load_state_dict(main_state_dict)
return main_model
================================================
FILE: asr/pytorch_backend/asr_mix.py
================================================
#!/usr/bin/env python3
"""
This script is used for multi-speaker speech recognition.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import json
import logging
import os
# chainer related
from chainer import training
from chainer.training import extensions
from itertools import zip_longest as zip_longest
import numpy as np
from tensorboardX import SummaryWriter
import torch
from espnet.asr.asr_mix_utils import add_results_to_json
from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import restore_snapshot
from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot
from espnet.asr.pytorch_backend.asr import CustomEvaluator
from espnet.asr.pytorch_backend.asr import CustomUpdater
from espnet.asr.pytorch_backend.asr import load_trained_model
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.utils.dataset import ChainerDataLoader
from espnet.utils.dataset import TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop
import matplotlib
matplotlib.use("Agg")
class CustomConverter(object):
"""Custom batch converter for Pytorch.
Args:
subsampling_factor (int): The subsampling factor.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkrs=2):
"""Initialize the converter."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
self.num_spkrs = num_spkrs
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch.
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0][0], batch[0][-self.num_spkrs :]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[:: self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list(
[torch.from_numpy(x.real).float() for x in xs], 0
).to(device, dtype=self.dtype)
xs_pad_imag = pad_list(
[torch.from_numpy(x.imag).float() for x in xs], 0
).to(device, dtype=self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it to E2E here
# because torch.nn.DataParallel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
device, dtype=self.dtype
)
ilens = torch.from_numpy(ilens).to(device)
if not isinstance(ys[0], np.ndarray):
ys_pad = []
for i in range(len(ys)): # speakers
ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
ys_pad = pad_list(ys_pad, self.ignore_id)
ys_pad = (
ys_pad.view(self.num_spkrs, -1, ys_pad.size(1))
.transpose(0, 1)
.to(device)
) # (B, num_spkrs, Tmax)
else:
ys_pad = pad_list(
[torch.from_numpy(y).long() for y in ys], self.ignore_id
).to(device)
return xs_pad, ilens, ys_pad
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
# specify model architecture
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args)
assert isinstance(model, ASRInterface)
subsampling_factor = model.subsample[0]
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
reporter = model.reporter
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
args.batch_size *= args.ngpu
# set torch device
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Setup an optimizer
if args.opt == "adadelta":
optimizer = torch.optim.Adadelta(
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
)
elif args.opt == "adam":
optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
elif args.opt == "noam":
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
optimizer = get_std_opt(
model.parameters(),
args.adim,
args.transformer_warmup_steps,
args.transformer_lr,
)
else:
raise NotImplementedError("unknown optimizer: " + args.opt)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
if args.opt == "noam":
model, optimizer.optimizer = amp.initialize(
model, optimizer.optimizer, opt_level=args.train_dtype
)
else:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.train_dtype
)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=dtype, num_spkrs=args.num_spkrs
)
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=-1,
)
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=-1,
)
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_iter = {
"main": ChainerDataLoader(
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=True,
collate_fn=lambda x: x[0],
)
}
valid_iter = {
"main": ChainerDataLoader(
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
batch_size=1,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=args.n_iter_processes,
)
}
# Set up a trainer
updater = CustomUpdater(
model,
args.grad_clip,
train_iter,
optimizer,
device,
args.ngpu,
args.grad_noise,
args.accum_grad,
use_apex=use_apex,
)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
# Resume from a snapshot
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
# Evaluate the model with the test dataset for each epoch
trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=device,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png"
)
)
# Save best models
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode != "ctc":
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"main/cer_ctc",
"validation/main/cer_ctc",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
"eps"
],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
if args.report_cer:
report_keys.append("validation/main/cer")
if args.report_wer:
report_keys.append("validation/main/wer")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
trainer.extend(
TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError(
"use '--api v2' option to decode with non-default language model"
)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
)
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
# sort data if batchsize > 1
keys = list(js.keys())
if args.batchsize > 1:
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
batch = [(name, js[name]) for name in names]
feats = load_inputs_and_targets(batch)[0]
nbest_hyps = model.recognize_batch(
feats, args, train_args.char_list, rnnlm=rnnlm
)
for i, name in enumerate(names):
nbest_hyp = [hyp[i] for hyp in nbest_hyps]
new_js[name] = add_results_to_json(
js[name], nbest_hyp, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
================================================
FILE: asr/pytorch_backend/recog.py
================================================
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
import json
import logging
import os
import torch
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.asr.pytorch_backend.asr import load_trained_model
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import BeamSearch
from espnet.nets.lm_interface import dynamic_import_lm
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.length_bonus import LengthBonus
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.nets.scorers.mmi_frame_scorer import MMIFrameScorer
# from espnet.nets.scorers.mmi_prefix_score import MMIFrameScorer
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.nets.scorers.word_ngram import WordNgramPartialScorer
from espnet.nets.scorers.mmi_rescorer import MMIRescorer
from espnet.utils.rtf_calculator import RTF_calculator
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
Notes:
The previous backend espnet.asr.pytorch_backend.asr.recog
only supports E2E and RNNLM
Args:
args (namespace): The program arguments.
See py:func:`espnet.bin.asr_recog.get_parser` for details
"""
logging.warning("experimental API for custom LMs is selected by --api v2")
if args.batchsize > 1:
raise NotImplementedError("multi-utt batch decoding is not implemented")
if args.streaming_mode is not None:
raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented")
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = torch.device("cuda")
else:
# So the cuda is not available now
device = torch.device("cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
assert torch.cuda.is_available() == False
print(f"Rank: {args.local_rank} Using device: {device}, ngpu: {args.ngpu}")
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(train_args.char_list), lm_args)
torch_load(args.rnnlm, lm)
lm.eval()
else:
lm = None
if args.ngram_model and args.ngram_weight > 0.0:
from espnet.nets.scorers.ngram import NgramFullScorer
from espnet.nets.scorers.ngram import NgramPartScorer
if args.ngram_scorer == "full":
ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
else:
ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
else:
ngram = None
# load mmi_scorer
if args.mmi_weight > 0.0:
# Also make sure it is K2MMI
assert hasattr(model.ctc, "dump_weight")
# Dump a pth for each rank to avoid conflits when reading / writing
weight_path = os.path.dirname(args.result_label) + "/dump"
os.makedirs(weight_path, exist_ok=True)
model.ctc.dump_weight(args.local_rank, weight_path)
mmi_scorer = MMIFrameScorer
mmi = mmi_scorer(lang=model.ctc.lang,
device=device,
idim=train_args.adim,
sos_id=model.sos,
rank=args.local_rank,
use_segment=args.use_segment,
char_list=train_args.char_list,
weight_path=weight_path)
else:
mmi = None
if args.mmi_rescore:
weight_path = os.path.dirname(args.result_label) + "/dump"
os.makedirs(weight_path, exist_ok=True)
model.ctc.dump_weight(args.local_rank, weight_path)
assert args.mmi_weight <= 0.0
mmi_rescorer = MMIRescorer(lang=model.ctc.lang,
device=device,
idim=train_args.adim,
sos_id=model.sos,
rank=args.local_rank,
use_segment=args.use_segment,
char_list=train_args.char_list,
weight_path=weight_path)
else:
mmi_rescorer = None
if args.ctc_weight > 0.0:
ctc_module = model.third_loss if hasattr(model, "third_loss") else model.ctc
ctc = CTCPrefixScorer(ctc_module, model.eos)
else:
ctc = None
if args.word_ngram_weight > 0.0:
word_ngram_scorer = WordNgramPartialScorer
print(f"Using word ngram model: {args.word_ngram}", flush=True)
word_ngram_scorer = WordNgramPartialScorer(args.word_ngram,
device,
train_args.char_list,
log_semiring=args.word_ngram_log_semiring)
else:
word_ngram_scorer = None
scorers = model.scorers()
scorers["ctc"] = ctc
scorers["mmi"] = mmi
scorers["lm"] = lm
scorers["ngram"] = ngram
scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
scorers["word_ngram"] = word_ngram_scorer
weights = dict(
decoder=1.0 - args.ctc_weight,
ctc=args.ctc_weight,
lm=args.lm_weight,
ngram=args.ngram_weight,
length_bonus=args.penalty,
mmi=args.mmi_weight,
word_ngram=args.word_ngram_weight,
)
beam_search = BeamSearch(
beam_size=args.beam_size,
vocab_size=len(train_args.char_list),
weights=weights,
scorers=scorers,
sos=model.sos,
eos=model.eos,
token_list=train_args.char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
mmi_rescorer=mmi_rescorer,
)
# TODO(karita): make all scorers batchfied
if args.batchsize == 1:
non_batch = [
k
for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
beam_search.__class__ = BatchBeamSearch
logging.info("BatchBeamSearch implementation is selected.")
else:
logging.warning(
f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation."
)
dtype = getattr(torch, args.dtype)
logging.info(f"Decoding device={device}, dtype={dtype}")
model.to(device=device, dtype=dtype).eval()
# beam_search.to(device=device, dtype=dtype).eval()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
rtf_calculator = RTF_calculator(js)
rtf_calculator.tik()
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype))
nbest_hyps = beam_search(
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
)
nbest_hyps = [
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
]
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
rtf_calculator.tok()
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
================================================
FILE: bin/__init__.py
================================================
"""Initialize sub package."""
================================================
FILE: bin/asr_align.py
================================================
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2020 Johns Hopkins University (Xuankai Chang)
# 2020, Technische Universität München; Dominik Winkelbauer, Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
This program performs CTC segmentation to align utterances within audio files.
Inputs:
`--data-json`:
A json containing list of utterances and audio files
`--model`:
An already trained ASR model
Output:
`--output`:
A plain `segments` file with utterance positions in the audio files.
Selected parameters:
`--min-window-size`:
Minimum window size considered for a single utterance. The current default value
should be OK in most cases. Larger values might give better results; too large
values cause IndexErrors.
`--subsampling-factor`:
If the encoder sub-samples its input, the number of frames at the CTC layer is
reduced by this factor.
`--frame-duration`:
This is the non-overlapping duration of a single frame in milliseconds (the
inverse of frames per millisecond).
`--set-blank`:
In the rare case that the blank token has not the index 0 in the character
dictionary, this parameter sets the index of the blank token.
`--gratis-blank`:
Sets the transition cost for blank tokens to zero. Useful if there are longer
unrelated segments between segments.
`--replace-spaces-with-blanks`:
Spaces are replaced with blanks. Helps to model pauses between words. May
increase length of ground truth. May lead to misaligned segments when combined
with the option `--gratis-blank`.
"""
import configargparse
import logging
import os
import sys
# imports for inference
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.io_utils import LoadInputsAndTargets
import json
import torch
# imports for CTC segmentation
from ctc_segmentation import ctc_segmentation
from ctc_segmentation import CtcSegmentationParameters
from ctc_segmentation import determine_utterance_segments
from ctc_segmentation import prepare_text
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Align text to audio using CTC segmentation."
"using a pre-trained speech recognition model.",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="Decoding config file path.")
parser.add_argument(
"--ngpu", type=int, default=0, help="Number of GPUs (max. 1 is supported)"
)
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)",
)
parser.add_argument(
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
# task related
parser.add_argument(
"--data-json", type=str, help="Json of recognition data for audio and text"
)
parser.add_argument("--utt-text", type=str, help="Text separated into utterances")
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
parser.add_argument(
"--num-encs", default=1, type=int, help="Number of encoders in the model."
)
# ctc-segmentation related
parser.add_argument(
"--subsampling-factor",
type=int,
default=None,
help="Subsampling factor."
" If the encoder sub-samples its input, the number of frames at the CTC layer"
" is reduced by this factor. For example, a BLSTMP with subsampling 1_2_2_1_1"
" has a subsampling factor of 4.",
)
parser.add_argument(
"--frame-duration",
type=int,
default=None,
help="Non-overlapping duration of a single frame in milliseconds.",
)
parser.add_argument(
"--min-window-size",
type=int,
default=None,
help="Minimum window size considered for utterance.",
)
parser.add_argument(
"--max-window-size",
type=int,
default=None,
help="Maximum window size considered for utterance.",
)
parser.add_argument(
"--use-dict-blank",
type=int,
default=None,
help="DEPRECATED.",
)
parser.add_argument(
"--set-blank",
type=int,
default=None,
help="Index of model dictionary for blank token (default: 0).",
)
parser.add_argument(
"--gratis-blank",
type=int,
default=None,
help="Set the transition cost of the blank token to zero. Audio sections"
" labeled with blank tokens can then be skipped without penalty. Useful"
" if there are unrelated audio segments between utterances.",
)
parser.add_argument(
"--replace-spaces-with-blanks",
type=int,
default=None,
help="Fill blanks in between words to better model pauses between words."
" Segments can be misaligned if this option is combined with --gratis-blank."
" May increase length of ground truth.",
)
parser.add_argument(
"--scoring-length",
type=int,
default=None,
help="Changes partitioning length L for calculation of the confidence score.",
)
parser.add_argument(
"--output",
type=configargparse.FileType("w"),
required=True,
help="Output segments file",
)
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args, extra = parser.parse_known_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
if args.ngpu == 0 and args.dtype == "float16":
raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
# check CUDA_VISIBLE_DEVICES
device = "cpu"
if args.ngpu == 1:
device = "cuda"
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu > 1:
logging.error("Decoding only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# recog
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
ctc_align(args, device)
else:
raise ValueError("Only pytorch is supported.")
sys.exit(0)
def ctc_align(args, device):
"""ESPnet-specific interface for CTC segmentation.
Parses configuration, infers the CTC posterior probabilities,
and then aligns start and end of utterances using CTC segmentation.
Results are written to the output file given in the args.
:param args: given configuration
:param device: for inference; one of ['cuda', 'cpu']
:return: 0 on success
"""
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=True,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
logging.info(f"Decoding device={device}")
# Warn for nets with high memory consumption on long audio files
if hasattr(model, "enc"):
encoder_module = model.enc.__class__.__module__
elif hasattr(model, "encoder"):
encoder_module = model.encoder.__class__.__module__
else:
encoder_module = "Unknown"
logging.info(f"Encoder module: {encoder_module}")
logging.info(f"CTC module: {model.ctc.__class__.__module__}")
if "rnn" not in encoder_module:
logging.warning("No BLSTM model detected; memory consumption may be high.")
model.to(device=device).eval()
# read audio and text json data
with open(args.data_json, "rb") as f:
js = json.load(f)["utts"]
with open(args.utt_text, "r", encoding="utf-8") as f:
lines = f.readlines()
i = 0
text = {}
segment_names = {}
for name in js.keys():
text_per_audio = []
segment_names_per_audio = []
while i < len(lines) and lines[i].startswith(name):
text_per_audio.append(lines[i][lines[i].find(" ") + 1 :])
segment_names_per_audio.append(lines[i][: lines[i].find(" ")])
i += 1
text[name] = text_per_audio
segment_names[name] = segment_names_per_audio
# apply configuration
config = CtcSegmentationParameters()
if args.subsampling_factor is not None:
config.subsampling_factor = args.subsampling_factor
if args.frame_duration is not None:
config.frame_duration_ms = args.frame_duration
if args.min_window_size is not None:
config.min_window_size = args.min_window_size
if args.max_window_size is not None:
config.max_window_size = args.max_window_size
config.char_list = train_args.char_list
if args.use_dict_blank is not None:
logging.warning(
"The option --use-dict-blank is deprecated. If needed,"
" use --set-blank instead."
)
if args.set_blank is not None:
config.blank = args.set_blank
if args.replace_spaces_with_blanks is not None:
if args.replace_spaces_with_blanks:
config.replace_spaces_with_blanks = True
else:
config.replace_spaces_with_blanks = False
if args.gratis_blank:
config.blank_transition_cost_zero = True
if config.blank_transition_cost_zero and args.replace_spaces_with_blanks:
logging.error(
"Blanks are inserted between words, and also the transition cost of blank"
" is zero. This configuration may lead to misalignments!"
)
if args.scoring_length is not None:
config.score_min_mean_over_L = args.scoring_length
logging.info(
f"Frame timings: {config.frame_duration_ms}ms * {config.subsampling_factor}"
)
# Iterate over audio files to decode and align
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) Aligning " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat, label = load_inputs_and_targets(batch)
feat = feat[0]
with torch.no_grad():
# Encode input frames
enc_output = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
# Apply ctc layer to obtain log character probabilities
lpz = model.ctc.log_softmax(enc_output)[0].cpu().numpy()
# Prepare the text for aligning
ground_truth_mat, utt_begin_indices = prepare_text(config, text[name])
# Align using CTC segmentation
timings, char_probs, state_list = ctc_segmentation(
config, lpz, ground_truth_mat
)
logging.debug(f"state_list = {state_list}")
# Obtain list of utterances with time intervals and confidence score
segments = determine_utterance_segments(
config, utt_begin_indices, char_probs, timings, text[name]
)
# Write to "segments" file
for i, boundary in enumerate(segments):
utt_segment = (
f"{segment_names[name][i]} {name} {boundary[0]:.2f}"
f" {boundary[1]:.2f} {boundary[2]:.9f}\n"
)
args.output.write(utt_segment)
return 0
if __name__ == "__main__":
main(sys.argv[1:])
================================================
FILE: bin/asr_enhance.py
================================================
#!/usr/bin/env python3
import configargparse
from distutils.util import strtobool
import logging
import os
import random
import sys
import numpy as np
from espnet.asr.pytorch_backend.asr import enhance
# NOTE: you need this func to generate our sphinx doc
def get_parser():
parser = configargparse.ArgumentParser(
description="Enhance noisy speech for speech recognition",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="config file path")
parser.add(
"--config2",
is_config_file=True,
help="second config file path that overwrites the settings in `--config`.",
)
parser.add(
"--config3",
is_config_file=True,
help="third config file path that overwrites the settings "
"in `--config` and `--config2`.",
)
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
parser.add_argument(
"--backend",
default="chainer",
type=str,
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--batchsize",
default=1,
type=int,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)"
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
# Outputs configuration
parser.add_argument(
"--enh-wspecifier",
type=str,
default=None,
help="Specify the output way for enhanced speech."
"e.g. ark,scp:outdir,wav.scp",
)
parser.add_argument(
"--enh-filetype",
type=str,
default="sound",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for enhanced speech. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
parser.add_argument(
"--keep-length",
type=strtobool,
default=True,
help="Adjust the output length to match " "with the input for enhanced speech",
)
parser.add_argument(
"--image-dir", type=str, default=None, help="The directory saving the images."
)
parser.add_argument(
"--num-images",
type=int,
default=20,
help="The number of images files to be saved. "
"If negative, all samples are to be saved.",
)
# IStft
parser.add_argument(
"--apply-istft",
type=strtobool,
default=True,
help="Apply istft to the output from the network",
)
parser.add_argument(
"--istft-win-length",
type=int,
default=512,
help="The window length for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-n-shift",
type=str,
default=256,
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-window",
type=str,
default="hann",
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
return parser
def main(args):
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(kamo): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# recog
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
enhance(args)
else:
raise ValueError("Only pytorch is supported.")
if __name__ == "__main__":
main(sys.argv[1:])
================================================
FILE: bin/asr_recog.py
================================================
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""End-to-end speech recognition model decoding script."""
import configargparse
import logging
import os
import random
import sys
import tracemalloc
import numpy as np
from espnet.utils.cli_utils import strtobool
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Transcribe text from speech using "
"a speech recognition model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path that overwrites the settings "
"in `--config` and `--config2`",
)
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)",
)
parser.add_argument(
"--backend",
type=str,
default="chainer",
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"--api",
default="v1",
choices=["v1", "v2"],
help="Beam search APIs "
"v1: Default API. It only supports the ASRInterface.recognize method "
"and DefaultRNNLM. "
"v2: Experimental API. It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)"
)
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)",
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
parser.add_argument(
"--num-spkrs",
type=int,
default=1,
choices=[1, 2],
help="Number of speakers in the speech",
)
parser.add_argument(
"--num-encs", default=1, type=int, help="Number of encoders in the model."
)
# search related
parser.add_argument("--nbest", type=int, default=10, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""",
)
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
parser.add_argument(
"--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding"
)
parser.add_argument(
"--weights-ctc-dec",
type=float,
action="append",
help="ctc weight assigned to each encoder during decoding."
"[in multi-encoder mode only]",
)
parser.add_argument(
"--ctc-window-margin",
type=int,
default=0,
help="""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled""",
)
# transducer related
parser.add_argument(
"--search-type",
type=str,
default="alsd",
choices=["default", "nsc", "tsd", "alsd", "ctc_greedy", "ctc_beam"],
help="""Type of beam search implementation to use during inference.
Can be either: default beam search, n-step constrained beam search ("nsc"),
time-synchronous decoding ("tsd") or alignment-length synchronous decoding
("alsd").
Additional associated parameters: "nstep" + "prefix-alpha" (for nsc),
"max-sym-exp" (for tsd) and "u-max" (for alsd)""",
)
parser.add_argument(
"--nstep",
type=int,
default=1,
help="Number of expansion steps allowed in NSC beam search.",
)
parser.add_argument(
"--prefix-alpha",
type=int,
default=2,
help="Length prefix difference allowed in NSC beam search.",
)
parser.add_argument(
"--max-sym-exp",
type=int,
default=2,
help="Number of symbol expansions allowed in TSD decoding.",
)
parser.add_argument(
"--u-max",
type=int,
default=400,
help="Length prefix difference allowed in ALSD beam search.",
)
parser.add_argument(
"--score-norm",
type=strtobool,
nargs="?",
default=True,
help="Normalize transducer scores by length",
)
# rnnlm related
parser.add_argument(
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
)
parser.add_argument(
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
)
parser.add_argument(
"--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read"
)
parser.add_argument(
"--word-rnnlm-conf",
type=str,
default=None,
help="Word RNNLM model config file to read",
)
parser.add_argument("--word-dict", type=str, default=None, help="Word list to read")
parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight")
# ngram related
parser.add_argument(
"--ngram-model", type=str, default=None, help="ngram model file to read"
)
parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight")
parser.add_argument(
"--ngram-scorer",
type=str,
default="part",
choices=("full", "part"),
help="""if the ngram is set as a part scorer, similar with CTC scorer,
ngram scorer only scores topK hypethesis.
if the ngram is set as full scorer, ngram scorer scores all hypthesis
the decoding speed of part scorer is musch faster than full one""",
)
# streaming related
parser.add_argument(
"--streaming-mode",
type=str,
default=None,
choices=["window", "segment"],
help="""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode""",
)
parser.add_argument("--streaming-window", type=int, default=10, help="Window size")
parser.add_argument(
"--streaming-min-blank-dur",
type=int,
default=10,
help="Minimum blank duration threshold",
)
parser.add_argument(
"--streaming-onset-margin", type=int, default=1, help="Onset margin"
)
parser.add_argument(
"--streaming-offset-margin", type=int, default=1, help="Offset margin"
)
# non-autoregressive related
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
parser.add_argument(
"--maskctc-n-iterations",
type=int,
default=10,
help="Number of decoding iterations."
"For Mask CTC, set 0 to predict 1 mask/iter.",
)
parser.add_argument(
"--maskctc-probability-threshold",
type=float,
default=0.999,
help="Threshold probability for CTC output",
)
parser.add_argument(
"--k2-decode",
gitextract_ni0k430x/ ├── .gitignore ├── .run.sh.swp ├── README.md ├── __init__.py ├── asr/ │ ├── __init__.py │ ├── asr_mix_utils.py │ ├── asr_utils.py │ ├── chainer_backend/ │ │ ├── __init__.py │ │ └── asr.py │ └── pytorch_backend/ │ ├── __init__.py │ ├── asr.py │ ├── asr_init.py │ ├── asr_mix.py │ └── recog.py ├── bin/ │ ├── __init__.py │ ├── asr_align.py │ ├── asr_enhance.py │ ├── asr_recog.py │ ├── asr_train.py │ ├── lm_train.py │ ├── mt_train.py │ ├── mt_trans.py │ ├── st_train.py │ ├── st_trans.py │ ├── tts_decode.py │ ├── tts_train.py │ ├── vc_decode.py │ └── vc_train.py ├── egs/ │ ├── .gitignore │ ├── aishell1/ │ │ ├── .gitignore │ │ ├── aed.sh │ │ ├── cmd.sh │ │ ├── conf/ │ │ │ ├── fbank.conf │ │ │ ├── gpu.conf │ │ │ ├── lm.yaml │ │ │ ├── lm_rnn.yaml │ │ │ ├── lm_transformer.yaml │ │ │ ├── pitch.conf │ │ │ ├── queue.conf │ │ │ ├── slurm.conf │ │ │ ├── specaug.yaml │ │ │ ├── specaug_test.yaml │ │ │ └── tuning/ │ │ │ ├── decode_pytorch_transformer.yaml │ │ │ ├── decode_rnn.yaml │ │ │ ├── train_pytorch_conformer_kernel15.yaml │ │ │ ├── train_pytorch_conformer_kernel31.yaml │ │ │ ├── train_pytorch_conformer_kernel31_large.yaml │ │ │ ├── train_pytorch_conformer_kernel31_small.yaml │ │ │ ├── train_pytorch_transformer.yaml │ │ │ ├── train_rnn.yaml │ │ │ └── transducer/ │ │ │ ├── decode_default.yaml │ │ │ ├── train_conformer-rnn_transducer.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_att.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_small.yaml │ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml │ │ │ ├── train_conformer-rnn_transducer_ngpu4_large.yaml │ │ │ ├── train_transducer.yaml │ │ │ └── train_transducer_aux.yaml │ │ ├── local/ │ │ │ ├── add_lex_disambig.pl │ │ │ ├── aishell_data_prep.sh │ │ │ ├── aishell_train_lms.sh │ │ │ ├── apply_map.pl │ │ │ ├── build_sp_text.py │ │ │ ├── build_word_mapping.py │ │ │ ├── compile_bigram.sh │ │ │ ├── download_and_untar.sh │ │ │ ├── fstaddselfloops.pl │ │ │ ├── k2_aishell_prepare_dict.sh │ │ │ ├── k2_aishell_prepare_dict_char.sh │ │ │ ├── k2_prepare_lang.sh │ │ │ ├── make_lexicon_fst.py │ │ │ ├── max_rescore.py │ │ │ ├── parse_options.sh │ │ │ ├── parse_text_jieba.py │ │ │ ├── prepare_word_lex.py │ │ │ └── sym2int.pl │ │ ├── nt.sh │ │ ├── path.sh │ │ └── prepare.sh │ ├── aishell2/ │ │ ├── .gitignore │ │ ├── aed.sh │ │ ├── conf/ │ │ │ ├── .fbank.conf.swp │ │ │ ├── fbank.conf │ │ │ ├── gpu.conf │ │ │ ├── lm.yaml │ │ │ ├── lm_rnn.yaml │ │ │ ├── lm_transformer.yaml │ │ │ ├── pitch.conf │ │ │ ├── queue.conf │ │ │ ├── slurm.conf │ │ │ ├── specaug.yaml │ │ │ ├── specaug_test.yaml │ │ │ └── tuning/ │ │ │ ├── decode_pytorch_transformer.yaml │ │ │ ├── decode_rnn.yaml │ │ │ ├── train_pytorch_conformer_kernel15.yaml │ │ │ ├── train_pytorch_conformer_kernel31.yaml │ │ │ ├── train_pytorch_transformer.yaml │ │ │ ├── train_rnn.yaml │ │ │ └── transducer/ │ │ │ ├── decode_default.yaml │ │ │ ├── train_conformer-rnn_transducer.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml │ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml │ │ │ ├── train_transducer.yaml │ │ │ └── train_transducer_aux.yaml │ │ ├── local/ │ │ │ ├── add_lex_disambig.pl │ │ │ ├── apply_map.pl │ │ │ ├── fstaddselfloops.pl │ │ │ ├── jieba_split_text.py │ │ │ ├── k2_prepare_lang.sh │ │ │ ├── make_lexicon_fst.py │ │ │ ├── max_rescore.py │ │ │ ├── mmi_rescore.sh │ │ │ ├── parse_options.sh │ │ │ ├── prepare_data.sh │ │ │ ├── prepare_dict.sh │ │ │ ├── rerank.py │ │ │ ├── sym2int.pl │ │ │ ├── train_lms.sh │ │ │ └── word_segmentation.py │ │ ├── nt.sh │ │ └── prepare.sh │ ├── asrucs/ │ │ ├── .gitignore │ │ ├── cmd.sh │ │ ├── conf/ │ │ │ ├── decode.yaml │ │ │ ├── fbank.conf │ │ │ ├── gpu.conf │ │ │ ├── lm.yaml │ │ │ ├── lm_rnn.yaml │ │ │ ├── lm_transformer.yaml │ │ │ ├── pitch.conf │ │ │ ├── pure_ctc.yaml │ │ │ ├── queue.conf │ │ │ ├── slurm.conf │ │ │ ├── specaug.yaml │ │ │ ├── specaug_test.yaml │ │ │ ├── train.yaml │ │ │ ├── train_conformer-rnn_transducer_cs.yaml │ │ │ └── tuning/ │ │ │ ├── decode_pytorch_transformer.yaml │ │ │ ├── decode_rnn.yaml │ │ │ ├── train_pytorch_conformer_kernel15.yaml │ │ │ ├── train_pytorch_conformer_kernel31.yaml │ │ │ ├── train_pytorch_conformer_kernel31_large.yaml │ │ │ ├── train_pytorch_conformer_kernel31_small.yaml │ │ │ ├── train_pytorch_transformer.yaml │ │ │ ├── train_rnn.yaml │ │ │ └── transducer/ │ │ │ ├── decode_default.yaml │ │ │ ├── train_conformer-rnn_transducer.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_att.yaml │ │ │ ├── train_conformer-rnn_transducer_aux_ngpu4_small.yaml │ │ │ ├── train_conformer-rnn_transducer_ngpu4.yaml │ │ │ ├── train_conformer-rnn_transducer_ngpu4_large.yaml │ │ │ ├── train_transducer.yaml │ │ │ └── train_transducer_aux.yaml │ │ ├── espnet │ │ ├── espnet_utils │ │ ├── local/ │ │ │ ├── add_seperator.py │ │ │ ├── generate_fake_cs.py │ │ │ └── prepare_fake_cs.sh │ │ ├── nt.sh │ │ ├── path.sh │ │ ├── prepare.sh │ │ ├── steps │ │ ├── text │ │ └── utils │ ├── espnet_utils/ │ │ ├── add_uttcls_json.py │ │ ├── addjson.py │ │ ├── apply-cmvn.py │ │ ├── asr_align_wav.sh │ │ ├── average_checkpoints.py │ │ ├── build_fake_lexicon.py │ │ ├── build_sp_text.py │ │ ├── calculate_rtf.py │ │ ├── change_root.py │ │ ├── change_yaml.py │ │ ├── clean_corpus.sh │ │ ├── compute-cmvn-stats.py │ │ ├── compute-fbank-feats.py │ │ ├── compute-stft-feats.py │ │ ├── concat_json_multiref.py │ │ ├── concatjson.py │ │ ├── convert_fbank.sh │ │ ├── convert_fbank_to_wav.py │ │ ├── copy-feats.py │ │ ├── data2json.sh │ │ ├── divide_lang.sh │ │ ├── double_precious_cer.py │ │ ├── download_from_google_drive.sh │ │ ├── dump-pcm.py │ │ ├── dump.sh │ │ ├── dump_pcm.sh │ │ ├── eval-source-separation.py │ │ ├── eval_perm_free_error.py │ │ ├── eval_source_separation.sh │ │ ├── feat-to-shape.py │ │ ├── feat_to_shape.sh │ │ ├── feats2npy.py │ │ ├── filt.py │ │ ├── filter_all_eng_utts.py │ │ ├── filter_scp.py │ │ ├── filter_trn.py │ │ ├── free-gpu.sh │ │ ├── gdown.pl │ │ ├── generate_wav.sh │ │ ├── generate_wav_from_fbank.py │ │ ├── get_yaml.py │ │ ├── jieba_build_dict.py │ │ ├── json2sctm.py │ │ ├── json2text.py │ │ ├── json2trn.py │ │ ├── json2trn_mt.py │ │ ├── json2trn_wo_dict.py │ │ ├── k2/ │ │ │ ├── add_lex_disambig.pl │ │ │ ├── apply_map.pl │ │ │ ├── fstaddselfloops.pl │ │ │ ├── k2_prepare_lang.sh │ │ │ ├── parse_options.sh │ │ │ └── sym2int.pl │ │ ├── make_fbank.sh │ │ ├── make_pair_json.py │ │ ├── make_stft.sh │ │ ├── mbr_analysis.py │ │ ├── mcd_calculate.py │ │ ├── merge_scp2json.py │ │ ├── mergejson.py │ │ ├── mix-mono-wav-scp.py │ │ ├── mmi_rescore.sh │ │ ├── pack_model.sh │ │ ├── prepare_block_load.sh │ │ ├── prepare_mer.py │ │ ├── queue-freegpu.pl │ │ ├── recog_wav.sh │ │ ├── reduce_data_dir.sh │ │ ├── remove_longshortdata.sh │ │ ├── remove_punctuation.pl │ │ ├── rerank_mmi.py │ │ ├── result2json.py │ │ ├── score_bleu.sh │ │ ├── score_lang_id.py │ │ ├── score_sclite.sh │ │ ├── score_sclite_case.sh │ │ ├── score_sclite_wo_dict.sh │ │ ├── scp2json.py │ │ ├── show_result.sh │ │ ├── significant_test.sh │ │ ├── sort_scp_by_length.py │ │ ├── speed_perturb.sh │ │ ├── split_scp.py │ │ ├── split_scp_fix_length.py │ │ ├── splitjson.py │ │ ├── spm_decode │ │ ├── spm_encode │ │ ├── spm_train │ │ ├── stdout.pl │ │ ├── synth_wav.sh │ │ ├── text2token.py │ │ ├── text2vocabulary.py │ │ ├── text_norm.py │ │ ├── trace_rnnt.py │ │ ├── train_lms_srilm.sh │ │ ├── translate_wav.sh │ │ ├── trim_silence.py │ │ ├── trim_silence.sh │ │ ├── trn2ctm.py │ │ ├── trn2stm.py │ │ ├── update_json.sh │ │ ├── word_ngram_rescore.py │ │ └── word_ngram_rescore.sh │ ├── steps/ │ │ ├── align_basis_fmllr.sh │ │ ├── align_basis_fmllr_lats.sh │ │ ├── align_fmllr.sh │ │ ├── align_fmllr_lats.sh │ │ ├── align_lvtln.sh │ │ ├── align_raw_fmllr.sh │ │ ├── align_sgmm2.sh │ │ ├── align_si.sh │ │ ├── best_path_weights.sh │ │ ├── cleanup/ │ │ │ ├── clean_and_segment_data.sh │ │ │ ├── clean_and_segment_data_nnet3.sh │ │ │ ├── combine_short_segments.py │ │ │ ├── create_segments_from_ctm.pl │ │ │ ├── debug_lexicon.sh │ │ │ ├── decode_fmllr_segmentation.sh │ │ │ ├── decode_segmentation.sh │ │ │ ├── decode_segmentation_nnet3.sh │ │ │ ├── find_bad_utts.sh │ │ │ ├── find_bad_utts_nnet.sh │ │ │ ├── internal/ │ │ │ │ ├── align_ctm_ref.py │ │ │ │ ├── compute_tf_idf.py │ │ │ │ ├── ctm_to_text.pl │ │ │ │ ├── get_ctm_edits.py │ │ │ │ ├── get_non_scored_words.py │ │ │ │ ├── get_pron_stats.py │ │ │ │ ├── make_one_biased_lm.py │ │ │ │ ├── modify_ctm_edits.py │ │ │ │ ├── resolve_ctm_edits_overlaps.py │ │ │ │ ├── retrieve_similar_docs.py │ │ │ │ ├── segment_ctm_edits.py │ │ │ │ ├── segment_ctm_edits_mild.py │ │ │ │ ├── split_text_into_docs.pl │ │ │ │ ├── stitch_documents.py │ │ │ │ ├── taint_ctm_edits.py │ │ │ │ └── tf_idf.py │ │ │ ├── lattice_oracle_align.sh │ │ │ ├── make_biased_lm_graphs.sh │ │ │ ├── make_biased_lms.py │ │ │ ├── make_segmentation_data_dir.sh │ │ │ ├── make_segmentation_graph.sh │ │ │ ├── make_utterance_fsts.pl │ │ │ ├── make_utterance_graph.sh │ │ │ ├── segment_long_utterances.sh │ │ │ ├── segment_long_utterances_nnet3.sh │ │ │ └── split_long_utterance.sh │ │ ├── combine_ali_dirs.sh │ │ ├── combine_trans_dirs.sh │ │ ├── compare_alignments.sh │ │ ├── compute_cmvn_stats.sh │ │ ├── compute_vad_decision.sh │ │ ├── conf/ │ │ │ ├── append_eval_to_ctm.py │ │ │ ├── append_prf_to_ctm.py │ │ │ ├── apply_calibration.sh │ │ │ ├── convert_ctm_to_tra.py │ │ │ ├── get_ctm_conf.sh │ │ │ ├── lattice_depth_per_frame.sh │ │ │ ├── parse_arpa_unigrams.py │ │ │ ├── prepare_calibration_data.py │ │ │ ├── prepare_word_categories.py │ │ │ └── train_calibration.sh │ │ ├── copy_ali_dir.sh │ │ ├── copy_lat_dir.sh │ │ ├── copy_trans_dir.sh │ │ ├── data/ │ │ │ ├── augment_data_dir.py │ │ │ ├── data_dir_manipulation_lib.py │ │ │ ├── make_musan.py │ │ │ ├── make_musan.sh │ │ │ └── reverberate_data_dir.py │ │ ├── decode.sh │ │ ├── decode_basis_fmllr.sh │ │ ├── decode_biglm.sh │ │ ├── decode_combine.sh │ │ ├── decode_fmllr.sh │ │ ├── decode_fmllr_extra.sh │ │ ├── decode_fmmi.sh │ │ ├── decode_fromlats.sh │ │ ├── decode_lvtln.sh │ │ ├── decode_nolats.sh │ │ ├── decode_raw_fmllr.sh │ │ ├── decode_sgmm2.sh │ │ ├── decode_sgmm2_fromlats.sh │ │ ├── decode_sgmm2_rescore.sh │ │ ├── decode_sgmm2_rescore_project.sh │ │ ├── decode_with_map.sh │ │ ├── diagnostic/ │ │ │ ├── analyze_alignments.sh │ │ │ ├── analyze_lats.sh │ │ │ ├── analyze_lattice_depth_stats.py │ │ │ └── analyze_phone_length_stats.py │ │ ├── dict/ │ │ │ ├── apply_g2p.sh │ │ │ ├── apply_g2p_phonetisaurus.sh │ │ │ ├── apply_lexicon_edits.py │ │ │ ├── get_pron_stats.py │ │ │ ├── internal/ │ │ │ │ ├── get_subsegments.py │ │ │ │ ├── prune_pron_candidates.py │ │ │ │ └── sum_arc_info.py │ │ │ ├── learn_lexicon_bayesian.sh │ │ │ ├── learn_lexicon_greedy.sh │ │ │ ├── merge_learned_lexicons.py │ │ │ ├── prons_to_lexicon.py │ │ │ ├── prune_pron_candidates.py │ │ │ ├── select_prons_bayesian.py │ │ │ ├── select_prons_greedy.py │ │ │ ├── train_g2p.sh │ │ │ └── train_g2p_phonetisaurus.sh │ │ ├── get_ctm.sh │ │ ├── get_ctm_conf_fast.sh │ │ ├── get_ctm_fast.sh │ │ ├── get_fmllr_basis.sh │ │ ├── get_lexicon_probs.sh │ │ ├── get_prons.sh │ │ ├── get_train_ctm.sh │ │ ├── info/ │ │ │ ├── chain_dir_info.pl │ │ │ ├── gmm_dir_info.pl │ │ │ ├── nnet2_dir_info.pl │ │ │ ├── nnet3_dir_info.pl │ │ │ └── nnet3_disc_dir_info.pl │ │ ├── libs/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── nnet3/ │ │ │ ├── __init__.py │ │ │ ├── report/ │ │ │ │ ├── __init__.py │ │ │ │ └── log_parse.py │ │ │ ├── train/ │ │ │ │ ├── __init__.py │ │ │ │ ├── chain_objf/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── acoustic_model.py │ │ │ │ ├── common.py │ │ │ │ ├── dropout_schedule.py │ │ │ │ └── frame_level_objf/ │ │ │ │ ├── __init__.py │ │ │ │ ├── acoustic_model.py │ │ │ │ ├── common.py │ │ │ │ └── raw_model.py │ │ │ └── xconfig/ │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── basic_layers.py │ │ │ ├── composite_layers.py │ │ │ ├── convolution.py │ │ │ ├── gru.py │ │ │ ├── layers.py │ │ │ ├── lstm.py │ │ │ ├── parser.py │ │ │ ├── stats_layer.py │ │ │ ├── trivial_layers.py │ │ │ └── utils.py │ │ ├── lmrescore.sh │ │ ├── lmrescore_const_arpa.sh │ │ ├── lmrescore_const_arpa_undeterminized.sh │ │ ├── lmrescore_rnnlm_lat.sh │ │ ├── make_denlats.sh │ │ ├── make_denlats_sgmm2.sh │ │ ├── make_fbank.sh │ │ ├── make_fbank_pitch.sh │ │ ├── make_index.sh │ │ ├── make_mfcc.sh │ │ ├── make_mfcc_pitch.sh │ │ ├── make_mfcc_pitch_online.sh │ │ ├── make_phone_graph.sh │ │ ├── make_plp.sh │ │ ├── make_plp_pitch.sh │ │ ├── nnet/ │ │ │ ├── align.sh │ │ │ ├── decode.sh │ │ │ ├── ivector/ │ │ │ │ ├── extract_ivectors.sh │ │ │ │ ├── train_diag_ubm.sh │ │ │ │ └── train_ivector_extractor.sh │ │ │ ├── make_bn_feats.sh │ │ │ ├── make_denlats.sh │ │ │ ├── make_fmllr_feats.sh │ │ │ ├── make_fmmi_feats.sh │ │ │ ├── make_priors.sh │ │ │ ├── pretrain_dbn.sh │ │ │ ├── train.sh │ │ │ ├── train_mmi.sh │ │ │ ├── train_mpe.sh │ │ │ └── train_scheduler.sh │ │ ├── nnet2/ │ │ │ ├── adjust_priors.sh │ │ │ ├── align.sh │ │ │ ├── check_ivectors_compatible.sh │ │ │ ├── convert_lda_to_raw.sh │ │ │ ├── convert_nnet1_to_nnet2.sh │ │ │ ├── create_appended_model.sh │ │ │ ├── decode.sh │ │ │ ├── dump_bottleneck_features.sh │ │ │ ├── get_egs.sh │ │ │ ├── get_egs2.sh │ │ │ ├── get_egs_discriminative2.sh │ │ │ ├── get_ivector_id.sh │ │ │ ├── get_lda.sh │ │ │ ├── get_lda_block.sh │ │ │ ├── get_perturbed_feats.sh │ │ │ ├── make_denlats.sh │ │ │ ├── make_multisplice_configs.py │ │ │ ├── relabel_egs.sh │ │ │ ├── relabel_egs2.sh │ │ │ ├── remove_egs.sh │ │ │ ├── retrain_fast.sh │ │ │ ├── retrain_simple2.sh │ │ │ ├── retrain_tanh.sh │ │ │ ├── train_block.sh │ │ │ ├── train_convnet_accel2.sh │ │ │ ├── train_discriminative.sh │ │ │ ├── train_discriminative2.sh │ │ │ ├── train_discriminative_multilang2.sh │ │ │ ├── train_more.sh │ │ │ ├── train_more2.sh │ │ │ ├── train_multilang2.sh │ │ │ ├── train_multisplice_accel2.sh │ │ │ ├── train_multisplice_ensemble.sh │ │ │ ├── train_pnorm.sh │ │ │ ├── train_pnorm_accel2.sh │ │ │ ├── train_pnorm_bottleneck_fast.sh │ │ │ ├── train_pnorm_ensemble.sh │ │ │ ├── train_pnorm_fast.sh │ │ │ ├── train_pnorm_multisplice.sh │ │ │ ├── train_pnorm_multisplice2.sh │ │ │ ├── train_pnorm_simple.sh │ │ │ ├── train_pnorm_simple2.sh │ │ │ ├── train_tanh.sh │ │ │ ├── train_tanh_bottleneck.sh │ │ │ ├── train_tanh_fast.sh │ │ │ └── update_nnet.sh │ │ ├── nnet3/ │ │ │ ├── adjust_priors.sh │ │ │ ├── align.sh │ │ │ ├── align_lats.sh │ │ │ ├── chain/ │ │ │ │ ├── align_lats.sh │ │ │ │ ├── build_tree.sh │ │ │ │ ├── build_tree_multiple_sources.sh │ │ │ │ ├── e2e/ │ │ │ │ │ ├── README.txt │ │ │ │ │ ├── compute_biphone_stats.py │ │ │ │ │ ├── get_egs_e2e.sh │ │ │ │ │ ├── prepare_e2e.sh │ │ │ │ │ ├── text_to_phones.py │ │ │ │ │ └── train_e2e.py │ │ │ │ ├── gen_topo.pl │ │ │ │ ├── gen_topo.py │ │ │ │ ├── gen_topo2.py │ │ │ │ ├── gen_topo3.py │ │ │ │ ├── gen_topo4.py │ │ │ │ ├── gen_topo5.py │ │ │ │ ├── gen_topo_orig.py │ │ │ │ ├── get_egs.sh │ │ │ │ ├── get_model_context.sh │ │ │ │ ├── get_phone_post.sh │ │ │ │ ├── make_weighted_den_fst.sh │ │ │ │ ├── multilingual/ │ │ │ │ │ └── combine_egs.sh │ │ │ │ ├── train.py │ │ │ │ └── train_tdnn.sh │ │ │ ├── chain2/ │ │ │ │ ├── combine_egs.sh │ │ │ │ ├── compute_preconditioning_matrix.sh │ │ │ │ ├── get_raw_egs.sh │ │ │ │ ├── internal/ │ │ │ │ │ ├── get_best_model.sh │ │ │ │ │ └── get_train_schedule.py │ │ │ │ ├── process_egs.sh │ │ │ │ ├── randomize_egs.sh │ │ │ │ ├── train.sh │ │ │ │ ├── validate_processed_egs.sh │ │ │ │ ├── validate_randomized_egs.sh │ │ │ │ └── validate_raw_egs.sh │ │ │ ├── components.py │ │ │ ├── compute_output.sh │ │ │ ├── convert_nnet2_to_nnet3.py │ │ │ ├── decode.sh │ │ │ ├── decode_grammar.sh │ │ │ ├── decode_lookahead.sh │ │ │ ├── decode_looped.sh │ │ │ ├── decode_score_fusion.sh │ │ │ ├── decode_semisup.sh │ │ │ ├── dot/ │ │ │ │ ├── descriptor_parser.py │ │ │ │ └── nnet3_to_dot.py │ │ │ ├── get_degs.sh │ │ │ ├── get_egs.sh │ │ │ ├── get_egs_discriminative.sh │ │ │ ├── get_egs_targets.sh │ │ │ ├── get_saturation.pl │ │ │ ├── get_successful_models.py │ │ │ ├── lstm/ │ │ │ │ ├── make_configs.py │ │ │ │ └── train.sh │ │ │ ├── make_bottleneck_features.sh │ │ │ ├── make_denlats.sh │ │ │ ├── make_tdnn_configs.py │ │ │ ├── multilingual/ │ │ │ │ ├── allocate_multilingual_examples.py │ │ │ │ └── combine_egs.sh │ │ │ ├── nnet3_to_dot.sh │ │ │ ├── report/ │ │ │ │ ├── convert_model.py │ │ │ │ ├── generate_plots.py │ │ │ │ └── summarize_compute_debug_timing.py │ │ │ ├── tdnn/ │ │ │ │ ├── make_configs.py │ │ │ │ ├── train.sh │ │ │ │ └── train_raw_nnet.sh │ │ │ ├── train_discriminative.sh │ │ │ ├── train_dnn.py │ │ │ ├── train_raw_dnn.py │ │ │ ├── train_raw_rnn.py │ │ │ ├── train_rnn.py │ │ │ ├── train_tdnn.sh │ │ │ ├── xconfig_to_config.py │ │ │ └── xconfig_to_configs.py │ │ ├── online/ │ │ │ ├── decode.sh │ │ │ ├── nnet2/ │ │ │ │ ├── align.sh │ │ │ │ ├── copy_data_dir.sh │ │ │ │ ├── copy_ivector_dir.sh │ │ │ │ ├── decode.sh │ │ │ │ ├── dump_nnet_activations.sh │ │ │ │ ├── extract_ivectors.sh │ │ │ │ ├── extract_ivectors_online.sh │ │ │ │ ├── get_egs.sh │ │ │ │ ├── get_egs2.sh │ │ │ │ ├── get_egs_discriminative2.sh │ │ │ │ ├── get_pca_transform.sh │ │ │ │ ├── make_denlats.sh │ │ │ │ ├── prepare_online_decoding.sh │ │ │ │ ├── prepare_online_decoding_retrain.sh │ │ │ │ ├── prepare_online_decoding_transfer.sh │ │ │ │ ├── train_diag_ubm.sh │ │ │ │ └── train_ivector_extractor.sh │ │ │ ├── nnet3/ │ │ │ │ ├── decode.sh │ │ │ │ ├── decode_wake_word.sh │ │ │ │ └── prepare_online_decoding.sh │ │ │ └── prepare_online_decoding.sh │ │ ├── oracle_wer.sh │ │ ├── overlap/ │ │ │ ├── get_overlap_segments.py │ │ │ ├── get_overlap_targets.py │ │ │ ├── output_to_rttm.py │ │ │ ├── post_process_output.sh │ │ │ └── prepare_overlap_graph.py │ │ ├── paste_feats.sh │ │ ├── pytorchnn/ │ │ │ ├── check_py.py │ │ │ ├── compute_sentence_scores.py │ │ │ ├── data.py │ │ │ ├── lmrescore_nbest_pytorchnn.sh │ │ │ ├── model.py │ │ │ └── train.py │ │ ├── resegment_data.sh │ │ ├── resegment_text.sh │ │ ├── rnnlmrescore.sh │ │ ├── scoring/ │ │ │ ├── score_kaldi_cer.sh │ │ │ ├── score_kaldi_compare.sh │ │ │ └── score_kaldi_wer.sh │ │ ├── search_index.sh │ │ ├── segmentation/ │ │ │ ├── ali_to_targets.sh │ │ │ ├── combine_targets_dirs.sh │ │ │ ├── convert_targets_dir_to_whole_recording.sh │ │ │ ├── convert_utt2spk_and_segments_to_rttm.py │ │ │ ├── copy_targets_dir.sh │ │ │ ├── decode_sad.sh │ │ │ ├── detect_speech_activity.sh │ │ │ ├── evaluate_segmentation.pl │ │ │ ├── get_targets_for_out_of_segments.sh │ │ │ ├── internal/ │ │ │ │ ├── arc_info_to_targets.py │ │ │ │ ├── find_oov_phone.py │ │ │ │ ├── get_default_targets_for_out_of_segments.py │ │ │ │ ├── get_transform_probs_mat.py │ │ │ │ ├── merge_segment_targets_to_recording.py │ │ │ │ ├── merge_targets.py │ │ │ │ ├── prepare_sad_graph.py │ │ │ │ ├── resample_targets.py │ │ │ │ ├── sad_to_segments.py │ │ │ │ └── verify_phones_list.py │ │ │ ├── lats_to_targets.sh │ │ │ ├── merge_targets_dirs.sh │ │ │ ├── post_process_sad_to_segments.sh │ │ │ ├── prepare_targets_gmm.sh │ │ │ ├── resample_targets_dir.sh │ │ │ └── validate_targets_dir.sh │ │ ├── select_feats.sh │ │ ├── shift_feats.sh │ │ ├── subset_ali_dir.sh │ │ ├── tandem/ │ │ │ ├── align_fmllr.sh │ │ │ ├── align_sgmm2.sh │ │ │ ├── align_si.sh │ │ │ ├── decode.sh │ │ │ ├── decode_fmllr.sh │ │ │ ├── decode_sgmm2.sh │ │ │ ├── make_denlats.sh │ │ │ ├── make_denlats_sgmm2.sh │ │ │ ├── mk_aslf_lda_mllt.sh │ │ │ ├── mk_aslf_sgmm2.sh │ │ │ ├── train_deltas.sh │ │ │ ├── train_lda_mllt.sh │ │ │ ├── train_mllt.sh │ │ │ ├── train_mmi.sh │ │ │ ├── train_mmi_sgmm2.sh │ │ │ ├── train_mono.sh │ │ │ ├── train_sat.sh │ │ │ ├── train_sgmm2.sh │ │ │ └── train_ubm.sh │ │ ├── tfrnnlm/ │ │ │ ├── check_py.py │ │ │ ├── check_tensorflow_installed.sh │ │ │ ├── lmrescore_rnnlm_lat.sh │ │ │ ├── lmrescore_rnnlm_lat_pruned.sh │ │ │ ├── lstm.py │ │ │ ├── lstm_fast.py │ │ │ ├── reader.py │ │ │ └── vanilla_rnnlm.py │ │ ├── train_deltas.sh │ │ ├── train_diag_ubm.sh │ │ ├── train_lda_mllt.sh │ │ ├── train_lvtln.sh │ │ ├── train_map.sh │ │ ├── train_mmi.sh │ │ ├── train_mmi_fmmi.sh │ │ ├── train_mmi_fmmi_indirect.sh │ │ ├── train_mmi_sgmm2.sh │ │ ├── train_mono.sh │ │ ├── train_mpe.sh │ │ ├── train_quick.sh │ │ ├── train_raw_sat.sh │ │ ├── train_sat.sh │ │ ├── train_sat_basis.sh │ │ ├── train_segmenter.sh │ │ ├── train_sgmm2.sh │ │ ├── train_sgmm2_group.sh │ │ ├── train_smbr.sh │ │ ├── train_ubm.sh │ │ └── word_align_lattices.sh │ └── utils/ │ ├── add_disambig.pl │ ├── add_lex_disambig.pl │ ├── analyze_segments.pl │ ├── apply_map.pl │ ├── best_wer.sh │ ├── build_const_arpa_lm.sh │ ├── combine_data.sh │ ├── convert_slf.pl │ ├── convert_slf_parallel.sh │ ├── copy_data_dir.sh │ ├── create_data_link.pl │ ├── create_split_dir.pl │ ├── ctm/ │ │ ├── convert_ctm.pl │ │ ├── fix_ctm.sh │ │ └── resolve_ctm_overlaps.py │ ├── data/ │ │ ├── combine_short_segments.sh │ │ ├── convert_data_dir_to_whole.sh │ │ ├── extend_segment_times.py │ │ ├── extract_wav_segments_data_dir.sh │ │ ├── fix_subsegment_feats.pl │ │ ├── get_allowed_durations.py │ │ ├── get_frame_shift.sh │ │ ├── get_num_frames.sh │ │ ├── get_reco2dur.sh │ │ ├── get_reco2utt_for_data.sh │ │ ├── get_segments_for_data.sh │ │ ├── get_uniform_subsegments.py │ │ ├── get_utt2dur.sh │ │ ├── get_utt2num_frames.sh │ │ ├── internal/ │ │ │ ├── choose_utts_to_combine.py │ │ │ ├── combine_segments_to_recording.py │ │ │ ├── modify_speaker_info.py │ │ │ └── perturb_volume.py │ │ ├── limit_feature_dim.sh │ │ ├── modify_speaker_info.sh │ │ ├── modify_speaker_info_to_recording.sh │ │ ├── normalize_data_range.pl │ │ ├── perturb_data_dir_speed_3way.sh │ │ ├── perturb_data_dir_volume.sh │ │ ├── perturb_speed_to_allowed_lengths.py │ │ ├── remove_dup_utts.sh │ │ ├── resample_data_dir.sh │ │ ├── shift_and_combine_feats.sh │ │ ├── shift_feats.sh │ │ └── subsegment_data_dir.sh │ ├── dict_dir_add_pronprobs.sh │ ├── eps2disambig.pl │ ├── filt.py │ ├── filter_scp.pl │ ├── filter_scps.pl │ ├── find_arpa_oovs.pl │ ├── fix_data_dir.sh │ ├── format_lm.sh │ ├── format_lm_sri.sh │ ├── gen_topo.pl │ ├── int2sym.pl │ ├── kwslist_post_process.pl │ ├── lang/ │ │ ├── add_unigrams_arpa.pl │ │ ├── adjust_unk_arpa.pl │ │ ├── adjust_unk_graph.sh │ │ ├── bpe/ │ │ │ ├── add_final_optional_silence.sh │ │ │ ├── apply_bpe.py │ │ │ ├── bidi.py │ │ │ ├── learn_bpe.py │ │ │ ├── prepend_words.py │ │ │ └── reverse.py │ │ ├── check_g_properties.pl │ │ ├── check_phones_compatible.sh │ │ ├── compute_sentence_probs_arpa.py │ │ ├── extend_lang.sh │ │ ├── get_word_position_phone_map.pl │ │ ├── grammar/ │ │ │ ├── augment_phones_txt.py │ │ │ └── augment_words_txt.py │ │ ├── internal/ │ │ │ ├── apply_unk_lm.sh │ │ │ ├── arpa2fst_constrained.py │ │ │ └── modify_unk_pron.py │ │ ├── limit_arpa_unk_history.py │ │ ├── make_kn_lm.py │ │ ├── make_lexicon_fst.py │ │ ├── make_lexicon_fst_silprob.py │ │ ├── make_phone_bigram_lang.sh │ │ ├── make_phone_lm.py │ │ ├── make_position_dependent_subword_lexicon.py │ │ ├── make_subword_lexicon_fst.py │ │ ├── make_unk_lm.sh │ │ └── validate_disambig_sym_file.pl │ ├── ln.pl │ ├── make_absolute.sh │ ├── make_lexicon_fst.pl │ ├── make_lexicon_fst_silprob.pl │ ├── make_unigram_grammar.pl │ ├── map_arpa_lm.pl │ ├── mkgraph.sh │ ├── mkgraph_lookahead.sh │ ├── nnet/ │ │ ├── gen_dct_mat.py │ │ ├── gen_hamm_mat.py │ │ ├── gen_splice.py │ │ ├── make_blstm_proto.py │ │ ├── make_cnn_proto.py │ │ ├── make_lstm_proto.py │ │ ├── make_nnet_proto.py │ │ └── subset_data_tr_cv.sh │ ├── nnet-cpu/ │ │ ├── make_nnet_config.pl │ │ ├── make_nnet_config_block.pl │ │ ├── make_nnet_config_preconditioned.pl │ │ └── update_learning_rates.pl │ ├── nnet3/ │ │ └── convert_config_tdnn_to_affine.py │ ├── parallel/ │ │ ├── limit_num_gpus.sh │ │ ├── pbs.pl │ │ ├── queue.pl │ │ ├── retry.pl │ │ ├── run.pl │ │ └── slurm.pl │ ├── parse_options.sh │ ├── perturb_data_dir_speed.sh │ ├── pinyin_map.pl │ ├── prepare_extended_lang.sh │ ├── prepare_lang.sh │ ├── prepare_online_nnet_dist_build.sh │ ├── remove_data_links.sh │ ├── remove_oovs.pl │ ├── reverse_arpa.py │ ├── rnnlm_compute_scores.sh │ ├── s2eps.pl │ ├── scoring/ │ │ ├── wer_ops_details.pl │ │ ├── wer_per_spk_details.pl │ │ ├── wer_per_utt_details.pl │ │ └── wer_report.pl │ ├── segmentation.pl │ ├── show_lattice.sh │ ├── shuffle_list.pl │ ├── spk2utt_to_utt2spk.pl │ ├── split_data.sh │ ├── split_scp.pl │ ├── ssh.pl │ ├── subset_data_dir.sh │ ├── subset_scp.pl │ ├── subword/ │ │ ├── prepare_lang_subword.sh │ │ └── prepare_subword_text.sh │ ├── summarize_logs.pl │ ├── summarize_warnings.pl │ ├── sym2int.pl │ ├── utt2spk_to_spk2utt.pl │ ├── validate_data_dir.sh │ ├── validate_dict_dir.pl │ ├── validate_lang.pl │ ├── validate_text.pl │ └── write_kwslist.pl ├── env/ │ └── build_env.sh ├── kaldi ├── lm/ │ ├── __init__.py │ ├── chainer_backend/ │ │ ├── __init__.py │ │ ├── extlm.py │ │ └── lm.py │ ├── lm_utils.py │ └── pytorch_backend/ │ ├── __init__.py │ ├── extlm.py │ └── lm.py ├── mt/ │ ├── __init__.py │ ├── mt_utils.py │ └── pytorch_backend/ │ ├── __init__.py │ └── mt.py ├── nets/ │ ├── __init__.py │ ├── asr_interface.py │ ├── batch_beam_search.py │ ├── batch_beam_search_online_sim.py │ ├── beam_search.py │ ├── beam_search_transducer.py │ ├── chainer_backend/ │ │ ├── __init__.py │ │ ├── asr_interface.py │ │ ├── ctc.py │ │ ├── deterministic_embed_id.py │ │ ├── e2e_asr.py │ │ ├── e2e_asr_transformer.py │ │ ├── nets_utils.py │ │ ├── rnn/ │ │ │ ├── __init__.py │ │ │ ├── attentions.py │ │ │ ├── decoders.py │ │ │ ├── encoders.py │ │ │ └── training.py │ │ └── transformer/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── ctc.py │ │ ├── decoder.py │ │ ├── decoder_layer.py │ │ ├── embedding.py │ │ ├── encoder.py │ │ ├── encoder_layer.py │ │ ├── label_smoothing_loss.py │ │ ├── layer_norm.py │ │ ├── mask.py │ │ ├── positionwise_feed_forward.py │ │ ├── subsampling.py │ │ └── training.py │ ├── ctc_prefix_score.py │ ├── e2e_asr_common.py │ ├── e2e_mt_common.py │ ├── lm_interface.py │ ├── mt_interface.py │ ├── pytorch_backend/ │ │ ├── __init__.py │ │ ├── conformer/ │ │ │ ├── __init__.py │ │ │ ├── argument.py │ │ │ ├── convolution.py │ │ │ ├── encoder.py │ │ │ ├── encoder_layer.py │ │ │ └── swish.py │ │ ├── ctc.py │ │ ├── e2e_asr.py │ │ ├── e2e_asr_conformer.py │ │ ├── e2e_asr_maskctc.py │ │ ├── e2e_asr_mix.py │ │ ├── e2e_asr_mix_transformer.py │ │ ├── e2e_asr_mulenc.py │ │ ├── e2e_asr_transducer.py │ │ ├── e2e_asr_transducer_cs.py │ │ ├── e2e_asr_transformer.py │ │ ├── e2e_mt.py │ │ ├── e2e_mt_transformer.py │ │ ├── e2e_st.py │ │ ├── e2e_st_conformer.py │ │ ├── e2e_st_transformer.py │ │ ├── e2e_tts_fastspeech.py │ │ ├── e2e_tts_tacotron2.py │ │ ├── e2e_tts_transformer.py │ │ ├── e2e_vc_tacotron2.py │ │ ├── e2e_vc_transformer.py │ │ ├── fastspeech/ │ │ │ ├── __init__.py │ │ │ ├── duration_calculator.py │ │ │ ├── duration_predictor.py │ │ │ └── length_regulator.py │ │ ├── frontends/ │ │ │ ├── __init__.py │ │ │ ├── beamformer.py │ │ │ ├── dnn_beamformer.py │ │ │ ├── dnn_wpe.py │ │ │ ├── feature_transform.py │ │ │ ├── frontend.py │ │ │ └── mask_estimator.py │ │ ├── gtn_ctc.py │ │ ├── initialization.py │ │ ├── lm/ │ │ │ ├── __init__.py │ │ │ ├── default.py │ │ │ ├── seq_rnn.py │ │ │ └── transformer.py │ │ ├── maskctc/ │ │ │ ├── __init__.py │ │ │ ├── add_mask_token.py │ │ │ └── mask.py │ │ ├── nets_utils.py │ │ ├── rnn/ │ │ │ ├── __init__.py │ │ │ ├── argument.py │ │ │ ├── attentions.py │ │ │ ├── decoders.py │ │ │ └── encoders.py │ │ ├── streaming/ │ │ │ ├── __init__.py │ │ │ ├── segment.py │ │ │ └── window.py │ │ ├── tacotron2/ │ │ │ ├── __init__.py │ │ │ ├── cbhg.py │ │ │ ├── decoder.py │ │ │ └── encoder.py │ │ ├── transducer/ │ │ │ ├── __init__.py │ │ │ ├── arguments.py │ │ │ ├── auxiliary_task.py │ │ │ ├── blocks.py │ │ │ ├── causal_conv1d.py │ │ │ ├── custom_decoder.py │ │ │ ├── custom_encoder.py │ │ │ ├── error_calculator.py │ │ │ ├── initializer.py │ │ │ ├── joint_network.py │ │ │ ├── loss.py │ │ │ ├── rnn_decoder.py │ │ │ ├── rnn_encoder.py │ │ │ ├── tdnn.py │ │ │ ├── transformer_decoder_layer.py │ │ │ ├── utils.py │ │ │ └── vgg2l.py │ │ ├── transformer/ │ │ │ ├── __init__.py │ │ │ ├── add_sos_eos.py │ │ │ ├── argument.py │ │ │ ├── attention.py │ │ │ ├── contextual_block_encoder_layer.py │ │ │ ├── decoder.py │ │ │ ├── decoder_layer.py │ │ │ ├── dynamic_conv.py │ │ │ ├── dynamic_conv2d.py │ │ │ ├── embedding.py │ │ │ ├── encoder.py │ │ │ ├── encoder_layer.py │ │ │ ├── encoder_mix.py │ │ │ ├── initializer.py │ │ │ ├── label_smoothing_loss.py │ │ │ ├── layer_norm.py │ │ │ ├── lightconv.py │ │ │ ├── lightconv2d.py │ │ │ ├── mask.py │ │ │ ├── multi_layer_conv.py │ │ │ ├── optimizer.py │ │ │ ├── plot.py │ │ │ ├── positionwise_feed_forward.py │ │ │ ├── repeat.py │ │ │ ├── sgd_optimizer.py │ │ │ ├── subsampling.py │ │ │ └── subsampling_without_posenc.py │ │ └── wavenet.py │ ├── scorer_interface.py │ ├── scorers/ │ │ ├── .mmi_rnnt_scorer.py.swp │ │ ├── __init__.py │ │ ├── _mmi_utils.py │ │ ├── ctc.py │ │ ├── ctc_rnnt_scorer.py │ │ ├── length_bonus.py │ │ ├── lookahead.py │ │ ├── mmi.py │ │ ├── mmi_alignment_score.py │ │ ├── mmi_frame_prefix_scorer.py │ │ ├── mmi_frame_scorer.py │ │ ├── mmi_frame_scorer_trace.py │ │ ├── mmi_lookahead.py │ │ ├── mmi_lookahead_bak.py │ │ ├── mmi_lookahead_split.py │ │ ├── mmi_prefix_score.py │ │ ├── mmi_rescorer.py │ │ ├── mmi_rnnt_lookahead_scorer.py │ │ ├── mmi_rnnt_scorer.py │ │ ├── mmi_utils.py │ │ ├── new_mmi_frame_scorer.py │ │ ├── ngram.py │ │ ├── sorted_matcher.py │ │ ├── test.py │ │ ├── tlg_scorer.py │ │ ├── trace_frame.py │ │ └── word_ngram.py │ ├── st_interface.py │ ├── transducer_decoder_interface.py │ └── tts_interface.py ├── optimizer/ │ ├── __init__.py │ ├── chainer.py │ ├── factory.py │ ├── parser.py │ └── pytorch.py ├── scheduler/ │ ├── __init__.py │ ├── chainer.py │ ├── pytorch.py │ └── scheduler.py ├── snowfall/ │ ├── __init__.py │ ├── common.py │ ├── data/ │ │ ├── __init__.py │ │ ├── aishell.py │ │ ├── asr_datamodule.py │ │ ├── datamodule.py │ │ └── librispeech.py │ ├── decoding/ │ │ ├── __init__.py │ │ ├── graph.py │ │ └── lm_rescore.py │ ├── dist.py │ ├── lexicon.py │ ├── models/ │ │ ├── __init__.py │ │ ├── conformer.py │ │ ├── contextnet.py │ │ ├── interface.py │ │ ├── tdnn.py │ │ ├── tdnn_lstm.py │ │ ├── tdnnf.py │ │ └── transformer.py │ ├── objectives/ │ │ ├── __init__.py │ │ ├── common.py │ │ ├── ctc.py │ │ └── mmi.py │ ├── training/ │ │ ├── __init__.py │ │ ├── ctc_graph.py │ │ ├── diagnostics.py │ │ ├── mmi_graph.py │ │ └── mmi_mbr_graph.py │ └── warpper/ │ ├── k2_decode.py │ ├── mmi_test.py │ ├── mmi_utils.py │ ├── prefix_scorer.py │ ├── warpper_ctc.py │ └── warpper_mmi.py ├── st/ │ ├── __init__.py │ └── pytorch_backend/ │ ├── __init__.py │ └── st.py ├── transform/ │ ├── __init__.py │ ├── add_deltas.py │ ├── channel_selector.py │ ├── cmvn.py │ ├── functional.py │ ├── perturb.py │ ├── spec_augment.py │ ├── spectrogram.py │ ├── transform_interface.py │ ├── transformation.py │ └── wpe.py ├── tts/ │ ├── __init__.py │ └── pytorch_backend/ │ ├── __init__.py │ └── tts.py ├── utils/ │ ├── __init__.py │ ├── bmuf.py │ ├── check_kwargs.py │ ├── cli_readers.py │ ├── cli_utils.py │ ├── cli_writers.py │ ├── dataset.py │ ├── deterministic_utils.py │ ├── draw_num_fst.py │ ├── dynamic_import.py │ ├── fill_missing_args.py │ ├── io_utils.py │ ├── parse_decoding_process.py │ ├── parse_npy.py │ ├── print.py │ ├── rtf_calculator.py │ ├── sampler.py │ ├── spec_augment.py │ └── training/ │ ├── __init__.py │ ├── batchfy.py │ ├── evaluator.py │ ├── iterators.py │ ├── tensorboard_logger.py │ └── train_utils.py ├── vc/ │ └── pytorch_backend/ │ └── vc.py └── version.txt
Showing preview only (280K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (3493 symbols across 424 files)
FILE: asr/asr_mix_utils.py
class PlotAttentionReport (line 30) | class PlotAttentionReport(extension.Extension):
method __init__ (line 45) | def __init__(self, att_vis_fn, data, outdir, converter, device, revers...
method __call__ (line 56) | def __call__(self, trainer):
method log_attentions (line 69) | def log_attentions(self, logger, step):
method get_attention_weights (line 79) | def get_attention_weights(self):
method get_attention_weight (line 94) | def get_attention_weight(self, idx, att_w, spkr_idx):
method draw_attention_plot (line 108) | def draw_attention_plot(self, att_w):
method _plot_and_save_attention (line 133) | def _plot_and_save_attention(self, att_w, filename):
function add_results_to_json (line 139) | def add_results_to_json(js, nbest_hyps_sd, char_list):
FILE: asr/asr_utils.py
class CompareValueTrigger (line 18) | class CompareValueTrigger(object):
method __init__ (line 28) | def __init__(self, key, compare_fn, trigger=(1, "epoch")):
method __call__ (line 37) | def __call__(self, trainer):
method _init_summary (line 62) | def _init_summary(self):
class PlotAttentionReport (line 74) | class PlotAttentionReport(extension.Extension):
method __init__ (line 98) | def __init__(
method __call__ (line 130) | def __call__(self, trainer):
method log_attentions (line 180) | def log_attentions(self, logger, step):
method get_attention_weights (line 210) | def get_attention_weights(self):
method trim_attention_weight (line 229) | def trim_attention_weight(self, uttid, att_w):
method draw_attention_plot (line 247) | def draw_attention_plot(self, att_w):
method draw_han_plot (line 274) | def draw_han_plot(self, att_w):
method _plot_and_save_attention (line 314) | def _plot_and_save_attention(self, att_w, filename, han_mode=False):
class PlotCTCReport (line 329) | class PlotCTCReport(extension.Extension):
method __init__ (line 353) | def __init__(
method __call__ (line 385) | def __call__(self, trainer):
method log_ctc_probs (line 419) | def log_ctc_probs(self, logger, step):
method get_ctc_probs (line 439) | def get_ctc_probs(self):
method trim_ctc_prob (line 455) | def trim_ctc_prob(self, uttid, prob):
method draw_ctc_plot (line 463) | def draw_ctc_plot(self, ctc_prob):
method _plot_and_save_ctc (line 499) | def _plot_and_save_ctc(self, ctc_prob, filename):
function restore_snapshot (line 505) | def restore_snapshot(model, snapshot, load_fn=None):
function _restore_snapshot (line 525) | def _restore_snapshot(model, snapshot, load_fn=None):
function adadelta_eps_decay (line 535) | def adadelta_eps_decay(eps_decay):
function _adadelta_eps_decay (line 554) | def _adadelta_eps_decay(trainer, eps_decay):
function adam_lr_decay (line 568) | def adam_lr_decay(eps_decay):
function _adam_lr_decay (line 587) | def _adam_lr_decay(trainer, eps_decay):
function torch_snapshot (line 601) | def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.e...
function _torch_snapshot_object (line 617) | def _torch_snapshot_object(trainer, target, filename, savefun):
function add_gradient_noise (line 661) | def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_fa...
function get_model_conf (line 685) | def get_model_conf(model_path, conf_path=None):
function chainer_load (line 713) | def chainer_load(path, model):
function torch_save (line 729) | def torch_save(path, model):
function snapshot_object (line 743) | def snapshot_object(target, filename):
function torch_load (line 767) | def torch_load(path, model):
function torch_resume (line 790) | def torch_resume(snapshot_path, trainer, load_trainer_and_opt=True):
function parse_hypothesis (line 836) | def parse_hypothesis(hyp, char_list):
function add_results_to_json (line 860) | def add_results_to_json(js, nbest_hyps, char_list):
function plot_spectrogram (line 921) | def plot_spectrogram(
function format_mulenc_args (line 1001) | def format_mulenc_args(args):
FILE: asr/chainer_backend/asr.py
function train (line 51) | def train(args):
function recog (line 480) | def recog(args):
FILE: asr/pytorch_backend/asr.py
function _recursive_to (line 77) | def _recursive_to(xs, device):
function is_alphabet (line 84) | def is_alphabet(char):
class CustomEvaluator (line 90) | class CustomEvaluator(BaseEvaluator):
method __init__ (line 106) | def __init__(self, model, iterator, target, device, ngpu=None):
method evaluate (line 118) | def evaluate(self):
class CustomUpdater (line 155) | class CustomUpdater(StandardUpdater):
method __init__ (line 170) | def __init__(
method update_core (line 197) | def update_core(self):
method update (line 261) | def update(self):
class CustomConverter (line 269) | class CustomConverter(object):
method __init__ (line 278) | def __init__(self, subsampling_factor=1, dtype=torch.float32):
method __call__ (line 284) | def __call__(self, batch, device=torch.device("cpu")):
class CustomConverterMulEnc (line 344) | class CustomConverterMulEnc(object):
method __init__ (line 353) | def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32):
method __call__ (line 360) | def __call__(self, batch, device=torch.device("cpu")):
function train (line 412) | def train(args):
function recog (line 1016) | def recog(args):
function enhance (line 1368) | def enhance(args):
function ctc_align (line 1587) | def ctc_align(args):
FILE: asr/pytorch_backend/asr_init.py
function freeze_modules (line 18) | def freeze_modules(model, modules):
function transfer_verification (line 40) | def transfer_verification(model_state_dict, partial_state_dict, modules):
function get_partial_state_dict (line 72) | def get_partial_state_dict(model_state_dict, modules):
function get_lm_state_dict (line 94) | def get_lm_state_dict(lm_state_dict):
function filter_modules (line 118) | def filter_modules(model_state_dict, modules):
function load_trained_model (line 151) | def load_trained_model(model_path, training=True):
function _load_trained_model (line 185) | def _load_trained_model(model_path, training=True, patience=10):
function get_trained_model_state_dict (line 195) | def get_trained_model_state_dict(model_path):
function load_trained_modules (line 233) | def load_trained_modules(idim, odim, args, interface=ASRInterface):
FILE: asr/pytorch_backend/asr_mix.py
class CustomConverter (line 54) | class CustomConverter(object):
method __init__ (line 63) | def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkr...
method __call__ (line 70) | def __call__(self, batch, device=torch.device("cpu")):
function train (line 130) | def train(args):
function recog (line 532) | def recog(args):
FILE: asr/pytorch_backend/recog.py
function recog_v2 (line 27) | def recog_v2(args):
FILE: bin/asr_align.py
function get_parser (line 64) | def get_parser():
function main (line 184) | def main(args):
function ctc_align (line 228) | def ctc_align(args, device):
FILE: bin/asr_enhance.py
function get_parser (line 15) | def get_parser():
function main (line 138) | def main(args):
FILE: bin/asr_recog.py
function get_parser (line 22) | def get_parser():
function main (line 373) | def main(args):
FILE: bin/asr_train.py
function get_parser (line 29) | def get_parser(parser=None, required=True):
function main (line 576) | def main(cmd_args):
FILE: bin/lm_train.py
function get_parser (line 27) | def get_parser(parser=None, required=True):
function main (line 190) | def main(cmd_args):
FILE: bin/mt_train.py
function get_parser (line 29) | def get_parser(parser=None, required=True):
function main (line 375) | def main(cmd_args):
FILE: bin/mt_trans.py
function get_parser (line 19) | def get_parser():
function main (line 125) | def main(args):
FILE: bin/st_train.py
function get_parser (line 28) | def get_parser(parser=None, required=True):
function main (line 445) | def main(cmd_args):
FILE: bin/st_trans.py
function get_parser (line 19) | def get_parser():
function main (line 122) | def main(args):
FILE: bin/tts_decode.py
function get_parser (line 19) | def get_parser():
function main (line 118) | def main(args):
FILE: bin/tts_train.py
function get_parser (line 24) | def get_parser():
function main (line 294) | def main(cmd_args):
FILE: bin/vc_decode.py
function get_parser (line 19) | def get_parser():
function main (line 112) | def main(args):
FILE: bin/vc_train.py
function get_parser (line 24) | def get_parser():
function main (line 304) | def main(cmd_args):
FILE: egs/aishell1/local/make_lexicon_fst.py
function get_args (line 21) | def get_args():
function read_lexiconp (line 60) | def read_lexiconp(filename):
function write_nonterminal_arcs (line 119) | def write_nonterminal_arcs(start_state, loop_state, next_state,
function write_fst_no_silence (line 173) | def write_fst_no_silence(lexicon, nonterminals=None, left_context_phones...
function write_fst_with_silence (line 220) | def write_fst_with_silence(lexicon, sil_prob, sil_phone, sil_disambig,
function write_words_txt (line 308) | def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, f...
function read_nonterminals (line 322) | def read_nonterminals(filename):
function read_left_context_phones (line 338) | def read_left_context_phones(filename):
function is_token (line 354) | def is_token(s):
function main (line 363) | def main():
FILE: egs/aishell2/local/make_lexicon_fst.py
function get_args (line 21) | def get_args():
function read_lexiconp (line 60) | def read_lexiconp(filename):
function write_nonterminal_arcs (line 119) | def write_nonterminal_arcs(start_state, loop_state, next_state,
function write_fst_no_silence (line 173) | def write_fst_no_silence(lexicon, nonterminals=None, left_context_phones...
function write_fst_with_silence (line 220) | def write_fst_with_silence(lexicon, sil_prob, sil_phone, sil_disambig,
function write_words_txt (line 308) | def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, f...
function read_nonterminals (line 322) | def read_nonterminals(filename):
function read_left_context_phones (line 338) | def read_left_context_phones(filename):
function is_token (line 354) | def is_token(s):
function main (line 363) | def main():
FILE: egs/asrucs/local/add_seperator.py
function is_all_chinese (line 3) | def is_all_chinese(strs):
FILE: egs/asrucs/local/generate_fake_cs.py
function read_datadir (line 9) | def read_datadir(d):
function generate_pairs (line 27) | def generate_pairs(chn_dur_dict, eng_dur_dict, dur_maximum):
function write_utts (line 51) | def write_utts(tgt_dir, pair_list, wav_dict, text_dict):
function main (line 81) | def main():
FILE: egs/espnet_utils/add_uttcls_json.py
function main (line 4) | def main():
FILE: egs/espnet_utils/addjson.py
function get_parser (line 21) | def get_parser():
FILE: egs/espnet_utils/apply-cmvn.py
function get_parser (line 16) | def get_parser():
function main (line 102) | def main():
FILE: egs/espnet_utils/average_checkpoints.py
function main (line 11) | def main():
function get_parser (line 120) | def get_parser():
FILE: egs/espnet_utils/calculate_rtf.py
function get_parser (line 14) | def get_parser():
function main (line 25) | def main():
FILE: egs/espnet_utils/change_yaml.py
function get_parser (line 8) | def get_parser():
function main (line 35) | def main():
FILE: egs/espnet_utils/compute-cmvn-stats.py
function get_parser (line 15) | def get_parser():
function main (line 63) | def main():
FILE: egs/espnet_utils/compute-fbank-feats.py
function get_parser (line 20) | def get_parser():
function main (line 92) | def main():
FILE: egs/espnet_utils/compute-stft-feats.py
function get_parser (line 20) | def get_parser():
function main (line 84) | def main():
FILE: egs/espnet_utils/concat_json_multiref.py
function get_parser (line 16) | def get_parser():
function main (line 25) | def main():
FILE: egs/espnet_utils/concatjson.py
function get_parser (line 19) | def get_parser():
function truncate_tail (line 29) | def truncate_tail(d, size):
FILE: egs/espnet_utils/convert_fbank_to_wav.py
function logmelspc_to_linearspc (line 23) | def logmelspc_to_linearspc(lmspc, fs, n_mels, n_fft, fmin=None, fmax=None):
function griffin_lim (line 49) | def griffin_lim(spc, n_fft, n_shift, win_length, window="hann", n_iters=...
function get_parser (line 98) | def get_parser():
function main (line 147) | def main():
FILE: egs/espnet_utils/copy-feats.py
function get_parser (line 13) | def get_parser():
function main (line 63) | def main():
FILE: egs/espnet_utils/dump-pcm.py
function get_parser (line 14) | def get_parser():
function main (line 81) | def main():
FILE: egs/espnet_utils/eval-source-separation.py
function eval_STOI (line 23) | def eval_STOI(ref, y, fs, extended=False, compute_permutation=True):
function eval_PESQ (line 75) | def eval_PESQ(ref, enh, fs, compute_permutation: bool = True, wideband: ...
function get_parser (line 178) | def get_parser():
function main (line 244) | def main():
FILE: egs/espnet_utils/eval_perm_free_error.py
function permutationDFS (line 17) | def permutationDFS(source, start, res):
function permutation_schemes (line 32) | def permutation_schemes(num_spkrs):
function convert_score (line 50) | def convert_score(keys, dic):
function get_utt_permutation (line 61) | def get_utt_permutation(old_dic, num_spkrs=2):
function get_results (line 90) | def get_results(result_file, result_key):
function merge_results (line 129) | def merge_results(results):
function get_parser (line 161) | def get_parser():
function main (line 177) | def main():
FILE: egs/espnet_utils/feat-to-shape.py
function get_parser (line 12) | def get_parser():
function main (line 45) | def main():
FILE: egs/espnet_utils/feats2npy.py
function get_parser (line 12) | def get_parser():
FILE: egs/espnet_utils/filt.py
function get_parser (line 12) | def get_parser():
function main (line 29) | def main(args):
function filter_file (line 34) | def filter_file(infile, filt, exclude):
FILE: egs/espnet_utils/filter_all_eng_utts.py
function is_all_chinese (line 3) | def is_all_chinese(strs):
FILE: egs/espnet_utils/generate_wav_from_fbank.py
class TimeInvariantMLSAFilter (line 29) | class TimeInvariantMLSAFilter(object):
method __init__ (line 47) | def __init__(self, coef, alpha, n_shift):
method __call__ (line 55) | def __call__(self, y):
function get_parser (line 76) | def get_parser():
function main (line 100) | def main():
FILE: egs/espnet_utils/get_yaml.py
function get_parser (line 7) | def get_parser():
function main (line 19) | def main():
FILE: egs/espnet_utils/json2sctm.py
function get_parser (line 13) | def get_parser():
function main (line 31) | def main(args):
function wrd_name (line 96) | def wrd_name(trn):
FILE: egs/espnet_utils/json2text.py
function get_parser (line 15) | def get_parser():
FILE: egs/espnet_utils/json2trn.py
function get_parser (line 17) | def get_parser():
function main (line 30) | def main(args):
function convert (line 35) | def convert(jsonf, dic, refs, hyps, num_spkrs=1):
FILE: egs/espnet_utils/json2trn_mt.py
function get_parser (line 18) | def get_parser():
function main (line 38) | def main(args):
function convert (line 44) | def convert(jsonf, dic, refs, hyps, srcs, dic_src):
FILE: egs/espnet_utils/json2trn_wo_dict.py
function get_parser (line 16) | def get_parser():
function main (line 28) | def main(args):
function convert (line 33) | def convert(jsonf, refs, hyps, num_spkrs=1):
FILE: egs/espnet_utils/make_pair_json.py
function get_parser (line 16) | def get_parser():
FILE: egs/espnet_utils/mbr_analysis.py
function main (line 11) | def main():
FILE: egs/espnet_utils/mcd_calculate.py
function find_files (line 24) | def find_files(root_dir, query="*.wav", include_root_dir=True):
function low_cut_filter (line 46) | def low_cut_filter(x, fs, cutoff=70):
function spc2npow (line 68) | def spc2npow(spectrogram):
function _spvec2pow (line 92) | def _spvec2pow(specvec):
function extfrm (line 120) | def extfrm(data, npow, power_threshold=-20):
function world_extract (line 152) | def world_extract(wav_path, args):
function get_basename (line 175) | def get_basename(path):
function calculate (line 179) | def calculate(file_list, gt_file_list, args, MCD):
function get_parser (line 213) | def get_parser():
function main (line 251) | def main():
FILE: egs/espnet_utils/merge_scp2json.py
function shape (line 21) | def shape(x):
function get_parser (line 41) | def get_parser():
FILE: egs/espnet_utils/mergejson.py
function get_parser (line 20) | def get_parser():
FILE: egs/espnet_utils/mix-mono-wav-scp.py
function get_parser (line 14) | def get_parser():
function main (line 30) | def main():
FILE: egs/espnet_utils/prepare_mer.py
function is_all_chinese (line 3) | def is_all_chinese(strs):
FILE: egs/espnet_utils/result2json.py
function get_parser (line 17) | def get_parser():
FILE: egs/espnet_utils/score_lang_id.py
function get_parser (line 12) | def get_parser():
function main (line 25) | def main(args):
function scoring (line 30) | def scoring(ref, hyp, out):
FILE: egs/espnet_utils/scp2json.py
function get_parser (line 15) | def get_parser():
FILE: egs/espnet_utils/splitjson.py
function get_parser (line 20) | def get_parser():
FILE: egs/espnet_utils/text2token.py
function exist_or_not (line 15) | def exist_or_not(i, match_pos):
function get_parser (line 27) | def get_parser():
function main (line 68) | def main():
FILE: egs/espnet_utils/text2vocabulary.py
function get_parser (line 15) | def get_parser():
FILE: egs/espnet_utils/text_norm.py
function remove_punc (line 13) | def remove_punc(s):
function is_all_chinese (line 18) | def is_all_chinese(strs):
function is_contain_chinese (line 24) | def is_contain_chinese(check_str):
function splice (line 30) | def splice(s):
function digit_norm (line 63) | def digit_norm(s):
function remove_blank_chn (line 85) | def remove_blank_chn(s):
function add_blank_boundary (line 99) | def add_blank_boundary(s):
function split_eng_words (line 114) | def split_eng_words(s):
function split_chn_words (line 125) | def split_chn_words(s):
function upper_or_lower (line 135) | def upper_or_lower(s, upper=True):
function process_one_line (line 141) | def process_one_line(content, args):
function get_parser (line 181) | def get_parser():
function main (line 204) | def main():
FILE: egs/espnet_utils/trace_rnnt.py
function main (line 23) | def main():
FILE: egs/espnet_utils/trim_silence.py
function _time_to_str (line 21) | def _time_to_str(time_idx):
function get_parser (line 26) | def get_parser():
function main (line 63) | def main():
FILE: egs/espnet_utils/trn2ctm.py
function get_parser (line 12) | def get_parser():
function main (line 19) | def main(args):
function convert (line 24) | def convert(trn=None, ctm=None):
FILE: egs/espnet_utils/trn2stm.py
function get_parser (line 11) | def get_parser():
function main (line 25) | def main(args):
function convert (line 30) | def convert(trn=None, stm=None, orig_stm=None):
FILE: egs/espnet_utils/word_ngram_rescore.py
function score_texts (line 11) | def score_texts(ngram, texts, ignore_strs=["<eos>", " "]):
function main (line 16) | def main():
FILE: egs/steps/cleanup/combine_short_segments.py
function GetArgs (line 16) | def GetArgs():
function RunKaldiCommand (line 34) | def RunKaldiCommand(command, wait = True):
function MakeDir (line 49) | def MakeDir(dir):
function CheckFiles (line 58) | def CheckFiles(input_data_dir):
function ParseFileToDict (line 64) | def ParseFileToDict(file, assert2fields = False, value_processor = None):
function WriteDictToFile (line 77) | def WriteDictToFile(dict, file_name):
function ParseDataDirInfo (line 92) | def ParseDataDirInfo(data_dir):
function GetCombinedUttIndexRange (line 108) | def GetCombinedUttIndexRange(utt_index, utts, utt_durs, minimum_duration):
function WriteCombinedDirFiles (line 175) | def WriteCombinedDirFiles(output_dir, utt2spk, spk2utt, text, feat, utt2...
function CombineSegments (line 242) | def CombineSegments(input_dir, output_dir, minimum_duration):
function Main (line 298) | def Main():
FILE: egs/steps/cleanup/internal/align_ctm_ref.py
function get_args (line 33) | def get_args():
function read_text (line 115) | def read_text(text_file):
function read_ctm (line 139) | def read_ctm(ctm_file, file_and_channel2reco=None):
function smith_waterman_alignment (line 183) | def smith_waterman_alignment(ref, hyp, similarity_score_function,
function print_alignment (line 355) | def print_alignment(recording, alignment, out_file_handle):
function get_edit_type (line 367) | def get_edit_type(hyp_word, ref_word, duration=-1, eps_symbol='<eps>',
function get_ctm_edits (line 388) | def get_ctm_edits(alignment_output, ctm_array, eps_symbol="<eps>",
function ctm_line_to_string (line 491) | def ctm_line_to_string(ctm_line):
function test_alignment (line 499) | def test_alignment(align_full_hyp):
function run (line 514) | def run(args):
function main (line 609) | def main():
FILE: egs/steps/cleanup/internal/compute_tf_idf.py
function _get_args (line 21) | def _get_args():
function _run (line 87) | def _run(args):
function main (line 130) | def main():
FILE: egs/steps/cleanup/internal/get_ctm_edits.py
function OpenFiles (line 116) | def OpenFiles():
function PadArrays (line 173) | def PadArrays(edits_array, ctm_array):
function GetEditType (line 223) | def GetEditType(hyp_word, ref_word, duration):
function FloatToString (line 245) | def FloatToString(f):
function OutputCtm (line 255) | def OutputCtm(utterance_id, edits_array, ctm_array):
function ProcessOneUtterance (line 275) | def ProcessOneUtterance(utterance_id, edits_line, ctm_lines):
function ProcessData (line 319) | def ProcessData():
FILE: egs/steps/cleanup/internal/get_non_scored_words.py
function read_lang (line 56) | def read_lang(lang_dir):
FILE: egs/steps/cleanup/internal/get_pron_stats.py
function GetArgs (line 27) | def GetArgs():
function CheckArgs (line 54) | def CheckArgs(args):
function ReadEntries (line 68) | def ReadEntries(file_handle):
function GetStatsFromCtmProns (line 122) | def GetStatsFromCtmProns(silphones, optional_silence, non_scored_words, ...
function WriteStats (line 214) | def WriteStats(stats, file_handle):
function Main (line 219) | def Main():
FILE: egs/steps/cleanup/internal/make_one_biased_lm.py
class NgramCounts (line 56) | class NgramCounts(object):
method __init__ (line 67) | def __init__(self, ngram_order):
method AddCount (line 96) | def AddCount(self, history, predicted_word, count):
method AddRawCountsFromLine (line 101) | def AddRawCountsFromLine(self, line):
method AddRawCountsFromStandardInput (line 114) | def AddRawCountsFromStandardInput(self):
method GetHistToTotalCount (line 131) | def GetHistToTotalCount(self):
method CompletelyDiscountLowCountStates (line 146) | def CompletelyDiscountLowCountStates(self, min_count):
method ApplyBackoff (line 165) | def ApplyBackoff(self, D):
method Print (line 185) | def Print(self, info_string):
method AddTopWords (line 205) | def AddTopWords(self, top_words_file):
method GetTotalCountMap (line 230) | def GetTotalCountMap(self):
method GetHistToStateMap (line 239) | def GetHistToStateMap(self):
method GetProb (line 250) | def GetProb(self, hist, word, total_count_map):
method PrintAsFst (line 261) | def PrintAsFst(self, word_disambig_symbol):
FILE: egs/steps/cleanup/internal/modify_ctm_edits.py
function ReadNonScoredWords (line 105) | def ReadNonScoredWords(non_scored_words_file):
function ProcessLineForNonScoredWords (line 134) | def ProcessLineForNonScoredWords(a):
function ProcessUtteranceForRepetitions (line 185) | def ProcessUtteranceForRepetitions(split_lines_of_utt):
function ProcessUtterance (line 307) | def ProcessUtterance(split_lines_of_utt):
function ProcessData (line 318) | def ProcessData():
function PrintNonScoredStats (line 367) | def PrintNonScoredStats():
function PrintRepetitionStats (line 400) | def PrintRepetitionStats():
FILE: egs/steps/cleanup/internal/resolve_ctm_edits_overlaps.py
function get_args (line 34) | def get_args():
function read_segments (line 56) | def read_segments(segments_file):
function read_ctm_edits (line 79) | def read_ctm_edits(ctm_edits_file, segments):
function wer (line 127) | def wer(ctm_edit_lines):
function choose_best_ctm_lines (line 142) | def choose_best_ctm_lines(first_lines, second_lines,
function resolve_overlaps (line 153) | def resolve_overlaps(ctm_edits, segments):
function ctm_edit_line_to_string (line 285) | def ctm_edit_line_to_string(line):
function write_ctm_edits (line 292) | def write_ctm_edits(ctm_edit_lines, out_file):
function run (line 298) | def run(args):
function main (line 325) | def main():
FILE: egs/steps/cleanup/internal/retrieve_similar_docs.py
function get_args (line 99) | def get_args():
function read_map (line 164) | def read_map(file_handle, num_values_per_key=None,
function get_document_ids (line 224) | def get_document_ids(source_docs, indexes):
function run (line 238) | def run(args):
function main (line 344) | def main():
FILE: egs/steps/cleanup/internal/segment_ctm_edits.py
function IsTainted (line 119) | def IsTainted(split_line_of_utt):
function ComputeSegmentCores (line 136) | def ComputeSegmentCores(split_lines_of_utt):
class Segment (line 176) | class Segment(object):
method __init__ (line 177) | def __init__(self, split_lines_of_utt, start_index, end_index, debug_s...
method PossiblyAddTaintedLines (line 213) | def PossiblyAddTaintedLines(self):
method PossiblySplitSegment (line 253) | def PossiblySplitSegment(self):
method PossiblyTruncateBoundaries (line 300) | def PossiblyTruncateBoundaries(self):
method RelaxBoundaryTruncation (line 330) | def RelaxBoundaryTruncation(self):
method PossiblyAddUnkPadding (line 381) | def PossiblyAddUnkPadding(self):
method MergeWithSegment (line 432) | def MergeWithSegment(self, other):
method StartTime (line 465) | def StartTime(self):
method DebugInfo (line 474) | def DebugInfo(self):
method EndTime (line 481) | def EndTime(self):
method Length (line 489) | def Length(self):
method IsWholeUtterance (line 492) | def IsWholeUtterance(self):
method JunkProportion (line 503) | def JunkProportion(self):
method PossiblyTruncateStartForJunkProportion (line 526) | def PossiblyTruncateStartForJunkProportion(self):
method PossiblyTruncateEndForJunkProportion (line 567) | def PossiblyTruncateEndForJunkProportion(self):
method ContainsAtLeastOneScoredNonOovWord (line 611) | def ContainsAtLeastOneScoredNonOovWord(self):
method Text (line 624) | def Text(self):
function AccumulateSegmentStats (line 643) | def AccumulateSegmentStats(segment_list, text):
function PrintSegmentStats (line 649) | def PrintSegmentStats():
function GetSegmentsForUtterance (line 680) | def GetSegmentsForUtterance(split_lines_of_utt):
function FloatToString (line 789) | def FloatToString(f):
function TimeToString (line 800) | def TimeToString(time, frame_length):
function WriteSegmentsForUtterance (line 810) | def WriteSegmentsForUtterance(text_output_handle, segments_output_handle,
function PrintDebugInfoForUtterance (line 836) | def PrintDebugInfoForUtterance(ctm_edits_out_handle,
function AccWordStatsForUtterance (line 880) | def AccWordStatsForUtterance(split_lines_of_utt,
function PrintWordStats (line 896) | def PrintWordStats(word_stats_out):
function ProcessData (line 926) | def ProcessData():
function ReadNonScoredWords (line 995) | def ReadNonScoredWords(non_scored_words_file):
FILE: egs/steps/cleanup/internal/segment_ctm_edits_mild.py
function non_scored_words (line 44) | def non_scored_words():
function get_args (line 48) | def get_args():
function is_tainted (line 245) | def is_tainted(split_line_of_utt):
function compute_segment_cores (line 250) | def compute_segment_cores(split_lines_of_utt):
class SegmentStats (line 315) | class SegmentStats(object):
method __init__ (line 318) | def __init__(self):
method wer (line 327) | def wer(self):
method bad_proportion (line 334) | def bad_proportion(self):
method incorrect_proportion (line 342) | def incorrect_proportion(self):
method combine (line 349) | def combine(self, other, scale=1):
method assert_equal (line 360) | def assert_equal(self, other):
method compare (line 375) | def compare(self, other):
method __str__ (line 393) | def __str__(self):
class Segment (line 410) | class Segment(object):
method __init__ (line 413) | def __init__(self, split_lines_of_utt, start_index, end_index,
method copy (line 458) | def copy(self, copy_stats=True):
method __str__ (line 469) | def __str__(self):
method compute_stats (line 472) | def compute_stats(self):
method possibly_add_tainted_lines (line 532) | def possibly_add_tainted_lines(self):
method possibly_split_segment (line 583) | def possibly_split_segment(self, max_internal_silence_length,
method possibly_split_long_segment (line 646) | def possibly_split_long_segment(self, max_segment_length,
method possibly_truncate_boundaries (line 789) | def possibly_truncate_boundaries(self, max_edge_silence_length,
method relax_boundary_truncation (line 826) | def relax_boundary_truncation(self, min_segment_length,
method possibly_add_unk_padding (line 890) | def possibly_add_unk_padding(self, max_unk_padding):
method start_time (line 940) | def start_time(self):
method debug_info (line 955) | def debug_info(self, include_stats=True):
method end_time (line 978) | def end_time(self):
method length (line 991) | def length(self):
method is_whole_utterance (line 995) | def is_whole_utterance(self):
method get_junk_proportion (line 1005) | def get_junk_proportion(self):
method get_junk_duration (line 1024) | def get_junk_duration(self):
method merge_adjacent_segment (line 1028) | def merge_adjacent_segment(self, other):
method merge_with_segment (line 1073) | def merge_with_segment(self, other, max_deleted_words):
method contains_atleast_one_scored_non_oov_word (line 1163) | def contains_atleast_one_scored_non_oov_word(self):
method text (line 1179) | def text(self, oov_symbol, eps_symbol="<eps_symbol>"):
class SegmentsMerger (line 1195) | class SegmentsMerger(object):
method __init__ (line 1209) | def __init__(self, segments):
method _get_merged_cluster (line 1240) | def _get_merged_cluster(self, cluster1, cluster2, rejected_clusters=None,
method merge_clusters (line 1306) | def merge_clusters(self, scoring_function,
function merge_segments (line 1402) | def merge_segments(segments, args):
function get_segments_for_utterance (line 1502) | def get_segments_for_utterance(split_lines_of_utt, args, utterance_stats):
function float_to_string (line 1741) | def float_to_string(f):
function time_to_string (line 1754) | def time_to_string(time, frame_length):
function write_segments_for_utterance (line 1768) | def write_segments_for_utterance(text_output_handle, segments_output_han...
function print_debug_info_for_utterance (line 1794) | def print_debug_info_for_utterance(ctm_edits_out_handle,
class WordStats (line 1840) | class WordStats(object):
method __init__ (line 1848) | def __init__(self):
method accumulate_for_utterance (line 1851) | def accumulate_for_utterance(self, split_lines_of_utt,
method print (line 1867) | def print(self, word_stats_out):
function process_data (line 1896) | def process_data(args, oov_symbol, utterance_stats, word_stats):
function read_non_scored_words (line 1951) | def read_non_scored_words(non_scored_words_file):
class UtteranceStats (line 1962) | class UtteranceStats(object):
method __init__ (line 1964) | def __init__(self):
method accumulate_segment_stats (line 1975) | def accumulate_segment_stats(self, segment_list, text):
method print_segment_stats (line 1984) | def print_segment_stats(self):
function main (line 2010) | def main():
FILE: egs/steps/cleanup/internal/stitch_documents.py
function get_args (line 46) | def get_args():
function run (line 82) | def run(args):
function main (line 140) | def main():
FILE: egs/steps/cleanup/internal/taint_ctm_edits.py
function ProcessUtterance (line 81) | def ProcessUtterance(split_lines_of_utt, remove_deletions=True):
function ProcessData (line 141) | def ProcessData():
function PrintNonScoredStats (line 191) | def PrintNonScoredStats():
function PrintStats (line 220) | def PrintStats():
FILE: egs/steps/cleanup/internal/tf_idf.py
class IDFStats (line 21) | class IDFStats(object):
method __init__ (line 24) | def __init__(self):
method get_inverse_document_frequency (line 28) | def get_inverse_document_frequency(self, term, weighting_scheme="log"):
method accumulate (line 57) | def accumulate(self, term):
method write (line 64) | def write(self, file_handle):
method read (line 76) | def read(self, file_handle):
class TFStats (line 89) | class TFStats(object):
method __init__ (line 93) | def __init__(self):
method get_term_frequency (line 97) | def get_term_frequency(self, term, doc, weighting_scheme="raw",
method accumulate (line 124) | def accumulate(self, doc, text, ngram_order):
method compute_term_stats (line 133) | def compute_term_stats(self, idf_stats=None):
method __str__ (line 147) | def __str__(self):
method read (line 159) | def read(self, file_handle, ngram_order=None, idf_stats=None):
class TFIDF (line 187) | class TFIDF(object):
method __init__ (line 195) | def __init__(self):
method get_value (line 198) | def get_value(self, term, doc):
method compute_similarity_scores (line 204) | def compute_similarity_scores(self, source_tfidf, source_docs=None,
method read (line 273) | def read(self, tf_idf_file):
method write (line 327) | def write(self, tf_idf_file):
function write_tfidf_from_stats (line 340) | def write_tfidf_from_stats(
function read_key (line 392) | def read_key(fd):
function read_tfidf_ark (line 407) | def read_tfidf_ark(file_handle):
FILE: egs/steps/cleanup/make_biased_lms.py
function ProcessGroupOfLines (line 47) | def ProcessGroupOfLines(group_of_lines):
FILE: egs/steps/data/augment_data_dir.py
function get_args (line 21) | def get_args():
function check_args (line 71) | def check_args(args):
function get_noise_list (line 93) | def get_noise_list(noise_wav_scp_filename):
function augment_wav (line 104) | def augment_wav(utt, wav, dur, fg_snr_opts, bg_snr_opts, fg_noise_utts, \
function get_new_id (line 153) | def get_new_id(utt, utt_modifier_type, utt_modifier):
function copy_file_if_exists (line 167) | def copy_file_if_exists(input_file, output_file, utt_modifier_type,
function create_augmented_utt2uniq (line 186) | def create_augmented_utt2uniq(input_dir, output_dir,
function main (line 197) | def main():
FILE: egs/steps/data/data_dir_manipulation_lib.py
function RunKaldiCommand (line 3) | def RunKaldiCommand(command, wait = True):
FILE: egs/steps/data/make_musan.py
function get_args (line 13) | def get_args():
function check_args (line 34) | def check_args(args):
function process_music_annotations (line 43) | def process_music_annotations(path):
function prepare_music (line 54) | def prepare_music(root_dir, use_vocals, sampling_rate):
function prepare_speech (line 92) | def prepare_speech(root_dir, sampling_rate):
function prepare_noise (line 125) | def prepare_noise(root_dir, sampling_rate):
function main (line 158) | def main():
FILE: egs/steps/data/reverberate_data_dir.py
function get_args (line 12) | def get_args():
function check_args (line 86) | def check_args(args):
class list_cyclic_iterator (line 119) | class list_cyclic_iterator(object):
method __init__ (line 120) | def __init__(self, list):
method __next__ (line 125) | def __next__(self):
function pick_item_with_probability (line 132) | def pick_item_with_probability(x):
function parse_file_to_dict (line 155) | def parse_file_to_dict(file, assert2fields = False, value_processor = No...
function write_dict_to_file (line 170) | def write_dict_to_file(dict, file_name):
function create_corrupted_utt2uniq (line 186) | def create_corrupted_utt2uniq(input_dir, output_dir, num_replicas, inclu...
function add_point_source_noise (line 206) | def add_point_source_noise(noise_addition_descriptor, # descriptor to s...
function generate_reverberation_opts (line 240) | def generate_reverberation_opts(room_dict, # the room dictionary, pleas...
function get_new_id (line 302) | def get_new_id(id, prefix=None, copy=0):
function generate_reverberated_wav_scp (line 315) | def generate_reverberated_wav_scp(wav_scp, # a dictionary whose values ...
function add_prefix_to_fields (line 379) | def add_prefix_to_fields(input_file, output_file, num_replicas, include_...
function create_reverberated_copy (line 401) | def create_reverberated_copy(input_dir,
function smooth_probability_distribution (line 458) | def smooth_probability_distribution(set_list, smoothing_weight=0.0, targ...
function parse_set_parameter_strings (line 493) | def parse_set_parameter_strings(set_para_array):
function parse_rir_list (line 516) | def parse_rir_list(rir_set_para_array, smoothing_weight, sampling_rate =...
function almost_equal (line 552) | def almost_equal(value_1, value_2, accuracy = 10**-8):
function make_room_dict (line 558) | def make_room_dict(rir_list):
function parse_noise_list (line 581) | def parse_noise_list(noise_set_para_array, smoothing_weight, sampling_ra...
function main (line 641) | def main():
FILE: egs/steps/diagnostic/analyze_lattice_depth_stats.py
function GetPercentile (line 116) | def GetPercentile(depth_to_count, fraction):
function GetMean (line 131) | def GetMean(depth_to_count):
FILE: egs/steps/diagnostic/analyze_phone_length_stats.py
function GetPercentile (line 145) | def GetPercentile(length_to_count, fraction):
function GetMean (line 160) | def GetMean(length_to_count):
FILE: egs/steps/dict/apply_lexicon_edits.py
function GetArgs (line 10) | def GetArgs():
function CheckArgs (line 37) | def CheckArgs(args):
function ReadLexicon (line 51) | def ReadLexicon(lexicon_file_handle):
function ApplyLexiconEdits (line 66) | def ApplyLexiconEdits(lexicon, lexicon_edits_file_handle):
function WriteLexicon (line 98) | def WriteLexicon(lexicon, out_lexicon_handle):
function Main (line 103) | def Main():
FILE: egs/steps/dict/get_pron_stats.py
function GetArgs (line 12) | def GetArgs():
function CheckArgs (line 41) | def CheckArgs(args):
function GetStatsFromArcInfo (line 56) | def GetStatsFromArcInfo(arc_info_file_handle, phone_map_handle):
function WriteStats (line 85) | def WriteStats(stats, file_handle):
function Main (line 91) | def Main():
FILE: egs/steps/dict/internal/get_subsegments.py
function GetArgs (line 12) | def GetArgs():
function CheckArgs (line 43) | def CheckArgs(args):
function GetSubsegments (line 60) | def GetSubsegments(args, vocab):
function ReadVocab (line 119) | def ReadVocab(vocab_file_handle):
function Main (line 133) | def Main():
FILE: egs/steps/dict/internal/prune_pron_candidates.py
function GetArgs (line 12) | def GetArgs():
function CheckArgs (line 51) | def CheckArgs(args):
function ReadStats (line 61) | def ReadStats(pron_stats_handle):
function ReadLexicon (line 77) | def ReadLexicon(lexicon_handle):
function ReadLexiconp (line 91) | def ReadLexiconp(lexiconp_handle):
function PruneProns (line 108) | def PruneProns(args, stats, ref_lexicon, lexicon_phonetic_decoding, lexi...
function Main (line 148) | def Main():
FILE: egs/steps/dict/internal/sum_arc_info.py
class StrToBoolAction (line 11) | class StrToBoolAction(argparse.Action):
method __call__ (line 14) | def __call__(self, parser, namespace, values, option_string=None):
function GetArgs (line 23) | def GetArgs():
function CheckArgs (line 52) | def CheckArgs(args):
function Main (line 67) | def Main():
FILE: egs/steps/dict/merge_learned_lexicons.py
function GetArgs (line 12) | def GetArgs():
function CheckArgs (line 62) | def CheckArgs(args):
function ReadArcStats (line 77) | def ReadArcStats(arc_stats_file_handle):
function ReadWordCounts (line 99) | def ReadWordCounts(word_counts_file_handle):
function ReadLexicon (line 111) | def ReadLexicon(args, lexicon_file_handle, counts):
function WriteEditsAndSummary (line 129) | def WriteEditsAndSummary(args, learned_lexicon, ref_lexicon, pd_lexicon,...
function WriteLearnedLexiconOov (line 237) | def WriteLearnedLexiconOov(learned_lexicon, ref_lexicon, file_handle):
function Main (line 244) | def Main():
FILE: egs/steps/dict/prons_to_lexicon.py
class StrToBoolAction (line 13) | class StrToBoolAction(argparse.Action):
method __call__ (line 16) | def __call__(self, parser, namespace, values, option_string=None):
function GetArgs (line 24) | def GetArgs():
function CheckArgs (line 64) | def CheckArgs(args):
function ReadStats (line 87) | def ReadStats(args):
function ReadLexicon (line 104) | def ReadLexicon(lexicon_file_handle):
function ConvertWordCountsToProbs (line 119) | def ConvertWordCountsToProbs(args, lexicon, word_count):
function ConvertWordProbsToLexicon (line 132) | def ConvertWordProbsToLexicon(word_probs):
function NormalizeLexicon (line 139) | def NormalizeLexicon(lexicon, set_max_to_one = True,
function TakeTopN (line 155) | def TakeTopN(lexicon, top_N):
function WriteLexicon (line 167) | def WriteLexicon(args, lexicon, filter_lexicon):
function Main (line 188) | def Main():
FILE: egs/steps/dict/prune_pron_candidates.py
function GetArgs (line 13) | def GetArgs():
function CheckArgs (line 42) | def CheckArgs(args):
function ReadStats (line 51) | def ReadStats(pron_stats_handle):
function ReadLexicon (line 69) | def ReadLexicon(ref_lexicon_handle):
function PruneProns (line 86) | def PruneProns(args, stats, ref_lexicon):
function Main (line 114) | def Main():
FILE: egs/steps/dict/select_prons_bayesian.py
function GetArgs (line 13) | def GetArgs():
function CheckArgs (line 91) | def CheckArgs(args):
function ReadPronStats (line 115) | def ReadPronStats(pron_stats_file_handle):
function ReadWordCounts (line 130) | def ReadWordCounts(word_counts_file_handle):
function ReadLexicon (line 142) | def ReadLexicon(args, lexicon_file_handle, counts):
function FilterPhoneticDecodingLexicon (line 160) | def FilterPhoneticDecodingLexicon(args, phonetic_decoding_lexicon, stats):
function ComputePriorCounts (line 184) | def ComputePriorCounts(args, counts, ref_lexicon, g2p_lexicon, phonetic_...
function ComputePosteriors (line 204) | def ComputePosteriors(args, stats, ref_lexicon, g2p_lexicon, phonetic_de...
function SelectPronsBayesian (line 258) | def SelectPronsBayesian(args, counts, posteriors, ref_lexicon, g2p_lexic...
function WriteEditsAndSummary (line 316) | def WriteEditsAndSummary(args, learned_lexicon, ref_lexicon, phonetic_de...
function WriteLearnedLexiconOov (line 414) | def WriteLearnedLexiconOov(learned_lexicon, ref_lexicon, file_handle):
function Main (line 421) | def Main():
FILE: egs/steps/dict/select_prons_greedy.py
function GetArgs (line 12) | def GetArgs():
function CheckArgs (line 69) | def CheckArgs(args):
function ReadArcStats (line 118) | def ReadArcStats(arc_stats_file_handle):
function ReadWordCounts (line 140) | def ReadWordCounts(word_counts_file_handle):
function ReadLexicon (line 152) | def ReadLexicon(args, lexicon_file_handle, counts):
function FilterPhoneticDecodingLexicon (line 170) | def FilterPhoneticDecodingLexicon(args, pd_lexicon):
function OneEMIter (line 187) | def OneEMIter(args, word, stats, prons, pron_probs, debug=False):
function SelectPronsGreedy (line 212) | def SelectPronsGreedy(args, stats, counts, ref_lexicon, g2p_lexicon, pd_...
function WriteLearnedLexicon (line 352) | def WriteLearnedLexicon(learned_lexicon, file_handle):
function Main (line 358) | def Main():
FILE: egs/steps/libs/common.py
function send_mail (line 31) | def send_mail(message, subject, email_id):
function str_to_bool (line 44) | def str_to_bool(value):
class StrToBoolAction (line 53) | class StrToBoolAction(argparse.Action):
method __call__ (line 57) | def __call__(self, parser, namespace, values, option_string=None):
class NullstrToNoneAction (line 65) | class NullstrToNoneAction(argparse.Action):
method __call__ (line 70) | def __call__(self, parser, namespace, values, option_string=None):
class smart_open (line 77) | class smart_open(object):
method __init__ (line 87) | def __init__(self, filename, mode="r"):
method __enter__ (line 92) | def __enter__(self):
method __exit__ (line 101) | def __exit__(self, *args):
function check_if_cuda_compiled (line 106) | def check_if_cuda_compiled():
function execute_command (line 115) | def execute_command(command):
function get_command_stdout (line 132) | def get_command_stdout(command, require_zero_status = True):
function wait_for_background_commands (line 159) | def wait_for_background_commands():
function background_command (line 168) | def background_command(command, require_zero_status = False):
function background_command_waiter (line 198) | def background_command_waiter(command, popen_object, require_zero_status):
function get_number_of_leaves_from_tree (line 215) | def get_number_of_leaves_from_tree(alidir):
function get_number_of_leaves_from_model (line 226) | def get_number_of_leaves_from_model(dir):
function get_number_of_jobs (line 238) | def get_number_of_jobs(alidir):
function get_ivector_dim (line 248) | def get_ivector_dim(ivector_dir=None):
function get_ivector_extractor_id (line 257) | def get_ivector_extractor_id(ivector_dir=None):
function get_feat_dim (line 268) | def get_feat_dim(feat_dir):
function get_feat_dim_from_scp (line 278) | def get_feat_dim_from_scp(feat_scp):
function read_kaldi_matrix (line 286) | def read_kaldi_matrix(matrix_file):
function write_kaldi_matrix (line 308) | def write_kaldi_matrix(output_file, matrix):
function write_matrix_ascii (line 329) | def write_matrix_ascii(file_or_fd, mat, key=None):
function read_matrix_ascii (line 365) | def read_matrix_ascii(file_or_fd):
function read_key (line 403) | def read_key(fd):
function read_mat_ark (line 421) | def read_mat_ark(file_or_fd):
function force_symlink (line 448) | def force_symlink(file1, file2):
function compute_lifter_coeffs (line 458) | def compute_lifter_coeffs(lifter, dim):
function compute_idct_matrix (line 466) | def compute_idct_matrix(K, N, cepstral_lifter=0):
function write_idct_matrix (line 488) | def write_idct_matrix(feat_dim, cepstral_lifter, file_path):
FILE: egs/steps/libs/nnet3/report/log_parse.py
class KaldiLogParseException (line 48) | class KaldiLogParseException(Exception):
method __init__ (line 52) | def __init__(self, message = None):
function fill_nonlin_stats_table_with_regex_result (line 63) | def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_...
function parse_progress_logs_for_nonlinearity_stats (line 148) | def parse_progress_logs_for_nonlinearity_stats(exp_dir):
function parse_difference_string (line 201) | def parse_difference_string(string):
class MalformedClippedProportionLineException (line 209) | class MalformedClippedProportionLineException(Exception):
method __init__ (line 210) | def __init__(self, line):
function parse_progress_logs_for_clipped_proportion (line 216) | def parse_progress_logs_for_clipped_proportion(exp_dir):
function parse_progress_logs_for_param_diff (line 292) | def parse_progress_logs_for_param_diff(exp_dir, pattern):
function get_train_times (line 366) | def get_train_times(exp_dir):
function parse_prob_logs (line 390) | def parse_prob_logs(exp_dir, key='accuracy', output="output"):
function parse_rnnlm_prob_logs (line 447) | def parse_rnnlm_prob_logs(exp_dir, key='objf'):
function generate_acc_logprob_report (line 512) | def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
FILE: egs/steps/libs/nnet3/train/chain_objf/acoustic_model.py
function create_phone_lm (line 25) | def create_phone_lm(dir, tree_dir, run_opts, lm_opts=None):
function create_denominator_fst (line 53) | def create_denominator_fst(dir, tree_dir, run_opts):
function generate_chain_egs (line 65) | def generate_chain_egs(dir, data, lat_dir, egs_dir,
function train_new_models (line 121) | def train_new_models(dir, iter, srand, num_jobs,
function train_one_iteration (line 238) | def train_one_iteration(dir, iter, srand, egs_dir,
function check_for_required_files (line 369) | def check_for_required_files(feat_dir, tree_dir, lat_dir=None):
function compute_preconditioning_matrix (line 381) | def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts,
function prepare_initial_acoustic_model (line 446) | def prepare_initial_acoustic_model(dir, run_opts, srand=-1, input_model=...
function compute_train_cv_probabilities (line 472) | def compute_train_cv_probabilities(dir, iter, egs_dir, l2_regularize,
function compute_progress (line 520) | def compute_progress(dir, iter, run_opts):
function combine_models (line 558) | def combine_models(dir, num_iters, models_to_combine, num_chunk_per_mini...
FILE: egs/steps/libs/nnet3/train/common.py
class RunOpts (line 27) | class RunOpts(object):
method __init__ (line 35) | def __init__(self):
function get_outputs_list (line 44) | def get_outputs_list(model_file, get_raw_nnet_from_am=True):
function get_multitask_egs_opts (line 60) | def get_multitask_egs_opts(egs_dir, egs_prefix="",
function get_successful_models (line 107) | def get_successful_models(num_models, log_file_pattern,
function get_average_nnet_model (line 143) | def get_average_nnet_model(dir, iter, nnets_list, run_opts,
function get_best_nnet_model (line 166) | def get_best_nnet_model(dir, iter, best_model_index, run_opts,
function validate_chunk_width (line 191) | def validate_chunk_width(chunk_width):
function principal_chunk_width (line 209) | def principal_chunk_width(chunk_width):
function validate_range_str (line 217) | def validate_range_str(range_str):
function validate_minibatch_size_str (line 244) | def validate_minibatch_size_str(minibatch_size_str):
function halve_range_str (line 282) | def halve_range_str(range_str):
function halve_minibatch_size_str (line 299) | def halve_minibatch_size_str(minibatch_size_str):
function copy_egs_properties_to_exp_dir (line 321) | def copy_egs_properties_to_exp_dir(egs_dir, dir):
function parse_generic_config_vars_file (line 334) | def parse_generic_config_vars_file(var_file):
function get_input_model_info (line 364) | def get_input_model_info(input_model):
function verify_egs_dir (line 394) | def verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_extractor_id,
function compute_presoftmax_prior_scale (line 495) | def compute_presoftmax_prior_scale(dir, alidir, num_jobs, run_opts,
function smooth_presoftmax_prior_scale_vector (line 528) | def smooth_presoftmax_prior_scale_vector(pdf_counts,
function prepare_initial_network (line 542) | def prepare_initial_network(dir, run_opts, srand=-3, input_model=None):
function get_model_combine_iters (line 561) | def get_model_combine_iters(num_iters, num_epochs,
function get_current_num_jobs (line 605) | def get_current_num_jobs(it, num_it, start, step, end):
function get_learning_rate (line 615) | def get_learning_rate(iter, num_jobs, num_iters, num_archives_processed,
function should_do_shrinkage (line 630) | def should_do_shrinkage(iter, model_file, shrink_saturation_threshold,
function remove_nnet_egs (line 657) | def remove_nnet_egs(egs_dir):
function clean_nnet_dir (line 662) | def clean_nnet_dir(nnet_dir, num_iters, egs_dir,
function remove_model (line 679) | def remove_model(nnet_dir, iter, num_iters, models_to_combine=None,
function positive_int (line 695) | def positive_int(arg):
class CommonParser (line 702) | class CommonParser(object):
method __init__ (line 713) | def __init__(self,
class SelfTest (line 1000) | class SelfTest(unittest.TestCase):
method test_halve_minibatch_size_str (line 1002) | def test_halve_minibatch_size_str(self):
method test_validate_chunk_width (line 1009) | def test_validate_chunk_width(self):
method test_validate_minibatch_size_str (line 1014) | def test_validate_minibatch_size_str(self):
method test_get_current_num_jobs (line 1026) | def test_get_current_num_jobs(self):
FILE: egs/steps/libs/nnet3/train/dropout_schedule.py
function _parse_dropout_option (line 18) | def _parse_dropout_option(dropout_option):
function _parse_dropout_string (line 68) | def _parse_dropout_string(dropout_str):
function _get_component_dropout (line 131) | def _get_component_dropout(dropout_schedule, data_fraction):
function _get_dropout_proportions (line 187) | def _get_dropout_proportions(dropout_schedule, data_fraction):
function get_dropout_edit_option (line 226) | def get_dropout_edit_option(dropout_schedule, data_fraction, iter_):
function get_dropout_edit_string (line 269) | def get_dropout_edit_string(dropout_schedule, data_fraction, iter_):
function _self_test (line 308) | def _self_test():
FILE: egs/steps/libs/nnet3/train/frame_level_objf/acoustic_model.py
function generate_egs (line 21) | def generate_egs(data, alidir, egs_dir,
function prepare_initial_acoustic_model (line 64) | def prepare_initial_acoustic_model(dir, alidir, run_opts,
FILE: egs/steps/libs/nnet3/train/frame_level_objf/common.py
function train_new_models (line 28) | def train_new_models(dir, iter, srand, num_jobs,
function train_one_iteration (line 174) | def train_one_iteration(dir, iter, srand, egs_dir,
function compute_preconditioning_matrix (line 320) | def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts,
function compute_train_cv_probabilities (line 380) | def compute_train_cv_probabilities(dir, iter, egs_dir, run_opts,
function compute_progress (line 438) | def compute_progress(dir, iter, egs_dir,
function combine_models (line 476) | def combine_models(dir, num_iters, models_to_combine, egs_dir,
function get_realign_iters (line 562) | def get_realign_iters(realign_times, num_iters,
function align (line 589) | def align(dir, data, lang, run_opts, iter=None,
function realign (line 620) | def realign(dir, iter, feat_dir, lang, prev_egs_dir, cur_egs_dir,
function adjust_am_priors (line 653) | def adjust_am_priors(dir, input_model, avg_posterior_vector, output_model,
function compute_average_posterior (line 665) | def compute_average_posterior(dir, iter, egs_dir, num_archives,
FILE: egs/steps/libs/nnet3/train/frame_level_objf/raw_model.py
function generate_egs_using_targets (line 20) | def generate_egs_using_targets(data, targets_scp, egs_dir,
FILE: egs/steps/libs/nnet3/xconfig/attention.py
class XconfigAttentionLayer (line 27) | class XconfigAttentionLayer(XconfigLayerBase):
method __init__ (line 28) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 37) | def set_default_configs(self):
method check_configs (line 62) | def check_configs(self):
method output_name (line 80) | def output_name(self, auxiliary_output=None):
method attention_input_dim (line 91) | def attention_input_dim(self):
method attention_output_dim (line 100) | def attention_output_dim(self):
method output_dim (line 109) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 112) | def get_full_config(self):
method _generate_config (line 124) | def _generate_config(self):
method _add_components (line 139) | def _add_components(self, input_desc, input_dim, nonlinearities):
FILE: egs/steps/libs/nnet3/xconfig/basic_layers.py
class XconfigLayerBase (line 20) | class XconfigLayerBase(object):
method __init__ (line 24) | def __init__(self, first_token, key_to_value, all_layers):
method set_configs (line 79) | def set_configs(self, key_to_value, all_layers):
method str (line 145) | def str(self):
method __str__ (line 168) | def __str__(self):
method normalize_descriptors (line 171) | def normalize_descriptors(self):
method convert_to_descriptor (line 183) | def convert_to_descriptor(self, descriptor_string, all_layers):
method get_dim_for_descriptor (line 203) | def get_dim_for_descriptor(self, descriptor, all_layers):
method get_string_for_descriptor (line 213) | def get_string_for_descriptor(self, descriptor, all_layers):
method get_name (line 224) | def get_name(self):
method set_default_configs (line 233) | def set_default_configs(self):
method set_derived_configs (line 239) | def set_derived_configs(self):
method check_configs (line 246) | def check_configs(self):
method get_input_descriptor_names (line 252) | def get_input_descriptor_names(self):
method auxiliary_outputs (line 268) | def auxiliary_outputs(self):
method output_name (line 278) | def output_name(self, auxiliary_output=None):
method output_dim (line 293) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 300) | def get_full_config(self):
class XconfigInputLayer (line 314) | class XconfigInputLayer(XconfigLayerBase):
method __init__ (line 321) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 326) | def set_default_configs(self):
method check_configs (line 330) | def check_configs(self):
method get_input_descriptor_names (line 336) | def get_input_descriptor_names(self):
method output_name (line 340) | def output_name(self, auxiliary_outputs=None):
method output_dim (line 346) | def output_dim(self, auxiliary_outputs=None):
method get_full_config (line 352) | def get_full_config(self):
class XconfigTrivialOutputLayer (line 364) | class XconfigTrivialOutputLayer(XconfigLayerBase):
method __init__ (line 381) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 386) | def set_default_configs(self):
method check_configs (line 394) | def check_configs(self):
method output_name (line 402) | def output_name(self, auxiliary_outputs=None):
method output_dim (line 409) | def output_dim(self, auxiliary_outputs=None):
method get_full_config (line 415) | def get_full_config(self):
class XconfigOutputLayer (line 446) | class XconfigOutputLayer(XconfigLayerBase):
method __init__ (line 485) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 490) | def set_default_configs(self):
method check_configs (line 522) | def check_configs(self):
method auxiliary_outputs (line 539) | def auxiliary_outputs(self):
method output_name (line 547) | def output_name(self, auxiliary_output=None):
method output_dim (line 562) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 572) | def get_full_config(self):
method _generate_config (line 584) | def _generate_config(self):
class XconfigBasicLayer (line 672) | class XconfigBasicLayer(XconfigLayerBase):
method __init__ (line 706) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 709) | def set_default_configs(self):
method check_configs (line 739) | def check_configs(self):
method output_name (line 757) | def output_name(self, auxiliary_output=None):
method output_dim (line 768) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 776) | def get_full_config(self):
method _generate_config (line 787) | def _generate_config(self):
method _add_components (line 802) | def _add_components(self, input_desc, input_dim, nonlinearities):
class XconfigFixedAffineLayer (line 937) | class XconfigFixedAffineLayer(XconfigLayerBase):
method __init__ (line 955) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 959) | def set_default_configs(self):
method check_configs (line 968) | def check_configs(self):
method output_name (line 972) | def output_name(self, auxiliary_output=None):
method output_dim (line 978) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 985) | def get_full_config(self):
class XconfigAffineLayer (line 1028) | class XconfigAffineLayer(XconfigLayerBase):
method __init__ (line 1049) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 1053) | def set_default_configs(self):
method set_derived_configs (line 1070) | def set_derived_configs(self):
method check_configs (line 1075) | def check_configs(self):
method output_name (line 1079) | def output_name(self, auxiliary_output=None):
method output_dim (line 1085) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 1093) | def get_full_config(self):
class XconfigIdctLayer (line 1128) | class XconfigIdctLayer(XconfigLayerBase):
method __init__ (line 1150) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 1154) | def set_default_configs(self):
method check_configs (line 1163) | def check_configs(self):
method output_name (line 1167) | def output_name(self, auxiliary_output=None):
method output_dim (line 1173) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 1180) | def get_full_config(self):
method _generate_config (line 1193) | def _generate_config(self):
class XconfigExistingLayer (line 1224) | class XconfigExistingLayer(XconfigLayerBase):
method __init__ (line 1243) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 1249) | def set_default_configs(self):
method check_configs (line 1252) | def check_configs(self):
method get_input_descriptor_names (line 1257) | def get_input_descriptor_names(self):
method output_name (line 1260) | def output_name(self, auxiliary_outputs=None):
method output_dim (line 1265) | def output_dim(self, auxiliary_outputs=None):
method get_full_config (line 1270) | def get_full_config(self):
class XconfigSpecAugmentLayer (line 1277) | class XconfigSpecAugmentLayer(XconfigLayerBase):
method __init__ (line 1297) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 1300) | def set_default_configs(self):
method check_configs (line 1308) | def check_configs(self):
method output_name (line 1314) | def output_name(self, auxiliary_output=None):
method output_dim (line 1318) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 1323) | def get_full_config(self):
method _generate_config (line 1336) | def _generate_config(self):
function test_layers (line 1362) | def test_layers():
FILE: egs/steps/libs/nnet3/xconfig/composite_layers.py
class XconfigTdnnfLayer (line 68) | class XconfigTdnnfLayer(XconfigLayerBase):
method __init__ (line 70) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 74) | def set_default_configs(self):
method set_derived_configs (line 86) | def set_derived_configs(self):
method check_configs (line 89) | def check_configs(self):
method output_name (line 113) | def output_name(self, auxiliary_output=None):
method output_dim (line 127) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 130) | def get_full_config(self):
method _generate_config (line 140) | def _generate_config(self):
class XconfigPrefinalLayer (line 241) | class XconfigPrefinalLayer(XconfigLayerBase):
method __init__ (line 243) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 247) | def set_default_configs(self):
method set_derived_configs (line 255) | def set_derived_configs(self):
method check_configs (line 258) | def check_configs(self):
method output_name (line 264) | def output_name(self, auxiliary_output=None):
method output_dim (line 268) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 271) | def get_full_config(self):
method _generate_config (line 281) | def _generate_config(self):
FILE: egs/steps/libs/nnet3/xconfig/convolution.py
class XconfigConvLayer (line 115) | class XconfigConvLayer(XconfigLayerBase):
method __init__ (line 116) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 122) | def set_default_configs(self):
method set_derived_configs (line 143) | def set_derived_configs(self):
method check_offsets_var (line 157) | def check_offsets_var(self, str):
method check_configs (line 169) | def check_configs(self):
method auxiliary_outputs (line 214) | def auxiliary_outputs(self):
method output_name (line 217) | def output_name(self, auxiliary_output = None):
method output_dim (line 229) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 233) | def get_full_config(self):
method _generate_cnn_config (line 245) | def _generate_cnn_config(self):
class XconfigResBlock (line 416) | class XconfigResBlock(XconfigLayerBase):
method __init__ (line 417) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 421) | def set_default_configs(self):
method set_derived_configs (line 443) | def set_derived_configs(self):
method check_configs (line 463) | def check_configs(self):
method auxiliary_outputs (line 471) | def auxiliary_outputs(self):
method output_name (line 474) | def output_name(self, auxiliary_output = None):
method output_dim (line 492) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 497) | def get_full_config(self):
method _generate_normal_resblock_config (line 535) | def _generate_normal_resblock_config(self):
method _generate_bottleneck_resblock_config (line 629) | def _generate_bottleneck_resblock_config(self):
class XconfigRes2Block (line 775) | class XconfigRes2Block(XconfigLayerBase):
method __init__ (line 776) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 780) | def set_default_configs(self):
method set_derived_configs (line 803) | def set_derived_configs(self):
method check_configs (line 826) | def check_configs(self):
method auxiliary_outputs (line 830) | def auxiliary_outputs(self):
method output_name (line 833) | def output_name(self, auxiliary_output = None):
method output_dim (line 837) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 841) | def get_full_config(self):
method _generate_normal_resblock_config (line 869) | def _generate_normal_resblock_config(self):
method _generate_bottleneck_resblock_config (line 1015) | def _generate_bottleneck_resblock_config(self):
class ChannelAverageLayer (line 1149) | class ChannelAverageLayer(XconfigLayerBase):
method __init__ (line 1150) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1154) | def set_default_configs(self):
method set_derived_configs (line 1158) | def set_derived_configs(self):
method check_configs (line 1161) | def check_configs(self):
method auxiliary_outputs (line 1170) | def auxiliary_outputs(self):
method output_name (line 1173) | def output_name(self, auxiliary_output = None):
method output_dim (line 1177) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1182) | def get_full_config(self):
method _generate_channel_average_config (line 1190) | def _generate_channel_average_config(self):
FILE: egs/steps/libs/nnet3/xconfig/gru.py
class XconfigGruLayer (line 36) | class XconfigGruLayer(XconfigLayerBase):
method __init__ (line 37) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 41) | def set_default_configs(self):
method set_derived_configs (line 53) | def set_derived_configs(self):
method check_configs (line 57) | def check_configs(self):
method output_name (line 69) | def output_name(self, auxiliary_output = None):
method output_dim (line 73) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 76) | def get_full_config(self):
method generate_gru_config (line 88) | def generate_gru_config(self):
class XconfigPgruLayer (line 197) | class XconfigPgruLayer(XconfigLayerBase):
method __init__ (line 198) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 202) | def set_default_configs(self):
method set_derived_configs (line 217) | def set_derived_configs(self):
method check_configs (line 225) | def check_configs(self):
method auxiliary_outputs (line 246) | def auxiliary_outputs(self):
method output_name (line 249) | def output_name(self, auxiliary_output = None):
method output_dim (line 259) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 270) | def get_full_config(self):
method generate_pgru_config (line 282) | def generate_pgru_config(self):
class XconfigNormPgruLayer (line 400) | class XconfigNormPgruLayer(XconfigLayerBase):
method __init__ (line 401) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 405) | def set_default_configs(self):
method set_derived_configs (line 422) | def set_derived_configs(self):
method check_configs (line 430) | def check_configs(self):
method auxiliary_outputs (line 456) | def auxiliary_outputs(self):
method output_name (line 459) | def output_name(self, auxiliary_output = None):
method output_dim (line 469) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 480) | def get_full_config(self):
method generate_pgru_config (line 492) | def generate_pgru_config(self):
class XconfigOpgruLayer (line 631) | class XconfigOpgruLayer(XconfigLayerBase):
method __init__ (line 632) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 636) | def set_default_configs(self):
method set_derived_configs (line 651) | def set_derived_configs(self):
method check_configs (line 659) | def check_configs(self):
method auxiliary_outputs (line 680) | def auxiliary_outputs(self):
method output_name (line 683) | def output_name(self, auxiliary_output = None):
method output_dim (line 693) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 704) | def get_full_config(self):
method generate_pgru_config (line 716) | def generate_pgru_config(self):
class XconfigNormOpgruLayer (line 834) | class XconfigNormOpgruLayer(XconfigLayerBase):
method __init__ (line 835) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 839) | def set_default_configs(self):
method set_derived_configs (line 857) | def set_derived_configs(self):
method check_configs (line 865) | def check_configs(self):
method auxiliary_outputs (line 891) | def auxiliary_outputs(self):
method output_name (line 894) | def output_name(self, auxiliary_output = None):
method output_dim (line 904) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 915) | def get_full_config(self):
method generate_pgru_config (line 927) | def generate_pgru_config(self):
class XconfigFastGruLayer (line 1065) | class XconfigFastGruLayer(XconfigLayerBase):
method __init__ (line 1066) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1070) | def set_default_configs(self):
method set_derived_configs (line 1088) | def set_derived_configs(self):
method check_configs (line 1092) | def check_configs(self):
method output_name (line 1104) | def output_name(self, auxiliary_output = None):
method output_dim (line 1108) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1111) | def get_full_config(self):
method generate_gru_config (line 1123) | def generate_gru_config(self):
class XconfigFastPgruLayer (line 1225) | class XconfigFastPgruLayer(XconfigLayerBase):
method __init__ (line 1226) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1230) | def set_default_configs(self):
method set_derived_configs (line 1251) | def set_derived_configs(self):
method check_configs (line 1259) | def check_configs(self):
method auxiliary_outputs (line 1280) | def auxiliary_outputs(self):
method output_name (line 1283) | def output_name(self, auxiliary_output = None):
method output_dim (line 1293) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1304) | def get_full_config(self):
method generate_pgru_config (line 1316) | def generate_pgru_config(self):
class XconfigFastNormPgruLayer (line 1436) | class XconfigFastNormPgruLayer(XconfigLayerBase):
method __init__ (line 1437) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1441) | def set_default_configs(self):
method set_derived_configs (line 1464) | def set_derived_configs(self):
method check_configs (line 1472) | def check_configs(self):
method auxiliary_outputs (line 1498) | def auxiliary_outputs(self):
method output_name (line 1501) | def output_name(self, auxiliary_output = None):
method output_dim (line 1511) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1522) | def get_full_config(self):
method generate_pgru_config (line 1534) | def generate_pgru_config(self):
class XconfigFastOpgruLayer (line 1681) | class XconfigFastOpgruLayer(XconfigLayerBase):
method __init__ (line 1682) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1686) | def set_default_configs(self):
method set_derived_configs (line 1707) | def set_derived_configs(self):
method check_configs (line 1715) | def check_configs(self):
method auxiliary_outputs (line 1736) | def auxiliary_outputs(self):
method output_name (line 1739) | def output_name(self, auxiliary_output = None):
method output_dim (line 1749) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1760) | def get_full_config(self):
method generate_pgru_config (line 1772) | def generate_pgru_config(self):
class XconfigFastNormOpgruLayer (line 1895) | class XconfigFastNormOpgruLayer(XconfigLayerBase):
method __init__ (line 1896) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 1900) | def set_default_configs(self):
method set_derived_configs (line 1923) | def set_derived_configs(self):
method check_configs (line 1931) | def check_configs(self):
method auxiliary_outputs (line 1957) | def auxiliary_outputs(self):
method output_name (line 1960) | def output_name(self, auxiliary_output = None):
method output_dim (line 1970) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1981) | def get_full_config(self):
method generate_pgru_config (line 1993) | def generate_pgru_config(self):
FILE: egs/steps/libs/nnet3/xconfig/lstm.py
class XconfigLstmLayer (line 45) | class XconfigLstmLayer(XconfigLayerBase):
method __init__ (line 46) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 50) | def set_default_configs(self):
method set_derived_configs (line 64) | def set_derived_configs(self):
method check_configs (line 68) | def check_configs(self):
method auxiliary_outputs (line 80) | def auxiliary_outputs(self):
method output_name (line 83) | def output_name(self, auxiliary_output = None):
method output_dim (line 93) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 104) | def get_full_config(self):
method _generate_lstm_config (line 116) | def _generate_lstm_config(self):
class XconfigLstmpLayer (line 295) | class XconfigLstmpLayer(XconfigLayerBase):
method __init__ (line 296) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 302) | def set_default_configs(self):
method set_derived_configs (line 321) | def set_derived_configs(self):
method check_configs (line 329) | def check_configs(self):
method auxiliary_outputs (line 356) | def auxiliary_outputs(self):
method output_name (line 359) | def output_name(self, auxiliary_output = None):
method output_dim (line 370) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 381) | def get_full_config(self):
method _generate_lstm_config (line 393) | def _generate_lstm_config(self):
class XconfigFastLstmLayer (line 601) | class XconfigFastLstmLayer(XconfigLayerBase):
method __init__ (line 602) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 606) | def set_default_configs(self):
method set_derived_configs (line 626) | def set_derived_configs(self):
method check_configs (line 630) | def check_configs(self):
method auxiliary_outputs (line 639) | def auxiliary_outputs(self):
method output_name (line 642) | def output_name(self, auxiliary_output = None):
method output_dim (line 653) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 663) | def get_full_config(self):
method _generate_lstm_config (line 675) | def _generate_lstm_config(self):
class XconfigLstmbLayer (line 798) | class XconfigLstmbLayer(XconfigLayerBase):
method __init__ (line 799) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 803) | def set_default_configs(self):
method set_derived_configs (line 823) | def set_derived_configs(self):
method check_configs (line 827) | def check_configs(self):
method auxiliary_outputs (line 837) | def auxiliary_outputs(self):
method output_name (line 840) | def output_name(self, auxiliary_output = None):
method output_dim (line 844) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 848) | def get_full_config(self):
method _generate_lstm_config (line 860) | def _generate_lstm_config(self):
class XconfigFastLstmpLayer (line 994) | class XconfigFastLstmpLayer(XconfigLayerBase):
method __init__ (line 995) | def __init__(self, first_token, key_to_value, prev_names = None):
method set_default_configs (line 999) | def set_default_configs(self):
method set_derived_configs (line 1022) | def set_derived_configs(self):
method check_configs (line 1031) | def check_configs(self):
method auxiliary_outputs (line 1050) | def auxiliary_outputs(self):
method output_name (line 1053) | def output_name(self, auxiliary_output = None):
method output_dim (line 1064) | def output_dim(self, auxiliary_output = None):
method get_full_config (line 1075) | def get_full_config(self):
method _generate_lstm_config (line 1087) | def _generate_lstm_config(self):
FILE: egs/steps/libs/nnet3/xconfig/parser.py
function xconfig_line_to_object (line 97) | def xconfig_line_to_object(config_line, prev_layers = None):
function get_model_component_info (line 113) | def get_model_component_info(model_filename):
function read_xconfig_file (line 183) | def read_xconfig_file(xconfig_filename, existing_layers=None):
FILE: egs/steps/libs/nnet3/xconfig/stats_layer.py
class XconfigStatsLayer (line 13) | class XconfigStatsLayer(XconfigLayerBase):
method __init__ (line 37) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 41) | def set_default_configs(self):
method set_derived_configs (line 46) | def set_derived_configs(self):
method check_configs (line 81) | def check_configs(self):
method _generate_config (line 92) | def _generate_config(self):
method output_name (line 129) | def output_name(self, auxiliary_output=None):
method output_dim (line 134) | def output_dim(self, auxiliary_outputs=None):
method get_full_config (line 137) | def get_full_config(self):
FILE: egs/steps/libs/nnet3/xconfig/trivial_layers.py
class XconfigRenormComponent (line 17) | class XconfigRenormComponent(XconfigLayerBase):
method __init__ (line 26) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 29) | def set_default_configs(self):
method check_configs (line 33) | def check_configs(self):
method output_name (line 36) | def output_name(self, auxiliary_output=None):
method output_dim (line 40) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 45) | def get_full_config(self):
method _generate_config (line 56) | def _generate_config(self):
class XconfigBatchnormComponent (line 74) | class XconfigBatchnormComponent(XconfigLayerBase):
method __init__ (line 86) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 89) | def set_default_configs(self):
method check_configs (line 94) | def check_configs(self):
method output_name (line 97) | def output_name(self, auxiliary_output=None):
method output_dim (line 101) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 106) | def get_full_config(self):
method _generate_config (line 119) | def _generate_config(self):
class XconfigNoOpComponent (line 136) | class XconfigNoOpComponent(XconfigLayerBase):
method __init__ (line 144) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 147) | def set_default_configs(self):
method check_configs (line 150) | def check_configs(self):
method output_name (line 153) | def output_name(self, auxiliary_output=None):
method output_dim (line 157) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 162) | def get_full_config(self):
method _generate_config (line 173) | def _generate_config(self):
class XconfigDeltaLayer (line 189) | class XconfigDeltaLayer(XconfigLayerBase):
method __init__ (line 199) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 202) | def set_default_configs(self):
method check_configs (line 205) | def check_configs(self):
method output_name (line 208) | def output_name(self, auxiliary_output=None):
method output_dim (line 212) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 217) | def get_full_config(self):
method _generate_config (line 228) | def _generate_config(self):
class XconfigLinearComponent (line 260) | class XconfigLinearComponent(XconfigLayerBase):
method __init__ (line 279) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 282) | def set_default_configs(self):
method check_configs (line 291) | def check_configs(self):
method output_name (line 295) | def output_name(self, auxiliary_output=None):
method output_dim (line 299) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 304) | def get_full_config(self):
method _generate_config (line 315) | def _generate_config(self):
class XconfigCombineFeatureMapsLayer (line 339) | class XconfigCombineFeatureMapsLayer(XconfigLayerBase):
method __init__ (line 356) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 359) | def set_default_configs(self):
method check_configs (line 366) | def check_configs(self):
method output_name (line 382) | def output_name(self, auxiliary_output=None):
method output_dim (line 386) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 391) | def get_full_config(self):
method _generate_config (line 402) | def _generate_config(self):
class XconfigAffineComponent (line 435) | class XconfigAffineComponent(XconfigLayerBase):
method __init__ (line 454) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 457) | def set_default_configs(self):
method check_configs (line 466) | def check_configs(self):
method output_name (line 470) | def output_name(self, auxiliary_output=None):
method output_dim (line 474) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 479) | def get_full_config(self):
method _generate_config (line 490) | def _generate_config(self):
class XconfigPerElementScaleComponent (line 514) | class XconfigPerElementScaleComponent(XconfigLayerBase):
method __init__ (line 533) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 536) | def set_default_configs(self):
method check_configs (line 544) | def check_configs(self):
method output_name (line 547) | def output_name(self, auxiliary_output=None):
method output_dim (line 551) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 555) | def get_full_config(self):
method _generate_config (line 566) | def _generate_config(self):
class XconfigPerElementOffsetComponent (line 588) | class XconfigPerElementOffsetComponent(XconfigLayerBase):
method __init__ (line 607) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 610) | def set_default_configs(self):
method check_configs (line 618) | def check_configs(self):
method output_name (line 621) | def output_name(self, auxiliary_output=None):
method output_dim (line 625) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 629) | def get_full_config(self):
method _generate_config (line 640) | def _generate_config(self):
class XconfigDimRangeComponent (line 663) | class XconfigDimRangeComponent(XconfigLayerBase):
method __init__ (line 672) | def __init__(self, first_token, key_to_value, prev_names=None):
method set_default_configs (line 675) | def set_default_configs(self):
method check_configs (line 680) | def check_configs(self):
method output_name (line 691) | def output_name(self, auxiliary_output=None):
method output_dim (line 695) | def output_dim(self, auxiliary_output=None):
method get_full_config (line 702) | def get_full_config(self):
method _generate_config (line 713) | def _generate_config(self):
FILE: egs/steps/libs/nnet3/xconfig/utils.py
function get_prev_names (line 22) | def get_prev_names(all_layers, current_layer):
function split_layer_name (line 46) | def split_layer_name(full_layer_name):
function get_dim_from_layer_name (line 67) | def get_dim_from_layer_name(all_layers, current_layer, full_layer_name):
function get_string_from_layer_name (line 110) | def get_string_from_layer_name(all_layers, current_layer, full_layer_name):
function convert_value_to_type (line 148) | def convert_value_to_type(key, dest_type, string_value):
class Descriptor (line 187) | class Descriptor(object):
method __init__ (line 188) | def __init__(self,
method config_string (line 233) | def config_string(self, layer_to_string):
method str (line 243) | def str(self):
method __str__ (line 251) | def __str__(self):
method dim (line 261) | def dim(self, layer_to_dim):
function expect_token (line 294) | def expect_token(expected_item, seen_item, what_parsing):
function is_valid_line_name (line 301) | def is_valid_line_name(name):
function parse_new_descriptor (line 314) | def parse_new_descriptor(tokens, pos, prev_names):
function replace_bracket_expressions_in_descriptor (line 497) | def replace_bracket_expressions_in_descriptor(descriptor_string,
function tokenize_descriptor (line 538) | def tokenize_descriptor(descriptor_string,
function parse_config_line (line 568) | def parse_config_line(orig_config_line):
function test_library (line 650) | def test_library():
FILE: egs/steps/nnet2/make_multisplice_configs.py
function get_convolution_index_set (line 13) | def get_convolution_index_set(x, y):
function parse_splice_string (line 22) | def parse_splice_string(splice_string):
function create_config_files (line 60) | def create_config_files(output_dir, params):
FILE: egs/steps/nnet3/chain/e2e/train_e2e.py
function get_args (line 38) | def get_args():
function process_args (line 180) | def process_args(args):
function train (line 244) | def train(args, run_opts):
function main (line 553) | def main():
FILE: egs/steps/nnet3/chain/train.py
function get_args (line 38) | def get_args():
function process_args (line 203) | def process_args(args):
function train (line 271) | def train(args, run_opts):
function main (line 641) | def main():
FILE: egs/steps/nnet3/chain2/internal/get_train_schedule.py
function get_args (line 22) | def get_args():
function get_schedules (line 91) | def get_schedules(args):
function main (line 154) | def main():
FILE: egs/steps/nnet3/components.py
function GetSumDescriptor (line 12) | def GetSumDescriptor(inputs):
function AddInputLayer (line 30) | def AddInputLayer(config_lines, feat_dim, splice_indexes=[0], ivector_di...
function AddNoOpLayer (line 49) | def AddNoOpLayer(config_lines, name, input):
function AddLdaLayer (line 59) | def AddLdaLayer(config_lines, name, input, lda_file):
function AddFixedAffineLayer (line 62) | def AddFixedAffineLayer(config_lines, name, input, matrix_file):
function AddBlockAffineLayer (line 73) | def AddBlockAffineLayer(config_lines, name, input, output_dim, num_blocks):
function AddPermuteLayer (line 84) | def AddPermuteLayer(config_lines, name, input, column_map):
function AddAffineLayer (line 94) | def AddAffineLayer(config_lines, name, input, output_dim, ng_affine_opti...
function AddAffRelNormLayer (line 107) | def AddAffRelNormLayer(config_lines, name, input, output_dim, ng_affine_...
function AddAffPnormLayer (line 127) | def AddAffPnormLayer(config_lines, name, input, pnorm_input_dim, pnorm_o...
function AddConvolutionLayer (line 142) | def AddConvolutionLayer(config_lines, name, input,
function AddMaxpoolingLayer (line 180) | def AddMaxpoolingLayer(config_lines, name, input,
function AddSoftmaxLayer (line 218) | def AddSoftmaxLayer(config_lines, name, input):
function AddSigmoidLayer (line 229) | def AddSigmoidLayer(config_lines, name, input, self_repair_scale = None):
function AddOutputLayer (line 240) | def AddOutputLayer(config_lines, input, label_delay = None, suffix=None,...
function AddFinalLayer (line 252) | def AddFinalLayer(config_lines, input, output_dim,
function AddLstmLayer (line 289) | def AddLstmLayer(config_lines,
function AddBLstmLayer (line 440) | def AddBLstmLayer(config_lines,
FILE: egs/steps/nnet3/convert_nnet2_to_nnet3.py
function GetArgs (line 67) | def GetArgs():
class Nnet3Model (line 97) | class Nnet3Model(object):
method __init__ (line 100) | def __init__(self):
method add_component (line 112) | def add_component(self, component, pairs):
method write_config (line 149) | def write_config(self, filename):
method write_model (line 187) | def write_model(self, model, binary="true"):
function parse_nnet2_to_nnet3 (line 206) | def parse_nnet2_to_nnet3(line_buffer):
function parse_transition_model (line 238) | def parse_transition_model(line_buffer):
function parse_nnet2_header (line 259) | def parse_nnet2_header(line_buffer):
function parse_component (line 271) | def parse_component(line, line_buffer):
function parse_standard_component (line 294) | def parse_standard_component(component, line, line_buffer):
function parse_fixed_scale_component (line 301) | def parse_fixed_scale_component(component, line, line_buffer):
function parse_sum_group_component (line 315) | def parse_sum_group_component(component, line, line_buffer):
function parse_fixed_bias_component (line 323) | def parse_fixed_bias_component(component, line, line_buffer):
function parse_splice_component (line 337) | def parse_splice_component(component, line, line_buffer):
function parse_end_of_component (line 355) | def parse_end_of_component(component, line, line_buffer):
function parse_affine_component (line 364) | def parse_affine_component(component, line, line_buffer):
function parse_weights (line 386) | def parse_weights(line_buffer):
function parse_bias (line 402) | def parse_bias(line):
function parse_vector (line 408) | def parse_vector(line):
function parse_priors (line 412) | def parse_priors(line, line_buffer):
function token_to_string (line 423) | def token_to_string(token):
function consume_token (line 432) | def consume_token(token, line):
function make_splice_string (line 440) | def make_splice_string(nodename, context, const_component_dim=0):
function Main (line 455) | def Main():
FILE: egs/steps/nnet3/dot/descriptor_parser.py
function ParseSubsegmentsAndArguments (line 12) | def ParseSubsegmentsAndArguments(segment_endpoints, sub_segments, argume...
function IdentifyNestedSegments (line 40) | def IdentifyNestedSegments(input_string):
FILE: egs/steps/nnet3/dot/nnet3_to_dot.py
function GetDotNodeName (line 89) | def GetDotNodeName(name_string, is_component = False):
function ProcessAppendDescriptor (line 103) | def ProcessAppendDescriptor(segment, parent_node_name, affix, edge_attri...
function ProcessRoundDescriptor (line 139) | def ProcessRoundDescriptor(segment, parent_node_name, affix, edge_attrib...
function ProcessOffsetDescriptor (line 161) | def ProcessOffsetDescriptor(segment, parent_node_name, affix, edge_attri...
function ProcessSumDescriptor (line 183) | def ProcessSumDescriptor(segment, parent_node_name, affix, edge_attribut...
function ProcessReplaceIndexDescriptor (line 218) | def ProcessReplaceIndexDescriptor(segment, parent_node_name, affix, edge...
function ProcessIfDefinedDescriptor (line 240) | def ProcessIfDefinedDescriptor(segment, parent_node_name, affix, edge_at...
function DescriptorSegmentToDot (line 257) | def DescriptorSegmentToDot(segment, parent_node_name, affix, edge_attrib...
function Nnet3DescriptorToDot (line 278) | def Nnet3DescriptorToDot(descriptor, parent_node_name):
function ParseNnet3String (line 289) | def ParseNnet3String(string):
function Nnet3ComponentToDot (line 322) | def Nnet3ComponentToDot(component_config, component_attributes = None):
function Nnet3InputToDot (line 344) | def Nnet3InputToDot(parsed_config):
function Nnet3OutputToDot (line 349) | def Nnet3OutputToDot(parsed_config):
function Nnet3DimrangeToDot (line 356) | def Nnet3DimrangeToDot(parsed_config):
function Nnet3ComponentNodeToDot (line 366) | def Nnet3ComponentNodeToDot(parsed_config):
function GroupConfigs (line 375) | def GroupConfigs(configs, node_prefixes = None):
function ParseConfigLines (line 395) | def ParseConfigLines(lines, node_prefixes = None, component_attributes =...
FILE: egs/steps/nnet3/lstm/make_configs.py
function GetArgs (line 17) | def GetArgs():
function CheckArgs (line 115) | def CheckArgs(args):
function PrintConfig (line 162) | def PrintConfig(file_name, config_lines):
function ParseSpliceString (line 169) | def ParseSpliceString(splice_indexes, label_delay=None):
function ParseLstmDelayString (line 211) | def ParseLstmDelayString(lstm_delay):
function MakeConfigs (line 232) | def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets,
function ProcessSpliceIndexes (line 326) | def ProcessSpliceIndexes(config_dir, splice_indexes, label_delay, num_ls...
function Main (line 347) | def Main():
FILE: egs/steps/nnet3/multilingual/allocate_multilingual_examples.py
function get_args (line 60) | def get_args():
function read_lines (line 107) | def read_lines(file_handle, num_lines):
function process_multilingual_egs (line 119) | def process_multilingual_egs(args):
function main (line 222) | def main():
FILE: egs/steps/nnet3/report/convert_model.py
function read_next_token (line 25) | def read_next_token(s, pos):
function check_for_newline (line 51) | def check_for_newline(s, pos):
function read_float (line 70) | def read_float(s, pos):
function read_int (line 88) | def read_int(s, pos):
function read_vector (line 106) | def read_vector(s, pos):
function read_matrix (line 138) | def read_matrix(s, pos):
function is_component_type (line 193) | def is_component_type(component_type):
function read_generic (line 200) | def read_generic(s, pos, terminating_token, action_dict):
function get_action_dict (line 248) | def get_action_dict(component_type):
function get_stdout_from_command (line 293) | def get_stdout_from_command(command):
function read_component (line 309) | def read_component(s, pos):
function read_model (line 339) | def read_model(filename):
function compute_derived_quantities (line 390) | def compute_derived_quantities(model):
function compute_progress (line 417) | def compute_progress(model1, model2):
function test (line 454) | def test():
FILE: egs/steps/nnet3/report/generate_plots.py
function get_args (line 41) | def get_args():
class LatexReport (line 86) | class LatexReport(object):
method __init__ (line 89) | def __init__(self, pdf_file):
method add_figure (line 98) | def add_figure(self, figure_pdf, title):
method close (line 117) | def close(self):
method compile (line 121) | def compile(self):
function latex_compliant_name (line 141) | def latex_compliant_name(name_string):
function generate_acc_logprob_plots (line 152) | def generate_acc_logprob_plots(exp_dir, output_dir, plot, key='accuracy',
function insert_a_column_legend (line 219) | def insert_a_column_legend(legend_handle, legend_label, lp, mp, hp,
function plot_a_nonlin_component (line 229) | def plot_a_nonlin_component(fig, dirs, stat_tables_per_component_per_dir,
function generate_nonlin_stats_plots (line 351) | def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_di...
function generate_clipped_proportion_plots (line 482) | def generate_clipped_proportion_plots(exp_dir, output_dir, plot,
function generate_parameter_diff_plots (line 579) | def generate_parameter_diff_plots(exp_dir, output_dir, plot,
function generate_plots (line 710) | def generate_plots(exp_dir, output_dir, output_names, comparison_dir=None,
function main (line 782) | def main():
FILE: egs/steps/nnet3/report/summarize_compute_debug_timing.py
function GetArgs (line 19) | def GetArgs():
function FindOpenParanthesisPosition (line 38) | def FindOpenParanthesisPosition(string):
function ExtractCommandName (line 63) | def ExtractCommandName(command_string):
function Main (line 76) | def Main():
FILE: egs/steps/nnet3/tdnn/make_configs.py
function GetArgs (line 21) | def GetArgs():
function CheckArgs (line 138) | def CheckArgs(args):
function AddConvMaxpLayer (line 221) | def AddConvMaxpLayer(config_lines, name, input, args):
function AddCnnLayers (line 241) | def AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_l...
function PrintConfig (line 273) | def PrintConfig(file_name, config_lines):
function ParseCnnString (line 280) | def ParseCnnString(cnn_param_string_list):
function ParseSpliceString (line 301) | def ParseSpliceString(splice_indexes):
function MakeConfigs (line 336) | def MakeConfigs(config_dir, splice_indexes_string,
function Main (line 537) | def Main():
FILE: egs/steps/nnet3/train_dnn.py
function get_args (line 39) | def get_args():
function process_args (line 109) | def process_args(args):
function train (line 168) | def train(args, run_opts):
function main (line 456) | def main():
FILE: egs/steps/nnet3/train_raw_dnn.py
function get_args (line 38) | def get_args():
function process_args (line 127) | def process_args(args):
function train (line 186) | def train(args, run_opts):
function main (line 498) | def main():
FILE: egs/steps/nnet3/train_raw_rnn.py
function get_args (line 38) | def get_args():
function process_args (line 167) | def process_args(args):
function train (line 232) | def train(args, run_opts):
function main (line 558) | def main():
FILE: egs/steps/nnet3/train_rnn.py
function get_args (line 38) | def get_args():
function process_args (line 158) | def process_args(args):
function train (line 223) | def train(args, run_opts):
function main (line 545) | def main():
FILE: egs/steps/nnet3/xconfig_to_config.py
function get_args (line 25) | def get_args():
function write_config_file (line 64) | def write_config_file(config_file_out, all_layers):
function main (line 92) | def main():
FILE: egs/steps/nnet3/xconfig_to_configs.py
function get_args (line 23) | def get_args():
function check_args (line 62) | def check_args(args):
function backup_xconfig_file (line 68) | def backup_xconfig_file(xconfig_file, config_dir):
function write_expanded_xconfig_files (line 99) | def write_expanded_xconfig_files(config_dir, all_layers):
function get_config_headers (line 142) | def get_config_headers():
function write_config_files (line 176) | def write_config_files(config_dir, all_layers):
function add_nnet_context_info (line 233) | def add_nnet_context_info(config_dir, nnet_edits=None,
function check_model_contexts (line 268) | def check_model_contexts(config_dir, nnet_edits=None, existing_model=None):
function main (line 317) | def main():
FILE: egs/steps/overlap/get_overlap_segments.py
function get_args (line 21) | def get_args():
class Segment (line 32) | class Segment:
method __init__ (line 35) | def __init__(self, reco_id, start_time, dur = None, end_time = None, s...
function groupby (line 46) | def groupby(iterable, keyfunc):
function find_overlapping_segments (line 52) | def find_overlapping_segments(segs, label):
function find_single_speaker_segments (line 77) | def find_single_speaker_segments(segs):
function main (line 108) | def main():
FILE: egs/steps/overlap/get_overlap_targets.py
function get_args (line 29) | def get_args():
class Segment (line 64) | class Segment:
method __init__ (line 72) | def __init__(self, reco_id, start_time, dur = None, end_time = None, l...
function groupby (line 83) | def groupby(iterable, keyfunc):
function run (line 89) | def run(args):
function main (line 149) | def main():
FILE: egs/steps/overlap/output_to_rttm.py
function get_args (line 34) | def get_args():
function to_str (line 94) | def to_str(segment):
class SegmenterStats (line 100) | class SegmenterStats(object):
method __init__ (line 103) | def __init__(self):
method add (line 113) | def add(self, other):
method __str__ (line 124) | def __str__(self):
function process_label (line 143) | def process_label(text_label):
class Segmentation (line 159) | class Segmentation(object):
method __init__ (line 164) | def __init__(self, region_type):
method initialize_segments (line 170) | def initialize_segments(self, alignment, frame_shift=0.01):
method filter_short_segments (line 204) | def filter_short_segments(self, min_dur):
method pad_segments (line 222) | def pad_segments(self, segment_padding, max_duration=float("inf")):
method merge_consecutive_segments (line 260) | def merge_consecutive_segments(self, max_dur):
method write (line 282) | def write(self, key, file_handle):
function run (line 293) | def run(args):
function main (line 330) | def main():
FILE: egs/steps/overlap/prepare_overlap_graph.py
function get_args (line 24) | def get_args():
function print_states (line 79) | def print_states(args, file_handle):
function main (line 197) | def main():
FILE: egs/steps/pytorchnn/compute_sentence_scores.py
function load_nbest (line 19) | def load_nbest(path):
function read_vocab (line 53) | def read_vocab(path):
function get_input_and_target (line 77) | def get_input_and_target(hyp, vocab):
function compute_sentence_score (line 110) | def compute_sentence_score(model, criterion, ntokens, data, target,
function compute_scores (line 148) | def compute_scores(nbest, model, criterion, ntokens, vocab, model_type='...
function write_scores (line 202) | def write_scores(nbest_and_scores, path):
function main (line 224) | def main():
FILE: egs/steps/pytorchnn/data.py
class Dictionary (line 9) | class Dictionary(object):
method __init__ (line 10) | def __init__(self):
method read_vocab (line 14) | def read_vocab(self, path):
method __len__ (line 24) | def __len__(self):
class Corpus (line 28) | class Corpus(object):
method __init__ (line 29) | def __init__(self, path):
method tokenize (line 36) | def tokenize(self, path):
FILE: egs/steps/pytorchnn/model.py
class RNNModel (line 10) | class RNNModel(nn.Module):
method __init__ (line 12) | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5,
method init_weights (line 41) | def init_weights(self):
method forward (line 47) | def forward(self, x, hidden):
method init_hidden (line 54) | def init_hidden(self, bsz):
class PositionalEncoding (line 62) | class PositionalEncoding(nn.Module):
method __init__ (line 79) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 91) | def forward(self, x):
class TransformerModel (line 106) | class TransformerModel(nn.Module):
method __init__ (line 109) | def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5,
method _generate_square_subsequent_mask (line 133) | def _generate_square_subsequent_mask(self, sz):
method init_weights (line 139) | def init_weights(self):
method forward (line 145) | def forward(self, src, has_mask=True):
FILE: egs/steps/pytorchnn/train.py
function batchify (line 92) | def batchify(data, bsz, random_start_idx=False):
function repackage_hidden (line 135) | def repackage_hidden(h):
function get_batch (line 143) | def get_batch(source, i):
function train (line 150) | def train():
function evaluate (line 189) | def evaluate(source):
FILE: egs/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py
function get_args (line 33) | def get_args():
function main (line 55) | def main():
FILE: egs/steps/segmentation/internal/arc_info_to_targets.py
function get_args (line 31) | def get_args():
function run (line 63) | def run(args):
function main (line 151) | def main():
FILE: egs/steps/segmentation/internal/find_oov_phone.py
function main (line 16) | def main():
FILE: egs/steps/segmentation/internal/get_default_targets_for_out_of_segments.py
function get_args (line 38) | def get_args():
function run (line 90) | def run(args):
function main (line 171) | def main():
FILE: egs/steps/segmentation/internal/get_transform_probs_mat.py
function get_args (line 12) | def get_args():
function run (line 50) | def run(args):
function main (line 73) | def main():
FILE: egs/steps/segmentation/internal/merge_segment_targets_to_recording.py
function get_args (line 33) | def get_args():
function read_reco2utt_file (line 80) | def read_reco2utt_file(reco2utt_file):
function read_reco2num_frames_file (line 93) | def read_reco2num_frames_file(reco2num_frames_file):
function read_segments_file (line 107) | def read_segments_file(segments_file, reco2utt):
function read_targets_scp (line 126) | def read_targets_scp(targets_scp, segments):
function run (line 142) | def run(args):
function main (line 282) | def main():
FILE: egs/steps/segmentation/internal/merge_targets.py
function get_args (line 37) | def get_args():
function should_remove_frame (line 86) | def should_remove_frame(row, dim):
function run (line 160) | def run(args):
function main (line 205) | def main():
FILE: egs/steps/segmentation/internal/prepare_sad_graph.py
function get_args (line 34) | def get_args():
function print_states (line 81) | def print_states(args, file_handle):
function main (line 154) | def main():
FILE: egs/steps/segmentation/internal/resample_targets.py
function get_args (line 35) | def get_args():
function run (line 70) | def run(args):
function main (line 108) | def main():
FILE: egs/steps/segmentation/internal/sad_to_segments.py
function get_args (line 34) | def get_args():
function to_str (line 90) | def to_str(segment):
class SegmenterStats (line 96) | class SegmenterStats(object):
method __init__ (line 99) | def __init__(self):
method add (line 109) | def add(self, other):
method __str__ (line 120) | def __str__(self):
function process_label (line 139) | def process_label(text_label):
class Segmentation (line 154) | class Segmentation(object):
method __init__ (line 157) | def __init__(self):
method initialize_segments (line 161) | def initialize_segments(self, alignment, frame_shift=0.01):
method filter_short_segments (line 195) | def filter_short_segments(self, min_dur):
method pad_speech_segments (line 213) | def pad_speech_segments(self, segment_padding, max_duration=float("inf...
method merge_consecutive_segments (line 251) | def merge_consecutive_segments(self, max_dur):
method write (line 273) | def write(self, key, file_handle):
function run (line 286) | def run(args):
function main (line 323) | def main():
FILE: egs/steps/segmentation/internal/verify_phones_list.py
function get_args (line 12) | def get_args():
function main (line 25) | def main():
FILE: egs/steps/tfrnnlm/lstm.py
class Config (line 48) | class Config(object):
function data_type (line 62) | def data_type():
class RNNLMModel (line 66) | class RNNLMModel(tf.Module):
method __init__ (line 69) | def __init__(self, config, logits_bias_initializer=None):
method get_logits (line 97) | def get_logits(self, word_ids, is_training=False):
method get_loss (line 106) | def get_loss(self, word_ids, labels, is_training=False):
method get_score (line 111) | def get_score(self, logits):
method get_initial_state (line 116) | def get_initial_state(self):
method single_step (line 124) | def single_step(self, context, word_id):
class RNNLMModelTrainer (line 142) | class RNNLMModelTrainer(tf.Module):
method __init__ (line 145) | def __init__(self, model: RNNLMModel, config):
method train_one_epoch (line 154) | def train_one_epoch(self, data_producer, learning_rate, verbose=True):
method evaluate (line 164) | def evaluate(self, data_producer):
method _train_step (line 173) | def _train_step(self, inputs, labels):
function get_config (line 184) | def get_config():
function main (line 188) | def main(_):
FILE: egs/steps/tfrnnlm/lstm_fast.py
class Config (line 50) | class Config(object):
function data_type (line 65) | def data_type():
function new_softmax (line 72) | def new_softmax(labels, logits):
class MyFastLossFunction (line 87) | class MyFastLossFunction(LossFunctionWrapper):
method __init__ (line 88) | def __init__(self):
class FastRNNLMModel (line 92) | class FastRNNLMModel(RNNLMModel):
method __init__ (line 93) | def __init__(self, config):
method get_loss (line 96) | def get_loss(self, word_ids, labels, is_training=False):
method get_score (line 101) | def get_score(self, logits):
function get_config (line 106) | def get_config():
function main (line 110) | def main(_):
FILE: egs/steps/tfrnnlm/reader.py
function _read_words (line 28) | def _read_words(filename):
function _build_vocab (line 32) | def _build_vocab(filename):
function _file_to_word_ids (line 38) | def _file_to_word_ids(filename, word_to_id):
function rnnlm_raw_data (line 43) | def rnnlm_raw_data(data_path, vocab_path):
function rnnlm_gen_data (line 64) | def rnnlm_gen_data(*files):
class RNNLMProducer (line 92) | class RNNLMProducer(tf.Module):
method __init__ (line 95) | def __init__(self, raw_data, batch_size, num_steps, name=None):
method iterate (line 112) | def iterate(self):
FILE: egs/steps/tfrnnlm/vanilla_rnnlm.py
class Config (line 54) | class Config(object):
function data_type (line 68) | def data_type():
class RnnlmInput (line 72) | class RnnlmInput(object):
method __init__ (line 75) | def __init__(self, config, data, name=None):
class RnnlmModel (line 82) | class RnnlmModel(object):
method __init__ (line 85) | def __init__(self, is_training, config, input_):
method assign_lr (line 211) | def assign_lr(self, session, lr_value):
method input (line 215) | def input(self):
method initial_state (line 219) | def initial_state(self):
method cost (line 223) | def cost(self):
method final_state (line 227) | def final_state(self):
method lr (line 231) | def lr(self):
method train_op (line 235) | def train_op(self):
function run_epoch (line 238) | def run_epoch(session, model, eval_op=None, verbose=False):
function get_config (line 272) | def get_config():
function main (line 275) | def main(_):
FILE: egs/utils/ctm/resolve_ctm_overlaps.py
function get_args (line 38) | def get_args():
function read_segments (line 61) | def read_segments(segments_file):
function read_ctm (line 84) | def read_ctm(ctm_file, segments):
function resolve_overlaps (line 127) | def resolve_overlaps(ctms, segments):
function ctm_line_to_string (line 265) | def ctm_line_to_string(line):
function write_ctm (line 271) | def write_ctm(ctm_lines, out_file):
function run (line 277) | def run(args):
function main (line 302) | def main():
FILE: egs/utils/data/extend_segment_times.py
function FloatToString (line 93) | def FloatToString(f):
FILE: egs/utils/data/get_allowed_durations.py
function get_args (line 32) | def get_args():
function read_kaldi_mapfile (line 89) | def read_kaldi_mapfile(path):
function find_duration_range (line 104) | def find_duration_range(utt2dur, coverage_factor):
function get_allowed_durations (line 136) | def get_allowed_durations(start_dur, end_dur, args):
function get_trivial_allowed_durations (line 165) | def get_trivial_allowed_durations(utt2dur, args):
function main (line 197) | def main():
FILE: egs/utils/data/get_uniform_subsegments.py
function get_args (line 13) | def get_args():
function run (line 65) | def run(args):
function main (line 109) | def main():
FILE: egs/utils/data/internal/choose_utts_to_combine.py
function LessThan (line 71) | def LessThan(x, y):
function CombineList (line 90) | def CombineList(min_duration, durations):
function SelfTest (line 188) | def SelfTest():
function GetUtteranceGroups (line 232) | def GetUtteranceGroups(min_duration, merge_within_speakers_only, spk2utt...
FILE: egs/utils/data/internal/combine_segments_to_recording.py
function get_args (line 12) | def get_args():
function main (line 32) | def main():
FILE: egs/utils/data/internal/modify_speaker_info.py
function SplitIntoGroups (line 65) | def SplitIntoGroups(uttlist):
function GetFormatString (line 87) | def GetFormatString(d):
FILE: egs/utils/data/internal/perturb_volume.py
function get_args (line 18) | def get_args():
function read_reco2vol (line 50) | def read_reco2vol(volumes_file):
function run (line 69) | def run(args):
function main (line 109) | def main():
FILE: egs/utils/data/perturb_speed_to_allowed_lengths.py
function get_args (line 30) | def get_args():
class Utterance (line 63) | class Utterance(object):
method __init__ (line 68) | def __init__(self, uid, wavefile, speaker, transcription, dur):
method to_kaldi_utt_str (line 76) | def to_kaldi_utt_str(self):
method to_kaldi_wave_str (line 79) | def to_kaldi_wave_str(self):
method to_kaldi_dur_str (line 82) | def to_kaldi_dur_str(self):
function read_kaldi_datadir (line 86) | def read_kaldi_datadir(dir):
function read_kaldi_mapfile (line 126) | def read_kaldi_mapfile(path):
function generate_kaldi_data_files (line 140) | def generate_kaldi_data_files(utterances, outdir):
function find_duration_range (line 178) | def find_duration_range(utterances, coverage_factor):
function find_allowed_durations (line 212) | def find_allowed_durations(start_dur, end_dur, args):
function perturb_utterances (line 242) | def perturb_utterances(utterances, allowed_durations, args):
function main (line 310) | def main():
FILE: egs/utils/lang/bpe/apply_bpe.py
class BPE (line 27) | class BPE(object):
method __init__ (line 29) | def __init__(self, codes, merges=-1, separator='@@', vocab=None, gloss...
method process_line (line 62) | def process_line(self, line):
method segment (line 79) | def segment(self, sentence):
method _isolate_glossaries (line 102) | def _isolate_glossaries(self, word):
function create_parser (line 109) | def create_parser():
function get_pairs (line 150) | def get_pairs(word):
function encode (line 162) | def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version...
function recursive_split (line 226) | def recursive_split(segment, bpe_codes, vocab, separator, final=False):
function check_vocab_and_split (line 253) | def check_vocab_and_split(orig, bpe_codes, vocab, separator):
function read_vocabulary (line 278) | def read_vocabulary(vocab_file, threshold):
function isolate_glossary (line 292) | def isolate_glossary(word, glossary):
FILE: egs/utils/lang/bpe/bidi.py
function determine_text_direction (line 15) | def determine_text_direction(text):
function utf8_visual_to_logical (line 23) | def utf8_visual_to_logical(text):
function utf8_logical_to_visual (line 37) | def utf8_logical_to_visual(text):
FILE: egs/utils/lang/bpe/learn_bpe.py
function create_parser (line 30) | def create_parser():
function get_vocabulary (line 58) | def get_vocabulary(fobj, is_dict=False):
function update_pair_statistics (line 76) | def update_pair_statistics(pair, changed, stats, indices):
function get_pair_statistics (line 135) | def get_pair_statistics(vocab):
function replace_pair (line 154) | def replace_pair(pair, vocab, indices):
function prune_stats (line 178) | def prune_stats(stats, big_stats, threshold):
function main (line 194) | def main(infile, outfile, num_symbols, min_frequency=2, verbose=False, i...
FILE: egs/utils/lang/compute_sentence_probs_arpa.py
function check_args (line 29) | def check_args(args):
function is_logprob (line 35) | def is_logprob(input):
function check_number (line 45) | def check_number(model_file, tot_num):
function load_model (line 58) | def load_model(model_file):
function compute_sublist_prob (line 84) | def compute_sublist_prob(sub_list):
function compute_begin_prob (line 97) | def compute_begin_prob(sub_list):
function compute_sentence_prob (line 111) | def compute_sentence_prob(sentence, ngram_order):
function output_result (line 132) | def output_result(text_in_handle, output_file_handle, ngram_order):
FILE: egs/utils/lang/grammar/augment_phones_txt.py
function get_args (line 9) | def get_args():
function read_phones_txt (line 27) | def read_phones_txt(filename):
function read_nonterminals (line 57) | def read_nonterminals(filename):
function write_phones_txt (line 73) | def write_phones_txt(orig_lines, highest_numbered_symbol, nonterminals, ...
function main (line 88) | def main():
FILE: egs/utils/lang/grammar/augment_words_txt.py
function get_args (line 9) | def get_args():
function read_words_txt (line 28) | def read_words_txt(filename):
function read_nonterminals (line 58) | def read_nonterminals(filename):
function write_words_txt (line 74) | def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, f...
function main (line 88) | def main():
FILE: egs/utils/lang/internal/arpa2fst_constrained.py
class HistoryState (line 48) | class HistoryState(object):
method __init__ (line 49) | def __init__(self):
class ArpaModel (line 60) | class ArpaModel(object):
method __init__ (line 61) | def __init__(self):
method Read (line 71) | def Read(self, arpa_in):
method GetProb (line 167) | def GetProb(self, hist, word):
method GetStateForHist (line 187) | def GetStateForHist(self, hist_to_state, hist):
method GetHistToStateMap (line 199) | def GetHistToStateMap(self):
method PrintAsFst (line 242) | def PrintAsFst(self, disambig_symbol, bigram_map):
function ReadBigramMap (line 344) | def ReadBigramMap(bigrams_file):
FILE: egs/utils/lang/limit_arpa_unk_history.py
function get_ngram_stats (line 38) | def get_ngram_stats(old_lm_lines):
function find_and_replace_unks (line 59) | def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
function read_old_lm (line 122) | def read_old_lm():
function write_new_lm (line 133) | def write_new_lm(new_lm_lines, ngram_counts, ngram_diffs):
function main (line 152) | def main():
FILE: egs/utils/lang/make_kn_lm.py
class CountsForHistory (line 46) | class CountsForHistory:
method __init__ (line 51) | def __init__(self):
method words (line 60) | def words(self):
method __str__ (line 63) | def __str__(self):
method add_count (line 70) | def add_count(self, predicted_word, context_word, count):
class NgramCounts (line 79) | class NgramCounts:
method __init__ (line 88) | def __init__(self, ngram_order, bos_symbol='<s>', eos_symbol='</s>'):
method add_count (line 105) | def add_count(self, history, predicted_word, context_word, count):
method add_raw_counts_from_line (line 110) | def add_raw_counts_from_line(self, line):
method add_raw_counts_from_standard_input (line 128) | def add_raw_counts_from_standard_input(self):
method add_raw_counts_from_file (line 140) | def add_raw_counts_from_file(self, filename):
method cal_discounting_constants (line 152) | def cal_discounting_constants(self):
method cal_f (line 172) | def cal_f(self):
method cal_bow (line 205) | def cal_bow(self):
method print_raw_counts (line 248) | def print_raw_counts(self, info_string):
method print_modified_counts (line 263) | def print_modified_counts(self, info_string):
method print_f (line 284) | def print_f(self, info_string):
method print_f_and_bow (line 303) | def print_f_and_bow(self, info_string):
method print_as_arpa (line 326) | def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encod...
FILE: egs/utils/lang/make_lexicon_fst.py
function get_args (line 21) | def get_args():
function read_lexiconp (line 60) | def read_lexiconp(filename):
function write_nonterminal_arcs (line 119) | def write_nonterminal_arcs(start_state, loop_state, next_state,
function write_fst_no_silence (line 173) | def write_fst_no_silence(lexicon, nonterminals=None, left_context_phones...
function write_fst_with_silence (line 220) | def write_fst_with_silence(lexicon, sil_prob, sil_phone, sil_disambig,
function write_words_txt (line 308) | def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, f...
function read_nonterminals (line 322) | def read_nonterminals(filename):
function read_left_context_phones (line 338) | def read_left_context_phones(filename):
function is_token (line 354) | def is_token(s):
function main (line 363) | def main():
FILE: egs/utils/lang/make_lexicon_fst_silprob.py
function get_args (line 23) | def get_args():
function read_silprobs (line 70) | def read_silprobs(filename):
function read_lexiconp (line 116) | def read_lexiconp(filename):
function write_nonterminal_arcs (line 184) | def write_nonterminal_arcs(start_state, sil_state, non_sil_state,
function write_fst (line 266) | def write_fst(lexicon, silprobs, sil_phone, sil_disambig,
function read_nonterminals (line 356) | def read_nonterminals(filename):
function read_left_context_phones (line 372) | def read_left_context_phones(filename):
function main (line 387) | def main():
FILE: egs/utils/lang/make_phone_lm.py
class CountsForHistory (line 69) | class CountsForHistory(object):
method __init__ (line 74) | def __init__(self):
method Words (line 80) | def Words(self):
method __str__ (line 83) | def __str__(self):
method AddCount (line 99) | def AddCount(self, predicted_word, count):
class NgramCounts (line 113) | class NgramCounts(object):
method __init__ (line 122) | def __init__(self, ngram_order):
method AddCount (line 145) | def AddCount(self, history, predicted_word, count):
method AddRawCountsFromLine (line 151) | def AddRawCountsFromLine(self, line):
method AddRawCountsFromStandardInput (line 165) | def AddRawCountsFromStandardInput(self):
method ApplyBackoff (line 184) | def ApplyBackoff(self):
method Print (line 218) | def Print(self, info_string):
method GetHistToStateMap (line 233) | def GetHistToStateMap(self):
method GetProb (line 249) | def GetProb(self, hist, word):
method PruneEmptyStates (line 270) | def PruneEmptyStates(self):
method EnsureStructurallyNeededNgramsExist (line 300) | def EnsureStructurallyNeededNgramsExist(self):
method PrintAsFst (line 339) | def PrintAsFst(self, word_disambig_symbol):
method GetProtectedNgrams (line 384) | def GetProtectedNgrams(self):
method PruneNgram (line 409) | def PruneNgram(self, hist, word):
method PruningLogprobChange (line 442) | def PruningLogprobChange(self, count, discount, backoff_count, backoff...
method GetLikeChangeFromPruningNgram (line 547) | def GetLikeChangeFromPruningNgram(self, hist, word):
method PruneToIntermediateTarget (line 568) | def PruneToIntermediateTarget(self, num_extra_ngrams):
method PruneToFinalTarget (line 643) | def PruneToFinalTarget(self, num_extra_ngrams):
method GetNumExtraNgrams (line 694) | def GetNumExtraNgrams(self):
method GetNumNgrams (line 702) | def GetNumNgrams(self, hist_len = None):
method IntToString (line 721) | def IntToString(self, i):
method PrintAsArpa (line 732) | def PrintAsArpa(self):
FILE: egs/utils/lang/make_position_dependent_subword_lexicon.py
function get_args (line 22) | def get_args():
function is_end (line 39) | def is_end(subword, separator):
function write_position_dependent_lexicon (line 44) | def write_position_dependent_lexicon(lexiconp, separator):
function main (line 101) | def main():
FILE: egs/utils/lang/make_subword_lexicon_fst.py
function get_args (line 12) | def get_args():
function contain_disambig_symbol (line 48) | def contain_disambig_symbol(phones):
function print_arc (line 55) | def print_arc(src, dest, phone, word, cost):
function is_end (line 58) | def is_end(word, separator):
function get_suffix (line 63) | def get_suffix(phone):
function write_fst_no_silence (line 71) | def write_fst_no_silence(lexicon, position_dependent, separator):
function write_fst_with_silence (line 162) | def write_fst_with_silence(lexicon, sil_phone, sil_prob, sil_disambig, p...
function main (line 287) | def main():
FILE: egs/utils/nnet/gen_dct_mat.py
function print_on_same_line (line 31) | def print_on_same_line(text):
FILE: egs/utils/nnet/gen_hamm_mat.py
function print_on_same_line (line 29) | def print_on_same_line(text):
FILE: egs/utils/nnet/gen_splice.py
function print_on_same_line (line 28) | def print_on_same_line(text):
FILE: egs/utils/nnet/make_nnet_proto.py
function Glorot (line 105) | def Glorot(dim1, dim2):
FILE: egs/utils/nnet3/convert_config_tdnn_to_affine.py
function get_parser (line 12) | def get_parser():
function main (line 30) | def main(args):
FILE: lm/chainer_backend/extlm.py
class MultiLevelLM (line 15) | class MultiLevelLM(chainer.Chain):
method __init__ (line 19) | def __init__(
method __call__ (line 45) | def __call__(self, state, x):
method final (line 95) | def final(self, state):
class LookAheadWordLM (line 106) | class LookAheadWordLM(chainer.Chain):
method __init__ (line 110) | def __init__(
method __call__ (line 127) | def __call__(self, state, x):
method final (line 192) | def final(self, state):
FILE: lm/chainer_backend/lm.py
class DefaultRNNLM (line 51) | class DefaultRNNLM(LMInterface, link.Chain):
method add_arguments (line 60) | def add_arguments(parser):
class ClassifierWithState (line 78) | class ClassifierWithState(link.Chain):
method __init__ (line 86) | def __init__(
method __call__ (line 104) | def __call__(self, state, *args, **kwargs):
method predict (line 145) | def predict(self, state, x):
method final (line 159) | def final(self, state):
class RNNLM (line 174) | class RNNLM(chainer.Chain):
method __init__ (line 183) | def __init__(self, n_vocab, n_layers, n_units, typ="lstm"):
method __call__ (line 204) | def __call__(self, state, x):
class BPTTUpdater (line 244) | class BPTTUpdater(training.updaters.StandardUpdater):
method __init__ (line 254) | def __init__(self, train_iter, optimizer, schedulers, device, accum_gr...
method update_core (line 260) | def update_core(self):
class LMEvaluator (line 301) | class LMEvaluator(BaseEvaluator):
method __init__ (line 309) | def __init__(self, val_iter, eval_model, device):
method evaluate (line 312) | def evaluate(self):
function train (line 333) | def train(args):
FILE: lm/lm_utils.py
function load_dataset (line 21) | def load_dataset(path, label_dict, outdir=None):
function read_tokens (line 61) | def read_tokens(filename, label_dict):
function count_tokens (line 81) | def count_tokens(data, unk_id=None):
function compute_perplexity (line 102) | def compute_perplexity(result):
class ParallelSentenceIterator (line 113) | class ParallelSentenceIterator(chainer.dataset.Iterator):
method __init__ (line 122) | def __init__(
method __next__ (line 165) | def __next__(self):
method start_shuffle (line 194) | def start_shuffle(self):
method epoch_detail (line 198) | def epoch_detail(self):
method previous_epoch_detail (line 203) | def previous_epoch_detail(self):
method serialize (line 208) | def serialize(self, serializer):
class MakeSymlinkToBestModel (line 227) | class MakeSymlinkToBestModel(extension.Extension):
method __init__ (line 235) | def __init__(self, key, prefix="model", suffix="best"):
method __call__ (line 243) | def __call__(self, trainer):
method serialize (line 257) | def serialize(self, serializer):
function make_lexical_tree (line 274) | def make_lexical_tree(word_dict, subword_dict, word_unk):
FILE: lm/pytorch_backend/extlm.py
class MultiLevelLM (line 18) | class MultiLevelLM(nn.Module):
method __init__ (line 22) | def __init__(
method forward (line 50) | def forward(self, state, x):
method final (line 107) | def final(self, state):
class LookAheadWordLM (line 118) | class LookAheadWordLM(nn.Module):
method __init__ (line 122) | def __init__(
method forward (line 142) | def forward(self, state, x):
method final (line 215) | def final(self, state):
FILE: lm/pytorch_backend/lm.py
function compute_perplexity (line 50) | def compute_perplexity(result):
class Reporter (line 63) | class Reporter(Chain):
method report (line 66) | def report(self, loss):
function concat_examples (line 71) | def concat_examples(batch, device=None, padding=None):
class BPTTUpdater (line 89) | class BPTTUpdater(training.StandardUpdater):
method __init__ (line 92) | def __init__(
method update_core (line 126) | def update_core(self):
class LMEvaluator (line 169) | class LMEvaluator(BaseEvaluator):
method __init__ (line 172) | def __init__(self, val_iter, eval_model, reporter, device):
method evaluate (line 185) | def evaluate(self):
function train (line 213) | def train(args):
FILE: mt/mt_utils.py
function parse_hypothesis (line 13) | def parse_hypothesis(hyp, char_list):
function add_results_to_json (line 35) | def add_results_to_json(js, nbest_hyps, char_list):
FILE: mt/pytorch_backend/mt.py
class CustomConverter (line 56) | class CustomConverter(object):
method __init__ (line 59) | def __init__(self):
method __call__ (line 68) | def __call__(self, batch, device=torch.device("cpu")):
function train (line 96) | def train(args):
function trans (line 514) | def trans(args):
FILE: nets/asr_interface.py
class ASRInterface (line 9) | class ASRInterface:
method add_arguments (line 13) | def add_arguments(parser):
method build (line 18) | def build(cls, idim: int, odim: int, **kwargs):
method forward (line 38) | def forward(self, xs, ilens, ys):
method recognize (line 55) | def recognize(self, x, recog_args, char_list=None, rnnlm=None):
method recognize_batch (line 67) | def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
method calculate_all_attentions (line 79) | def calculate_all_attentions(self, xs, ilens, ys):
method calculate_all_ctc_probs (line 90) | def calculate_all_ctc_probs(self, xs, ilens, ys):
method attention_plot_class (line 102) | def attention_plot_class(self):
method ctc_plot_class (line 109) | def ctc_plot_class(self):
method get_total_subsampling_factor (line 115) | def get_total_subsampling_factor(self):
method encode (line 121) | def encode(self, feat):
method scorers (line 133) | def scorers(self):
function dynamic_import_asr (line 157) | def dynamic_import_asr(module, backend):
FILE: nets/batch_beam_search.py
class BatchHypothesis (line 17) | class BatchHypothesis(NamedTuple):
method __len__ (line 26) | def __len__(self) -> int:
class BatchBeamSearch (line 31) | class BatchBeamSearch(BeamSearch):
method batchfy (line 34) | def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
method _batch_select (line 48) | def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> Batc...
method _select (line 60) | def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
method unbatchfy (line 70) | def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
method batch_beam (line 85) | def batch_beam(
method init_hyp (line 111) | def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
method score_full (line 137) | def score_full(
method score_partial (line 160) | def score_partial(
method merge_states (line 186) | def merge_states(self, states: Any, part_states: Any, part_idx: int) -...
method search (line 207) | def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> Ba...
method post_process (line 286) | def post_process(
FILE: nets/batch_beam_search_online_sim.py
class BatchBeamSearchOnlineSim (line 16) | class BatchBeamSearchOnlineSim(BatchBeamSearch):
method set_streaming_config (line 28) | def set_streaming_config(self, asr_config: str):
method set_block_size (line 72) | def set_block_size(self, block_size: int):
method set_hop_size (line 80) | def set_hop_size(self, hop_size: int):
method set_look_ahead (line 88) | def set_look_ahead(self, look_ahead: int):
method forward (line 96) | def forward(
method extend (line 255) | def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
FILE: nets/beam_search.py
class Hypothesis (line 19) | class Hypothesis(NamedTuple):
method asdict (line 27) | def asdict(self) -> dict:
method __str__ (line 35) | def __str__(self):
class BeamSearch (line 46) | class BeamSearch(object):
method __init__ (line 49) | def __init__(
method init_hyp (line 131) | def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
method append_token (line 156) | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
method score_full (line 170) | def score_full(
method score_partial (line 193) | def score_partial(
method beam (line 217) | def beam(
method merge_scores (line 247) | def merge_scores(
method merge_states (line 278) | def merge_states(self, states: Any, part_states: Any, part_idx: int) -...
method search (line 299) | def search(
method __call__ (line 357) | def __call__(
method post_process (line 445) | def post_process(
function beam_search (line 496) | def beam_search(
FILE: nets/beam_search_transducer.py
class BeamSearchTransducer (line 25) | class BeamSearchTransducer:
method __init__ (line 28) | def __init__(
method __call__ (line 162) | def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NS...
method sort_nbest (line 186) | def sort_nbest(
method vocab_regularization (line 205) | def vocab_regularization(self, hyps):
method greedy_search (line 218) | def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
method ctc_greedy_search (line 247) | def ctc_greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
method ctc_beam_search (line 258) | def ctc_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
method default_beam_search (line 359) | def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
method time_sync_decoding (line 431) | def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
method align_length_sync_decoding (line 531) | def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothes...
method nsc_beam_search (line 740) | def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
function log_add (line 931) | def log_add(args: List[int]) -> float:
function is_all_chinese (line 941) | def is_all_chinese(strs):
FILE: nets/chainer_backend/asr_interface.py
class ChainerASRInterface (line 7) | class ChainerASRInterface(ASRInterface, chainer.Chain):
method custom_converter (line 11) | def custom_converter(*args, **kw):
method custom_updater (line 16) | def custom_updater(*args, **kw):
method custom_parallel_updater (line 21) | def custom_parallel_updater(*args, **kw):
method get_total_subsampling_factor (line 25) | def get_total_subsampling_factor(self):
FILE: nets/chainer_backend/ctc.py
class CTC (line 10) | class CTC(chainer.Chain):
method __init__ (line 20) | def __init__(self, odim, eprojs, dropout_rate):
method __call__ (line 28) | def __call__(self, hs, ys):
method log_softmax (line 72) | def log_softmax(self, hs):
class WarpCTC (line 87) | class WarpCTC(chainer.Chain):
method __init__ (line 97) | def __init__(self, odim, eprojs, dropout_rate):
method __call__ (line 105) | def __call__(self, hs, ys):
method log_softmax (line 140) | def log_softmax(self, hs):
method argmax (line 154) | def argmax(self, hs_pad):
function ctc_for (line 164) | def ctc_for(args, odim):
FILE: nets/chainer_backend/deterministic_embed_id.py
class EmbedIDFunction (line 22) | class EmbedIDFunction(function_node.FunctionNode):
method __init__ (line 23) | def __init__(self, ignore_label=None):
method check_type_forward (line 27) | def check_type_forward(self, in_types):
method forward (line 36) | def forward(self, inputs):
method backward (line 63) | def backward(self, indexes, grad_outputs):
class EmbedIDGrad (line 71) | class EmbedIDGrad(function_node.FunctionNode):
method __init__ (line 72) | def __init__(self, w_shape, ignore_label=None):
method forward (line 77) | def forward(self, inputs):
method backward (line 125) | def backward(self, indexes, grads):
function embed_id (line 147) | def embed_id(x, W, ignore_label=None):
class EmbedID (line 194) | class EmbedID(link.Link):
method __init__ (line 234) | def __init__(self, in_size, out_size, initialW=None, ignore_label=None):
method __call__ (line 243) | def __call__(self, x):
FILE: nets/chainer_backend/e2e_asr.py
class E2E (line 25) | class E2E(ChainerASRInterface):
method add_arguments (line 39) | def add_arguments(parser):
method get_total_subsampling_factor (line 43) | def get_total_subsampling_factor(self):
method __init__ (line 47) | def __init__(self, idim, odim, args, flag_return=True):
method forward (line 93) | def forward(self, xs, ilens, ys):
method recognize (line 147) | def recognize(self, x, recog_args, char_list, rnnlm=None):
method calculate_all_attentions (line 182) | def calculate_all_attentions(self, xs, ilens, ys):
method custom_converter (line 200) | def custom_converter(subsampling_factor=0):
method custom_updater (line 207) | def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1):
method custom_parallel_updater (line 216) | def custom_parallel_updater(iters, optimizer, converter, devices, accu...
FILE: nets/chainer_backend/e2e_asr_transformer.py
class E2E (line 39) | class E2E(ChainerASRInterface):
method add_arguments (line 53) | def add_arguments(parser):
method get_total_subsampling_factor (line 142) | def get_total_subsampling_factor(self):
method __init__ (line 146) | def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True):
method reset_parameters (line 217) | def reset_parameters(self, args):
method forward (line 251) | def forward(self, xs, ilens, ys_pad, calculate_attentions=False):
method calculate_attentions (line 344) | def calculate_attentions(self, xs, x_mask, ys_pad):
method recognize (line 348) | def recognize(self, x_block, recog_args, char_list=None, rnnlm=None):
method recognize_beam (line 380) | def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None):
method calculate_all_attentions (line 573) | def calculate_all_attentions(self, xs, ilens, ys):
method attention_plot_class (line 598) | def attention_plot_class(self):
method custom_converter (line 610) | def custom_converter(subsampling_factor=0):
method custom_updater (line 615) | def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1):
method custom_parallel_updater (line 622) | def custom_parallel_updater(iters, optimizer, converter, devices, accu...
FILE: nets/chainer_backend/nets_utils.py
function _subsamplex (line 4) | def _subsamplex(x, n):
FILE: nets/chainer_backend/rnn/attentions.py
class AttDot (line 9) | class AttDot(chainer.Chain):
method __init__ (line 19) | def __init__(self, eprojs, dunits, att_dim):
method reset (line 32) | def reset(self):
method __call__ (line 38) | def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0):
class AttLoc (line 87) | class AttLoc(chainer.Chain):
method __init__ (line 99) | def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
method reset (line 118) | def reset(self):
method __call__ (line 124) | def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0):
class NoAtt (line 200) | class NoAtt(chainer.Chain):
method __init__ (line 208) | def __init__(self):
method reset (line 215) | def reset(self):
method __call__ (line 222) | def __call__(self, enc_hs, dec_z, att_prev):
function att_for (line 258) | def att_for(args):
FILE: nets/chainer_backend/rnn/decoders.py
class Decoder (line 21) | class Decoder(chainer.Chain):
method __init__ (line 42) | def __init__(
method rnn_forward (line 90) | def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
method __call__ (line 118) | def __call__(self, hs, ys):
method recognize_beam (line 221) | def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None):
method calculate_all_attentions (line 444) | def calculate_all_attentions(self, hs, ys):
function decoder_for (line 498) | def decoder_for(args, odim, sos, eos, att, labeldist):
FILE: nets/chainer_backend/rnn/encoders.py
class RNNP (line 16) | class RNNP(chainer.Chain):
method __init__ (line 30) | def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ=...
method __call__ (line 58) | def __call__(self, xs, ilens):
class RNN (line 94) | class RNN(chainer.Chain):
method __init__ (line 107) | def __init__(self, idim, elayers, cdim, hdim, dropout, typ="lstm"):
method __call__ (line 121) | def __call__(self, xs, ilens):
class VGG2L (line 155) | class VGG2L(chainer.Chain):
method __init__ (line 163) | def __init__(self, in_channel=1):
method __call__ (line 174) | def __call__(self, xs, ilens):
class Encoder (line 227) | class Encoder(chainer.Chain):
method __init__ (line 241) | def __init__(
method __call__ (line 293) | def __call__(self, xs, ilens):
function encoder_for (line 310) | def encoder_for(args, idim, subsample):
FILE: nets/chainer_backend/rnn/training.py
function sum_sqnorm (line 23) | def sum_sqnorm(arr):
class CustomUpdater (line 43) | class CustomUpdater(training.StandardUpdater):
method __init__ (line 68) | def __init__(self, train_iter, optimizer, converter, device, accum_gra...
method update_core (line 79) | def update_core(self):
method update (line 112) | def update(self):
class CustomParallelUpdater (line 118) | class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
method __init__ (line 145) | def __init__(self, train_iters, optimizer, converter, devices, accum_g...
method update_core (line 156) | def update_core(self):
method update (line 212) | def update(self):
class CustomConverter (line 218) | class CustomConverter(object):
method __init__ (line 226) | def __init__(self, subsampling_factor=1):
method __call__ (line 229) | def __call__(self, batch, device):
FILE: nets/chainer_backend/transformer/attention.py
class MultiHeadAttention (line 14) | class MultiHeadAttention(chainer.Chain):
method __init__ (line 30) | def __init__(self, n_units, h=8, dropout=0.1, initialW=None, initial_b...
method forward (line 65) | def forward(self, e_var, s_var=None, mask=None, batch=1):
FILE: nets/chainer_backend/transformer/ctc.py
class CTC (line 12) | class CTC(chainer.Chain):
method __init__ (line 22) | def __init__(self, odim, eprojs, dropout_rate):
method __call__ (line 31) | def __call__(self, hs, ys):
method log_softmax (line 75) | def log_softmax(self, hs):
class WarpCTC (line 90) | class WarpCTC(chainer.Chain):
method __init__ (line 100) | def __init__(self, odim, eprojs, dropout_rate):
method forward (line 117) | def forward(self, hs, ys):
method log_softmax (line 148) | def log_softmax(self, hs):
method argmax (line 162) | def argmax(self, hs_pad):
FILE: nets/chainer_backend/transformer/decoder.py
class Decoder (line 17) | class Decoder(chainer.Chain):
method __init__ (line 32) | def __init__(self, odim, args, initialW=None, initial_bias=None):
method make_attention_mask (line 70) | def make_attention_mask(self, source_block, target_block):
method forward (line 84) | def forward(self, ys_pad, source, x_mask):
method recognize (line 112) | def recognize(self, e, yy_mask, source):
FILE: nets/chainer_backend/transformer/decoder_layer.py
class DecoderLayer (line 15) | class DecoderLayer(chainer.Chain):
method __init__ (line 26) | def __init__(
method forward (line 58) | def forward(self, e, s, xy_mask, yy_mask, batch):
FILE: nets/chainer_backend/transformer/embedding.py
class PositionalEncoding (line 10) | class PositionalEncoding(chainer.Chain):
method __init__ (line 19) | def __init__(self, n_units, dropout=0.1, length=5000):
method forward (line 33) | def forward(self, e):
FILE: nets/chainer_backend/transformer/encoder.py
class Encoder (line 19) | class Encoder(chainer.Chain):
method __init__ (line 34) | def __init__(
method forward (line 104) | def forward(self, e, ilens):
FILE: nets/chainer_backend/transformer/encoder_layer.py
class EncoderLayer (line 15) | class EncoderLayer(chainer.Chain):
method __init__ (line 26) | def __init__(
method forward (line 51) | def forward(self, e, xx_mask, batch):
FILE: nets/chainer_backend/transformer/label_smoothing_loss.py
class LabelSmoothingLoss (line 11) | class LabelSmoothingLoss(chainer.Chain):
method __init__ (line 21) | def __init__(self, smoothing, n_target_vocab, normalize_length=False, ...
method forward (line 35) | def forward(self, ys_block, ys_pad):
FILE: nets/chainer_backend/transformer/layer_norm.py
class LayerNorm (line 7) | class LayerNorm(L.LayerNormalization):
method __init__ (line 10) | def __init__(self, dims, eps=1e-12):
method __call__ (line 14) | def __call__(self, e):
FILE: nets/chainer_backend/transformer/mask.py
function make_history_mask (line 4) | def make_history_mask(xp, block):
FILE: nets/chainer_backend/transformer/positionwise_feed_forward.py
class PositionwiseFeedForward (line 12) | class PositionwiseFeedForward(chainer.Chain):
method __init__ (line 22) | def __init__(
method __call__ (line 55) | def __call__(self, e):
FILE: nets/chainer_backend/transformer/subsampling.py
class Conv2dSubsampling (line 15) | class Conv2dSubsampling(chainer.Chain):
method __init__ (line 24) | def __init__(
method forward (line 63) | def forward(self, xs, ilens):
class LinearSampling (line 82) | class LinearSampling(chainer.Chain):
method __init__ (line 91) | def __init__(self, idim, dims, dropout=0.1, initialW=None, initial_bia...
method forward (line 105) | def forward(self, xs, ilens):
FILE: nets/chainer_backend/transformer/training.py
function sum_sqnorm (line 20) | def sum_sqnorm(arr):
class CustomUpdater (line 40) | class CustomUpdater(training.StandardUpdater):
method __init__ (line 65) | def __init__(self, train_iter, optimizer, converter, device, accum_gra...
method update_core (line 77) | def update_core(self):
method update (line 108) | def update(self):
class CustomParallelUpdater (line 115) | class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
method __init__ (line 141) | def __init__(self, train_iters, optimizer, converter, devices, accum_g...
method update_core (line 154) | def update_core(self):
method update (line 208) | def update(self):
class VaswaniRule (line 215) | class VaswaniRule(extension.Extension):
method __init__ (line 233) | def __init__(
method initialize (line 253) | def initialize(self, trainer):
method __call__ (line 264) | def __call__(self, trainer):
method serialize (line 273) | def serialize(self, serializer):
method _get_optimizer (line 278) | def _get_optimizer(self, trainer):
method _update_value (line 282) | def _update_value(self, optimizer, value):
class CustomConverter (line 288) | class CustomConverter(object):
method __init__ (line 296) | def __init__(self):
method __call__ (line 300) | def __call__(self, batch, device):
FILE: nets/ctc_prefix_score.py
class CTCPrefixScoreTH (line 12) | class CTCPrefixScoreTH(object):
method __init__ (line 23) | def __init__(self, x, xlens, blank, eos, margin=0):
method __call__ (line 69) | def __call__(self, y, state, scoring_ids=None, att_w=None):
method index_select_state (line 190) | def index_select_state(self, state, best_ids):
method extend_prob (line 223) | def extend_prob(self, x):
method extend_state (line 245) | def extend_state(self, state):
class CTCPrefixScore (line 273) | class CTCPrefixScore(object):
method __init__ (line 284) | def __init__(self, x, blank, eos, xp):
method initial_state (line 292) | def initial_state(self):
method __call__ (line 306) | def __call__(self, y, cs, r_prev):
FILE: nets/e2e_asr_common.py
function end_detect (line 19) | def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
function label_smoothing_dist (line 53) | def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
function get_vgg2l_odim (line 87) | def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
class ErrorCalculator (line 101) | class ErrorCalculator(object):
method __init__ (line 112) | def __init__(
method __call__ (line 130) | def __call__(self, ys_hat, ys_pad, is_ctc=False):
method calculate_cer_ctc (line 155) | def calculate_cer_ctc(self, ys_hat, ys_pad):
method convert_to_char (line 187) | def convert_to_char(self, ys_hat, ys_pad):
method calculate_cer (line 212) | def calculate_cer(self, seqs_hat, seqs_true):
method calculate_wer (line 229) | def calculate_wer(self, seqs_hat, seqs_true):
FILE: nets/e2e_mt_common.py
class ErrorCalculator (line 13) | class ErrorCalculator(object):
method __init__ (line 24) | def __init__(self, char_list, sym_space, sym_pad, report_bleu=False):
method __call__ (line 36) | def __call__(self, ys_hat, ys_pad):
method calculate_corpus_bleu (line 51) | def calculate_corpus_bleu(self, ys_hat, ys_pad):
FILE: nets/lm_interface.py
class LMInterface (line 10) | class LMInterface(ScorerInterface):
method add_arguments (line 14) | def add_arguments(parser):
method build (line 19) | def build(cls, n_vocab: int, **kwargs):
method forward (line 40) | def forward(self, x, t):
function dynamic_import_lm (line 71) | def dynamic_import_lm(module, backend):
FILE: nets/mt_interface.py
class MTInterface (line 8) | class MTInterface:
method add_arguments (line 12) | def add_arguments(parser):
method build (line 17) | def build(cls, idim: int, odim: int, **kwargs):
method forward (line 37) | def forward(self, xs, ilens, ys):
method translate (line 54) | def translate(self, x, trans_args, char_list=None, rnnlm=None):
method translate_batch (line 66) | def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
method calculate_all_attentions (line 78) | def calculate_all_attentions(self, xs, ilens, ys):
method attention_plot_class (line 90) | def attention_plot_class(self):
FILE: nets/pytorch_backend/conformer/argument.py
function add_arguments_conformer_common (line 11) | def add_arguments_conformer_common(group):
function verify_rel_pos_type (line 66) | def verify_rel_pos_type(args):
FILE: nets/pytorch_backend/conformer/convolution.py
class ConvolutionModule (line 13) | class ConvolutionModule(nn.Module):
method __init__ (line 22) | def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=T...
method forward (line 59) | def forward(self, x):
FILE: nets/pytorch_backend/conformer/encoder.py
class Encoder (line 35) | class Encoder(torch.nn.Module):
method __init__ (line 66) | def __init__(
method forward (line 222) | def forward(self, xs, masks):
FILE: nets/pytorch_backend/conformer/encoder_layer.py
class EncoderLayer (line 17) | class EncoderLayer(nn.Module):
method __init__ (line 42) | def __init__(
method forward (line 76) | def forward(self, x_input, mask, cache=None):
FILE: nets/pytorch_backend/conformer/swish.py
class Swish (line 13) | class Swish(torch.nn.Module):
method forward (line 16) | def forward(self, x):
FILE: nets/pytorch_backend/ctc.py
class CTC (line 12) | class CTC(torch.nn.Module):
method __init__ (line 22) | def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", red...
method loss_fn (line 69) | def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
method forward (line 89) | def forward(self, hs_pad, hlens, ys_pad, texts):
method softmax (line 158) | def softmax(self, hs_pad):
method log_softmax (line 168) | def log_softmax(self, hs_pad):
method argmax (line 177) | def argmax(self, hs_pad):
method forced_align (line 186) | def forced_align(self, h, y, blank_id=0):
function ctc_for (line 252) | def ctc_for(args, odim, reduce=True):
FILE: nets/pytorch_backend/e2e_asr.py
class Reporter (line 43) | class Reporter(chainer.Chain):
method report (line 46) | def report(self, loss_ctc, loss_att, loss_third, loss_mbr, acc, cer_ct...
class E2E (line 60) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 70) | def add_arguments(parser):
method encoder_add_arguments (line 78) | def encoder_add_arguments(parser):
method attention_add_arguments (line 85) | def attention_add_arguments(parser):
method decoder_add_arguments (line 92) | def decoder_add_arguments(parser):
method get_total_subsampling_factor (line 98) | def get_total_subsampling_factor(self):
method __init__ (line 105) | def __init__(self, idim, odim, args):
method init_like_chainer (line 186) | def init_like_chainer(self):
method forward (line 204) | def forward(self, xs_pad, ilens, ys_pad):
method scorers (line 339) | def scorers(self):
method encode (line 343) | def encode(self, x):
method recognize (line 371) | def recognize(self, x, recog_args, char_list, rnnlm=None):
method recognize_batch (line 393) | def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
method enhance (line 446) | def enhance(self, xs):
method calculate_all_attentions (line 468) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
method calculate_all_ctc_probs (line 496) | def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad):
method subsample_frames (line 526) | def subsample_frames(self, x):
FILE: nets/pytorch_backend/e2e_asr_conformer.py
class E2E (line 21) | class E2E(E2ETransformer):
method add_arguments (line 31) | def add_arguments(parser):
method add_conformer_arguments (line 38) | def add_conformer_arguments(parser):
method __init__ (line 44) | def __init__(self, idim, odim, args, ignore_id=-1):
FILE: nets/pytorch_backend/e2e_asr_maskctc.py
class E2E (line 32) | class E2E(E2ETransformer):
method add_arguments (line 42) | def add_arguments(parser):
method add_maskctc_arguments (line 50) | def add_maskctc_arguments(parser):
method __init__ (line 63) | def __init__(self, idim, odim, args, ignore_id=-1):
method forward (line 102) | def forward(self, xs_pad, ilens, ys_pad):
method recognize (line 174) | def recognize(self, x, recog_args, char_list=None, rnnlm=None):
FILE: nets/pytorch_backend/e2e_asr_mix.py
class PIT (line 48) | class PIT(object):
method __init__ (line 54) | def __init__(self, num_spkrs):
method min_pit_sample (line 71) | def min_pit_sample(self, loss):
method pit_process (line 93) | def pit_process(self, losses):
method permutationDFS (line 110) | def permutationDFS(self, source, start):
class E2E (line 131) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 140) | def add_arguments(parser):
method encoder_mix_add_arguments (line 149) | def encoder_mix_add_arguments(parser):
method get_total_subsampling_factor (line 168) | def get_total_subsampling_factor(self):
method __init__ (line 172) | def __init__(self, idim, odim, args):
method init_like_chainer (line 255) | def init_like_chainer(self):
method forward (line 274) | def forward(self, xs_pad, ilens, ys_pad):
method recognize (line 491) | def recognize(self, x, recog_args, char_list, rnnlm=None):
method recognize_batch (line 547) | def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
method enhance (line 610) | def enhance(self, xs):
method calculate_all_attentions (line 638) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
class EncoderMix (line 697) | class EncoderMix(torch.nn.Module):
method __init__ (line 714) | def __init__(
method forward (line 782) | def forward(self, xs_pad, ilens):
function encoder_for (line 811) | def encoder_for(args, idim, subsample):
FILE: nets/pytorch_backend/e2e_asr_mix_transformer.py
class E2E (line 43) | class E2E(E2EASR, ASRInterface, torch.nn.Module):
method add_arguments (line 52) | def add_arguments(parser):
method __init__ (line 58) | def __init__(self, idim, odim, args, ignore_id=-1):
method forward (line 92) | def forward(self, xs_pad, ilens, ys_pad):
method decoder_and_attention (line 212) | def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size):
method encode (line 226) | def encode(self, x):
method recog (line 238) | def recog(self, enc_output, recog_args, char_list=None, rnnlm=None, us...
method recognize (line 443) | def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit...
FILE: nets/pytorch_backend/e2e_asr_mulenc.py
class Reporter (line 36) | class Reporter(chainer.Chain):
method report (line 39) | def report(self, loss_ctc_list, loss_att, acc, cer_ctc_list, cer, wer,...
class E2E (line 58) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 68) | def add_arguments(parser):
method encoder_add_arguments (line 77) | def encoder_add_arguments(parser):
method attention_add_arguments (line 132) | def attention_add_arguments(parser):
method decoder_add_arguments (line 250) | def decoder_add_arguments(parser):
method ctc_add_arguments (line 290) | def ctc_add_arguments(parser):
method get_total_subsampling_factor (line 314) | def get_total_subsampling_factor(self):
method __init__ (line 325) | def __init__(self, idims, odim, args):
method init_like_chainer (line 438) | def init_like_chainer(self):
method forward (line 484) | def forward(self, xs_pad_list, ilens_list, ys_pad):
method scorers (line 673) | def scorers(self):
method encode (line 682) | def encode(self, x_list):
method recognize (line 717) | def recognize(self, x_list, recog_args, char_list, rnnlm=None):
method recognize_batch (line 748) | def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None):
method calculate_all_attentions (line 821) | def calculate_all_attentions(self, xs_pad_list, ilens_list, ys_pad):
method calculate_all_ctc_probs (line 859) | def calculate_all_ctc_probs(self, xs_pad_list, ilens_list, ys_pad):
FILE: nets/pytorch_backend/e2e_asr_transducer.py
class Reporter (line 62) | class Reporter(chainer.Chain):
method report (line 65) | def report(
class E2E (line 95) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 108) | def add_arguments(parser):
method att_scorer_arguments (line 126) | def att_scorer_arguments(parser):
method encoder_add_general_arguments (line 134) | def encoder_add_general_arguments(parser):
method encoder_add_rnn_arguments (line 142) | def encoder_add_rnn_arguments(parser):
method encoder_add_custom_arguments (line 150) | def encoder_add_custom_arguments(parser):
method decoder_add_general_arguments (line 158) | def decoder_add_general_arguments(parser):
method decoder_add_rnn_arguments (line 166) | def decoder_add_rnn_arguments(parser):
method decoder_add_custom_arguments (line 174) | def decoder_add_custom_arguments(parser):
method training_add_custom_arguments (line 182) | def training_add_custom_arguments(parser):
method transducer_add_arguments (line 190) | def transducer_add_arguments(parser):
method auxiliary_task_add_arguments (line 198) | def auxiliary_task_add_arguments(parser):
method attention_plot_class (line 206) | def attention_plot_class(self):
method get_total_subsampling_factor (line 210) | def get_total_subsampling_factor(self):
method __init__ (line 219) | def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, trainin...
method default_parameters (line 472) | def default_parameters(self, args):
method forward (line 481) | def forward(self, xs_pad, ilens, ys_pad, texts, xs_pad_orig):
method mbr_forward (line 664) | def mbr_forward(self, xs_pad_orig, ilens, ys_pad, hs_pad):
method compute_edit_distance (line 776) | def compute_edit_distance(self, hyps, refs):
method encode_custom (line 796) | def encode_custom(self, x):
method encode_rnn (line 811) | def encode_rnn(self, x):
method recognize (line 833) | def recognize(self, x, beam_search, decode_feature="combine"):
method calculate_all_attentions (line 856) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, texts, xs_pa...
FILE: nets/pytorch_backend/e2e_asr_transducer_cs.py
class Reporter (line 47) | class Reporter(chainer.Chain):
method report (line 50) | def report(
class E2E (line 81) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 94) | def add_arguments(parser):
method encoder_add_general_arguments (line 110) | def encoder_add_general_arguments(parser):
method encoder_add_custom_arguments (line 118) | def encoder_add_custom_arguments(parser):
method decoder_add_general_arguments (line 126) | def decoder_add_general_arguments(parser):
method decoder_add_rnn_arguments (line 134) | def decoder_add_rnn_arguments(parser):
method training_add_custom_arguments (line 142) | def training_add_custom_arguments(parser):
method transducer_add_arguments (line 150) | def transducer_add_arguments(parser):
method transducer_add_code_switch_arguments (line 158) | def transducer_add_code_switch_arguments(parser):
method auxiliary_task_add_arguments (line 166) | def auxiliary_task_add_arguments(parser):
method get_total_subsampling_factor (line 173) | def get_total_subsampling_factor(self):
method __init__ (line 184) | def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, trainin...
method default_parameters (line 324) | def default_parameters(self, args):
method forward (line 334) | def forward(self, xs_pad, ilens, ys_pad, texts, xs_pad_orig):
method combine_fn (line 439) | def combine_fn(self, chn_hs_pad, eng_hs_pad, chn_hs_mask, eng_hs_mask):
method monolingual_mask (line 442) | def monolingual_mask(self, ys_pad):
method encoder_forward (line 455) | def encoder_forward(self, x):
method recognize (line 478) | def recognize(self, x, beam_search=None, decode_feature="combine"):
method calculate_all_attentions (line 493) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, texts, xs_pa...
FILE: nets/pytorch_backend/e2e_asr_transformer.py
class E2E (line 60) | class E2E(ASRInterface, torch.nn.Module):
method add_arguments (line 70) | def add_arguments(parser):
method attention_plot_class (line 79) | def attention_plot_class(self):
method get_total_subsampling_factor (line 83) | def get_total_subsampling_factor(self):
method __init__ (line 87) | def __init__(self, idim, odim, args, ignore_id=-1):
method reset_parameters (line 218) | def reset_parameters(self, args):
method forward (line 223) | def forward(self, xs_pad, ilens, ys_pad, texts, xs_pad_orig):
method mbr_forward (line 328) | def mbr_forward(self, xs_pad_orig, ilens, ys_pad, hs_pad, hs_mask):
method compute_edit_distance (line 392) | def compute_edit_distance(self, hyps, refs):
method scorers (line 412) | def scorers(self):
method encode (line 416) | def encode(self, x):
method recognize (line 428) | def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit...
method calculate_all_attentions (line 755) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, texts, xs_pa...
method calculate_all_ctc_probs (line 781) | def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, texts, xs_pad...
FILE: nets/pytorch_backend/e2e_mt.py
class Reporter (line 34) | class Reporter(chainer.Chain):
method report (line 37) | def report(self, loss, acc, ppl, bleu):
class E2E (line 45) | class E2E(MTInterface, torch.nn.Module):
method add_arguments (line 55) | def add_arguments(parser):
method encoder_add_arguments (line 63) | def encoder_add_arguments(parser):
method attention_add_arguments (line 70) | def attention_add_arguments(parser):
method decoder_add_arguments (line 77) | def decoder_add_arguments(parser):
method __init__ (line 83) | def __init__(self, idim, odim, args):
method init_like_fairseq (line 190) | def init_like_fairseq(self):
method forward (line 203) | def forward(self, xs_pad, ilens, ys_pad):
method target_language_biasing (line 260) | def target_language_biasing(self, xs_pad, ilens, ys_pad):
method translate (line 284) | def translate(self, x, trans_args, char_list, rnnlm=None):
method translate_batch (line 319) | def translate_batch(self, xs, trans_args, char_list, rnnlm=None):
method calculate_all_attentions (line 353) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
FILE: nets/pytorch_backend/e2e_mt_transformer.py
class E2E (line 38) | class E2E(MTInterface, torch.nn.Module):
method add_arguments (line 48) | def add_arguments(parser):
method attention_plot_class (line 55) | def attention_plot_class(self):
method __init__ (line 59) | def __init__(self, idim, odim, args, ignore_id=-1):
method reset_parameters (line 140) | def reset_parameters(self, args):
method forward (line 152) | def forward(self, xs_pad, ilens, ys_pad):
method scorers (line 203) | def scorers(self):
method encode (line 207) | def encode(self, xs):
method target_forcing (line 214) | def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None):
method translate (line 242) | def translate(self, x, trans_args, char_list=None):
method calculate_all_attentions (line 393) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
FILE: nets/pytorch_backend/e2e_st.py
class Reporter (line 46) | class Reporter(chainer.Chain):
method report (line 49) | def report(
class E2E (line 78) | class E2E(STInterface, torch.nn.Module):
method add_arguments (line 88) | def add_arguments(parser):
method encoder_add_arguments (line 96) | def encoder_add_arguments(parser):
method attention_add_arguments (line 103) | def attention_add_arguments(parser):
method decoder_add_arguments (line 110) | def decoder_add_arguments(parser):
method get_total_subsampling_factor (line 116) | def get_total_subsampling_factor(self):
method __init__ (line 120) | def __init__(self, idim, odim, args):
method init_like_chainer (line 263) | def init_like_chainer(self):
method forward (line 281) | def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
method forward_asr (line 387) | def forward_asr(self, hs_pad, hlens, ys_pad):
method forward_mt (line 502) | def forward_mt(self, xs_pad, ys_pad):
method scorers (line 527) | def scorers(self):
method encode (line 531) | def encode(self, x):
method translate (line 552) | def translate(self, x, trans_args, char_list, rnnlm=None):
method translate_batch (line 571) | def translate_batch(self, xs, trans_args, char_list, rnnlm=None):
method calculate_all_attentions (line 603) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src):
method calculate_all_ctc_probs (line 633) | def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src):
method subsample_frames (line 658) | def subsample_frames(self, x):
FILE: nets/pytorch_backend/e2e_st_conformer.py
class E2E (line 20) | class E2E(E2ETransformer):
method add_arguments (line 30) | def add_arguments(parser):
method add_conformer_arguments (line 37) | def add_conformer_arguments(parser):
method __init__ (line 43) | def __init__(self, idim, odim, args, ignore_id=-1):
FILE: nets/pytorch_backend/e2e_st_transformer.py
class E2E (line 41) | class E2E(STInterface, torch.nn.Module):
method add_arguments (line 51) | def add_arguments(parser):
method attention_plot_class (line 58) | def attention_plot_class(self):
method get_total_subsampling_factor (line 62) | def get_total_subsampling_factor(self):
method __init__ (line 66) | def __init__(self, idim, odim, args, ignore_id=-1):
method reset_parameters (line 183) | def reset_parameters(self, args):
method forward (line 192) | def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
method forward_asr (line 283) | def forward_asr(self, hs_pad, hs_mask, ys_pad):
method forward_mt (line 343) | def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask):
method scorers (line 375) | def scorers(self):
method encode (line 379) | def encode(self, x):
method translate (line 391) | def translate(
method calculate_all_attentions (line 540) | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src):
method calculate_all_ctc_probs (line 563) | def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src):
FILE: nets/pytorch_backend/e2e_tts_fastspeech.py
class FeedForwardTransformerLoss (line 34) | class FeedForwardTransformerLoss(torch.nn.Module):
method __init__ (line 37) | def __init__(self, use_masking=True, use_weighted_masking=False):
method forward (line 57) | def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens):
class FeedForwardTransformer (line 112) | class FeedForwardTransformer(TTSInterface, torch.nn.Module):
method add_arguments (line 128) | def add_arguments(parser):
method __init__ (line 377) | def __init__(self, idim, odim, args=None):
method _forward (line 567) | def _forward(
method forward (line 627) | def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *arg...
method calculate_all_attentions (line 686) | def calculate_all_attentions(
method inference (line 741) | def inference(self, x, inference_args, spemb=None, *args, **kwargs):
method _integrate_with_spk_embed (line 777) | def _integrate_with_spk_embed(self, hs, spembs):
method _source_mask (line 801) | def _source_mask(self, ilens):
method _load_teacher_model (line 822) | def _load_teacher_model(self, model_path):
method _reset_parameters (line 844) | def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_al...
method _transfer_from_teacher (line 853) | def _transfer_from_teacher(self, transferred_encoder_module):
method attention_plot_class (line 874) | def attention_plot_class(self):
method base_plot_keys (line 882) | def base_plot_keys(self):
FILE: nets/pytorch_backend/e2e_tts_tacotron2.py
class GuidedAttentionLoss (line 25) | class GuidedAttentionLoss(torch.nn.Module):
method __init__ (line 39) | def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
method _reset_masks (line 56) | def _reset_masks(self):
method forward (line 60) | def forward(self, att_ws, ilens, olens):
method _make_guided_attention_masks (line 84) | def _make_guided_attention_masks(self, ilens, olens):
method _make_guided_attention_mask (line 96) | def _make_guided_attention_mask(ilen, olen, sigma):
method _make_masks (line 128) | def _make_masks(ilens, olens):
class Tacotron2Loss (line 166) | class Tacotron2Loss(torch.nn.Module):
method __init__ (line 169) | def __init__(
method forward (line 198) | def forward(self, after_outs, before_outs, logits, ys, labels, olens):
method _load_state_dict_pre_hook (line 249) | def _load_state_dict_pre_hook(
class Tacotron2 (line 273) | class Tacotron2(TTSInterface, torch.nn.Module):
method add_arguments (line 288) | def add_arguments(parser):
method __init__ (line 519) | def __init__(self, idim, odim, args=None):
method forward (line 702) | def forward(
method inference (line 790) | def inference(self, x, inference_args, spemb=None, *args, **kwargs):
method calculate_all_attentions (line 838) | def calculate_all_attentions(
method base_plot_keys (line 875) | def base_plot_keys(self):
FILE: nets/pytorch_backend/e2e_tts_transformer.py
class GuidedMultiHeadAttentionLoss (line 31) | class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
method forward (line 42) | def forward(self, att_ws, ilens, olens):
class TTSPlot (line 77) | class TTSPlot(PlotAttentionReport):
method plotfn (line 80) | def plotfn(
class Transformer (line 115) | class Transformer(TTSInterface, torch.nn.Module):
method add_arguments (line 129) | def add_arguments(parser):
method attention_plot_class (line 438) | def attention_plot_class(self):
method __init__ (line 442) | def __init__(self, idim, odim, args=None):
method _reset_parameters (line 677) | def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_al...
method _add_first_frame_and_remove_last_frame (line 686) | def _add_first_frame_and_remove_last_frame(self, ys):
method forward (line 692) | def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **...
method inference (line 841) | def inference(self, x, inference_args, spemb=None, *args, **kwargs):
method calculate_all_attentions (line 942) | def calculate_all_attentions(
method _integrate_with_spk_embed (line 1053) | def _integrate_with_spk_embed(self, hs, spembs):
method _source_mask (line 1077) | def _source_mask(self, ilens):
method _target_mask (line 1098) | def _target_mask(self, olens):
method base_plot_keys (line 1129) | def base_plot_keys(self):
FILE: nets/pytorch_backend/e2e_vc_tacotron2.py
class Tacotron2 (line 29) | class Tacotron2(TTSInterface, torch.nn.Module):
method add_arguments (line 38) | def add_arguments(parser):
method __init__ (line 278) | def __init__(self, idim, odim, args=None):
method forward (line 491) | def forward(
method inference (line 658) | def inference(self, x, inference_args, spemb=None, *args, **kwargs):
method calculate_all_attentions (line 706) | def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, ...
method base_plot_keys (line 752) | def base_plot_keys(self):
method _sort_by_length (line 776) | def _sort_by_length(self, xs, ilens):
method _revert_sort_by_length (line 780) | def _revert_sort_by_length(self, xs, ilens, sort_idx):
FILE: nets/pytorch_backend/e2e_vc_transformer.py
class Transformer (line 34) | class Transformer(TTSInterface, torch.nn.Module):
method add_arguments (line 53) | def add_arguments(parser):
method attention_plot_class (line 372) | def attention_plot_class(self):
method __init__ (line 376) | def __init__(self, idim, odim, args=None):
method _reset_parameters (line 619) | def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_al...
method _add_first_frame_and_remove_last_frame (line 628) | def _add_first_frame_and_remove_last_frame(self, ys):
method forward (line 634) | def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **...
method inference (line 814) | def inference(self, x, inference_args, spemb=None, *args, **kwargs):
method calculate_all_attentions (line 928) | def calculate_all_attentions(
method _integrate_with_spk_embed (line 1055) | def _integrate_with_spk_embed(self, hs, spembs):
method _source_mask (line 1079) | def _source_mask(self, ilens):
method _target_mask (line 1100) | def _target_mask(self, olens):
method base_plot_keys (line 1131) | def base_plot_keys(self):
FILE: nets/pytorch_backend/fastspeech/duration_calculator.py
class DurationCalculator (line 16) | class DurationCalculator(torch.nn.Module):
method __init__ (line 24) | def __init__(self, teacher_model):
method forward (line 44) | def forward(self, xs, ilens, ys, olens, spembs=None):
method _calculate_duration (line 82) | def _calculate_duration(att_w, ilen, olen):
method _init_diagonal_head (line 87) | def _init_diagonal_head(self, att_ws):
method _calculate_encoder_decoder_attentions (line 91) | def _calculate_encoder_decoder_attentions(self, xs, ilens, ys, olens, ...
FILE: nets/pytorch_backend/fastspeech/duration_predictor.py
class DurationPredictor (line 14) | class DurationPredictor(torch.nn.Module):
method __init__ (line 33) | def __init__(
method _forward (line 68) | def _forward(self, xs, x_masks=None, is_inference=False):
method forward (line 87) | def forward(self, xs, x_masks=None):
method inference (line 101) | def inference(self, xs, x_masks=None):
class DurationPredictorLoss (line 116) | class DurationPredictorLoss(torch.nn.Module):
method __init__ (line 123) | def __init__(self, offset=1.0, reduction="mean"):
method forward (line 135) | def forward(self, outputs, targets):
FILE: nets/pytorch_backend/fastspeech/length_regulator.py
class LengthRegulator (line 20) | class LengthRegulator(torch.nn.Module):
method __init__ (line 34) | def __init__(self, pad_value=0.0):
method forward (line 48) | def forward(self, xs, ds, alpha=1.0):
method _repeat_one_sequence (line 76) | def _repeat_one_sequence(self, x, d):
method _legacy_repeat_one_sequence (line 80) | def _legacy_repeat_one_sequence(self, x, d):
FILE: nets/pytorch_backend/frontends/beamformer.py
function get_power_spectral_density_matrix (line 6) | def get_power_spectral_density_matrix(
function get_mvdr_vector (line 40) | def get_mvdr_vector(
function apply_beamforming_vector (line 79) | def apply_beamforming_vector(
FILE: nets/pytorch_backend/frontends/dnn_beamformer.py
class DNN_Beamformer (line 19) | class DNN_Beamformer(torch.nn.Module):
method __init__ (line 28) | def __init__(
method forward (line 56) | def forward(
class AttentionReference (line 139) | class AttentionReference(torch.nn.Module):
method __init__ (line 140) | def __init__(self, bidim, att_dim):
method forward (line 145) | def forward(
FILE: nets/pytorch_backend/frontends/dnn_wpe.py
class DNN_WPE (line 11) | class DNN_WPE(torch.nn.Module):
method __init__ (line 12) | def __init__(
method forward (line 41) | def forward(
FILE: nets/pytorch_backend/frontends/feature_transform.py
class FeatureTransform (line 13) | class FeatureTransform(torch.nn.Module):
method __init__ (line 14) | def __init__(
method forward (line 45) | def forward(
class LogMel (line 78) | class LogMel(torch.nn.Module):
method __init__ (line 98) | def __init__(
method extra_repr (line 120) | def extra_repr(self):
method forward (line 123) | def forward(
class GlobalMVN (line 135) | class GlobalMVN(torch.nn.Module):
method __init__ (line 149) | def __init__(
method extra_repr (line 174) | def extra_repr(self):
method forward (line 180) | def forward(
class UtteranceMVN (line 193) | class UtteranceMVN(torch.nn.Module):
method __init__ (line 194) | def __init__(
method extra_repr (line 202) | def extra_repr(self):
method forward (line 205) | def forward(
function utterance_mvn (line 213) | def utterance_mvn(
function feature_transform_for (line 250) | def feature_transform_for(args, n_fft):
FILE: nets/pytorch_backend/frontends/frontend.py
class Frontend (line 15) | class Frontend(nn.Module):
method __init__ (line 16) | def __init__(
method forward (line 88) | def forward(
function frontend_for (line 128) | def frontend_for(args, idim):
FILE: nets/pytorch_backend/frontends/mask_estimator.py
class MaskEstimator (line 13) | class MaskEstimator(torch.nn.Module):
method __init__ (line 14) | def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
method forward (line 30) | def forward(
FILE: nets/pytorch_backend/gtn_ctc.py
class GTNCTCLossFunction (line 10) | class GTNCTCLossFunction(torch.autograd.Function):
method create_ctc_graph (line 17) | def create_ctc_graph(target, blank_idx):
method forward (line 41) | def forward(ctx, log_probs, targets, blank_idx=0, reduction="none"):
method backward (line 87) | def backward(ctx, grad_output):
FILE: nets/pytorch_backend/initialization.py
function lecun_normal_init_parameters (line 11) | def lecun_normal_init_parameters(module):
function uniform_init_parameters (line 34) | def uniform_init_parameters(module):
function set_forget_bias_to_one (line 51) | def set_forget_bias_to_one(bias):
FILE: nets/pytorch_backend/lm/default.py
class DefaultRNNLM (line 18) | class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module):
method add_arguments (line 29) | def add_arguments(parser):
method __init__ (line 69) | def __init__(self, n_vocab, args):
method state_dict (line 100) | def state_dict(self):
method load_state_dict (line 104) | def load_state_dict(self, d):
method forward (line 108) | def forward(self, x, t):
method score (line 140) | def score(self, y, state, x):
method final_score (line 157) | def final_score(self, state):
method batch_score (line 170) | def batch_score(
class ClassifierWithState (line 218) | class ClassifierWithState(nn.Module):
method __init__ (line 221) | def __init__(
method forward (line 240) | def forward(self, state, *args, **kwargs):
method predict (line 282) | def predict(self, state, x):
method buff_predict (line 296) | def buff_predict(self, state, x, n):
method final (line 311) | def final(self, state, index=None):
class RNNLM (line 328) | class RNNLM(nn.Module):
method __init__ (line 331) | def __init__(
method zero_state (line 393) | def zero_state(self, batchsize):
method forward (line 398) | def forward(self, state, x):
FILE: nets/pytorch_backend/lm/seq_rnn.py
class SequentialRNNLM (line 10) | class SequentialRNNLM(LMInterface, torch.nn.Module):
method add_arguments (line 19) | def add_arguments(parser):
method __init__ (line 40) | def __init__(self, n_vocab, args):
method _setup (line 58) | def _setup(
method _init_weights (line 98) | def _init_weights(self):
method forward (line 108) | def forward(self, x, t):
method _before_loss (line 134) | def _before_loss(self, input, hidden):
method init_state (line 143) | def init_state(self, x):
method score (line 162) | def score(self, y, state, x):
FILE: nets/pytorch_backend/lm/transformer.py
class TransformerLM (line 20) | class TransformerLM(nn.Module, LMInterface, BatchScorerInterface):
method add_arguments (line 24) | def add_arguments(parser):
method __init__ (line 79) | def __init__(self, n_vocab, args):
method _target_mask (line 137) | def _target_mask(self, ys_in_pad):
method forward (line 142) | def forward(
method score (line 178) | def score(
method score_partial (line 207) | def score_partial(
method select_state (line 214) | def select_state(self, states, i):
method batch_score (line 218) | def batch_score(
FILE: nets/pytorch_backend/maskctc/add_mask_token.py
function mask_uniform (line 13) | def mask_uniform(ys_pad, mask_token, eos, ignore_id):
FILE: nets/pytorch_backend/maskctc/mask.py
function square_mask (line 11) | def square_mask(ys_in_pad, ignore_id):
FILE: nets/pytorch_backend/nets_utils.py
function to_device (line 12) | def to_device(m, x):
function pad_list (line 34) | def pad_list(xs, pad_value):
function make_pad_mask (line 64) | def make_pad_mask(lengths, xs=None, length_dim=-1):
function make_non_pad_mask (line 179) | def make_non_pad_mask(lengths, xs=None, length_dim=-1):
function mask_by_length (line 268) | def mask_by_length(xs, lengths, fill=0):
function th_accuracy (line 299) | def th_accuracy(pad_outputs, pad_targets, ignore_label):
function to_torch_tensor (line 322) | def to_torch_tensor(x):
function get_subsample (line 390) | def get_subsample(train_args, mode, arch):
function rename_state_dict (line 471) | def rename_state_dict(
function get_activation (line 485) | def get_activation(act):
FILE: nets/pytorch_backend/rnn/argument.py
function add_arguments_rnn_encoder_common (line 7) | def add_arguments_rnn_encoder_common(group):
function add_arguments_rnn_decoder_common (line 60) | def add_arguments_rnn_decoder_common(group):
function add_arguments_rnn_attention_common (line 99) | def add_arguments_rnn_attention_common(group):
FILE: nets/pytorch_backend/rnn/attentions.py
function _apply_attention_constraint (line 13) | def _apply_attention_constraint(
class NoAtt (line 46) | class NoAtt(torch.nn.Module):
method __init__ (line 49) | def __init__(self):
method reset (line 56) | def reset(self):
method forward (line 63) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
class AttDot (line 94) | class AttDot(torch.nn.Module):
method __init__ (line 104) | def __init__(self, eprojs, dunits, att_dim, han_mode=False):
method reset (line 118) | def reset(self):
method forward (line 125) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
class AttAdd (line 171) | class AttAdd(torch.nn.Module):
method __init__ (line 181) | def __init__(self, eprojs, dunits, att_dim, han_mode=False):
method reset (line 195) | def reset(self):
method forward (line 202) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
class AttLoc (line 250) | class AttLoc(torch.nn.Module):
method __init__ (line 265) | def __init__(
method reset (line 290) | def reset(self):
method forward (line 297) | def forward(
class AttCov (line 383) | class AttCov(torch.nn.Module):
method __init__ (line 396) | def __init__(self, eprojs, dunits, att_dim, han_mode=False):
method reset (line 412) | def reset(self):
method forward (line 419) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scalin...
class AttLoc2D (line 485) | class AttLoc2D(torch.nn.Module):
method __init__ (line 502) | def __init__(
method reset (line 529) | def reset(self):
method forward (line 536) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
class AttLocRec (line 606) | class AttLocRec(torch.nn.Module):
method __init__ (line 622) | def __init__(
method reset (line 647) | def reset(self):
method forward (line 654) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scal...
class AttCovLoc (line 729) | class AttCovLoc(torch.nn.Module):
method __init__ (line 743) | def __init__(
method reset (line 769) | def reset(self):
method forward (line 776) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scalin...
class AttMultiHeadDot (line 845) | class AttMultiHeadDot(torch.nn.Module):
method __init__ (line 860) | def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_m...
method reset (line 883) | def reset(self):
method forward (line 891) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
class AttMultiHeadAdd (line 958) | class AttMultiHeadAdd(torch.nn.Module):
method __init__ (line 975) | def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_m...
method reset (line 1000) | def reset(self):
method forward (line 1008) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
class AttMultiHeadLoc (line 1075) | class AttMultiHeadLoc(torch.nn.Module):
method __init__ (line 1094) | def __init__(
method reset (line 1141) | def reset(self):
method forward (line 1149) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
class AttMultiHeadMultiResLoc (line 1232) | class AttMultiHeadMultiResLoc(torch.nn.Module):
method __init__ (line 1254) | def __init__(
method reset (line 1298) | def reset(self):
method forward (line 1306) | def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
class AttForward (line 1388) | class AttForward(torch.nn.Module):
method __init__ (line 1402) | def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
method reset (line 1423) | def reset(self):
method forward (line 1430) | def forward(
class AttForwardTA (line 1518) | class AttForwardTA(torch.nn.Module):
method __init__ (line 1533) | def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, ...
method reset (line 1556) | def reset(self):
method forward (line 1563) | def forward(
function att_for (line 1661) | def att_for(args, num_att=1, han_mode=False):
function initial_att (line 1726) | def initial_att(
function att_to_numpy (line 1774) | def att_to_numpy(att_ws, att):
FILE: nets/pytorch_backend/rnn/decoders.py
class Decoder (line 29) | class Decoder(torch.nn.Module, ScorerInterface):
method __init__ (line 50) | def __init__(
method zero_state (line 124) | def zero_state(self, hs_pad):
method rnn_forward (line 127) | def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
method forward (line 142) | def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
method recognize_beam (line 313) | def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, st...
method recognize_beam_batch (line 632) | def recognize_beam_batch(
method calculate_all_attentions (line 969) | def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, l...
method _get_last_yseq (line 1080) | def _get_last_yseq(exp_yseq):
method _append_ids (line 1087) | def _append_ids(yseq, ids):
method _index_select_list (line 1097) | def _index_select_list(yseq, lst):
method _index_select_lm_state (line 1104) | def _index_select_lm_state(rnnlm_state, dim, vidx):
method init_state (line 1116) | def init_state(self, x):
method score (line 1144) | def score(self, yseq, state, x):
function decoder_for (line 1199) | def decoder_for(args, odim, sos, eos, att, labeldist):
FILE: nets/pytorch_backend/rnn/encoders.py
class RNNP (line 15) | class RNNP(torch.nn.Module):
method __init__ (line 27) | def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ=...
method forward (line 56) | def forward(self, xs_pad, ilens, prev_state=None):
class RNN (line 95) | class RNN(torch.nn.Module):
method __init__ (line 106) | def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
method forward (line 134) | def forward(self, xs_pad, ilens, prev_state=None):
function reset_backward_rnn_state (line 165) | def reset_backward_rnn_state(states):
class VGG2L (line 178) | class VGG2L(torch.nn.Module):
method __init__ (line 184) | def __init__(self, in_channel=1):
method forward (line 194) | def forward(self, xs_pad, ilens, **kwargs):
class Encoder (line 240) | class Encoder(torch.nn.Module):
method __init__ (line 253) | def __init__(
method forward (line 307) | def forward(self, xs_pad, ilens, prev_states=None):
function encoder_for (line 331) | def encoder_for(args, idim, subsample):
FILE: nets/pytorch_backend/streaming/segment.py
class SegmentStreamingE2E (line 5) | class SegmentStreamingE2E(object):
method __init__ (line 12) | def __init__(self, e2e, recog_args, rnnlm=None):
method accept_input (line 42) | def accept_input(self, x):
FILE: nets/pytorch_backend/streaming/window.py
class WindowStreamingE2E (line 6) | class WindowStreamingE2E(object):
method __init__ (line 13) | def __init__(self, e2e, recog_args, rnnlm=None):
method accept_input (line 31) | def accept_input(self, x):
method _input_window_for_decoder (line 45) | def _input_window_for_decoder(self, use_all=False):
method decode_with_attention_offline (line 68) | def decode_with_attention_offline(self):
FILE: nets/pytorch_backend/tacotron2/cbhg.py
class CBHGLoss (line 18) | class CBHGLoss(torch.nn.Module):
method __init__ (line 21) | def __init__(self, use_masking=True):
method forward (line 31) | def forward(self, cbhg_outs, spcs, olens):
class CBHG (line 57) | class CBHG(torch.nn.Module):
method __init__ (line 69) | def __init__(
method forward (line 172) | def forward(self, xs, ilens):
method inference (line 216) | def inference(self, x):
method _sort_by_length (line 232) | def _sort_by_length(self, xs, ilens):
method _revert_sort_by_length (line 236) | def _revert_sort_by_length(self, xs, ilens, sort_idx):
class HighwayNet (line 241) | class HighwayNet(torch.nn.Module):
method __init__ (line 250) | def __init__(self, idim):
method forward (line 264) | def forward(self, x):
FILE: nets/pytorch_backend/tacotron2/decoder.py
function decoder_init (line 17) | def decoder_init(m):
class ZoneOutCell (line 23) | class ZoneOutCell(torch.nn.Module):
method __init__ (line 42) | def __init__(self, cell, zoneout_rate=0.1):
method forward (line 60) | def forward(self, inputs, hidden):
method _zoneout (line 79) | def _zoneout(self, h, next_h, prob):
class Prenet (line 96) | class Prenet(torch.nn.Module):
method __init__ (line 116) | def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
method forward (line 135) | def forward(self, x):
class Postnet (line 150) | class Postnet(torch.nn.Module):
method __init__ (line 165) | def __init__(
method forward (line 254) | def forward(self, xs):
class Decoder (line 269) | class Decoder(torch.nn.Module):
method __init__ (line 283) | def __init__(
method _zero_state (line 388) | def _zero_state(self, hs):
method forward (line 392) | def forward(self, hs, hlens, ys):
method inference (line 482) | def inference(
method calculate_all_attentions (line 618) | def calculate_all_attentions(self, hs, hlens, ys):
FILE: nets/pytorch_backend/tacotron2/encoder.py
function encoder_init (line 17) | def encoder_init(m):
class Encoder (line 23) | class Encoder(torch.nn.Module):
method __init__ (line 36) | def __init__(
method forward (line 130) | def forward(self, xs, ilens=None):
method inference (line 162) | def inference(self, x):
FILE: nets/pytorch_backend/transducer/arguments.py
function add_encoder_general_arguments (line 7) | def add_encoder_general_arguments(group):
function add_rnn_encoder_arguments (line 44) | def add_rnn_encoder_arguments(group):
function add_custom_encoder_arguments (line 74) | def add_custom_encoder_arguments(group):
function add_decoder_general_arguments (line 128) | def add_decoder_general_arguments(group):
function add_rnn_decoder_arguments (line 153) | def add_rnn_decoder_arguments(group):
function add_custom_decoder_arguments (line 171) | def add_custom_decoder_arguments(group):
function add_custom_training_arguments (line 204) | def add_custom_training_arguments(group):
function add_transducer_arguments (line 222) | def add_transducer_arguments(group):
function add_auxiliary_task_arguments (line 261) | def add_auxiliary_task_arguments(group):
function add_att_scorer_arguments (line 368) | def add_att_scorer_arguments(group):
function add_transducer_code_switch_arguments (line 453) | def add_transducer_code_switch_arguments(group):
FILE: nets/pytorch_backend/transducer/auxiliary_task.py
class AuxiliaryTask (line 14) | class AuxiliaryTask(torch.nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 54) | def forward(
FILE: nets/pytorch_backend/transducer/blocks.py
function check_and_prepare (line 38) | def check_and_prepare(net_part, blocks_arch, input_layer):
function get_pos_enc_and_att_class (line 208) | def get_pos_enc_and_att_class(net_part, pos_enc_type, self_attn_type):
function build_input_layer (line 242) | def build_input_layer(
function build_transformer_block (line 311) | def build_transformer_block(net_part, block_arch, pw_layer_type, pw_acti...
function build_conformer_block (line 356) | def build_conformer_block(
function build_causal_conv1d_block (line 412) | def build_causal_conv1d_block(block_arch):
function build_tdnn_block (line 429) | def build_tdnn_block(block_arch):
function build_blocks (line 464) | def build_blocks(
FILE: nets/py
Condensed preview — 1103 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (8,529K chars).
[
{
"path": ".gitignore",
"chars": 16,
"preview": "*.pyc\ninterface\n"
},
{
"path": "README.md",
"chars": 6577,
"preview": "# End-to-end speech secognition toolkit\nThis is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). \nIf this repo"
},
{
"path": "__init__.py",
"chars": 202,
"preview": "\"\"\"Initialize espnet package.\"\"\"\n\nimport os\ndirname = os.path.dirname(__file__)\nversion_file = os.path.join(dirname, \"ve"
},
{
"path": "asr/__init__.py",
"chars": 30,
"preview": "\"\"\"Initialize sub package.\"\"\"\n"
},
{
"path": "asr/asr_mix_utils.py",
"chars": 6393,
"preview": "#!/usr/bin/env python3\n\n\"\"\"\nThis script is used to provide utility functions designed for multi-speaker ASR.\n\nCopyright "
},
{
"path": "asr/asr_utils.py",
"chars": 35410,
"preview": "# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)\n\n"
},
{
"path": "asr/chainer_backend/__init__.py",
"chars": 30,
"preview": "\"\"\"Initialize sub package.\"\"\"\n"
},
{
"path": "asr/chainer_backend/asr.py",
"chars": 19752,
"preview": "# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)\n"
},
{
"path": "asr/pytorch_backend/__init__.py",
"chars": 30,
"preview": "\"\"\"Initialize sub package.\"\"\"\n"
},
{
"path": "asr/pytorch_backend/asr.py",
"chars": 61350,
"preview": "# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)\n"
},
{
"path": "asr/pytorch_backend/asr_init.py",
"chars": 9329,
"preview": "\"\"\"Finetuning methods.\"\"\"\n\nimport logging\nimport os\nimport torch\n\nfrom collections import OrderedDict\n\nfrom espnet.asr.a"
},
{
"path": "asr/pytorch_backend/asr_mix.py",
"chars": 22407,
"preview": "#!/usr/bin/env python3\n\n\"\"\"\nThis script is used for multi-speaker speech recognition.\n\nCopyright 2017 Johns Hopkins Univ"
},
{
"path": "asr/pytorch_backend/recog.py",
"chars": 8626,
"preview": "\"\"\"V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.\"\"\"\n\nimport json\nimport logging\nimp"
},
{
"path": "bin/__init__.py",
"chars": 30,
"preview": "\"\"\"Initialize sub package.\"\"\"\n"
},
{
"path": "bin/asr_align.py",
"chars": 12960,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2020 Johns Hopkins University (Xuankai Chang)\n# 2020, Te"
},
{
"path": "bin/asr_enhance.py",
"chars": 5778,
"preview": "#!/usr/bin/env python3\nimport configargparse\nfrom distutils.util import strtobool\nimport logging\nimport os\nimport random"
},
{
"path": "bin/asr_recog.py",
"chars": 15495,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (ht"
},
{
"path": "bin/asr_train.py",
"chars": 23518,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2017 Tomoki Hayashi (Nagoya University)\n# Apache 2.0 (http://www"
},
{
"path": "bin/lm_train.py",
"chars": 8939,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.or"
},
{
"path": "bin/mt_train.py",
"chars": 15078,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2019 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "bin/mt_trans.py",
"chars": 5956,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2019 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "bin/st_train.py",
"chars": 17497,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2019 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "bin/st_trans.py",
"chars": 5855,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2019 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "bin/tts_decode.py",
"chars": 5768,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "bin/tts_train.py",
"chars": 10680,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "bin/vc_decode.py",
"chars": 5595,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2020 Nagoya University (Wen-Chin Huang)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "bin/vc_train.py",
"chars": 10830,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2020 Nagoya University (Wen-Chin Huang)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "egs/.gitignore",
"chars": 50,
"preview": "launch\nespnet-2021724\nsegment_aishell1\nword_ngram\n"
},
{
"path": "egs/aishell1/.gitignore",
"chars": 34,
"preview": "dump\ndump32\ndump64\ndata\nexp\nfbank\n"
},
{
"path": "egs/aishell1/aed.sh",
"chars": 6670,
"preview": "#!/usr/bin/env bash\n\n# author: tyriontian\n# tyriontian@tencent.com\n\n. ./path.sh || exit 1;\n. ./cmd.sh || exit 1;\n\n# gene"
},
{
"path": "egs/aishell1/cmd.sh",
"chars": 3645,
"preview": "# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======\n# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>\n#"
},
{
"path": "egs/aishell1/conf/fbank.conf",
"chars": 44,
"preview": "--sample-frequency=16000 \n--num-mel-bins=80\n"
},
{
"path": "egs/aishell1/conf/gpu.conf",
"chars": 380,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/aishell1/conf/lm.yaml",
"chars": 268,
"preview": "# rnnlm related\nlayer: 2\nunit: 650\nopt: sgd # or adam\nbatchsize: 64 # batch size in LM training\nepoch: 20 "
},
{
"path": "egs/aishell1/conf/lm_rnn.yaml",
"chars": 7,
"preview": "lm.yaml"
},
{
"path": "egs/aishell1/conf/lm_transformer.yaml",
"chars": 623,
"preview": "# This Transformer LM setting w/ 4 GPUs took around 60 days for 50 epochs.\n# However, you can get better results in 6 da"
},
{
"path": "egs/aishell1/conf/pitch.conf",
"chars": 25,
"preview": "--sample-frequency=16000\n"
},
{
"path": "egs/aishell1/conf/queue.conf",
"chars": 353,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/aishell1/conf/slurm.conf",
"chars": 531,
"preview": "# Default configuration\ncommand sbatch --export=PATH\noption name=* --job-name $0\noption time=* --time $0\noption mem=* --"
},
{
"path": "egs/aishell1/conf/specaug.yaml",
"chars": 322,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 5\n inplace: true\n"
},
{
"path": "egs/aishell1/conf/specaug_test.yaml",
"chars": 320,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 0\n inplace: true\n"
},
{
"path": "egs/aishell1/conf/tuning/decode_pytorch_transformer.yaml",
"chars": 123,
"preview": "batchsize: 0\nbeam-size: 10\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.5\nlm-weight: 0.0\nngram-weight: 0"
},
{
"path": "egs/aishell1/conf/tuning/decode_rnn.yaml",
"chars": 92,
"preview": "beam-size: 20\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.6\nlm-weight: 0.3\n"
},
{
"path": "egs/aishell1/conf/tuning/train_pytorch_conformer_kernel15.yaml",
"chars": 1249,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell1/conf/tuning/train_pytorch_conformer_kernel31.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell1/conf/tuning/train_pytorch_conformer_kernel31_large.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 16\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell1/conf/tuning/train_pytorch_conformer_kernel31_small.yaml",
"chars": 1227,
"preview": "# network architecture\n# encoder related\nelayers: 8\neunits: 1024\n# decoder related\ndlayers: 4\ndunits: 1024\n# attention r"
},
{
"path": "egs/aishell1/conf/tuning/train_pytorch_transformer.yaml",
"chars": 992,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell1/conf/tuning/train_rnn.yaml",
"chars": 674,
"preview": "# network architecture\n# encoder related\netype: vggblstm # encoder architecture type\nelayers: 3\neunits: 1024\neprojs:"
},
{
"path": "egs/aishell1/conf/tuning/transducer/decode_default.yaml",
"chars": 83,
"preview": "# decoding parameters\nbatch: 0\nbeam-size: 10\nsearch-type: default\nscore-norm: True\n"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer.yaml",
"chars": 1091,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4.yaml",
"chars": 1093,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4_att.yaml",
"chars": 1333,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4_small.yaml",
"chars": 1172,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer_ngpu4.yaml",
"chars": 1171,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_conformer-rnn_transducer_ngpu4_large.yaml",
"chars": 1172,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_transducer.yaml",
"chars": 642,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/conf/tuning/transducer/train_transducer_aux.yaml",
"chars": 720,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell1/local/add_lex_disambig.pl",
"chars": 6799,
"preview": "#!/usr/bin/env perl\n# Copyright 2010-2011 Microsoft Corporation\n# 2013-2016 Johns Hopkins University (auth"
},
{
"path": "egs/aishell1/local/aishell_data_prep.sh",
"chars": 2144,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Xingyu Na\n# Apache 2.0\n\n. ./path.sh || exit 1;\n\nif [ $# != 2 ]; then\n echo \"Usage"
},
{
"path": "egs/aishell1/local/aishell_train_lms.sh",
"chars": 3539,
"preview": "#!/usr/bin/env bash\n\n# To be run from one directory above this script.\n. ./path.sh\n\ntext=data/local/train/text\nlexicon=d"
},
{
"path": "egs/aishell1/local/apply_map.pl",
"chars": 2962,
"preview": "#!/usr/bin/env perl\nuse warnings; #sed replacement for -w perl parameter\n# Copyright 2012 Johns Hopkins University (Aut"
},
{
"path": "egs/aishell1/local/build_sp_text.py",
"chars": 270,
"preview": "import sys\n\nin_f = sys.argv[1]\n\nfor line in open(in_f, 'r', encoding=\"utf8\"):\n elems = line.split()\n uttid = elems"
},
{
"path": "egs/aishell1/local/build_word_mapping.py",
"chars": 533,
"preview": "# convert the attention output vocabulary into lexicon vocabulary\nimport sys\n\natt_vocab = sys.argv[1]\nlex_vocab = sys.ar"
},
{
"path": "egs/aishell1/local/compile_bigram.sh",
"chars": 410,
"preview": "# Compile char level bigram LM. for MMI training. \n# The bigram should be sparse or 4300+ words would lead to 17M arcs a"
},
{
"path": "egs/aishell1/local/download_and_untar.sh",
"chars": 2524,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2014 Johns Hopkins University (author: Daniel Povey)\n# 2017 Xingyu Na\n#"
},
{
"path": "egs/aishell1/local/fstaddselfloops.pl",
"chars": 1154,
"preview": "#!/usr/bin/env perl\n\n# Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang)\n# Apache 2.0\n\nuse strict;\nuse warnings;\n\n"
},
{
"path": "egs/aishell1/local/k2_aishell_prepare_dict.sh",
"chars": 1233,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Xingyu Na\n# Apache 2.0\n\n# prepare dict resources\n\n# . ./path.sh\n\n[ $# != 2 ] && ec"
},
{
"path": "egs/aishell1/local/k2_aishell_prepare_dict_char.sh",
"chars": 701,
"preview": "# Build character-level dict for K2 CTC / MMI \n# The token list would be very large (10k+) if we use aishell lexicon \n# "
},
{
"path": "egs/aishell1/local/k2_prepare_lang.sh",
"chars": 16860,
"preview": "#!/usr/bin/env bash\n# Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey);\n# Arnab"
},
{
"path": "egs/aishell1/local/make_lexicon_fst.py",
"chars": 19128,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Johns Hopkins University (author: Daniel Povey)\n# Apache 2.0.\n\n# see get_arg"
},
{
"path": "egs/aishell1/local/max_rescore.py",
"chars": 1075,
"preview": "import sys\nimport json\nimport codecs\nimport copy\n\njson_f = sys.argv[1]\njson_f_out = sys.argv[2]\nbest_dict_f = sys.argv[3"
},
{
"path": "egs/aishell1/local/parse_options.sh",
"chars": 3657,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);\n# Arnab Ghoshal,"
},
{
"path": "egs/aishell1/local/parse_text_jieba.py",
"chars": 557,
"preview": "import jieba\nimport sys\n\nin_f = sys.argv[1]\nout_f = sys.argv[2]\nword_dict = sys.argv[3]\n\njieba.load_userdict(word_dict)\n"
},
{
"path": "egs/aishell1/local/prepare_word_lex.py",
"chars": 742,
"preview": "import sys\n\n\"\"\"\nMake a word-level lexicon for MMI training. \nPrevious lexicon accepts phones, here this lexicon accepts "
},
{
"path": "egs/aishell1/local/sym2int.pl",
"chars": 3237,
"preview": "#!/usr/bin/env perl\n# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)\n\n# Lice"
},
{
"path": "egs/aishell1/nt.sh",
"chars": 7267,
"preview": "#!/usr/bin/env bash\n\n# author: tyriontian\n# tyriontian@tencent.com\n\n. ./path.sh || exit 1;\n. ./cmd.sh || exit 1;\n\n# gene"
},
{
"path": "egs/aishell1/path.sh",
"chars": 1156,
"preview": "# this is necessary since docker images would not run .bashrc if the command line is not \"bash\"\nsource ~/.bashrc # to in"
},
{
"path": "egs/aishell1/prepare.sh",
"chars": 10407,
"preview": "#!/usr/bin/env bash\n\n# author: tyriontian\n# tyriontian@tencent.com\n\n. ./path.sh || exit 1;\n. ./cmd.sh || exit 1;\n\n# gene"
},
{
"path": "egs/aishell2/.gitignore",
"chars": 64,
"preview": "dump\ndump32\ndump64\ndata\nexp\nfbank\nexp_without_segmentation\n_exp\n"
},
{
"path": "egs/aishell2/aed.sh",
"chars": 6881,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/l"
},
{
"path": "egs/aishell2/conf/fbank.conf",
"chars": 44,
"preview": "--sample-frequency=16000 \n--num-mel-bins=80\n"
},
{
"path": "egs/aishell2/conf/gpu.conf",
"chars": 380,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/aishell2/conf/lm.yaml",
"chars": 268,
"preview": "# rnnlm related\nlayer: 2\nunit: 650\nopt: sgd # or adam\nbatchsize: 64 # batch size in LM training\nepoch: 20 "
},
{
"path": "egs/aishell2/conf/lm_rnn.yaml",
"chars": 7,
"preview": "lm.yaml"
},
{
"path": "egs/aishell2/conf/lm_transformer.yaml",
"chars": 623,
"preview": "# This Transformer LM setting w/ 4 GPUs took around 60 days for 50 epochs.\n# However, you can get better results in 6 da"
},
{
"path": "egs/aishell2/conf/pitch.conf",
"chars": 25,
"preview": "--sample-frequency=16000\n"
},
{
"path": "egs/aishell2/conf/queue.conf",
"chars": 353,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/aishell2/conf/slurm.conf",
"chars": 531,
"preview": "# Default configuration\ncommand sbatch --export=PATH\noption name=* --job-name $0\noption time=* --time $0\noption mem=* --"
},
{
"path": "egs/aishell2/conf/specaug.yaml",
"chars": 322,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 5\n inplace: true\n"
},
{
"path": "egs/aishell2/conf/specaug_test.yaml",
"chars": 320,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 0\n inplace: true\n"
},
{
"path": "egs/aishell2/conf/tuning/decode_pytorch_transformer.yaml",
"chars": 123,
"preview": "batchsize: 0\nbeam-size: 10\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.5\nlm-weight: 0.0\nngram-weight: 0"
},
{
"path": "egs/aishell2/conf/tuning/decode_rnn.yaml",
"chars": 92,
"preview": "beam-size: 20\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.6\nlm-weight: 0.3\n"
},
{
"path": "egs/aishell2/conf/tuning/train_pytorch_conformer_kernel15.yaml",
"chars": 1249,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell2/conf/tuning/train_pytorch_conformer_kernel31.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell2/conf/tuning/train_pytorch_transformer.yaml",
"chars": 992,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/aishell2/conf/tuning/train_rnn.yaml",
"chars": 674,
"preview": "# network architecture\n# encoder related\netype: vggblstm # encoder architecture type\nelayers: 3\neunits: 1024\neprojs:"
},
{
"path": "egs/aishell2/conf/tuning/transducer/decode_default.yaml",
"chars": 83,
"preview": "# decoding parameters\nbatch: 0\nbeam-size: 10\nsearch-type: default\nscore-norm: True\n"
},
{
"path": "egs/aishell2/conf/tuning/transducer/train_conformer-rnn_transducer.yaml",
"chars": 1091,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell2/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4.yaml",
"chars": 1173,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell2/conf/tuning/transducer/train_conformer-rnn_transducer_ngpu4.yaml",
"chars": 1171,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell2/conf/tuning/transducer/train_transducer.yaml",
"chars": 642,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell2/conf/tuning/transducer/train_transducer_aux.yaml",
"chars": 720,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/aishell2/local/add_lex_disambig.pl",
"chars": 6799,
"preview": "#!/usr/bin/env perl\n# Copyright 2010-2011 Microsoft Corporation\n# 2013-2016 Johns Hopkins University (auth"
},
{
"path": "egs/aishell2/local/apply_map.pl",
"chars": 2962,
"preview": "#!/usr/bin/env perl\nuse warnings; #sed replacement for -w perl parameter\n# Copyright 2012 Johns Hopkins University (Aut"
},
{
"path": "egs/aishell2/local/fstaddselfloops.pl",
"chars": 1154,
"preview": "#!/usr/bin/env perl\n\n# Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang)\n# Apache 2.0\n\nuse strict;\nuse warnings;\n\n"
},
{
"path": "egs/aishell2/local/jieba_split_text.py",
"chars": 634,
"preview": "import jieba\nimport sys\n\nsrc_file = sys.argv[1]\ndst_file = sys.argv[2]\ndict_file = sys.argv[3]\njieba.set_dictionary(dict"
},
{
"path": "egs/aishell2/local/k2_prepare_lang.sh",
"chars": 16860,
"preview": "#!/usr/bin/env bash\n# Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey);\n# Arnab"
},
{
"path": "egs/aishell2/local/make_lexicon_fst.py",
"chars": 19128,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Johns Hopkins University (author: Daniel Povey)\n# Apache 2.0.\n\n# see get_arg"
},
{
"path": "egs/aishell2/local/max_rescore.py",
"chars": 1075,
"preview": "import sys\nimport json\nimport codecs\nimport copy\n\njson_f = sys.argv[1]\njson_f_out = sys.argv[2]\nbest_dict_f = sys.argv[3"
},
{
"path": "egs/aishell2/local/mmi_rescore.sh",
"chars": 479,
"preview": "decode_dir=$1\ndict=$2\n\nmkdir -p $decode_dir/rescore\ndir=$decode_dir/rescore\n\nmkdir -p $dir/best\npython3 local/max_rescor"
},
{
"path": "egs/aishell2/local/parse_options.sh",
"chars": 3657,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);\n# Arnab Ghoshal,"
},
{
"path": "egs/aishell2/local/prepare_data.sh",
"chars": 1861,
"preview": "#!/usr/bin/env bash\n# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)\n# 20"
},
{
"path": "egs/aishell2/local/prepare_dict.sh",
"chars": 1739,
"preview": "#!/usr/bin/env bash\n# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)\n# 20"
},
{
"path": "egs/aishell2/local/rerank.py",
"chars": 639,
"preview": "import sys\nimport json\nimport codecs\n\n\njson_f = sys.argv[1]\njson_f_out = sys.argv[3]\nweight = float(sys.argv[2])\n\nwith c"
},
{
"path": "egs/aishell2/local/sym2int.pl",
"chars": 3237,
"preview": "#!/usr/bin/env perl\n# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)\n\n# Lice"
},
{
"path": "egs/aishell2/local/train_lms.sh",
"chars": 3458,
"preview": "#!/usr/bin/env bash\n# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)\n# 20"
},
{
"path": "egs/aishell2/local/word_segmentation.py",
"chars": 703,
"preview": "# encoding=utf-8\n# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)\n# 2018 "
},
{
"path": "egs/aishell2/nt.sh",
"chars": 6974,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/l"
},
{
"path": "egs/aishell2/prepare.sh",
"chars": 12361,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/l"
},
{
"path": "egs/asrucs/.gitignore",
"chars": 35,
"preview": "dump\ndump32\ndump64\ndata\nexp*\nfbank\n"
},
{
"path": "egs/asrucs/cmd.sh",
"chars": 18,
"preview": "../aishell1/cmd.sh"
},
{
"path": "egs/asrucs/conf/decode.yaml",
"chars": 38,
"preview": "tuning/decode_pytorch_transformer.yaml"
},
{
"path": "egs/asrucs/conf/fbank.conf",
"chars": 44,
"preview": "--sample-frequency=16000 \n--num-mel-bins=80\n"
},
{
"path": "egs/asrucs/conf/gpu.conf",
"chars": 380,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/asrucs/conf/lm.yaml",
"chars": 268,
"preview": "# rnnlm related\nlayer: 2\nunit: 650\nopt: sgd # or adam\nbatchsize: 64 # batch size in LM training\nepoch: 20 "
},
{
"path": "egs/asrucs/conf/lm_rnn.yaml",
"chars": 7,
"preview": "lm.yaml"
},
{
"path": "egs/asrucs/conf/lm_transformer.yaml",
"chars": 623,
"preview": "# This Transformer LM setting w/ 4 GPUs took around 60 days for 50 epochs.\n# However, you can get better results in 6 da"
},
{
"path": "egs/asrucs/conf/pitch.conf",
"chars": 25,
"preview": "--sample-frequency=16000\n"
},
{
"path": "egs/asrucs/conf/pure_ctc.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 15\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/asrucs/conf/queue.conf",
"chars": 353,
"preview": "# Default configuration\ncommand qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*\noption mem=* -l mem_free=$0,ram_free=$0"
},
{
"path": "egs/asrucs/conf/slurm.conf",
"chars": 531,
"preview": "# Default configuration\ncommand sbatch --export=PATH\noption name=* --job-name $0\noption time=* --time $0\noption mem=* --"
},
{
"path": "egs/asrucs/conf/specaug.yaml",
"chars": 322,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 5\n inplace: true\n"
},
{
"path": "egs/asrucs/conf/specaug_test.yaml",
"chars": 320,
"preview": "process:\n # these three processes are a.k.a. SpecAugument\n - type: \"time_warp\"\n max_time_warp: 0\n inplace: true\n"
},
{
"path": "egs/asrucs/conf/train.yaml",
"chars": 44,
"preview": "tuning/train_pytorch_conformer_kernel15.yaml"
},
{
"path": "egs/asrucs/conf/train_conformer-rnn_transducer_cs.yaml",
"chars": 1041,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/decode_pytorch_transformer.yaml",
"chars": 123,
"preview": "batchsize: 0\nbeam-size: 10\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.5\nlm-weight: 0.0\nngram-weight: 0"
},
{
"path": "egs/asrucs/conf/tuning/decode_rnn.yaml",
"chars": 92,
"preview": "beam-size: 20\npenalty: 0.0\nmaxlenratio: 0.0\nminlenratio: 0.0\nctc-weight: 0.6\nlm-weight: 0.3\n"
},
{
"path": "egs/asrucs/conf/tuning/train_pytorch_conformer_kernel15.yaml",
"chars": 1249,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/asrucs/conf/tuning/train_pytorch_conformer_kernel31.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/asrucs/conf/tuning/train_pytorch_conformer_kernel31_large.yaml",
"chars": 1228,
"preview": "# network architecture\n# encoder related\nelayers: 16\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/asrucs/conf/tuning/train_pytorch_conformer_kernel31_small.yaml",
"chars": 1227,
"preview": "# network architecture\n# encoder related\nelayers: 8\neunits: 1024\n# decoder related\ndlayers: 4\ndunits: 1024\n# attention r"
},
{
"path": "egs/asrucs/conf/tuning/train_pytorch_transformer.yaml",
"chars": 992,
"preview": "# network architecture\n# encoder related\nelayers: 12\neunits: 2048\n# decoder related\ndlayers: 6\ndunits: 2048\n# attention "
},
{
"path": "egs/asrucs/conf/tuning/train_rnn.yaml",
"chars": 674,
"preview": "# network architecture\n# encoder related\netype: vggblstm # encoder architecture type\nelayers: 3\neunits: 1024\neprojs:"
},
{
"path": "egs/asrucs/conf/tuning/transducer/decode_default.yaml",
"chars": 83,
"preview": "# decoding parameters\nbatch: 0\nbeam-size: 10\nsearch-type: default\nscore-norm: True\n"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer.yaml",
"chars": 1091,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4.yaml",
"chars": 1093,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4_att.yaml",
"chars": 1333,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer_aux_ngpu4_small.yaml",
"chars": 1172,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer_ngpu4.yaml",
"chars": 1171,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_conformer-rnn_transducer_ngpu4_large.yaml",
"chars": 1172,
"preview": "# minibatch related\nbatch-size: 32\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_transducer.yaml",
"chars": 642,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/conf/tuning/transducer/train_transducer_aux.yaml",
"chars": 720,
"preview": "# minibatch related\nbatch-size: 64\nmaxlen-in: 512\nmaxlen-out: 150\n\n# optimization related\ncriterion: loss\nearly-stop-cri"
},
{
"path": "egs/asrucs/espnet",
"chars": 27,
"preview": "../../../E2E-ASR-Framework/"
},
{
"path": "egs/asrucs/espnet_utils",
"chars": 16,
"preview": "../espnet_utils/"
},
{
"path": "egs/asrucs/local/add_seperator.py",
"chars": 573,
"preview": "import sys\n\ndef is_all_chinese(strs):\n for _char in strs:\n if not '\\u4e00' <= _char <= '\\u9fa5':\n r"
},
{
"path": "egs/asrucs/local/generate_fake_cs.py",
"chars": 2751,
"preview": "import sys\nimport torch\nimport random\nimport torchaudio\n\nMAX_LENGTH = 12 * 100 # 12 seconds\n\n\ndef read_datadir(d):\n w"
},
{
"path": "egs/asrucs/local/prepare_fake_cs.sh",
"chars": 1534,
"preview": "cs_dir=data/train_cs_fake/\n# generate wav files\n# python3 local/generate_fake_cs.py data/train_zh_trim/ data/train_en_tr"
},
{
"path": "egs/asrucs/nt.sh",
"chars": 7331,
"preview": "#!/usr/bin/env bash\n\n# author: tyriontian\n# tyriontian@tencent.com\n\n. ./path.sh || exit 1;\n. ./cmd.sh || exit 1;\n\n# gene"
},
{
"path": "egs/asrucs/path.sh",
"chars": 19,
"preview": "../aishell1/path.sh"
},
{
"path": "egs/asrucs/prepare.sh",
"chars": 10766,
"preview": "#!/usr/bin/env bash\n\n# author: tyriontian\n# tianjinchuan@stu.pku.edu.cn ; tyriontian@tencent.com\n\n# A Code-Switch ASR re"
},
{
"path": "egs/asrucs/steps",
"chars": 9,
"preview": "../steps/"
},
{
"path": "egs/asrucs/text",
"chars": 0,
"preview": ""
},
{
"path": "egs/asrucs/utils",
"chars": 9,
"preview": "../utils/"
},
{
"path": "egs/espnet_utils/add_uttcls_json.py",
"chars": 559,
"preview": "import json\nimport sys\n\ndef main():\n in_json = sys.argv[1]\n out_json = sys.argv[2]\n clsid = sys.argv[3]\n\n re"
},
{
"path": "egs/espnet_utils/addjson.py",
"chars": 4828,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www"
},
{
"path": "egs/espnet_utils/apply-cmvn.py",
"chars": 4771,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom distutils.util import strtobool\nimport logging\n\nimport kaldiio\nimport numpy\n"
},
{
"path": "egs/espnet_utils/asr_align_wav.sh",
"chars": 9110,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2020 Johns Hopkins University (Xuankai Chang)\n# 2020 Technische Universität München, Au"
},
{
"path": "egs/espnet_utils/average_checkpoints.py",
"chars": 5184,
"preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport json\nimport os\n\nimport numpy as np\n\n\ndef main():\n"
},
{
"path": "egs/espnet_utils/build_fake_lexicon.py",
"chars": 330,
"preview": "import sys\nimport os\n\norig_lexicon = sys.argv[1]\n\nfor line in open(orig_lexicon, encoding=\"utf-8\"):\n word = line.stri"
},
{
"path": "egs/espnet_utils/build_sp_text.py",
"chars": 270,
"preview": "import sys\n\nin_f = sys.argv[1]\n\nfor line in open(in_f, 'r', encoding=\"utf8\"):\n elems = line.split()\n uttid = elems"
},
{
"path": "egs/espnet_utils/calculate_rtf.py",
"chars": 1977,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2021 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "egs/espnet_utils/change_root.py",
"chars": 1532,
"preview": "# author: tyriontian\n# tyriontian@tencent.com\n\n# this script is to change the root dir of some data files, like dump dir"
},
{
"path": "egs/espnet_utils/change_yaml.py",
"chars": 3548,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom pathlib import Path\n\nimport yaml\n\n\ndef get_parser():\n parser = argparse.A"
},
{
"path": "egs/espnet_utils/clean_corpus.sh",
"chars": 1787,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2021 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://www.apache.org/licenses"
},
{
"path": "egs/espnet_utils/compute-cmvn-stats.py",
"chars": 6254,
"preview": "#!/usr/bin/env python3\nimport argparse\nimport logging\n\nimport kaldiio\nimport numpy as np\n\nfrom espnet.transform.transfor"
},
{
"path": "egs/espnet_utils/compute-fbank-feats.py",
"chars": 4403,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "egs/espnet_utils/compute-stft-feats.py",
"chars": 3885,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "egs/espnet_utils/concat_json_multiref.py",
"chars": 1811,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2018 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://ww"
},
{
"path": "egs/espnet_utils/concatjson.py",
"chars": 2246,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (ht"
},
{
"path": "egs/espnet_utils/convert_fbank.sh",
"chars": 2355,
"preview": "#!/usr/bin/env bash\n# Set bash to 'debug' mode, it will exit on :\n# -e 'error', -u 'undefined variable', -o ... 'error i"
},
{
"path": "egs/espnet_utils/convert_fbank_to_wav.py",
"chars": 6013,
"preview": "#!/usr/bin/env python3\n\n# Copyright 2018 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licens"
},
{
"path": "egs/espnet_utils/copy-feats.py",
"chars": 3410,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom distutils.util import strtobool\nimport logging\n\nfrom espnet.transform.transf"
},
{
"path": "egs/espnet_utils/data2json.sh",
"chars": 5028,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Johns Hopkins University (Shinji Watanabe)\n# Apache 2.0 (http://www.apache.org/l"
},
{
"path": "egs/espnet_utils/divide_lang.sh",
"chars": 1130,
"preview": "#!/bin/bash\n\n# Copyright 2021 Kyoto University (Hirofumi Inaguma)\n# Apache 2.0 (http://www.apache.org/licenses/LICENSE"
},
{
"path": "egs/espnet_utils/double_precious_cer.py",
"chars": 300,
"preview": "import sys\n\nin_f = sys.argv[1]\n\nfor line in open(in_f, encoding=\"utf-8\"):\n if \"Sum\" in line and \"|\" in line and \"Avg\""
},
{
"path": "egs/espnet_utils/download_from_google_drive.sh",
"chars": 1603,
"preview": "#!/usr/bin/env bash\n\n# Download zip, tar, or tar.gz file from google drive\n\n# Copyright 2019 Tomoki Hayashi\n# Apache 2."
},
{
"path": "egs/espnet_utils/dump-pcm.py",
"chars": 4869,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom distutils.util import strtobool\nimport logging\n\nimport kaldiio\nimport numpy\n"
},
{
"path": "egs/espnet_utils/dump.sh",
"chars": 2650,
"preview": "#!/usr/bin/env bash\n\n# Copyright 2017 Nagoya University (Tomoki Hayashi)\n# Apache 2.0 (http://www.apache.org/licenses/"
},
{
"path": "egs/espnet_utils/dump_pcm.sh",
"chars": 3948,
"preview": "#!/usr/bin/env bash\n\n# Begin configuration section.\nnj=4\ncmd=run.pl\ncompress=false\nwrite_utt2num_frames=false # if true "
},
{
"path": "egs/espnet_utils/eval-source-separation.py",
"chars": 14517,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom collections import OrderedDict\nfrom distutils.util import strtobool\nimport i"
},
{
"path": "egs/espnet_utils/eval_perm_free_error.py",
"chars": 6548,
"preview": "#!/usr/bin/env python3\n# encoding: utf-8\n\n# Copyright 2019 Johns Hopkins University (Xuankai Chang)\n# Apache 2.0 (http"
},
{
"path": "egs/espnet_utils/eval_source_separation.sh",
"chars": 1838,
"preview": "#!/usr/bin/env bash\n\necho \"$0 $*\" >&2 # Print the command line for logging\n\nnj=10\ncmd=run.pl\nevaltypes=\"SDR STOI ESTOI P"
},
{
"path": "egs/espnet_utils/feat-to-shape.py",
"chars": 2800,
"preview": "#!/usr/bin/env python3\nimport argparse\nimport logging\nimport sys\n\nfrom espnet.transform.transformation import Transforma"
},
{
"path": "egs/espnet_utils/feat_to_shape.sh",
"chars": 1786,
"preview": "#!/usr/bin/env bash\n\n# Begin configuration section.\nnj=188\ncmd=run.pl\nverbose=0\nfiletype=\"\"\npreprocess_conf=\"\"\n# End con"
},
{
"path": "egs/espnet_utils/feats2npy.py",
"chars": 837,
"preview": "#!/usr/bin/env python\n# coding: utf-8\n\nimport argparse\nfrom kaldiio import ReadHelper\nimport numpy as np\nimport os\nfrom"
},
{
"path": "egs/espnet_utils/filt.py",
"chars": 1745,
"preview": "#!/usr/bin/env python3\n\n# Apache 2.0\n\nimport argparse\nimport codecs\nimport sys\n\nis_python2 = sys.version_info[0] == 2\n\n\n"
},
{
"path": "egs/espnet_utils/filter_all_eng_utts.py",
"chars": 530,
"preview": "import sys\n\ndef is_all_chinese(strs):\n for _char in strs:\n if not '\\u4e00' <= _char <= '\\u9fa5':\n r"
},
{
"path": "egs/espnet_utils/filter_scp.py",
"chars": 428,
"preview": "import sys\n\nref_f = sys.argv[1]\nin_f = sys.argv[2]\n\n# output is in the order of ref_f\nref = []\nfor line in open(ref_f, e"
},
{
"path": "egs/espnet_utils/filter_trn.py",
"chars": 198,
"preview": "# this is to process the hyp.trn of sppd3\nimport sys\n\nin_f = sys.argv[1]\nignore = \"叮 当 叮 当 \"\n\nfor line in open(in_f, enc"
}
]
// ... and 903 more files (download for full content)
About this extraction
This page contains the full source code of the jctian98/e2e_lfmmi GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 1103 files (7.8 MB), approximately 2.1M tokens, and a symbol index with 3493 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.