Repository: stanfordnlp/stanza
Branch: main
Commit: 516b07140fdf
Files: 579
Total size: 3.8 MB
Directory structure:
gitextract_z9hqe0ws/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ ├── feature_request.md
│ │ └── question.md
│ ├── pull_request_template.md
│ ├── stale.yml
│ └── workflows/
│ └── stanza-tests.yaml
├── .gitignore
├── .travis.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── demo/
│ ├── CONLL_Dependency_Visualizer_Example.ipynb
│ ├── Dependency_Visualization_Testing.ipynb
│ ├── NER_Visualization.ipynb
│ ├── Stanza_Beginners_Guide.ipynb
│ ├── Stanza_CoreNLP_Interface.ipynb
│ ├── arabic_test.conllu.txt
│ ├── corenlp.py
│ ├── en_test.conllu.txt
│ ├── japanese_test.conllu.txt
│ ├── pipeline_demo.py
│ ├── scenegraph.py
│ ├── semgrex visualization.ipynb
│ ├── semgrex.py
│ └── ssurgeon_script.txt
├── doc/
│ └── CoreNLP.proto
├── scripts/
│ ├── config.sh
│ └── download_vectors.sh
├── setup.py
└── stanza/
├── __init__.py
├── _version.py
├── models/
│ ├── __init__.py
│ ├── _training_logging.py
│ ├── charlm.py
│ ├── classifier.py
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── base_classifier.py
│ │ ├── cnn_classifier.py
│ │ ├── config.py
│ │ ├── constituency_classifier.py
│ │ ├── data.py
│ │ ├── iterate_test.py
│ │ ├── trainer.py
│ │ └── utils.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── beam.py
│ │ ├── bert_embedding.py
│ │ ├── biaffine.py
│ │ ├── build_short_name_to_treebank.py
│ │ ├── char_model.py
│ │ ├── chuliu_edmonds.py
│ │ ├── constant.py
│ │ ├── convert_pretrain.py
│ │ ├── count_ner_coverage.py
│ │ ├── count_pretrain_coverage.py
│ │ ├── crf.py
│ │ ├── data.py
│ │ ├── doc.py
│ │ ├── dropout.py
│ │ ├── exceptions.py
│ │ ├── foundation_cache.py
│ │ ├── hlstm.py
│ │ ├── large_margin_loss.py
│ │ ├── loss.py
│ │ ├── maxout_linear.py
│ │ ├── packed_lstm.py
│ │ ├── peft_config.py
│ │ ├── pretrain.py
│ │ ├── relative_attn.py
│ │ ├── seq2seq_constant.py
│ │ ├── seq2seq_model.py
│ │ ├── seq2seq_modules.py
│ │ ├── seq2seq_utils.py
│ │ ├── short_name_to_treebank.py
│ │ ├── stanza_object.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── vocab.py
│ ├── constituency/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── base_trainer.py
│ │ ├── dynamic_oracle.py
│ │ ├── ensemble.py
│ │ ├── error_analysis_in_order.py
│ │ ├── evaluate_treebanks.py
│ │ ├── in_order_compound_oracle.py
│ │ ├── in_order_oracle.py
│ │ ├── label_attention.py
│ │ ├── lstm_model.py
│ │ ├── lstm_tree_stack.py
│ │ ├── parse_transitions.py
│ │ ├── parse_tree.py
│ │ ├── parser_training.py
│ │ ├── partitioned_transformer.py
│ │ ├── positional_encoding.py
│ │ ├── retagging.py
│ │ ├── score_converted_dependencies.py
│ │ ├── state.py
│ │ ├── text_processing.py
│ │ ├── top_down_oracle.py
│ │ ├── trainer.py
│ │ ├── transformer_tree_stack.py
│ │ ├── transition_sequence.py
│ │ ├── tree_embedding.py
│ │ ├── tree_reader.py
│ │ ├── tree_stack.py
│ │ └── utils.py
│ ├── constituency_parser.py
│ ├── coref/
│ │ ├── __init__.py
│ │ ├── anaphoricity_scorer.py
│ │ ├── bert.py
│ │ ├── cluster_checker.py
│ │ ├── config.py
│ │ ├── conll.py
│ │ ├── const.py
│ │ ├── coref_chain.py
│ │ ├── coref_config.toml
│ │ ├── dataset.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── pairwise_encoder.py
│ │ ├── predict.py
│ │ ├── rough_scorer.py
│ │ ├── span_predictor.py
│ │ ├── tokenizer_customization.py
│ │ ├── utils.py
│ │ └── word_encoder.py
│ ├── depparse/
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── model.py
│ │ ├── scorer.py
│ │ └── trainer.py
│ ├── identity_lemmatizer.py
│ ├── lang_identifier.py
│ ├── langid/
│ │ ├── __init__.py
│ │ ├── create_ud_data.py
│ │ ├── data.py
│ │ ├── model.py
│ │ └── trainer.py
│ ├── lemma/
│ │ ├── __init__.py
│ │ ├── attach_lemma_classifier.py
│ │ ├── data.py
│ │ ├── edit.py
│ │ ├── scorer.py
│ │ ├── trainer.py
│ │ └── vocab.py
│ ├── lemma_classifier/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── base_trainer.py
│ │ ├── baseline_model.py
│ │ ├── constants.py
│ │ ├── evaluate_many.py
│ │ ├── evaluate_models.py
│ │ ├── lstm_model.py
│ │ ├── prepare_dataset.py
│ │ ├── train_lstm_model.py
│ │ ├── train_many.py
│ │ ├── train_transformer_model.py
│ │ ├── transformer_model.py
│ │ └── utils.py
│ ├── lemmatizer.py
│ ├── mwt/
│ │ ├── __init__.py
│ │ ├── character_classifier.py
│ │ ├── data.py
│ │ ├── scorer.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── vocab.py
│ ├── mwt_expander.py
│ ├── ner/
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── model.py
│ │ ├── scorer.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── vocab.py
│ ├── ner_tagger.py
│ ├── parser.py
│ ├── pos/
│ │ ├── __init__.py
│ │ ├── build_xpos_vocab_factory.py
│ │ ├── data.py
│ │ ├── model.py
│ │ ├── scorer.py
│ │ ├── trainer.py
│ │ ├── vocab.py
│ │ ├── xpos_vocab_factory.py
│ │ └── xpos_vocab_utils.py
│ ├── tagger.py
│ ├── tokenization/
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── model.py
│ │ ├── tokenize_files.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── vocab.py
│ ├── tokenizer.py
│ └── wl_coref.py
├── pipeline/
│ ├── __init__.py
│ ├── _constants.py
│ ├── constituency_processor.py
│ ├── core.py
│ ├── coref_processor.py
│ ├── demo/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── demo_server.py
│ │ ├── stanza-brat.css
│ │ ├── stanza-brat.html
│ │ ├── stanza-brat.js
│ │ └── stanza-parseviewer.js
│ ├── depparse_processor.py
│ ├── external/
│ │ ├── __init__.py
│ │ ├── corenlp_converter_depparse.py
│ │ ├── jieba.py
│ │ ├── pythainlp.py
│ │ ├── spacy.py
│ │ └── sudachipy.py
│ ├── langid_processor.py
│ ├── lemma_processor.py
│ ├── morphseg_processor.py
│ ├── multilingual.py
│ ├── mwt_processor.py
│ ├── ner_processor.py
│ ├── pos_processor.py
│ ├── processor.py
│ ├── registry.py
│ ├── sentiment_processor.py
│ └── tokenize_processor.py
├── protobuf/
│ ├── CoreNLP_pb2.py
│ └── __init__.py
├── resources/
│ ├── __init__.py
│ ├── common.py
│ ├── default_packages.py
│ ├── installation.py
│ ├── prepare_resources.py
│ └── print_charlm_depparse.py
├── server/
│ ├── __init__.py
│ ├── annotator.py
│ ├── client.py
│ ├── dependency_converter.py
│ ├── java_protobuf_requests.py
│ ├── main.py
│ ├── morphology.py
│ ├── parser_eval.py
│ ├── semgrex.py
│ ├── ssurgeon.py
│ ├── tokensregex.py
│ ├── tsurgeon.py
│ └── ud_enhancer.py
├── tests/
│ ├── __init__.py
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── test_classifier.py
│ │ ├── test_constituency_classifier.py
│ │ ├── test_data.py
│ │ └── test_process_utils.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── test_bert_embedding.py
│ │ ├── test_char_model.py
│ │ ├── test_chuliu_edmonds.py
│ │ ├── test_common_data.py
│ │ ├── test_confusion.py
│ │ ├── test_constant.py
│ │ ├── test_data_conversion.py
│ │ ├── test_data_objects.py
│ │ ├── test_doc.py
│ │ ├── test_dropout.py
│ │ ├── test_foundation_cache.py
│ │ ├── test_pretrain.py
│ │ ├── test_relative_attn.py
│ │ ├── test_short_name_to_treebank.py
│ │ └── test_utils.py
│ ├── constituency/
│ │ ├── __init__.py
│ │ ├── test_convert_arboretum.py
│ │ ├── test_convert_it_vit.py
│ │ ├── test_convert_starlang.py
│ │ ├── test_ensemble.py
│ │ ├── test_in_order_compound_oracle.py
│ │ ├── test_in_order_oracle.py
│ │ ├── test_lstm_model.py
│ │ ├── test_parse_transitions.py
│ │ ├── test_parse_tree.py
│ │ ├── test_positional_encoding.py
│ │ ├── test_selftrain_vi_quad.py
│ │ ├── test_text_processing.py
│ │ ├── test_top_down_oracle.py
│ │ ├── test_trainer.py
│ │ ├── test_transformer_tree_stack.py
│ │ ├── test_transition_sequence.py
│ │ ├── test_tree_reader.py
│ │ ├── test_tree_stack.py
│ │ ├── test_utils.py
│ │ └── test_vietnamese.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── coref/
│ │ │ ├── __init__.py
│ │ │ └── test_hebrew_iahlt.py
│ │ ├── ner/
│ │ │ ├── __init__.py
│ │ │ ├── test_prepare_ner_file.py
│ │ │ └── test_utils.py
│ │ ├── test_common.py
│ │ └── test_vietnamese_renormalization.py
│ ├── depparse/
│ │ ├── __init__.py
│ │ ├── test_depparse_data.py
│ │ └── test_parser.py
│ ├── langid/
│ │ ├── __init__.py
│ │ ├── test_langid.py
│ │ └── test_multilingual.py
│ ├── lemma/
│ │ ├── __init__.py
│ │ ├── test_data.py
│ │ ├── test_lemma_trainer.py
│ │ └── test_lowercase.py
│ ├── lemma_classifier/
│ │ ├── __init__.py
│ │ ├── test_data_preparation.py
│ │ └── test_training.py
│ ├── morphseg/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_integration.py
│ │ ├── test_morpheme_segmenter.py
│ │ └── test_stanza_integration.py
│ ├── mwt/
│ │ ├── __init__.py
│ │ ├── test_character_classifier.py
│ │ ├── test_english_corner_cases.py
│ │ ├── test_prepare_mwt.py
│ │ └── test_utils.py
│ ├── ner/
│ │ ├── __init__.py
│ │ ├── test_bsf_2_beios.py
│ │ ├── test_bsf_2_iob.py
│ │ ├── test_combine_ner_datasets.py
│ │ ├── test_convert_amt.py
│ │ ├── test_convert_nkjp.py
│ │ ├── test_convert_starlang_ner.py
│ │ ├── test_data.py
│ │ ├── test_from_conllu.py
│ │ ├── test_models_ner_scorer.py
│ │ ├── test_ner_tagger.py
│ │ ├── test_ner_trainer.py
│ │ ├── test_ner_training.py
│ │ ├── test_ner_utils.py
│ │ ├── test_pay_amt_annotators.py
│ │ ├── test_split_wikiner.py
│ │ └── test_suc3.py
│ ├── pipeline/
│ │ ├── __init__.py
│ │ ├── pipeline_device_tests.py
│ │ ├── test_arabic_pipeline.py
│ │ ├── test_core.py
│ │ ├── test_decorators.py
│ │ ├── test_depparse.py
│ │ ├── test_english_pipeline.py
│ │ ├── test_french_pipeline.py
│ │ ├── test_lemmatizer.py
│ │ ├── test_pipeline_constituency_processor.py
│ │ ├── test_pipeline_depparse_processor.py
│ │ ├── test_pipeline_mwt_expander.py
│ │ ├── test_pipeline_ner_processor.py
│ │ ├── test_pipeline_pos_processor.py
│ │ ├── test_pipeline_sentiment_processor.py
│ │ ├── test_requirements.py
│ │ └── test_tokenizer.py
│ ├── pos/
│ │ ├── __init__.py
│ │ ├── test_data.py
│ │ ├── test_tagger.py
│ │ └── test_xpos_vocab_factory.py
│ ├── pytest.ini
│ ├── resources/
│ │ ├── __init__.py
│ │ ├── test_charlm_depparse.py
│ │ ├── test_common.py
│ │ ├── test_default_packages.py
│ │ ├── test_installation.py
│ │ └── test_prepare_resources.py
│ ├── server/
│ │ ├── __init__.py
│ │ ├── test_client.py
│ │ ├── test_java_protobuf_requests.py
│ │ ├── test_morphology.py
│ │ ├── test_parser_eval.py
│ │ ├── test_protobuf.py
│ │ ├── test_semgrex.py
│ │ ├── test_server_misc.py
│ │ ├── test_server_pretokenized.py
│ │ ├── test_server_request.py
│ │ ├── test_server_start.py
│ │ ├── test_ssurgeon.py
│ │ ├── test_tokensregex.py
│ │ ├── test_tsurgeon.py
│ │ └── test_ud_enhancer.py
│ ├── setup.py
│ └── tokenization/
│ ├── __init__.py
│ ├── test_prepare_tokenizer_treebank.py
│ ├── test_replace_long_tokens.py
│ ├── test_spaces.py
│ ├── test_tokenization_lst20.py
│ ├── test_tokenization_orchid.py
│ ├── test_tokenize_data.py
│ ├── test_tokenize_files.py
│ ├── test_tokenize_utils.py
│ └── test_vocab.py
└── utils/
├── __init__.py
├── avg_sent_len.py
├── charlm/
│ ├── __init__.py
│ ├── conll17_to_text.py
│ ├── dump_oscar.py
│ ├── make_lm_data.py
│ └── oscar_to_text.py
├── confusion.py
├── conll.py
├── constituency/
│ ├── __init__.py
│ ├── check_transitions.py
│ ├── grep_dev_logs.py
│ ├── grep_test_logs.py
│ └── list_tensors.py
├── datasets/
│ ├── __init__.py
│ ├── common.py
│ ├── conllu_to_text.py
│ ├── constituency/
│ │ ├── __init__.py
│ │ ├── build_silver_dataset.py
│ │ ├── common_trees.py
│ │ ├── convert_alt.py
│ │ ├── convert_arboretum.py
│ │ ├── convert_cintil.py
│ │ ├── convert_ctb.py
│ │ ├── convert_icepahc.py
│ │ ├── convert_it_turin.py
│ │ ├── convert_it_vit.py
│ │ ├── convert_spmrl.py
│ │ ├── convert_starlang.py
│ │ ├── count_common_words.py
│ │ ├── extract_all_silver_dataset.py
│ │ ├── extract_silver_dataset.py
│ │ ├── prepare_con_dataset.py
│ │ ├── reduce_dataset.py
│ │ ├── relabel_tags.py
│ │ ├── selftrain.py
│ │ ├── selftrain_it.py
│ │ ├── selftrain_single_file.py
│ │ ├── selftrain_vi_quad.py
│ │ ├── selftrain_wiki.py
│ │ ├── silver_variance.py
│ │ ├── split_holdout.py
│ │ ├── split_weighted_ensemble.py
│ │ ├── tokenize_wiki.py
│ │ ├── treebank_to_labeled_brackets.py
│ │ ├── utils.py
│ │ ├── vtb_convert.py
│ │ └── vtb_split.py
│ ├── contract_mwt.py
│ ├── coref/
│ │ ├── __init__.py
│ │ ├── balance_languages.py
│ │ ├── convert_hebrew_iahlt.py
│ │ ├── convert_hebrew_mixed.py
│ │ ├── convert_hindi.py
│ │ ├── convert_ontonotes.py
│ │ ├── convert_tamil.py
│ │ ├── convert_udcoref.py
│ │ ├── convert_udcoref_1.2.py
│ │ └── utils.py
│ ├── corenlp_segmenter_dataset.py
│ ├── depparse/
│ │ └── check_results.py
│ ├── ner/
│ │ ├── __init__.py
│ │ ├── build_en_combined.py
│ │ ├── check_for_duplicates.py
│ │ ├── combine_ner_datasets.py
│ │ ├── compare_entities.py
│ │ ├── conll_to_iob.py
│ │ ├── convert_amt.py
│ │ ├── convert_ar_aqmar.py
│ │ ├── convert_bn_daffodil.py
│ │ ├── convert_bsf_to_beios.py
│ │ ├── convert_bsnlp.py
│ │ ├── convert_en_conll03.py
│ │ ├── convert_fire_2013.py
│ │ ├── convert_he_iahlt.py
│ │ ├── convert_hy_armtdp.py
│ │ ├── convert_ijc.py
│ │ ├── convert_kk_kazNERD.py
│ │ ├── convert_lst20.py
│ │ ├── convert_mr_l3cube.py
│ │ ├── convert_my_ucsy.py
│ │ ├── convert_nkjp.py
│ │ ├── convert_nner22.py
│ │ ├── convert_nytk.py
│ │ ├── convert_ontonotes.py
│ │ ├── convert_rgai.py
│ │ ├── convert_sindhi_siner.py
│ │ ├── convert_starlang_ner.py
│ │ ├── count_entities.py
│ │ ├── json_to_bio.py
│ │ ├── misc_to_date.py
│ │ ├── ontonotes_multitag.py
│ │ ├── prepare_ner_dataset.py
│ │ ├── prepare_ner_file.py
│ │ ├── preprocess_wikiner.py
│ │ ├── simplify_en_worldwide.py
│ │ ├── simplify_ontonotes_to_worldwide.py
│ │ ├── split_wikiner.py
│ │ ├── suc_conll_to_iob.py
│ │ ├── suc_to_iob.py
│ │ └── utils.py
│ ├── pos/
│ │ ├── __init__.py
│ │ ├── convert_trees_to_pos.py
│ │ └── remove_columns.py
│ ├── prepare_depparse_treebank.py
│ ├── prepare_lemma_classifier.py
│ ├── prepare_lemma_treebank.py
│ ├── prepare_mwt_treebank.py
│ ├── prepare_pos_treebank.py
│ ├── prepare_tokenizer_data.py
│ ├── prepare_tokenizer_treebank.py
│ ├── pretrain/
│ │ ├── __init__.py
│ │ └── word_in_pretrain.py
│ ├── random_split_conllu.py
│ ├── sentiment/
│ │ ├── __init__.py
│ │ ├── add_constituency.py
│ │ ├── convert_italian_poetry_classification.py
│ │ ├── convert_italian_sentence_classification.py
│ │ ├── prepare_sentiment_dataset.py
│ │ ├── process_MELD.py
│ │ ├── process_airline.py
│ │ ├── process_arguana_xml.py
│ │ ├── process_corona.py
│ │ ├── process_es_tass2020.py
│ │ ├── process_it_sentipolc16.py
│ │ ├── process_ren_chinese.py
│ │ ├── process_sb10k.py
│ │ ├── process_scare.py
│ │ ├── process_slsd.py
│ │ ├── process_sst.py
│ │ ├── process_usage_german.py
│ │ ├── process_utils.py
│ │ └── process_vsfc_vietnamese.py
│ ├── thai_syllable_dict_generator.py
│ ├── tokenization/
│ │ ├── __init__.py
│ │ ├── convert_ml_cochin.py
│ │ ├── convert_my_alt.py
│ │ ├── convert_text_files.py
│ │ ├── convert_th_best.py
│ │ ├── convert_th_lst20.py
│ │ ├── convert_th_orchid.py
│ │ ├── convert_vi_vlsp.py
│ │ └── process_thai_tokenization.py
│ └── vietnamese/
│ ├── __init__.py
│ └── renormalize.py
├── default_paths.py
├── get_tqdm.py
├── helper_func.py
├── languages/
│ ├── __init__.py
│ └── kazakh_transliteration.py
├── lemma/
│ ├── __init__.py
│ └── count_ambiguous_lemmas.py
├── max_mwt_length.py
├── ner/
│ ├── __init__.py
│ ├── flair_ner_tag_dataset.py
│ ├── paying_annotators.py
│ └── spacy_ner_tag_dataset.py
├── pretrain/
│ ├── __init__.py
│ └── compare_pretrains.py
├── select_backoff.py
├── training/
│ ├── __init__.py
│ ├── common.py
│ ├── compose_ete_results.py
│ ├── remove_constituency_optimizer.py
│ ├── run_charlm.py
│ ├── run_constituency.py
│ ├── run_depparse.py
│ ├── run_ete.py
│ ├── run_lemma.py
│ ├── run_lemma_classifier.py
│ ├── run_mwt.py
│ ├── run_ner.py
│ ├── run_pos.py
│ ├── run_sentiment.py
│ ├── run_tokenizer.py
│ └── separate_ner_pretrain.py
└── visualization/
├── README
├── __init__.py
├── conll_deprel_visualization.py
├── constants.py
├── dependency_visualization.py
├── ner_visualization.py
├── semgrex_app.py
├── semgrex_visualizer.py
├── ssurgeon_visualizer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment (please complete the following information):**
- OS: [e.g. Windows, Ubuntu, CentOS, MacOS]
- Python version: [e.g. Python 3.6.8 from Anaconda]
- Stanza version: [e.g., 1.0.0]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .github/ISSUE_TEMPLATE/question.md
================================================
---
name: Question
about: 'Question about general usage. '
title: "[QUESTION]"
labels: question
assignees: ''
---
Before you start, make sure to check out:
* Our documentation: https://stanfordnlp.github.io/stanza/
* Our FAQ: https://stanfordnlp.github.io/stanza/faq.html
* Github issues (especially closed ones)
Your question might have an answer in these places!
If you still couldn't find the answer to your question, feel free to delete this text and write down your question. The more information you provide with your question, the faster we will be able to help you!
If you have a question about an issue you're facing when using Stanza, please try to provide a detailed step-by-step guide to reproduce the issue you're facing. Try to at least provide a minimal code sample to reproduce the problem you are facing, instead of just describing it. That would greatly help us in locating the issue faster and help you resolve it!
================================================
FILE: .github/pull_request_template.md
================================================
**BEFORE YOU START**: please make sure your pull request is against the `dev` branch.
We cannot accept pull requests against the `main` branch.
See our [contributing guide](https://github.com/stanfordnlp/stanza/blob/main/CONTRIBUTING.md) for details.
## Description
A brief and concise description of what your pull request is trying to accomplish.
## Fixes Issues
A list of issues/bugs with # references. (e.g., #123)
## Unit test coverage
Are there unit tests in place to make sure your code is functioning correctly?
(see [here](https://github.com/stanfordnlp/stanza/blob/master/tests/test_tagger.py) for a simple example)
## Known breaking changes/behaviors
Does this break anything in Stanza's existing user interface? If so, what is it and how is it addressed?
================================================
FILE: .github/stale.yml
================================================
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- fixed on dev
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: >
This issue has been automatically closed due to inactivity.
================================================
FILE: .github/workflows/stanza-tests.yaml
================================================
name: Run Stanza Tests
on: [push]
jobs:
Run-Stanza-Tests:
runs-on: self-hosted
steps:
- run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!"
- run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
- name: Check out repository code
uses: actions/checkout@v2
- run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner."
- run: echo "🖥️ The workflow is now ready to test your code on the runner."
- name: Run Stanza Tests
run: |
# set up environment
echo "Setting up environment..."
bash
#. $CONDA_PREFIX/etc/profile.d/conda.sh
. /home/stanzabuild/miniconda3/etc/profile.d/conda.sh
conda activate stanza
export STANZA_TEST_HOME=/scr/stanza_test
export CORENLP_HOME=$STANZA_TEST_HOME/corenlp_dir
export CLASSPATH=$CORENLP_HOME/*:
echo CORENLP_HOME=$CORENLP_HOME
echo CLASSPATH=$CLASSPATH
# install from stanza repo being evaluated
echo PWD: $pwd
echo PATH: $PATH
pip3 install -e .
pip3 install -e .[test]
pip3 install -e .[transformers]
pip3 install -e .[tokenizers]
pip3 install -e .[morphseg]
# set up for tests
echo "Running stanza test set up..."
rm -rf $STANZA_TEST_HOME
python3 stanza/tests/setup.py
# run tests
echo "Running tests..."
export CUDA_VISIBLE_DEVICES=2
pytest stanza/tests
- run: echo "🍏 This job's status is ${{ job.status }}."
================================================
FILE: .gitignore
================================================
# kept from original
.DS_Store
*.tmp
*.pkl
*.conllu
*.lem
*.toklabels
# also data w/o any slash to account for symlinks
data
data/
stanza_resources/
stanza_test/
saved_models/
logs/
log/
*_test_treebanks
wandb/
params/*/*.json
!params/*/default.json
# emacs backup files
*~
# VI backup files?
*py.swp
# standard github python project gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDE-related
.vscode/
.idea/vcs.xml
.idea/inspectionProfiles/profiles_settings.xml
.idea/workspace.xml
# Jekyll stuff, triggered by running the docs locally
.jekyll-cache/
.jekyll-metadata
_site/
# symlink / directory for data files
extern_data
================================================
FILE: .travis.yml
================================================
language: python
python:
- 3.6.5
notifications:
email: false
install:
- pip install --quiet .
- export CORENLP_HOME=~/corenlp-latest CORENLP_VERSION=stanford-corenlp-latest
- export CORENLP_URL="http://nlp.stanford.edu/software/${CORENLP_VERSION}.zip"
- wget $CORENLP_URL -O corenlp-latest.zip
- unzip corenlp-latest.zip > unzip.log
- export CORENLP_UNZIP=`grep creating unzip.log | head -n 1 | cut -d ":" -f 2`
- mv $CORENLP_UNZIP $CORENLP_HOME
- mkdir ~/stanza_test
- mkdir ~/stanza_test/in
- mkdir ~/stanza_test/out
- mkdir ~/stanza_test/scripts
- cp tests/data/external_server.properties ~/stanza_test/scripts
- cp tests/data/example_french.json ~/stanza_test/out
- cp tests/data/tiny_emb.* ~/stanza_test/in
- export STANZA_TEST_HOME=~/stanza_test
script:
- python -m pytest -m travis tests/
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to Stanza
We would love to see contributions to Stanza from the community! Contributions that we welcome include bugfixes and enhancements. If you want to report a bug or suggest a feature but don't intend to fix or implement it by yourself, please create a corresponding issue on [our issues page](https://github.com/stanfordnlp/stanza/issues). If you plan to contribute a bugfix or enhancement, please read the following.
## 🛠️ Bugfixes
For bugfixes, please follow these steps:
- Make sure a fix does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues) (including closed ones) and [pull requests](https://github.com/stanfordnlp/stanza/pulls).
- Confirm the bug with us by creating a bug-report issue. In your issue, you should at least include the platform and environment that you are running with, and a minimal code snippet that will reproduce the bug.
- Once the bug is confirmed, you can go ahead with implementing the bugfix, and create a pull request **against the `dev` branch**.
## 💡 Enhancements
For enhancements, please follow these steps:
- Make sure a similar enhancement suggestion does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues).
- Create a feature-request issue and discuss about this enhancement with us. We'll need to make sure this enhancement won't break existing user interface and functionalities.
- Once the enhancement is confirmed with us, you can go ahead with implementing it, and create a pull request **against the `dev` branch**.
================================================
FILE: LICENSE
================================================
Copyright 2019 The Board of Trustees of The Leland Stanford Junior University
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Stanza: A Python NLP Library for Many Human Languages
The Stanford NLP Group's official Python NLP library. It contains support for running various accurate natural language processing tools on 60+ languages and for accessing the Java Stanford CoreNLP software from Python. For detailed information please visit our [official website](https://stanfordnlp.github.io/stanza/).
🔥 A new collection of **biomedical** and **clinical** English model packages are now available, offering seamless experience for syntactic analysis and named entity recognition (NER) from biomedical literature text and clinical notes. For more information, check out our [Biomedical models documentation page](https://stanfordnlp.github.io/stanza/biomed.html).
### References
If you use this library in your research, please kindly cite our [ACL2020 Stanza system demo paper](https://arxiv.org/abs/2003.07082):
```bibtex
@inproceedings{qi2020stanza,
title={Stanza: A {Python} Natural Language Processing Toolkit for Many Human Languages},
author={Qi, Peng and Zhang, Yuhao and Zhang, Yuhui and Bolton, Jason and Manning, Christopher D.},
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
year={2020}
}
```
If you use our biomedical and clinical models, please also cite our [Stanza Biomedical Models description paper](https://arxiv.org/abs/2007.14640):
```bibtex
@article{zhang2021biomedical,
author = {Zhang, Yuhao and Zhang, Yuhui and Qi, Peng and Manning, Christopher D and Langlotz, Curtis P},
title = {Biomedical and clinical {E}nglish model packages for the {S}tanza {P}ython {NLP} library},
journal = {Journal of the American Medical Informatics Association},
year = {2021},
month = {06},
issn = {1527-974X}
}
```
The PyTorch implementation of the neural pipeline in this repository is due to [Peng Qi](http://qipeng.me) (@qipeng), [Yuhao Zhang](http://yuhao.im) (@yuhaozhang), and [Yuhui Zhang](https://cs.stanford.edu/~yuhuiz/) (@yuhui-zh15), with help from [Jason Bolton](mailto:jebolton@stanford.edu) (@j38), [Tim Dozat](https://web.stanford.edu/~tdozat/) (@tdozat) and [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/) (@AngledLuffa). Maintenance of this repo is currently led by [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/).
If you use the CoreNLP software through Stanza, please cite the CoreNLP software package and the respective modules as described [here](https://stanfordnlp.github.io/CoreNLP/#citing-stanford-corenlp-in-papers) ("Citing Stanford CoreNLP in papers"). The CoreNLP client is mostly written by [Arun Chaganty](http://arun.chagantys.org/), and [Jason Bolton](mailto:jebolton@stanford.edu) spearheaded merging the two projects together.
If you use the Semgrex or Ssurgeon part of CoreNLP, please cite [our GURT paper on Semgrex and Ssurgeon](https://aclanthology.org/2023.tlt-1.7/):
```bibtex
@inproceedings{bauer-etal-2023-semgrex,
title = "Semgrex and Ssurgeon, Searching and Manipulating Dependency Graphs",
author = "Bauer, John and
Kiddon, Chlo{\'e} and
Yeh, Eric and
Shan, Alex and
D. Manning, Christopher",
booktitle = "Proceedings of the 21st International Workshop on Treebanks and Linguistic Theories (TLT, GURT/SyntaxFest 2023)",
month = mar,
year = "2023",
address = "Washington, D.C.",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.tlt-1.7",
pages = "67--73",
abstract = "Searching dependency graphs and manipulating them can be a time consuming and challenging task to get right. We document Semgrex, a system for searching dependency graphs, and introduce Ssurgeon, a system for manipulating the output of Semgrex. The compact language used by these systems allows for easy command line or API processing of dependencies. Additionally, integration with publicly released toolkits in Java and Python allows for searching text relations and attributes over natural text.",
}
```
## Issues and Usage Q&A
To ask questions, report issues or request features 🤔, please use the [GitHub Issue Tracker](https://github.com/stanfordnlp/stanza/issues). Before creating a new issue, please make sure to search for existing issues that may solve your problem, or visit the [Frequently Asked Questions (FAQ) page](https://stanfordnlp.github.io/stanza/faq.html) on our website.
## Contributing to Stanza
We welcome community contributions to Stanza in the form of bugfixes 🛠️ and enhancements 💡! If you want to contribute, please first read [our contribution guideline](CONTRIBUTING.md).
## Installation
### pip
Stanza supports Python 3.6 or later. We recommend that you install Stanza via [pip](https://pip.pypa.io/en/stable/installing/), the Python package manager. To install, simply run:
```bash
pip install stanza
```
This should also help resolve all of the dependencies of Stanza, for instance [PyTorch](https://pytorch.org/) 1.3.0 or above.
If you currently have a previous version of `stanza` installed, use:
```bash
pip install stanza -U
```
### Anaconda
To install Stanza via Anaconda, use the following conda command:
```bash
conda install -c stanfordnlp stanza
```
Note that for now installing Stanza via Anaconda does not work for Python 3.10. For Python 3.10 please use pip installation.
### From Source
Alternatively, you can also install from source of this git repository, which will give you more flexibility in developing on top of Stanza. For this option, run
```bash
git clone https://github.com/stanfordnlp/stanza.git
cd stanza
pip install -e .
```
## Running Stanza
### Getting Started with the neural pipeline
To run your first Stanza pipeline, simply follow these steps in your Python interactive interpreter:
```python
>>> import stanza
>>> stanza.download('en') # Optional: pre-download English models (Pipeline can auto-download if needed)
>>> nlp = stanza.Pipeline('en') # This sets up a default neural pipeline in English
>>> doc = nlp("Barack Obama was born in Hawaii. He was elected president in 2008.")
>>> doc.sentences[0].print_dependencies()
```
If you encounter `requests.exceptions.ConnectionError`, please try to use a proxy:
```python
>>> import stanza
>>> proxies = {'http': 'http://ip:port', 'https': 'http://ip:port'}
>>> stanza.download('en', proxies=proxies) # Optional: pre-download English models (Pipeline can auto-download if needed)
>>> nlp = stanza.Pipeline('en') # This sets up a default neural pipeline in English
>>> doc = nlp("Barack Obama was born in Hawaii. He was elected president in 2008.")
>>> doc.sentences[0].print_dependencies()
```
The last command will print out the words in the first sentence in the input string (or [`Document`](https://stanfordnlp.github.io/stanza/data_objects.html#document), as it is represented in Stanza), as well as the indices for the word that governs it in the Universal Dependencies parse of that sentence (its "head"), along with the dependency relation between the words. The output should look like:
```
('Barack', '4', 'nsubj:pass')
('Obama', '1', 'flat')
('was', '4', 'aux:pass')
('born', '0', 'root')
('in', '6', 'case')
('Hawaii', '4', 'obl')
('.', '4', 'punct')
```
See [our getting started guide](https://stanfordnlp.github.io/stanza/installation_usage.html#getting-started) for more details.
### Accessing Java Stanford CoreNLP software
Aside from the neural pipeline, this package also includes an official wrapper for accessing the Java Stanford CoreNLP software with Python code.
There are a few initial setup steps.
* Download [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) and models for the language you wish to use
* Put the model jars in the distribution folder
* Tell the Python code where Stanford CoreNLP is located by setting the `CORENLP_HOME` environment variable (e.g., in *nix): `export CORENLP_HOME=/path/to/stanford-corenlp-4.5.3`
We provide [comprehensive examples](https://stanfordnlp.github.io/stanza/corenlp_client.html) in our documentation that show how one can use CoreNLP through Stanza and extract various annotations from it.
### Online Colab Notebooks
To get your started, we also provide interactive Jupyter notebooks in the `demo` folder. You can also open these notebooks and run them interactively on [Google Colab](https://colab.research.google.com). To view all available notebooks, follow these steps:
* Go to the [Google Colab website](https://colab.research.google.com)
* Navigate to `File` -> `Open notebook`, and choose `GitHub` in the pop-up menu
* Note that you do **not** need to give Colab access permission to your GitHub account
* Type `stanfordnlp/stanza` in the search bar, and click enter
### Trained Models for the Neural Pipeline
We currently provide models for all of the [Universal Dependencies](https://universaldependencies.org/) treebanks v2.8, as well as NER models for a few widely-spoken languages. You can find instructions for downloading and using these models [here](https://stanfordnlp.github.io/stanza/models.html).
### Batching To Maximize Pipeline Speed
To maximize speed performance, it is essential to run the pipeline on batches of documents. Running a for loop on one sentence at a time will be very slow. The best approach at this time is to concatenate documents together, with each document separated by a blank line (i.e., two line breaks `\n\n`). The tokenizer will recognize blank lines as sentence breaks. We are actively working on improving multi-document processing.
## Training your own neural pipelines
All neural modules in this library can be trained with your own data. The tokenizer, the multi-word token (MWT) expander, the POS/morphological features tagger, the lemmatizer and the dependency parser require [CoNLL-U](https://universaldependencies.org/format.html) formatted data, while the NER model requires the BIOES format. Currently, we do not support model training via the `Pipeline` interface. Therefore, to train your own models, you need to clone this git repository and run training from the source.
For detailed step-by-step guidance on how to train and evaluate your own models, please visit our [training documentation](https://stanfordnlp.github.io/stanza/training.html).
## LICENSE
Stanza is released under the Apache License, Version 2.0. See the [LICENSE](https://github.com/stanfordnlp/stanza/blob/master/LICENSE) file for more details.
================================================
FILE: demo/CONLL_Dependency_Visualizer_Example.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c0fd86c8",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n",
"\n",
"# load necessary conllu files - expected to be in the demo directory along with the notebook\n",
"en_file = \"en_test.conllu.txt\"\n",
"\n",
"# testing left to right languages\n",
"conll_to_visual(en_file, \"en\", sent_count=2)\n",
"conll_to_visual(en_file, \"en\", sent_count=10)\n",
"#conll_to_visual(en_file, \"en\", display_all=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc4b3f9b",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n",
"\n",
"jp_file = \"japanese_test.conllu.txt\"\n",
"conll_to_visual(jp_file, \"ja\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6852b8e8",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n",
"\n",
"# testing right to left languages\n",
"ar_file = \"arabic_test.conllu.txt\"\n",
"conll_to_visual(ar_file, \"ar\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.22"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: demo/Dependency_Visualization_Testing.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "64b2a9e0",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.dependency_visualization import visualize_strings\n",
"\n",
"ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة \"ليوبارد\" الالمانية', \"هل بإمكاني مساعدتك؟\", \n",
" \"أراك في مابعد\", \"لحظة من فضلك\"]\n",
"# Testing with right to left language\n",
"visualize_strings(ar_strings, \"ar\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35ef521b",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.dependency_visualization import visualize_strings\n",
"\n",
"en_strings = [\"This is a sentence.\", \n",
" \"He is wearing a red shirt\",\n",
" \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008.\"]\n",
"# Testing with left to right languages\n",
"visualize_strings(en_strings, \"en\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3cf10ba",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.dependency_visualization import visualize_strings\n",
"\n",
"zh_strings = [\"中国是一个很有意思的国家。\"]\n",
"# Testing with right to left language\n",
"visualize_strings(zh_strings, \"zh\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2b9b574",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.22"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: demo/NER_Visualization.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "abf300bb",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.ner_visualization import visualize_strings\n",
"\n",
"en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York.\n",
" He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon.\n",
" That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away.\n",
" On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour.\n",
" In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as\n",
" \"All I want for Christmas is You\" by Mariah Carey.''', \n",
" \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008\"]\n",
" \n",
"visualize_strings(en_strings, \"en\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5670921a",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.ner_visualization import visualize_strings\n",
"\n",
"zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。\n",
" 他正在考虑参加美国公开赛,这是除了温布尔登之外他最喜欢的网球赛事。\n",
" 那将是一次梦想之旅,当然不可能,因为它的出勤费为 5000 美元,距离 5000 英里。\n",
" 在去的路上,他看了 2 个小时的超级碗比赛,看了 1 个小时的托尔斯泰的《战争与碎片》。\n",
" 在纽约,他穿过布鲁克林大桥,聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''',\n",
" \"我觉得罗家费德勒住在加州, 在美国里面。\"]\n",
"visualize_strings(zh_strings, \"zh\", colors={\"PERSON\": \"yellow\", \"DATE\": \"red\", \"GPE\": \"blue\"})\n",
"visualize_strings(zh_strings, \"zh\", select=['PERSON', 'DATE'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8d96072",
"metadata": {},
"outputs": [],
"source": [
"from stanza.utils.visualization.ner_visualization import visualize_strings\n",
"\n",
"ar_strings = [\".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ\"\n",
" , \"اسمي أليكس ، أنا من الولايات المتحدة.\", \n",
" '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك \"كل ما أريده في عيد الميلاد هو أنت\" لماريا كاري.''']\n",
"\n",
"visualize_strings(ar_strings, \"ar\", colors={\"PER\": \"pink\", \"LOC\": \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\", \"ORG\": \"yellow\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22489b27",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.22"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: demo/Stanza_Beginners_Guide.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Stanza-Beginners-Guide.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "56LiYCkPM7V_",
"colab_type": "text"
},
"source": [
"# Welcome to Stanza!\n",
"\n",
"\n",
"\n",
"\n",
"Stanza is a Python NLP toolkit that supports 60+ human languages. It is built with highly accurate neural network components that enable efficient training and evaluation with your own annotated data, and offers pretrained models on 100 treebanks. Additionally, Stanza provides a stable, officially maintained Python interface to Java Stanford CoreNLP Toolkit.\n",
"\n",
"In this tutorial, we will demonstrate how to set up Stanza and annotate text with its native neural network NLP models. For the use of the Python CoreNLP interface, please see other tutorials."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yQff4Di5Nnq0",
"colab_type": "text"
},
"source": [
"## 1. Installing Stanza\n",
"\n",
"Note that Stanza only supports Python 3.6 and above. Installing and importing Stanza are as simple as running the following commands:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "owSj1UtdEvSU",
"colab_type": "code",
"colab": {}
},
"source": [
"# Install; note that the prefix \"!\" is not needed if you are running in a terminal\n",
"!pip install stanza\n",
"\n",
"# Import the package\n",
"import stanza"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ixllwEKeCJg",
"colab_type": "text"
},
"source": [
"### More Information\n",
"\n",
"For common troubleshooting, please visit our [troubleshooting page](https://stanfordnlp.github.io/stanfordnlp/installation_usage.html#troubleshooting)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aeyPs5ARO79d",
"colab_type": "text"
},
"source": [
"## 2. Downloading Models\n",
"\n",
"You can download models with the `stanza.download` command. The language can be specified with either a full language name (e.g., \"english\"), or a short code (e.g., \"en\"). \n",
"\n",
"By default, models will be saved to your `~/stanza_resources` directory. If you want to specify your own path to save the model files, you can pass a `dir=your_path` argument.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HDwRm-KXGcYo",
"colab_type": "code",
"colab": {}
},
"source": [
"# Download an English model into the default directory\n",
"print(\"Downloading English model...\")\n",
"stanza.download('en')\n",
"\n",
"# Similarly, download a (simplified) Chinese model\n",
"# Note that you can use verbose=False to turn off all printed messages\n",
"print(\"Downloading Chinese model...\")\n",
"stanza.download('zh', verbose=False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7HCfQ0SfdmsU",
"colab_type": "text"
},
"source": [
"### More Information\n",
"\n",
"Pretrained models are provided for 60+ different languages. For all languages, available models and the corresponding short language codes, please check out the [models page](https://stanfordnlp.github.io/stanza/models.html).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b3-WZJrzWD2o",
"colab_type": "text"
},
"source": [
"## 3. Processing Text\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XrnKl2m3fq2f",
"colab_type": "text"
},
"source": [
"### Constructing Pipeline\n",
"\n",
"To process a piece of text, you'll need to first construct a `Pipeline` with different `Processor` units. The pipeline is language-specific, so again you'll need to first specify the language (see examples).\n",
"\n",
"- By default, the pipeline will include all processors, including tokenization, multi-word token expansion, part-of-speech tagging, lemmatization, dependency parsing and named entity recognition (for supported languages). However, you can always specify what processors you want to include with the `processors` argument.\n",
"\n",
"- Stanza's pipeline is CUDA-aware, meaning that a CUDA-device will be used whenever it is available, otherwise CPUs will be used when a GPU is not found. You can force the pipeline to use CPU regardless by setting `use_gpu=False`.\n",
"\n",
"- Again, you can suppress all printed messages by setting `verbose=False`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "HbiTSBDPG53o",
"colab_type": "code",
"colab": {}
},
"source": [
"# Build an English pipeline, with all processors by default\n",
"print(\"Building an English pipeline...\")\n",
"en_nlp = stanza.Pipeline('en')\n",
"\n",
"# Build a Chinese pipeline, with customized processor list and no logging, and force it to use CPU\n",
"print(\"Building a Chinese pipeline...\")\n",
"zh_nlp = stanza.Pipeline('zh', processors='tokenize,lemma,pos,depparse', verbose=False, use_gpu=False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Go123Bx8e1wt",
"colab_type": "text"
},
"source": [
"### Annotating Text\n",
"\n",
"After a pipeline is successfully constructed, you can get annotations of a piece of text simply by passing the string into the pipeline object. The pipeline will return a `Document` object, which can be used to access detailed annotations from. For example:\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "k_p0h1UTHDMm",
"colab_type": "code",
"colab": {}
},
"source": [
"# Processing English text\n",
"en_doc = en_nlp(\"Barack Obama was born in Hawaii. He was elected president in 2008.\")\n",
"print(type(en_doc))\n",
"\n",
"# Processing Chinese text\n",
"zh_doc = zh_nlp(\"达沃斯世界经济论坛是每年全球政商界领袖聚在一起的年度盛事。\")\n",
"print(type(zh_doc))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DavwCP9egzNZ",
"colab_type": "text"
},
"source": [
"### More Information\n",
"\n",
"For more information on how to construct a pipeline and information on different processors, please visit our [pipeline page](https://stanfordnlp.github.io/stanfordnlp/pipeline.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O_PYLEGziQWR",
"colab_type": "text"
},
"source": [
"## 4. Accessing Annotations\n",
"\n",
"Annotations can be accessed from the returned `Document` object. \n",
"\n",
"A `Document` contains a list of `Sentence`s, and a `Sentence` contains a list of `Token`s and `Word`s. For the most part `Token`s and `Word`s overlap, but some tokens can be divided into mutiple words, for instance the French token `aux` is divided into the words `à` and `les`, while in English a word and a token are equivalent. Note that dependency parses are derived over `Word`s.\n",
"\n",
"Additionally, a `Span` object is used to represent annotations that are part of a document, such as named entity mentions.\n",
"\n",
"\n",
"The following example iterate over all English sentences and words, and print the word information one by one:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "B5691SpFHFZ6",
"colab_type": "code",
"colab": {}
},
"source": [
"for i, sent in enumerate(en_doc.sentences):\n",
" print(\"[Sentence {}]\".format(i+1))\n",
" for word in sent.words:\n",
" print(\"{:12s}\\t{:12s}\\t{:6s}\\t{:d}\\t{:12s}\".format(\\\n",
" word.text, word.lemma, word.pos, word.head, word.deprel))\n",
" print(\"\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-AUkCkNIrusq",
"colab_type": "text"
},
"source": [
"The following example iterate over all extracted named entity mentions and print out their character spans and types."
]
},
{
"cell_type": "code",
"metadata": {
"id": "5Uu0-WmvsnlK",
"colab_type": "code",
"colab": {}
},
"source": [
"print(\"Mention text\\tType\\tStart-End\")\n",
"for ent in en_doc.ents:\n",
" print(\"{}\\t{}\\t{}-{}\".format(ent.text, ent.type, ent.start_char, ent.end_char))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ql1SZlZOnMLo",
"colab_type": "text"
},
"source": [
"And similarly for the Chinese text:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XsVcEO9tHKPG",
"colab_type": "code",
"colab": {}
},
"source": [
"for i, sent in enumerate(zh_doc.sentences):\n",
" print(\"[Sentence {}]\".format(i+1))\n",
" for word in sent.words:\n",
" print(\"{:12s}\\t{:12s}\\t{:6s}\\t{:d}\\t{:12s}\".format(\\\n",
" word.text, word.lemma, word.pos, word.head, word.deprel))\n",
" print(\"\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dUhWAs8pnnHT",
"colab_type": "text"
},
"source": [
"Alternatively, you can directly print a `Word` object to view all its annotations as a Python dict:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6_UafNb7HHIg",
"colab_type": "code",
"colab": {}
},
"source": [
"word = en_doc.sentences[0].words[0]\n",
"print(word)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TAQlOsuRoq2V",
"colab_type": "text"
},
"source": [
"### More Information\n",
"\n",
"For all information on different data objects, please visit our [data objects page](https://stanfordnlp.github.io/stanza/data_objects.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiiWHxYPpmhd",
"colab_type": "text"
},
"source": [
"## 5. Resources\n",
"\n",
"Apart from this interactive tutorial, we also provide tutorials on our website that cover a variety of use cases such as how to use different model \"packages\" for a language, how to use spaCy as a tokenizer, how to process pretokenized text without running the tokenizer, etc. For these tutorials please visit [our Tutorials page](https://stanfordnlp.github.io/stanza/tutorials.html).\n",
"\n",
"Other resources that you may find helpful include:\n",
"\n",
"- [Stanza Homepage](https://stanfordnlp.github.io/stanza/index.html)\n",
"- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n",
"- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n",
"- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n",
"- [Stanza System Description Paper](http://arxiv.org/abs/2003.07082)\n"
]
}
]
}
================================================
FILE: demo/Stanza_CoreNLP_Interface.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Stanza-CoreNLP-Interface.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "2-4lzQTC9yxG",
"colab_type": "text"
},
"source": [
"# Stanza: A Tutorial on the Python CoreNLP Interface\n",
"\n",
"\n",
"\n",
"\n",
"While the Stanza library implements accurate neural network modules for basic functionalities such as part-of-speech tagging and dependency parsing, the [Stanford CoreNLP Java library](https://stanfordnlp.github.io/CoreNLP/) has been developed for years and offers more complementary features such as coreference resolution and relation extraction. To unlock these features, the Stanza library also offers an officially maintained Python interface to the CoreNLP Java library. This interface allows you to get NLP anntotations from CoreNLP by writing native Python code.\n",
"\n",
"\n",
"This tutorial walks you through the installation, setup and basic usage of this Python CoreNLP interface. If you want to learn how to use the neural network components in Stanza, please refer to other tutorials."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YpKwWeVkASGt",
"colab_type": "text"
},
"source": [
"## 1. Installation\n",
"\n",
"Before the installation starts, please make sure that you have Python 3 and Java installed on your computer. Since Colab already has them installed, we'll skip this procedure in this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k1Az2ECuAfG8",
"colab_type": "text"
},
"source": [
"### Installing Stanza\n",
"\n",
"Installing and importing Stanza are as simple as running the following commands:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xiFwYAgW4Mss",
"colab_type": "code",
"colab": {}
},
"source": [
"# Install stanza; note that the prefix \"!\" is not needed if you are running in a terminal\n",
"!pip install stanza\n",
"\n",
"# Import stanza\n",
"import stanza"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zFvaA8_A32_",
"colab_type": "text"
},
"source": [
"### Setting up Stanford CoreNLP\n",
"\n",
"In order for the interface to work, the Stanford CoreNLP library has to be installed and a `CORENLP_HOME` environment variable has to be pointed to the installation location.\n",
"\n",
"Here we are going to show you how to download and install the CoreNLP library on your machine, with Stanza's installation command:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MgK6-LPV-OdA",
"colab_type": "code",
"colab": {}
},
"source": [
"# Download the Stanford CoreNLP package with Stanza's installation command\n",
"# This'll take several minutes, depending on the network speed\n",
"corenlp_dir = './corenlp'\n",
"stanza.install_corenlp(dir=corenlp_dir)\n",
"\n",
"# Set the CORENLP_HOME environment variable to point to the installation location\n",
"import os\n",
"os.environ[\"CORENLP_HOME\"] = corenlp_dir"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jdq8MT-NAhKj",
"colab_type": "text"
},
"source": [
"That's all for the installation! 🎉 We can now double check if the installation is successful by listing files in the CoreNLP directory. You should be able to see a number of `.jar` files by running the following command:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "K5eIOaJp_tuo",
"colab_type": "code",
"colab": {}
},
"source": [
"# Examine the CoreNLP installation folder to make sure the installation is successful\n",
"!ls $CORENLP_HOME"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "S0xb9BHt__gx",
"colab_type": "text"
},
"source": [
"**Note 1**:\n",
"If you are want to use the interface in a terminal (instead of a Colab notebook), you can properly set the `CORENLP_HOME` environment variable with:\n",
"\n",
"```bash\n",
"export CORENLP_HOME=path_to_corenlp_dir\n",
"```\n",
"\n",
"Here we instead set this variable with the Python `os` library, simply because `export` command is not well-supported in Colab notebook.\n",
"\n",
"\n",
"**Note 2**:\n",
"The `stanza.install_corenlp()` function is only available since Stanza v1.1.1. If you are using an earlier version of Stanza, please check out our [manual installation page](https://stanfordnlp.github.io/stanza/client_setup.html#manual-installation) for how to install CoreNLP on your computer.\n",
"\n",
"**Note 3**:\n",
"Besides the installation function, we also provide a `stanza.download_corenlp_models()` function to help you download additional CoreNLP models for different languages that are not shipped with the default installation. Check out our [automatic installation website page](https://stanfordnlp.github.io/stanza/client_setup.html#automated-installation) for more information on how to use it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xJsuO6D8D05q",
"colab_type": "text"
},
"source": [
"## 2. Annotating Text with CoreNLP Interface"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dZNHxXHkH1K2",
"colab_type": "text"
},
"source": [
"### Constructing CoreNLPClient\n",
"\n",
"At a high level, the CoreNLP Python interface works by first starting a background Java CoreNLP server process, and then initializing a client instance in Python which can pass the text to the background server process, and accept the returned annotation results.\n",
"\n",
"We wrap these functionalities in a `CoreNLPClient` class. Therefore, we need to start by importing this class from Stanza."
]
},
{
"cell_type": "code",
"metadata": {
"id": "LS4OKnqJ8wui",
"colab_type": "code",
"colab": {}
},
"source": [
"# Import client module\n",
"from stanza.server import CoreNLPClient"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "WP4Dz6PIJHeL",
"colab_type": "text"
},
"source": [
"After the import is done, we can construct a `CoreNLPClient` instance. The constructor method takes a Python list of annotator names as argument. Here let's explore some basic annotators including tokenization, sentence split, part-of-speech tagging, lemmatization and named entity recognition (NER). \n",
"\n",
"Additionally, the client constructor accepts a `memory` argument, which specifies how much memory will be allocated to the background Java process. An `endpoint` option can be used to specify a port number used by the communication between the server and the client. The default port is 9000. However, since this port is pre-occupied by a system process in Colab, we'll manually set it to 9001 in the following example.\n",
"\n",
"Also, here we manually set `be_quiet=True` to avoid an IO issue in colab notebook. You should be able to use `be_quiet=False` on your own computer, which will print detailed logging information from CoreNLP during usage.\n",
"\n",
"For more options in constructing the clients, please refer to the [CoreNLP Client Options List](https://stanfordnlp.github.io/stanza/corenlp_client.html#corenlp-client-options)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "mbOBugvd9JaM",
"colab_type": "code",
"colab": {}
},
"source": [
"# Construct a CoreNLPClient with some basic annotators, a memory allocation of 4GB, and port number 9001\n",
"client = CoreNLPClient(\n",
" annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
" memory='4G', \n",
" endpoint='http://localhost:9001',\n",
" be_quiet=True)\n",
"print(client)\n",
"\n",
"# Start the background server and wait for some time\n",
"# Note that in practice this is totally optional, as by default the server will be started when the first annotation is performed\n",
"client.start()\n",
"import time; time.sleep(10)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kgTiVjNydmIW",
"colab_type": "text"
},
"source": [
"After the above code block finishes executing, if you print the background processes, you should be able to find the Java CoreNLP server running."
]
},
{
"cell_type": "code",
"metadata": {
"id": "spZrJ-oFdkdF",
"colab_type": "code",
"colab": {}
},
"source": [
"# Print background processes and look for java\n",
"# You should be able to see a StanfordCoreNLPServer java process running in the background\n",
"!ps -o pid,cmd | grep java"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KxJeJ0D2LoOs",
"colab_type": "text"
},
"source": [
"### Annotating Text\n",
"\n",
"Annotating a piece of text is as simple as passing the text into an `annotate` function of the client object. After the annotation is complete, a `Document` object will be returned with all annotations.\n",
"\n",
"Note that although in general annotations are very fast, the first annotation might take a while to complete in the notebook. Please stay patient."
]
},
{
"cell_type": "code",
"metadata": {
"id": "s194RnNg5z95",
"colab_type": "code",
"colab": {}
},
"source": [
"# Annotate some text\n",
"text = \"Albert Einstein was a German-born theoretical physicist. He developed the theory of relativity.\"\n",
"document = client.annotate(text)\n",
"print(type(document))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "semmA3e0TcM1",
"colab_type": "text"
},
"source": [
"## 3. Accessing Annotations\n",
"\n",
"Annotations can be accessed from the returned `Document` object.\n",
"\n",
"A `Document` contains a list of `Sentence`s, which contain a list of `Token`s. Here let's first explore the annotations stored in all tokens."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lIO4B5d6Rk4I",
"colab_type": "code",
"colab": {}
},
"source": [
"# Iterate over all tokens in all sentences, and print out the word, lemma, pos and ner tags\n",
"print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(\"Word\", \"Lemma\", \"POS\", \"NER\"))\n",
"\n",
"for i, sent in enumerate(document.sentence):\n",
" print(\"[Sentence {}]\".format(i+1))\n",
" for t in sent.token:\n",
" print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(t.word, t.lemma, t.pos, t.ner))\n",
" print(\"\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "msrJfvu8VV9m",
"colab_type": "text"
},
"source": [
"Alternatively, you can also browse the NER results by iterating over entity mentions over the sentences. For example:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ezEjc9LeV2Xs",
"colab_type": "code",
"colab": {}
},
"source": [
"# Iterate over all detected entity mentions\n",
"print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
"\n",
"for sent in document.sentence:\n",
" for m in sent.mentions:\n",
" print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ueGzBZ3hWzkN",
"colab_type": "text"
},
"source": [
"To print all annotations a sentence, token or mention has, you can simply print the corresponding obejct."
]
},
{
"cell_type": "code",
"metadata": {
"id": "4_S8o2BHXIed",
"colab_type": "code",
"colab": {}
},
"source": [
"# Print annotations of a token\n",
"print(document.sentence[0].token[0])\n",
"\n",
"# Print annotations of a mention\n",
"print(document.sentence[0].mentions[0])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qp66wjZ10xia",
"colab_type": "text"
},
"source": [
"**Note**: Since the Stanza CoreNLP client interface simply ports the CoreNLP annotation results to native Python objects, for a comprehensive lists of available annotators and how their annotation results can be accessed, you will need to visit the [Stanford CoreNLP website](https://stanfordnlp.github.io/CoreNLP/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IPqzMK90X0w3",
"colab_type": "text"
},
"source": [
"## 4. Shutting Down the CoreNLP Server\n",
"\n",
"To shut down the background CoreNLP server process, simply call the `stop` function of the client. Note that once a server is shutdown, you'll have to restart the server with the `start()` function before any annotation is requested."
]
},
{
"cell_type": "code",
"metadata": {
"id": "xrJq8lZ3Nw7b",
"colab_type": "code",
"colab": {}
},
"source": [
"# Shut down the background CoreNLP server\n",
"client.stop()\n",
"\n",
"time.sleep(10)\n",
"!ps -o pid,cmd | grep java"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "23Vwa_ifYfF7",
"colab_type": "text"
},
"source": [
"### More Information\n",
"\n",
"For more information on how to use the `CoreNLPClient`, please go to the [CoreNLPClient documentation page](https://stanfordnlp.github.io/stanza/corenlp_client.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YUrVT6kA_Bzx",
"colab_type": "text"
},
"source": [
"## 5. Simplifying Client Usage with the Python `with` statement\n",
"\n",
"In the above demo, we explicitly called the `client.start()` and `client.stop()` functions to start and stop a client-server connection. However, doing this in practice is usually suboptimal, since you may forget to call the `stop()` function at the end, resulting in an unused server process occupying your machine memory.\n",
"\n",
"To solve is, a simple solution is to use the client interface with the [Python `with` statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement). The `with` statement provides an elegant way to automatically start and stop the server process in your Python program, without you needing to worry about this. The following code snippet demonstrates how to establish a client, annotate an example text and then stop the server with a simple `with` statement. Note that we **always recommend** you to use the `with` statement when working with the Stanza CoreNLP client interface."
]
},
{
"cell_type": "code",
"metadata": {
"id": "H0ct2-R4AvJh",
"colab_type": "code",
"colab": {}
},
"source": [
"print(\"Starting a server with the Python \\\"with\\\" statement...\")\n",
"with CoreNLPClient(annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
" memory='4G', endpoint='http://localhost:9001', be_quiet=True) as client:\n",
" text = \"Albert Einstein was a German-born theoretical physicist.\"\n",
" document = client.annotate(text)\n",
"\n",
" print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
" for sent in document.sentence:\n",
" for m in sent.mentions:\n",
" print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))\n",
"\n",
"print(\"\\nThe server should be stopped upon exit from the \\\"with\\\" statement.\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W435Lwc4YqKb",
"colab_type": "text"
},
"source": [
"## 6. Other Resources\n",
"\n",
"- [Stanza Homepage](https://stanfordnlp.github.io/stanza/)\n",
"- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n",
"- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n",
"- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n"
]
}
]
}
================================================
FILE: demo/arabic_test.conllu.txt
================================================
# newdoc id = assabah.20041005.0017
# newpar id = assabah.20041005.0017:p1
# sent_id = assabah.20041005.0017:p1u1
# text = سوريا: تعديل وزاري واسع يشمل 8 حقائب
# orig_file_sentence ASB_ARB_20041005.0017#1
1 سوريا سُورِيَا X X--------- Foreign=Yes 0 root 0:root SpaceAfter=No|Vform=سُورِيَا|Gloss=Syria|Root=sUr|Translit=sūriyā|LTranslit=sūriyā
2 : : PUNCT G--------- _ 1 punct 1:punct Vform=:|Translit=:
3 تعديل تَعدِيل NOUN N------S1I Case=Nom|Definite=Ind|Number=Sing 6 nsubj 6:nsubj Vform=تَعدِيلٌ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlun|LTranslit=taʿdīl
4 وزاري وِزَارِيّ ADJ A-----MS1I Case=Nom|Definite=Ind|Gender=Masc|Number=Sing 3 amod 3:amod Vform=وِزَارِيٌّ|Gloss=ministry,ministerial|Root=w_z_r|Translit=wizārīyun|LTranslit=wizārīy
5 واسع وَاسِع ADJ A-----MS1I Case=Nom|Definite=Ind|Gender=Masc|Number=Sing 3 amod 3:amod Vform=وَاسِعٌ|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿun|LTranslit=wāsiʿ
6 يشمل شَمِل VERB VIIA-3MS-- Aspect=Imp|Gender=Masc|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act 1 parataxis 1:parataxis Vform=يَشمَلُ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=yašmalu|LTranslit=šamil
7 8 8 NUM Q--------- NumForm=Digit 6 obj 6:obj Vform=٨|Translit=8
8 حقائب حَقِيبَة NOUN N------P2I Case=Gen|Definite=Ind|Number=Plur 7 nmod 7:nmod:gen Vform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat
# newpar id = assabah.20041005.0017:p2
# sent_id = assabah.20041005.0017:p2u1
# text = دمشق (وكالات الانباء) - اجرى الرئيس السوري بشار الاسد تعديلا حكومياً واسعا تم بموجبه إقالة وزيري الداخلية والاعلام عن منصبيها في حين ظل محمد ناجي العطري رئيساً للحكومة.
# orig_file_sentence ASB_ARB_20041005.0017#2
1 دمشق دمشق X U--------- _ 0 root 0:root Vform=دمشق|Root=OOV|Translit=dmšq
2 ( ( PUNCT G--------- _ 3 punct 3:punct SpaceAfter=No|Vform=(|Translit=(
3 وكالات وِكَالَة NOUN N------P1R Case=Nom|Definite=Cons|Number=Plur 1 dep 1:dep Vform=وِكَالَاتُ|Gloss=agency|Root=w_k_l|Translit=wikālātu|LTranslit=wikālat
4 الانباء نَبَأ NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 3 nmod 3:nmod:gen SpaceAfter=No|Vform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ
5 ) ) PUNCT G--------- _ 3 punct 3:punct Vform=)|Translit=)
6 - - PUNCT G--------- _ 1 punct 1:punct Vform=-|Translit=-
7 اجرى أَجرَى VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 1 advcl 1:advcl:فِي_حِينَ Vform=أَجرَى|Gloss=conduct,carry_out,perform|Root=^g_r_y|Translit=ʾaǧrā|LTranslit=ʾaǧrā
8 الرئيس رَئِيس NOUN N------S1D Case=Nom|Definite=Def|Number=Sing 7 nsubj 7:nsubj Vform=اَلرَّئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=ar-raʾīsu|LTranslit=raʾīs
9 السوري سُورِيّ ADJ A-----MS1D Case=Nom|Definite=Def|Gender=Masc|Number=Sing 8 amod 8:amod Vform=اَلسُّورِيُّ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyu|LTranslit=sūrīy
10 بشار بشار X U--------- _ 11 nmod 11:nmod Vform=بشار|Root=OOV|Translit=bšār
11 الاسد الاسد X U--------- _ 8 nmod 8:nmod Vform=الاسد|Root=OOV|Translit=ālāsd
12 تعديلا تَعدِيل NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 7 obj 7:obj Vform=تَعدِيلًا|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlan|LTranslit=taʿdīl
13 حكومياً حُكُومِيّ ADJ A-----MS4I Case=Acc|Definite=Ind|Gender=Masc|Number=Sing 12 amod 12:amod Vform=حُكُومِيًّا|Gloss=governmental,state,official|Root=.h_k_m|Translit=ḥukūmīyan|LTranslit=ḥukūmīy
14 واسعا وَاسِع ADJ A-----MS4I Case=Acc|Definite=Ind|Gender=Masc|Number=Sing 12 amod 12:amod Vform=وَاسِعًا|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿan|LTranslit=wāsiʿ
15 تم تَمّ VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 12 acl 12:acl Vform=تَمَّ|Gloss=conclude,take_place|Root=t_m_m|Translit=tamma|LTranslit=tamm
16-18 بموجبه _ _ _ _ _ _ _ _
16 ب بِ ADP P--------- AdpType=Prep 18 case 18:case Vform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi
17 موجب مُوجِب NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 16 fixed 16:fixed Vform=مُوجِبِ|Gloss=reason,motive|Root=w_^g_b|Translit=mūǧibi|LTranslit=mūǧib
18 ه هُوَ PRON SP---3MS2- Case=Gen|Gender=Masc|Number=Sing|Person=3|PronType=Prs 15 nmod 15:nmod:بِ_مُوجِب:gen Vform=هِ|Gloss=he,she,it|Translit=hi|LTranslit=huwa
19 إقالة إِقَالَة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 15 nsubj 15:nsubj Vform=إِقَالَةُ|Gloss=dismissal,discharge|Root=q_y_l|Translit=ʾiqālatu|LTranslit=ʾiqālat
20 وزيري وَزِير NOUN N------D2R Case=Gen|Definite=Cons|Number=Dual 19 nmod 19:nmod:gen Vform=وَزِيرَي|Gloss=minister|Root=w_z_r|Translit=wazīray|LTranslit=wazīr
21 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 20 amod 20:amod Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy
22-23 والاعلام _ _ _ _ _ _ _ _
22 و وَ CCONJ C--------- _ 23 cc 23:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa
23 الإعلام إِعلَام NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 21 conj 20:amod|21:conj Vform=اَلإِعلَامِ|Gloss=information,media|Root=`_l_m|Translit=al-ʾiʿlāmi|LTranslit=ʾiʿlām
24 عن عَن ADP P--------- AdpType=Prep 25 case 25:case Vform=عَن|Gloss=about,from|Root=`an|Translit=ʿan|LTranslit=ʿan
25-26 منصبيها _ _ _ _ _ _ _ _
25 منصبي مَنصِب NOUN N------D2R Case=Gen|Definite=Cons|Number=Dual 19 nmod 19:nmod:عَن:gen Vform=مَنصِبَي|Gloss=post,position,office|Root=n_.s_b|Translit=manṣibay|LTranslit=manṣib
26 ها هُوَ PRON SP---3FS2- Case=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs 25 nmod 25:nmod:gen Vform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa
27 في فِي ADP P--------- AdpType=Prep 7 mark 7:mark Vform=فِي|Gloss=in|Root=fI|Translit=fī|LTranslit=fī
28 حين حِينَ ADP PI------2- AdpType=Prep|Case=Gen 7 mark 7:mark Vform=حِينِ|Gloss=when|Root=.h_y_n|Translit=ḥīni|LTranslit=ḥīna
29 ظل ظَلّ VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 7 parataxis 7:parataxis Vform=ظَلَّ|Gloss=remain,continue|Root=.z_l_l|Translit=ẓalla|LTranslit=ẓall
30 محمد محمد X U--------- _ 32 nmod 32:nmod Vform=محمد|Root=OOV|Translit=mḥmd
31 ناجي ناجي X U--------- _ 32 nmod 32:nmod Vform=ناجي|Root=OOV|Translit=nāǧy
32 العطري العطري X U--------- _ 29 nsubj 29:nsubj Vform=العطري|Root=OOV|Translit=ālʿṭry
33 رئيساً رَئِيس NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 29 xcomp 29:xcomp Vform=رَئِيسًا|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsan|LTranslit=raʾīs
34-35 للحكومة _ _ _ _ _ _ _ SpaceAfter=No
34 ل لِ ADP P--------- AdpType=Prep 35 case 35:case Vform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li
35 الحكومة حُكُومَة NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 33 nmod 33:nmod:لِ:gen Vform=اَلحُكُومَةِ|Gloss=government,administration|Root=.h_k_m|Translit=al-ḥukūmati|LTranslit=ḥukūmat
36 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=.
# newpar id = assabah.20041005.0017:p3
# sent_id = assabah.20041005.0017:p3u1
# text = واضافت المصادر ان مهدي دخل الله رئيس تحرير صحيفة الحزب الحاكم والليبرالي التوجهات تسلم منصب وزير الاعلام خلفا لاحمد الحسن فيما تسلم اللواء غازي كنعان رئيس شعبة الامن السياسي منصب وزير الداخلية.
# orig_file_sentence ASB_ARB_20041005.0017#3
1-2 واضافت _ _ _ _ _ _ _ _
1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa
2 أضافت أَضَاف VERB VP-A-3FS-- Aspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act 1 parataxis 1:parataxis Vform=أَضَافَت|Gloss=add,attach,receive_as_guest|Root=.d_y_f|Translit=ʾaḍāfat|LTranslit=ʾaḍāf
3 المصادر مَصدَر NOUN N------P1D Case=Nom|Definite=Def|Number=Plur 2 nsubj 2:nsubj Vform=اَلمَصَادِرُ|Gloss=source|Root=.s_d_r|Translit=al-maṣādiru|LTranslit=maṣdar
4 ان أَنَّ SCONJ C--------- _ 16 mark 16:mark Vform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna
5 مهدي مهدي X U--------- _ 6 nmod 6:nmod Vform=مهدي|Root=OOV|Translit=mhdy
6 دخل دخل X U--------- _ 16 nsubj 16:nsubj Vform=دخل|Root=OOV|Translit=dḫl
7 الله الله X U--------- _ 6 nmod 6:nmod Vform=الله|Root=OOV|Translit=āllh
8 رئيس رَئِيس NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 6 nmod 6:nmod:acc Vform=رَئِيسَ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsa|LTranslit=raʾīs
9 تحرير تَحرِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 8 nmod 8:nmod:gen Vform=تَحرِيرِ|Gloss=liberation,liberating,editorship,editing|Root=.h_r_r|Translit=taḥrīri|LTranslit=taḥrīr
10 صحيفة صَحِيفَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 9 nmod 9:nmod:gen Vform=صَحِيفَةِ|Gloss=newspaper,sheet,leaf|Root=.s_.h_f|Translit=ṣaḥīfati|LTranslit=ṣaḥīfat
11 الحزب حِزب NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 10 nmod 10:nmod:gen Vform=اَلحِزبِ|Gloss=party,band|Root=.h_z_b|Translit=al-ḥizbi|LTranslit=ḥizb
12 الحاكم حَاكِم NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 11 nmod 11:nmod:gen Vform=اَلحَاكِمِ|Gloss=ruler,governor|Root=.h_k_m|Translit=al-ḥākimi|LTranslit=ḥākim
13-14 والليبرالي _ _ _ _ _ _ _ _
13 و وَ CCONJ C--------- _ 6 cc 6:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa
14 الليبرالي لِيبِرَالِيّ ADJ A-----MS4D Case=Acc|Definite=Def|Gender=Masc|Number=Sing 6 amod 6:amod Vform=اَللِّيبِرَالِيَّ|Gloss=liberal|Root=lIbirAl|Translit=al-lībirālīya|LTranslit=lībirālīy
15 التوجهات تَوَجُّه NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 14 nmod 14:nmod:gen Vform=اَلتَّوَجُّهَاتِ|Gloss=attitude,approach|Root=w_^g_h|Translit=at-tawaǧǧuhāti|LTranslit=tawaǧǧuh
16 تسلم تَسَلَّم VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 2 ccomp 2:ccomp Vform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam
17 منصب مَنصِب NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 16 obj 16:obj Vform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib
18 وزير وَزِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 17 nmod 17:nmod:gen Vform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr
19 الاعلام عَلَم NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 18 nmod 18:nmod:gen Vform=اَلأَعلَامِ|Gloss=flag,banner,badge|Root=`_l_m|Translit=al-ʾaʿlāmi|LTranslit=ʿalam
20 خلفا خَلَف NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 16 obl 16:obl:acc Vform=خَلَفًا|Gloss=substitute,scion|Root=_h_l_f|Translit=ḫalafan|LTranslit=ḫalaf
21-22 لاحمد _ _ _ _ _ _ _ _
21 ل لِ ADP P--------- AdpType=Prep 23 case 23:case Vform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li
22 أحمد أَحمَد NOUN N------S2I Case=Gen|Definite=Ind|Number=Sing 23 nmod 23:nmod:gen Vform=أَحمَدَ|Gloss=Ahmad|Root=.h_m_d|Translit=ʾaḥmada|LTranslit=ʾaḥmad
23 الحسن الحسن X U--------- _ 20 nmod 20:nmod:لِ Vform=الحسن|Root=OOV|Translit=ālḥsn
24 فيما فِيمَا CCONJ C--------- _ 25 cc 25:cc Vform=فِيمَا|Gloss=while,during_which|Root=fI|Translit=fīmā|LTranslit=fīmā
25 تسلم تَسَلَّم VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 16 conj 2:ccomp|16:conj Vform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam
26 اللواء لِوَاء NOUN N------S1D Case=Nom|Definite=Def|Number=Sing 25 nsubj 25:nsubj Vform=اَللِّوَاءُ|Gloss=banner,flag|Root=l_w_y|Translit=al-liwāʾu|LTranslit=liwāʾ
27 غازي غازي X U--------- _ 28 nmod 28:nmod Vform=غازي|Root=OOV|Translit=ġāzy
28 كنعان كنعان X U--------- _ 26 nmod 26:nmod Vform=كنعان|Root=OOV|Translit=knʿān
29 رئيس رَئِيس NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 26 nmod 26:nmod:nom Vform=رَئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsu|LTranslit=raʾīs
30 شعبة شُعبَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 29 nmod 29:nmod:gen Vform=شُعبَةِ|Gloss=branch,subdivision|Root=^s_`_b|Translit=šuʿbati|LTranslit=šuʿbat
31 الامن أَمن NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 30 nmod 30:nmod:gen Vform=اَلأَمنِ|Gloss=security,safety|Root='_m_n|Translit=al-ʾamni|LTranslit=ʾamn
32 السياسي سِيَاسِيّ ADJ A-----MS2D Case=Gen|Definite=Def|Gender=Masc|Number=Sing 31 amod 31:amod Vform=اَلسِّيَاسِيِّ|Gloss=political|Root=s_w_s|Translit=as-siyāsīyi|LTranslit=siyāsīy
33 منصب مَنصِب NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 25 obj 25:obj Vform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib
34 وزير وَزِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 33 nmod 33:nmod:gen Vform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr
35 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 34 amod 34:amod SpaceAfter=No|Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy
36 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=.
# newpar id = assabah.20041005.0017:p4
# sent_id = assabah.20041005.0017:p4u1
# text = وذكرت وكالة الانباء السورية ان التعديل شمل ثماني حقائب بينها وزارتا الداخلية والاقتصاد.
# orig_file_sentence ASB_ARB_20041005.0017#4
1-2 وذكرت _ _ _ _ _ _ _ _
1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa
2 ذكرت ذَكَر VERB VP-A-3FS-- Aspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act 1 parataxis 1:parataxis Vform=ذَكَرَت|Gloss=mention,cite,remember|Root=_d_k_r|Translit=ḏakarat|LTranslit=ḏakar
3 وكالة وِكَالَة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 2 nsubj 2:nsubj Vform=وِكَالَةُ|Gloss=agency|Root=w_k_l|Translit=wikālatu|LTranslit=wikālat
4 الانباء نَبَأ NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 3 nmod 3:nmod:gen Vform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ
5 السورية سُورِيّ ADJ A-----FS1D Case=Nom|Definite=Def|Gender=Fem|Number=Sing 3 amod 3:amod Vform=اَلسُّورِيَّةُ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyatu|LTranslit=sūrīy
6 ان أَنَّ SCONJ C--------- _ 8 mark 8:mark Vform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna
7 التعديل تَعدِيل NOUN N------S4D Case=Acc|Definite=Def|Number=Sing 8 obl 8:obl:acc Vform=اَلتَّعدِيلَ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=at-taʿdīla|LTranslit=taʿdīl
8 شمل شَمِل VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 2 ccomp 2:ccomp Vform=شَمِلَ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=šamila|LTranslit=šamil
9 ثماني ثَمَانُون NUM QL------4R Case=Acc|Definite=Cons|NumForm=Word 8 obj 8:obj Vform=ثَمَانِي|Gloss=eighty|Root=_t_m_n|Translit=ṯamānī|LTranslit=ṯamānūn
10 حقائب حَقِيبَة NOUN N------P2I Case=Gen|Definite=Ind|Number=Plur 9 nmod 9:nmod:gen Vform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat
11-12 بينها _ _ _ _ _ _ _ _
11 بين بَينَ ADP PI------4- AdpType=Prep|Case=Acc 12 case 12:case Vform=بَينَ|Gloss=between,among|Root=b_y_n|Translit=bayna|LTranslit=bayna
12 ها هُوَ PRON SP---3FS2- Case=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs 10 obl 10:obl:بَينَ:gen Vform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa
13 وزارتا وِزَارَة NOUN N------D1R Case=Nom|Definite=Cons|Number=Dual 12 nsubj 12:nsubj Vform=وِزَارَتَا|Gloss=ministry|Root=w_z_r|Translit=wizāratā|LTranslit=wizārat
14 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 13 amod 13:amod Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy
15-16 والاقتصاد _ _ _ _ _ _ _ SpaceAfter=No
15 و وَ CCONJ C--------- _ 16 cc 16:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa
16 الاقتصاد اِقتِصَاد NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 14 conj 13:amod|14:conj Vform=اَلِاقتِصَادِ|Gloss=economy,saving|Root=q_.s_d|Translit=al-i-ʼqtiṣādi|LTranslit=iqtiṣād
17 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=.
================================================
FILE: demo/corenlp.py
================================================
from stanza.server import CoreNLPClient
# example text
print('---')
print('input text')
print('')
text = "Chris Manning is a nice person. Chris wrote a simple sentence. He also gives oranges to people."
print(text)
# set up the client
print('---')
print('starting up Java Stanford CoreNLP Server...')
# set up the client
with CoreNLPClient(annotators=['tokenize','ssplit','pos','lemma','ner','parse','depparse','coref'], timeout=60000, memory='16G') as client:
# submit the request to the server
ann = client.annotate(text)
# get the first sentence
sentence = ann.sentence[0]
# get the dependency parse of the first sentence
print('---')
print('dependency parse of first sentence')
dependency_parse = sentence.basicDependencies
print(dependency_parse)
# get the constituency parse of the first sentence
print('---')
print('constituency parse of first sentence')
constituency_parse = sentence.parseTree
print(constituency_parse)
# get the first subtree of the constituency parse
print('---')
print('first subtree of constituency parse')
print(constituency_parse.child[0])
# get the value of the first subtree
print('---')
print('value of first subtree of constituency parse')
print(constituency_parse.child[0].value)
# get the first token of the first sentence
print('---')
print('first token of first sentence')
token = sentence.token[0]
print(token)
# get the part-of-speech tag
print('---')
print('part of speech tag of token')
token.pos
print(token.pos)
# get the named entity tag
print('---')
print('named entity tag of token')
print(token.ner)
# get an entity mention from the first sentence
print('---')
print('first entity mention in sentence')
print(sentence.mentions[0])
# access the coref chain
print('---')
print('coref chains for the example')
print(ann.corefChain)
# Use tokensregex patterns to find who wrote a sentence.
pattern = '([ner: PERSON]+) /wrote/ /an?/ []{0,3} /sentence|article/'
matches = client.tokensregex(text, pattern)
# sentences contains a list with matches for each sentence.
assert len(matches["sentences"]) == 3
# length tells you whether or not there are any matches in this
assert matches["sentences"][1]["length"] == 1
# You can access matches like most regex groups.
matches["sentences"][1]["0"]["text"] == "Chris wrote a simple sentence"
matches["sentences"][1]["0"]["1"]["text"] == "Chris"
# Use semgrex patterns to directly find who wrote what.
pattern = '{word:wrote} >nsubj {}=subject >obj {}=object'
matches = client.semgrex(text, pattern)
# sentences contains a list with matches for each sentence.
assert len(matches["sentences"]) == 3
# length tells you whether or not there are any matches in this
assert matches["sentences"][1]["length"] == 1
# You can access matches like most regex groups.
matches["sentences"][1]["0"]["text"] == "wrote"
matches["sentences"][1]["0"]["$subject"]["text"] == "Chris"
matches["sentences"][1]["0"]["$object"]["text"] == "sentence"
================================================
FILE: demo/en_test.conllu.txt
================================================
# newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001
# newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001
# text = What if Google Morphed Into GoogleOS?
1 What what PRON WP PronType=Int 0 root 0:root _
2 if if SCONJ IN _ 4 mark 4:mark _
3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
4 Morphed morph VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
5 Into into ADP IN _ 6 case 6:case _
6 GoogleOS GoogleOS PROPN NNP Number=Sing 4 obl 4:obl:into SpaceAfter=No
7 ? ? PUNCT . _ 4 punct 4:punct _
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0002
# text = What if Google expanded on its search-engine (and now e-mail) wares into a full-fledged operating system?
1 What what PRON WP PronType=Int 0 root 0:root _
2 if if SCONJ IN _ 4 mark 4:mark _
3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
4 expanded expand VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
5 on on ADP IN _ 15 case 15:case _
6 its its PRON PRP$ Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 15 nmod:poss 15:nmod:poss _
7 search search NOUN NN Number=Sing 9 compound 9:compound SpaceAfter=No
8 - - PUNCT HYPH _ 9 punct 9:punct SpaceAfter=No
9 engine engine NOUN NN Number=Sing 15 compound 15:compound _
10 ( ( PUNCT -LRB- _ 9 punct 9:punct SpaceAfter=No
11 and and CCONJ CC _ 13 cc 13:cc _
12 now now ADV RB _ 13 advmod 13:advmod _
13 e-mail e-mail NOUN NN Number=Sing 9 conj 9:conj:and|15:compound SpaceAfter=No
14 ) ) PUNCT -RRB- _ 15 punct 15:punct _
15 wares wares NOUN NNS Number=Plur 4 obl 4:obl:on _
16 into into ADP IN _ 22 case 22:case _
17 a a DET DT Definite=Ind|PronType=Art 22 det 22:det _
18 full full ADV RB _ 20 advmod 20:advmod SpaceAfter=No
19 - - PUNCT HYPH _ 20 punct 20:punct SpaceAfter=No
20 fledged fledged ADJ JJ Degree=Pos 22 amod 22:amod _
21 operating operating NOUN NN Number=Sing 22 compound 22:compound _
22 system system NOUN NN Number=Sing 4 obl 4:obl:into SpaceAfter=No
23 ? ? PUNCT . _ 4 punct 4:punct _
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0003
# text = [via Microsoft Watch from Mary Jo Foley ]
1 [ [ PUNCT -LRB- _ 4 punct 4:punct SpaceAfter=No
2 via via ADP IN _ 4 case 4:case _
3 Microsoft Microsoft PROPN NNP Number=Sing 4 compound 4:compound _
4 Watch Watch PROPN NNP Number=Sing 0 root 0:root _
5 from from ADP IN _ 6 case 6:case _
6 Mary Mary PROPN NNP Number=Sing 4 nmod 4:nmod:from _
7 Jo Jo PROPN NNP Number=Sing 6 flat 6:flat _
8 Foley Foley PROPN NNP Number=Sing 6 flat 6:flat _
9 ] ] PUNCT -RRB- _ 4 punct 4:punct _
# newdoc id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700
# sent_id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-0001
# newpar id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-p0001
# text = (And, by the way, is anybody else just a little nostalgic for the days when that was a good thing?)
1 ( ( PUNCT -LRB- _ 14 punct 14:punct SpaceAfter=No
2 And and CCONJ CC _ 14 cc 14:cc SpaceAfter=No
3 , , PUNCT , _ 14 punct 14:punct _
4 by by ADP IN _ 6 case 6:case _
5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
6 way way NOUN NN Number=Sing 14 obl 14:obl:by SpaceAfter=No
7 , , PUNCT , _ 14 punct 14:punct _
8 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 14 cop 14:cop _
9 anybody anybody PRON NN Number=Sing 14 nsubj 14:nsubj _
10 else else ADJ JJ Degree=Pos 9 amod 9:amod _
11 just just ADV RB _ 13 advmod 13:advmod _
12 a a DET DT Definite=Ind|PronType=Art 13 det 13:det _
13 little little ADJ JJ Degree=Pos 14 obl:npmod 14:obl:npmod _
14 nostalgic nostalgic NOUN NN Number=Sing 0 root 0:root _
15 for for ADP IN _ 17 case 17:case _
16 the the DET DT Definite=Def|PronType=Art 17 det 17:det _
17 days day NOUN NNS Number=Plur 14 nmod 14:nmod:for|23:obl:npmod _
18 when when ADV WRB PronType=Rel 23 advmod 17:ref _
19 that that PRON DT Number=Sing|PronType=Dem 23 nsubj 23:nsubj _
20 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 23 cop 23:cop _
21 a a DET DT Definite=Ind|PronType=Art 23 det 23:det _
22 good good ADJ JJ Degree=Pos 23 amod 23:amod _
23 thing thing NOUN NN Number=Sing 17 acl:relcl 17:acl:relcl SpaceAfter=No
24 ? ? PUNCT . _ 14 punct 14:punct SpaceAfter=No
25 ) ) PUNCT -RRB- _ 14 punct 14:punct _
================================================
FILE: demo/japanese_test.conllu.txt
================================================
# newdoc id = test-s1
# sent_id = test-s1
# text = これに不快感を示す住民はいましたが,現在,表立って反対や抗議の声を挙げている住民はいないようです。
1 これ 此れ PRON 代名詞 _ 6 obl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=代名詞|SpaceAfter=No|UnidicInfo=,此れ,これ,これ,コレ,,,コレ,コレ,此れ
2 に に ADP 助詞-格助詞 _ 1 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,に,に,に,ニ,,,ニ,ニ,に
3 不快 不快 NOUN 名詞-普通名詞-形状詞可能 _ 4 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不快,不快,不快,フカイ,,,フカイ,フカイカン,不快感
4 感 感 NOUN 名詞-普通名詞-一般 _ 6 obj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,感,感,感,カン,,,カン,フカイカン,不快感
5 を を ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を
6 示す 示す VERB 動詞-一般-五段-サ行 _ 7 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-サ行|SpaceAfter=No|UnidicInfo=,示す,示す,示す,シメス,,,シメス,シメス,示す
7 住民 住民 NOUN 名詞-普通名詞-一般 _ 9 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民
8 は は ADP 助詞-係助詞 _ 7 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
9 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 29 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る
10 まし ます AUX 助動詞-助動詞-マス _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,まし,ます,マシ,,,マス,マス,ます
11 た た AUX 助動詞-助動詞-タ _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-タ|SpaceAfter=No|UnidicInfo=,た,た,た,タ,,,タ,タ,た
12 が が SCONJ 助詞-接続助詞 _ 9 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,が,が,が,ガ,,,ガ,ガ,が
13 , , PUNCT 補助記号-読点 _ 9 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,,
14 現在 現在 ADV 名詞-普通名詞-副詞可能 _ 16 advmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,現在,現在,現在,ゲンザイ,,,ゲンザイ,ゲンザイ,現在
15 , , PUNCT 補助記号-読点 _ 14 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,,
16 表立っ 表立つ VERB 動詞-一般-五段-タ行 _ 24 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-タ行|SpaceAfter=No|UnidicInfo=,表立つ,表立っ,表立つ,オモテダッ,,,オモテダツ,オモテダツ,表立つ
17 て て SCONJ 助詞-接続助詞 _ 16 mark _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テ,て
18 反対 反対 NOUN 名詞-普通名詞-サ変形状詞可能 _ 20 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,反対,反対,反対,ハンタイ,,,ハンタイ,ハンタイ,反対
19 や や ADP 助詞-副助詞 _ 18 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-副助詞|SpaceAfter=No|UnidicInfo=,や,や,や,ヤ,,,ヤ,ヤ,や
20 抗議 抗議 NOUN 名詞-普通名詞-サ変可能 _ 22 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,抗議,抗議,抗議,コーギ,,,コウギ,コウギ,抗議
21 の の ADP 助詞-格助詞 _ 20 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,ノ,の
22 声 声 NOUN 名詞-普通名詞-一般 _ 24 obj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,声,声,声,コエ,,,コエ,コエ,声
23 を を ADP 助詞-格助詞 _ 22 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を
24 挙げ 上げる VERB 動詞-非自立可能-下一段-ガ行 _ 27 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-下一段-ガ行|SpaceAfter=No|UnidicInfo=,上げる,挙げ,挙げる,アゲ,,,アゲル,アゲル,上げる
25 て て SCONJ 助詞-接続助詞 _ 24 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている
26 いる 居る VERB 動詞-非自立可能-上一段-ア行 _ 25 fixed _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,いる,いる,イル,,,イル,テイル,ている
27 住民 住民 NOUN 名詞-普通名詞-一般 _ 29 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民
28 は は ADP 助詞-係助詞 _ 27 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
29 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 0 root _ BunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る
30 ない ない AUX 助動詞-助動詞-ナイ Polarity=Neg 29 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ナイ|SpaceAfter=No|UnidicInfo=,ない,ない,ない,ナイ,,,ナイ,ナイ,ない
31 よう 様 AUX 形状詞-助動詞語幹 _ 29 aux _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=形状詞-助動詞語幹|PrevUDLemma=よう|SpaceAfter=No|UnidicInfo=,様,よう,よう,ヨー,,,ヨウ,ヨウ,様
32 です です AUX 助動詞-助動詞-デス _ 29 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-デス|PrevUDLemma=だ|SpaceAfter=No|UnidicInfo=,です,です,です,デス,,,デス,デス,です
33 。 。 PUNCT 補助記号-句点 _ 29 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。
# newdoc id = test-s2
# sent_id = test-s2
# text = 幸福の科学側からは,特にどうしてほしいという要望はいただいていません。
1 幸福 幸福 NOUN 名詞-普通名詞-形状詞可能 _ 4 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,幸福,幸福,幸福,コーフク,,,コウフク,コウフクノカガクガワ,幸福の科学側
2 の の ADP 助詞-格助詞 _ 1 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,コウフクノカガクガワ,幸福の科学側
3 科学 科学 NOUN 名詞-普通名詞-サ変可能 _ 4 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,科学,科学,科学,カガク,,,カガク,コウフクノカガクガワ,幸福の科学側
4 側 側 NOUN 名詞-普通名詞-一般 _ 17 obl _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,側,側,側,ガワ,,,ガワ,コウフクノカガクガワ,幸福の科学側
5 から から ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,から,から,から,カラ,,,カラ,カラ,から
6 は は ADP 助詞-係助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
7 , , PUNCT 補助記号-読点 _ 4 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,,
8 特に 特に ADV 副詞 _ 17 advmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,特に,特に,特に,トクニ,,,トクニ,トクニ,特に
9 どう どう ADV 副詞 _ 15 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,どう,どう,どう,ドー,,,ドウ,ドウスル,どうする
10 し 為る AUX 動詞-非自立可能-サ行変格 _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,し,する,シ,,,スル,ドウスル,どうする
11 て て SCONJ 助詞-接続助詞 _ 9 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-形容詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テホシイ,てほしい
12 ほしい 欲しい AUX 形容詞-非自立可能-形容詞 _ 11 fixed _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-形容詞|PrevUDLemma=ほしい|SpaceAfter=No|UnidicInfo=,欲しい,ほしい,ほしい,ホシー,,,ホシイ,テホシイ,てほしい
13 と と ADP 助詞-格助詞 _ 9 case _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,トイウ,という
14 いう 言う VERB 動詞-一般-五段-ワア行 _ 13 fixed _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,言う,いう,いう,イウ,,,イウ,トイウ,という
15 要望 要望 NOUN 名詞-普通名詞-サ変可能 _ 17 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,要望,要望,要望,ヨーボー,,,ヨウボウ,ヨウボウ,要望
16 は は ADP 助詞-係助詞 _ 15 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
17 いただい 頂く VERB 動詞-非自立可能-五段-カ行 _ 0 root _ BunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-五段-カ行|PrevUDLemma=いただく|SpaceAfter=No|UnidicInfo=,頂く,いただい,いただく,イタダイ,,,イタダク,イタダク,頂く
18 て て SCONJ 助詞-接続助詞 _ 17 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている
19 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 18 fixed _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,テイル,ている
20 ませ ます AUX 助動詞-助動詞-マス _ 17 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,ませ,ます,マセ,,,マス,マス,ます
21 ん ず AUX 助動詞-助動詞-ヌ Polarity=Neg 17 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ヌ|PrevUDLemma=ぬ|SpaceAfter=No|UnidicInfo=,ず,ん,ぬ,ン,,,ヌ,ズ,ず
22 。 。 PUNCT 補助記号-句点 _ 17 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。
# newdoc id = test-s3
# sent_id = test-s3
# text = 星取り参加は当然とされ,不参加は白眼視される。
1 星取り 星取り NOUN 名詞-普通名詞-一般 _ 2 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,星取り,星取り,星取り,ホシトリ,,,ホシトリ,ホシトリサンカ,星取り参加
2 参加 参加 NOUN 名詞-普通名詞-サ変可能 _ 4 nsubj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,ホシトリサンカ,星取り参加
3 は は ADP 助詞-係助詞 _ 2 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
4 当然 当然 ADJ 形状詞-一般 _ 6 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=形状詞-一般|SpaceAfter=No|UnidicInfo=,当然,当然,当然,トーゼン,,,トウゼン,トウゼン,当然
5 と と ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,ト,と
6 さ 為る VERB 動詞-非自立可能-サ行変格 _ 13 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,スル,する
7 れ れる AUX 助動詞-助動詞-レル _ 6 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れ,れる,レ,,,レル,レル,れる
8 , , PUNCT 補助記号-読点 _ 6 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,,
9 不 不 NOUN 接頭辞 Polarity=Neg 10 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不,不,不,フ,,,フ,フサンカ,不参加
10 参加 参加 NOUN 名詞-普通名詞-サ変可能 _ 13 nsubj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,フサンカ,不参加
11 は は ADP 助詞-係助詞 _ 10 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は
12 白眼 白眼 NOUN 名詞-普通名詞-一般 _ 13 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,白眼,白眼,白眼,ハクガン,,,ハクガン,ハクガンシスル,白眼視する
13 視 視 NOUN 接尾辞-名詞的-サ変可能 _ 0 root _ BunsetuBILabel=I|BunsetuPositionType=ROOT|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,視,視,視,シ,,,シ,ハクガンシスル,白眼視する
14 さ 為る AUX 動詞-非自立可能-サ行変格 _ 13 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,ハクガンシスル,白眼視する
15 れる れる AUX 助動詞-助動詞-レル _ 13 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れる,れる,レル,,,レル,レル,れる
16 。 。 PUNCT 補助記号-句点 _ 13 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。
================================================
FILE: demo/pipeline_demo.py
================================================
"""
A basic demo of the Stanza neural pipeline.
"""
import sys
import argparse
import os
import stanza
from stanza.resources.common import DEFAULT_MODEL_DIR
if __name__ == '__main__':
# get arguments
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--models_dir', help='location of models files | default: ~/stanza_resources',
default=DEFAULT_MODEL_DIR)
parser.add_argument('-l', '--lang', help='Demo language',
default="en")
parser.add_argument('-c', '--cpu', action='store_true', help='Use cpu as the device.')
args = parser.parse_args()
example_sentences = {"en": "Barack Obama was born in Hawaii. He was elected president in 2008.",
"zh": "中国文化经历上千年的历史演变,是各区域、各民族古代文化长期相互交流、借鉴、融合的结果。",
"fr": "Van Gogh grandit au sein d'une famille de l'ancienne bourgeoisie. Il tente d'abord de faire carrière comme marchand d'art chez Goupil & C.",
"vi": "Trận Trân Châu Cảng (hay Chiến dịch Hawaii theo cách gọi của Bộ Tổng tư lệnh Đế quốc Nhật Bản) là một đòn tấn công quân sự bất ngờ được Hải quân Nhật Bản thực hiện nhằm vào căn cứ hải quân của Hoa Kỳ tại Trân Châu Cảng thuộc tiểu bang Hawaii vào sáng Chủ Nhật, ngày 7 tháng 12 năm 1941, dẫn đến việc Hoa Kỳ sau đó quyết định tham gia vào hoạt động quân sự trong Chiến tranh thế giới thứ hai."}
if args.lang not in example_sentences:
print(f'Sorry, but we don\'t have a demo sentence for "{args.lang}" for the moment. Try one of these languages: {list(example_sentences.keys())}')
sys.exit(1)
# download the models
stanza.download(args.lang, dir=args.models_dir)
# set up a pipeline
print('---')
print('Building pipeline...')
pipeline = stanza.Pipeline(lang=args.lang, dir=args.models_dir, use_gpu=(not args.cpu))
# process the document
doc = pipeline(example_sentences[args.lang])
# access nlp annotations
print('')
print('Input: {}'.format(example_sentences[args.lang]))
print("The tokenizer split the input into {} sentences.".format(len(doc.sentences)))
print('---')
print('tokens of first sentence: ')
doc.sentences[0].print_tokens()
print('')
print('---')
print('dependency parse of first sentence: ')
doc.sentences[0].print_dependencies()
print('')
================================================
FILE: demo/scenegraph.py
================================================
"""
Very short demo for the SceneGraph interface in the CoreNLP server
Requires CoreNLP >= 4.5.5, Stanza >= 1.5.1
"""
import json
from stanza.server import CoreNLPClient
# start_server=None if you have the server running in another process on the same host
# you can start it with whatever normal options CoreNLPClient has
#
# preload=False avoids having the server unnecessarily load annotators
# if you don't plan on using them
with CoreNLPClient(preload=False) as client:
result = client.scenegraph("Jennifer's antennae are on her head.")
print(json.dumps(result, indent=2))
================================================
FILE: demo/semgrex visualization.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "2787d5f5",
"metadata": {},
"outputs": [],
"source": [
"import stanza\n",
"from stanza.server.semgrex import Semgrex\n",
"from stanza.models.common.constant import is_right_to_left\n",
"import spacy\n",
"from spacy import displacy\n",
"from spacy.tokens import Doc\n",
"from IPython.display import display, HTML\n",
"\n",
"\n",
"\"\"\"\n",
"IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\n",
"set an environment variable CLASSPATH equal to the path of your corenlp directory.\n",
"\n",
"Example: CLASSPATH=C:\\\\Users\\\\Alex\\\\PycharmProjects\\\\pythonProject\\\\stanford-corenlp-4.5.0\\\\stanford-corenlp-4.5.0\\\\*\n",
"\"\"\"\n",
"\n",
"%env CLASSPATH=C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\n",
"def get_sentences_html(doc, language):\n",
" \"\"\"\n",
" Returns a list of the HTML strings of the dependency visualizations of a given stanza doc object.\n",
"\n",
" The 'language' arg is the two-letter language code for the document to be processed.\n",
"\n",
" First converts the stanza doc object to a spacy doc object and uses displacy to generate an HTML\n",
" string for each sentence of the doc object.\n",
" \"\"\"\n",
" html_strings = []\n",
"\n",
" # blank model - we don't use any of the model features, just the visualization\n",
" nlp = spacy.blank(\"en\")\n",
" sentences_to_visualize = []\n",
" for sentence in doc.sentences:\n",
" words, lemmas, heads, deps, tags = [], [], [], [], []\n",
" if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact\n",
" sent_len = len(sentence.words)\n",
" for word in reversed(sentence.words):\n",
" words.append(word.text)\n",
" lemmas.append(word.lemma)\n",
" deps.append(word.deprel)\n",
" tags.append(word.upos)\n",
" if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza\n",
" heads.append(sent_len - word.id)\n",
" else:\n",
" heads.append(sent_len - word.head)\n",
" else: # left to right rendering\n",
" for word in sentence.words:\n",
" words.append(word.text)\n",
" lemmas.append(word.lemma)\n",
" deps.append(word.deprel)\n",
" tags.append(word.upos)\n",
" if word.head == 0:\n",
" heads.append(word.id - 1)\n",
" else:\n",
" heads.append(word.head - 1)\n",
" document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n",
" sentences_to_visualize.append(document_result)\n",
"\n",
" for line in sentences_to_visualize: # render all sentences through displaCy\n",
" html_strings.append(displacy.render(line, style=\"dep\",\n",
" options={\"compact\": True, \"word_spacing\": 30, \"distance\": 100,\n",
" \"arrow_spacing\": 20}, jupyter=False))\n",
" return html_strings\n",
"\n",
"\n",
"def find_nth(haystack, needle, n):\n",
" \"\"\"\n",
" Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\n",
" \"\"\"\n",
" start = haystack.find(needle)\n",
" while start >= 0 and n > 1:\n",
" start = haystack.find(needle, start + len(needle))\n",
" n -= 1\n",
" return start\n",
"\n",
"\n",
"def round_base(num, base=10):\n",
" \"\"\"\n",
" Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\n",
" \"\"\"\n",
" return base * round(num/base)\n",
"\n",
"\n",
"def process_sentence_html(orig_html, semgrex_sentence):\n",
" \"\"\"\n",
" Takes a semgrex sentence object and modifies the HTML of the original sentence's deprel visualization,\n",
" highlighting words involved in the search queries and adding the label of the word inside of the semgrex match.\n",
"\n",
" Returns the modified html string of the sentence's deprel visualization.\n",
" \"\"\"\n",
" tracker = {} # keep track of which words have multiple labels\n",
" DEFAULT_TSPAN_COUNT = 2 # the original displacy html assigns two objects per object\n",
" CLOSING_TSPAN_LEN = 8 # is 8 chars long\n",
" colors = ['#4477AA', '#66CCEE', '#228833', '#CCBB44', '#EE6677', '#AA3377', '#BBBBBB']\n",
" css_bolded_class = \"\\n\"\n",
" found_index = orig_html.find(\"\\n\") # returns index where the opening ends\n",
" # insert the new style class into html string\n",
" orig_html = orig_html[: found_index + 1] + css_bolded_class + orig_html[found_index + 1:]\n",
"\n",
" # Add color to words in the match, bold words in the match\n",
" for query in semgrex_sentence.result:\n",
" for i, match in enumerate(query.match):\n",
" color = colors[i]\n",
" paired_dy = 2\n",
" for node in match.node:\n",
" name, match_index = node.name, node.matchIndex\n",
" # edit existing to change color and bold the text\n",
" start = find_nth(orig_html, \" of interest\n",
" if match_index not in tracker: # if we've already bolded and colored, keep the first color\n",
" tspan_start = orig_html.find(\" inside of the \n",
" tspan_end = orig_html.find(\" \", start) # finds start of the end of the above \n",
" tspan_substr = orig_html[tspan_start: tspan_end + CLOSING_TSPAN_LEN + 1] + \"\\n\"\n",
" # color words in the hit and bold words in the hit\n",
" edited_tspan = tspan_substr.replace('class=\"displacy-word\"', 'class=\"bolded\"').replace(\n",
" 'fill=\"currentColor\"', f'fill=\"{color}\"')\n",
" # insert edited object into html string\n",
" orig_html = orig_html[: tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2:]\n",
" tracker[match_index] = DEFAULT_TSPAN_COUNT\n",
"\n",
" # next, we have to insert the new object for the label\n",
" # Copy old to copy formatting when creating new later\n",
" prev_tspan_start = find_nth(orig_html[start:], \" start index\n",
" prev_tspan_end = find_nth(orig_html[start:], \" \",\n",
" tracker[match_index] - 1) + start # find the prev start index\n",
" prev_tspan = orig_html[prev_tspan_start: prev_tspan_end + CLOSING_TSPAN_LEN + 1]\n",
"\n",
" # Find spot to insert new tspan\n",
" closing_tspan_start = find_nth(orig_html[start:], \" \", tracker[match_index]) + start\n",
" up_to_new_tspan = orig_html[: closing_tspan_start + CLOSING_TSPAN_LEN + 1]\n",
" rest_need_add_newline = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1:]\n",
"\n",
" # Calculate proper x value in svg\n",
" x_value_start = prev_tspan.find('x=\"')\n",
" x_value_end = prev_tspan[x_value_start + 3:].find('\"') + 3 # 3 is the length of the 'x=\"' substring\n",
" x_value = prev_tspan[x_value_start + 3: x_value_end + x_value_start]\n",
"\n",
" # Calculate proper y value in svg\n",
" DEFAULT_DY_VAL, dy = 2, 2\n",
" if paired_dy != DEFAULT_DY_VAL and node == match.node[\n",
" 1]: # we're on the second node and need to adjust height to match the paired node\n",
" dy = paired_dy\n",
" if node == match.node[0]:\n",
" paired_node_level = 2\n",
" if match.node[1].matchIndex in tracker: # check if we need to adjust heights of labels\n",
" paired_node_level = tracker[match.node[1].matchIndex]\n",
" dif = tracker[match_index] - paired_node_level\n",
" if dif > 0: # current node has more labels\n",
" paired_dy = DEFAULT_DY_VAL * dif + 1\n",
" dy = DEFAULT_DY_VAL\n",
" else: # paired node has more labels, adjust this label down\n",
" dy = DEFAULT_DY_VAL * (abs(dif) + 1)\n",
" paired_dy = DEFAULT_DY_VAL\n",
"\n",
" # Insert new object\n",
" new_tspan = f' {name[: 3].title()}. \\n' # abbreviate label names to 3 chars\n",
" orig_html = up_to_new_tspan + new_tspan + rest_need_add_newline\n",
" tracker[match_index] += 1\n",
" return orig_html\n",
"\n",
"\n",
"def render_html_strings(edited_html_strings):\n",
" \"\"\"\n",
" Renders the HTML to make the edits visible\n",
" \"\"\"\n",
" for html_string in edited_html_strings:\n",
" display(HTML(html_string))\n",
"\n",
"\n",
"def visualize_search_doc(doc, semgrex_queries, lang_code, start_match=0, end_match=10):\n",
" \"\"\"\n",
" Visualizes the semgrex results of running semgrex search on a stanza doc object with the given list of\n",
" semgrex queries. Returns a list of the edited HTML strings from the doc. Each element in the list represents\n",
" the HTML to render one of the sentences in the document.\n",
"\n",
" 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
"\n",
"\n",
" 'start_match' and 'end_match' determine which matches to visualize. Works similar to splices, so that\n",
" start_match=0 and end_match=10 will display the first 10 semgrex matches.\n",
" \"\"\"\n",
" matches_count = 0 # Limits number of visualizations\n",
" with Semgrex(classpath=\"$CLASSPATH\") as sem:\n",
" edited_html_strings = []\n",
" semgrex_results = sem.process(doc, *semgrex_queries)\n",
" # one html string for each sentence\n",
" unedited_html_strings = get_sentences_html(doc, lang_code)\n",
" for i in range(len(unedited_html_strings)):\n",
"\n",
" if matches_count >= end_match: # we've collected enough matches, stop early\n",
" break\n",
"\n",
" # check if sentence has matches, if not then do not visualize\n",
" has_none = True\n",
" for query in semgrex_results.result[i].result:\n",
" for match in query.match:\n",
" if match:\n",
" has_none = False\n",
"\n",
" # Process HTML if queries have matches\n",
" if not has_none:\n",
" if start_match <= matches_count < end_match:\n",
" edited_string = process_sentence_html(unedited_html_strings[i], semgrex_results.result[i])\n",
" edited_string = adjust_dep_arrows(edited_string)\n",
" edited_html_strings.append(edited_string)\n",
" matches_count += 1\n",
"\n",
" render_html_strings(edited_html_strings)\n",
" return edited_html_strings\n",
"\n",
"\n",
"def visualize_search_str(text, semgrex_queries, lang_code):\n",
" \"\"\"\n",
" Visualizes the deprel of the semgrex results from running semgrex search on a string with the given list of\n",
" semgrex queries. Returns a list of the edited HTML strings. Each element in the list represents\n",
" the HTML to render one of the sentences in the document.\n",
"\n",
" Internally, this function converts the string into a stanza doc object before processing the doc object.\n",
"\n",
" 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
" \"\"\"\n",
" nlp = stanza.Pipeline(lang_code, processors=\"tokenize, pos, lemma, depparse\")\n",
" doc = nlp(text)\n",
" return visualize_search_doc(doc, semgrex_queries, lang_code)\n",
"\n",
"\n",
"def adjust_dep_arrows(raw_html):\n",
" \"\"\"\n",
" The default spaCy dependency visualization has misaligned arrows.\n",
" We fix arrows by aligning arrow ends and bodies to the word that they are directed to. If a word has an\n",
" arrowhead that is pointing not directly on the word's center, align the arrowhead to match the center of the word.\n",
"\n",
" returns the edited html with fixed arrow placement\n",
" \"\"\"\n",
" HTML_ARROW_BEGINNING = ''\n",
" HTML_ARROW_ENDING = \" \"\n",
" HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending\n",
" arrows_start_idx = find_nth(haystack=raw_html, needle='', n=1)\n",
" words_html, arrows_html = raw_html[: arrows_start_idx], raw_html[arrows_start_idx:] # separate html for words and arrows\n",
" final_html = words_html # continually concatenate to this after processing each arrow\n",
" arrow_number = 1 # which arrow we're editing (1-indexed)\n",
" start_idx, end_of_class_idx = find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(arrows_html, HTML_ARROW_ENDING, arrow_number)\n",
" while start_idx != -1: # edit every arrow\n",
" arrow_section = arrows_html[start_idx: end_of_class_idx + HTML_ARROW_ENDING_LEN] # slice a single svg arrow object\n",
" if arrow_section[-1] == \"<\": # this is the last arrow in the HTML, don't cut the splice early\n",
" arrow_section = arrows_html[start_idx:]\n",
" edited_arrow_section = edit_dep_arrow(arrow_section)\n",
"\n",
" final_html = final_html + edited_arrow_section # continually update html with new arrow html until done\n",
"\n",
" # Prepare for next iteration\n",
" arrow_number += 1\n",
" start_idx = find_nth(arrows_html, '', n=arrow_number)\n",
" end_of_class_idx = find_nth(arrows_html, \" \", arrow_number)\n",
" return final_html\n",
"\n",
"\n",
"def edit_dep_arrow(arrow_html):\n",
" \"\"\"\n",
" The formatting of a displacy arrow in svg is the following:\n",
" \n",
" \n",
" \n",
" csubj \n",
" \n",
" \n",
" \n",
"\n",
" We edit the 'd = ...' parts of the section to fix the arrow direction and length\n",
"\n",
" returns the arrow_html with distances fixed\n",
" \"\"\"\n",
" WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50\n",
" M_OFFSET = 4 # length of 'd=\"M' that we search for to extract the number from d=\"M70, for instance\n",
" ARROW_PIXEL_SIZE = 4\n",
" first_d_idx, second_d_idx = find_nth(arrow_html, 'd=\"M', 1), find_nth(arrow_html, 'd=\"M', 2) # find where d=\"M starts\n",
" first_d_cutoff, second_d_cutoff = arrow_html.find(\",\", first_d_idx), arrow_html.find(\",\", second_d_idx) # isolate the number after 'M' e.g. 'M70'\n",
" # gives svg x values of arrow body starting position and arrowhead position\n",
" arrow_position, arrowhead_position = float(arrow_html[first_d_idx + M_OFFSET: first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET: second_d_cutoff])\n",
" # gives starting index of where 'fill=\"none\"' or 'fill=\"currentColor\"' begin, reference points to end the d= section\n",
" first_fill_start_idx, second_fill_start_idx = find_nth(arrow_html, \"fill\", n=1), find_nth(arrow_html, \"fill\", n=3)\n",
"\n",
" # isolate the d= ... section to edit\n",
" first_d, second_d = arrow_html[first_d_idx: first_fill_start_idx], arrow_html[second_d_idx: second_fill_start_idx]\n",
" first_d_split, second_d_split = first_d.split(\",\"), second_d.split(\",\")\n",
"\n",
" if arrow_position == arrowhead_position: # This arrow is incoming onto the word, center the arrow/head to word center\n",
" corrected_arrow_pos = corrected_arrowhead_pos = round_base(arrow_position, base=WORD_SPACING)\n",
"\n",
" # edit first_d -- arrow body\n",
" second_term = first_d_split[1].split(\" \")[0] + \" \" + str(corrected_arrow_pos)\n",
" first_d = 'd=\"M' + str(corrected_arrow_pos) + \",\" + second_term + \",\" + \",\".join(first_d_split[2:])\n",
"\n",
" # edit second_d -- arrowhead\n",
" second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
" third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
" second_d = 'd=\"M' + str(corrected_arrowhead_pos) + \",\" + second_term + \",\" + third_term + \",\" + \",\".join(second_d_split[3:])\n",
" else: # This arrow is outgoing to another word, center the arrow/head to that word's center\n",
" corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\n",
"\n",
" # edit first_d -- arrow body\n",
" third_term = first_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
" fourth_term = first_d_split[3].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
" terms = [first_d_split[0], first_d_split[1], third_term, fourth_term] + first_d_split[4:]\n",
" first_d = \",\".join(terms)\n",
"\n",
" # edit second_d -- arrow head\n",
" first_term = f'd=\"M{corrected_arrowhead_pos}'\n",
" second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
" third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
" terms = [first_term, second_term, third_term] + second_d_split[3:]\n",
" second_d = \",\".join(terms)\n",
" # rebuild and return html\n",
" return arrow_html[:first_d_idx] + first_d + \" \" + arrow_html[first_fill_start_idx:second_d_idx] + second_d + \" \" + arrow_html[second_fill_start_idx:]\n",
"\n",
"\n",
"def main():\n",
" nlp = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n",
"\n",
" # doc = nlp(\"This a dummy sentence. Banning opal removed all artifact decks from the meta. I miss playing lantern. This is a dummy sentence.\")\n",
" doc = nlp(\"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\")\n",
" # A single result .result[i].result[j] is a list of matches for sentence i on semgrex query j.\n",
" queries = [\"{pos:NN}=object en_pronouns.updated.conllu
# This script updates the UD 2.11 version of UD_English-Pronouns to
# better match punctuation attachments, MWT, and no double subjects.
# This turns unwanted csubj into advcl
{}=source >nsubj {} >csubj=bad {}
relabelNamedEdge -edge bad -reln advcl
# This detects punctuations which are not attached to the root and reattaches them
{word:/[.]/}=punct =3.15.0',
'requests',
'networkx',
'tomli;python_version<"3.11"',
'torch>=1.13.0',
'tqdm',
'udtools>=0.2.4',
],
# List required Python versions
python_requires='>=3.9',
# List additional groups of dependencies here (e.g. development
# dependencies). You can install these using the following syntax,
# for example:
# $ pip install -e .[dev,test]
extras_require={
'dev': [
'check-manifest',
],
'test': [
'coverage',
'pytest',
],
'transformers': [
'transformers>=3.0.0',
'peft>=0.6.1',
],
'datasets': [
'datasets',
],
'tokenizers': [
'jieba',
'pythainlp',
'python-crfsuite',
'spacy',
'sudachidict_core',
'sudachipy',
],
'visualization': [
'spacy',
'streamlit',
'ipython',
],
'morphseg': [
'morphseg>=0.2.0',
]
},
# If there are data files included in your packages that need to be
# installed, specify them here. If using Python 2.6 or less, then these
# have to be included in MANIFEST.in as well.
package_data={
"": ["pipeline/demo/*ttf",
"pipeline/demo/*css",
"pipeline/demo/*html",
"pipeline/demo/*js",
"pipeline/demo/*gif",],
},
include_package_data=True,
# Although 'package_data' is the preferred approach, in some case you may
# need to place data files outside of your packages. See:
# http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa
# In this case, 'data_file' will be installed into '/my_data'
data_files=[],
# To provide executable scripts, use entry points in preference to the
# "scripts" keyword. Entry points provide cross-platform support and allow
# pip to create the appropriate form of executable for the target platform.
entry_points={
},
)
================================================
FILE: stanza/__init__.py
================================================
from stanza.pipeline.core import DownloadMethod, Pipeline
from stanza.pipeline.multilingual import MultilingualPipeline
from stanza.models.common.doc import Document
from stanza.resources.common import download
from stanza.resources.installation import install_corenlp, download_corenlp_models
from stanza._version import __version__, __resources_version__
from stanza.pipeline.morphseg_processor import MorphSegProcessor
import logging
logger = logging.getLogger('stanza')
# if the client application hasn't set the log level, we set it
# ourselves to INFO
if logger.level == 0:
logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s",
datefmt='%Y-%m-%d %H:%M:%S')
log_handler.setFormatter(log_formatter)
# also, if the client hasn't added any handlers for this logger
# (or a default handler), we add a handler of our own
#
# client can later do
# logger.removeHandler(stanza.log_handler)
if not logger.hasHandlers():
logger.addHandler(log_handler)
================================================
FILE: stanza/_version.py
================================================
""" Single source of truth for version number """
__version__ = "1.11.1"
__resources_version__ = '1.11.0'
================================================
FILE: stanza/models/__init__.py
================================================
================================================
FILE: stanza/models/_training_logging.py
================================================
import logging
logger = logging.getLogger('stanza')
logger.setLevel(logging.DEBUG)
================================================
FILE: stanza/models/charlm.py
================================================
"""
Entry point for training and evaluating a character-level neural language model.
"""
import argparse
from copy import copy
import logging
import lzma
import math
import os
import random
import time
from types import GeneratorType
import numpy as np
import torch
from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer
from stanza.models.common.vocab import CharVocab
from stanza.models.common import utils
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def repackage_hidden(h):
"""Wraps hidden states in new Tensors,
to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def batchify(data, bsz, device):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1) # batch_first is True
data = data.to(device)
return data
def get_batch(source, i, seq_len):
seq_len = min(seq_len, source.size(1) - 1 - i)
data = source[:, i:i+seq_len]
target = source[:, i+1:i+1+seq_len].reshape(-1)
return data, target
def load_file(filename, vocab, direction):
with utils.open_read_text(filename) as fin:
data = fin.read()
idx = vocab['char'].map(data)
if direction == 'backward': idx = idx[::-1]
return torch.tensor(idx)
def load_data(path, vocab, direction):
if os.path.isdir(path):
filenames = sorted(os.listdir(path))
for filename in filenames:
logger.info('Loading data from {}'.format(filename))
data = load_file(os.path.join(path, filename), vocab, direction)
yield data
else:
data = load_file(path, vocab, direction)
yield data
def build_argparse():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--train_file', type=str, help="Input plaintext file")
parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files")
parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set")
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model")
parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model")
parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model")
parser.add_argument('--char_emb_dim', type=int, default=100, help="Dimension of unit embeddings")
parser.add_argument('--char_hidden_dim', type=int, default=1024, help="Dimension of hidden units")
parser.add_argument('--char_num_layers', type=int, default=1, help="Layers of RNN in the language model")
parser.add_argument('--char_dropout', type=float, default=0.05, help="Dropout probability")
parser.add_argument('--char_unit_dropout', type=float, default=1e-5, help="Randomly set an input char to UNK during training")
parser.add_argument('--char_rec_dropout', type=float, default=0.0, help="Recurrent dropout probability")
parser.add_argument('--batch_size', type=int, default=100, help="Batch size to use")
parser.add_argument('--bptt_size', type=int, default=250, help="Sequence length to consider at a time")
parser.add_argument('--epochs', type=int, default=50, help="Total epochs to train the model for")
parser.add_argument('--max_grad_norm', type=float, default=0.25, help="Maximum gradient norm to clip to")
parser.add_argument('--lr0', type=float, default=5, help="Initial learning rate")
parser.add_argument('--anneal', type=float, default=0.25, help="Anneal the learning rate by this amount when dev performance deteriorate")
parser.add_argument('--patience', type=int, default=1, help="Patience for annealing the learning rate")
parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
parser.add_argument('--momentum', type=float, default=0.0, help='Momentum for SGD.')
parser.add_argument('--cutoff', type=int, default=1000, help="Frequency cutoff for char vocab. By default we assume a very large corpus.")
parser.add_argument('--report_steps', type=int, default=50, help="Update step interval to report loss")
parser.add_argument('--eval_steps', type=int, default=100000, help="Update step interval to run eval on dev; set to -1 to eval after each epoch")
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
parser.add_argument('--vocab_save_name', type=str, default=None, help="File name to save the vocab")
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help="Directory to save models in")
parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.')
utils.add_device_args(parser)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def build_model_filename(args):
if args['save_name']:
save_name = args['save_name']
else:
save_name = '{}_{}_charlm.pt'.format(args['shorthand'], args['direction'])
model_file = os.path.join(args['save_dir'], save_name)
return model_file
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
if args.wandb_name:
args.wandb = True
args = vars(args)
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running {} character-level language model in {} mode".format(args['direction'], args['mode']))
utils.ensure_dir(args['save_dir'])
if args['mode'] == 'train':
train(args)
else:
evaluate(args)
def evaluate_epoch(args, vocab, data, model, criterion):
"""
Run an evaluation over entire dataset.
"""
model.eval()
device = next(model.parameters()).device
hidden = None
total_loss = 0
if isinstance(data, GeneratorType):
data = list(data)
assert len(data) == 1, 'Only support single dev/test file'
data = data[0]
batches = batchify(data, args['batch_size'], device)
with torch.no_grad():
for i in range(0, batches.size(1) - 1, args['bptt_size']):
data, target = get_batch(batches, i, args['bptt_size'])
lens = [data.size(1) for i in range(data.size(0))]
output, hidden, decoded = model.forward(data, lens, hidden)
loss = criterion(decoded.view(-1, len(vocab['char'])), target)
hidden = repackage_hidden(hidden)
total_loss += data.size(1) * loss.data.item()
return total_loss / batches.size(1)
def evaluate_and_save(args, vocab, data, trainer, best_loss, model_file, checkpoint_file, writer=None):
"""
Run an evaluation over entire dataset, print progress and save the model if necessary.
"""
start_time = time.time()
loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion)
ppl = math.exp(loss)
elapsed = int(time.time() - start_time)
# TODO: step the scheduler less often when the eval frequency is higher
previous_lr = get_current_lr(trainer, args)
trainer.scheduler.step(loss)
current_lr = get_current_lr(trainer, args)
if previous_lr != current_lr:
logger.info("Updating learning rate to %f", current_lr)
logger.info(
"| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}".format(
trainer.global_step,
elapsed,
loss,
ppl,
)
)
if best_loss is None or loss < best_loss:
best_loss = loss
trainer.save(model_file, full=False)
logger.info('new best model saved at step {:10d}'.format(trainer.global_step))
if writer:
writer.add_scalar('dev_loss', loss, global_step=trainer.global_step)
writer.add_scalar('dev_ppl', ppl, global_step=trainer.global_step)
if checkpoint_file:
trainer.save(checkpoint_file, full=True)
logger.info('new checkpoint saved at step {:10d}'.format(trainer.global_step))
return loss, ppl, best_loss
def get_current_lr(trainer, args):
return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]
def load_char_vocab(vocab_file):
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}
def train(args):
utils.log_training_args(args, logger)
model_file = build_model_filename(args)
vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \
else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand'])
if args['checkpoint']:
checkpoint_file = utils.checkpoint_name(args['save_dir'], model_file, args['checkpoint_save_name'])
else:
checkpoint_file = None
if os.path.exists(vocab_file):
logger.info('Loading existing vocab file')
vocab = load_char_vocab(vocab_file)
else:
logger.info('Building and saving vocab')
vocab = {'char': build_charlm_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff'])}
torch.save(vocab['char'].state_dict(), vocab_file)
logger.info("Training model with vocab size: {}".format(len(vocab['char'])))
if checkpoint_file and os.path.exists(checkpoint_file):
logger.info('Loading existing checkpoint: %s' % checkpoint_file)
trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True)
else:
trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab)
writer = None
if args['summary']:
from torch.utils.tensorboard import SummaryWriter
summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \
else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction'])
writer = SummaryWriter(log_dir=summary_dir)
# evaluate model within epoch if eval_interval is set
eval_within_epoch = False
if args['eval_steps'] > 0:
eval_within_epoch = True
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else '%s_%s_charlm' % (args['shorthand'], args['direction'])
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('best_loss', summary='min')
wandb.run.define_metric('ppl', summary='min')
device = next(trainer.model.parameters()).device
best_loss = None
start_epoch = trainer.epoch # will default to 1 for a new trainer
for trainer.epoch in range(start_epoch, args['epochs']+1):
# load train data from train_dir if not empty, otherwise load from file
if args['train_dir'] is not None:
train_path = args['train_dir']
else:
train_path = args['train_file']
train_data = load_data(train_path, vocab, args['direction'])
dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file
# run over entire training set
for data_chunk in train_data:
batches = batchify(data_chunk, args['batch_size'], device)
hidden = None
total_loss = 0.0
total_batches = math.ceil((batches.size(1) - 1) / args['bptt_size'])
iteration, i = 0, 0
# over the data chunk
while i < batches.size(1) - 1 - 1:
trainer.model.train()
trainer.global_step += 1
start_time = time.time()
bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2.
# prevent excessively small or negative sequence lengths
seq_len = max(5, int(np.random.normal(bptt, 5)))
# prevent very large sequence length, must be <= 1.2 x bptt
seq_len = min(seq_len, int(args['bptt_size'] * 1.2))
data, target = get_batch(batches, i, seq_len)
lens = [data.size(1) for i in range(data.size(0))]
trainer.optimizer.zero_grad()
output, hidden, decoded = trainer.model.forward(data, lens, hidden)
loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target)
total_loss += loss.data.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm'])
trainer.optimizer.step()
hidden = repackage_hidden(hidden)
if (iteration + 1) % args['report_steps'] == 0:
cur_loss = total_loss / args['report_steps']
elapsed = time.time() - start_time
logger.info(
"| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}".format(
trainer.epoch,
iteration + 1,
total_batches,
elapsed / args['report_steps'],
cur_loss,
math.exp(cur_loss),
)
)
if args['wandb']:
wandb.log({'train_loss': cur_loss}, step=trainer.global_step)
total_loss = 0.0
iteration += 1
i += seq_len
# evaluate if necessary
if eval_within_epoch and trainer.global_step % args['eval_steps'] == 0:
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
if args['wandb']:
wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
# if eval_interval isn't provided, run evaluation after each epoch
if not eval_within_epoch or trainer.epoch == args['epochs']:
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
if args['wandb']:
wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
if writer:
writer.close()
if args['wandb']:
wandb.finish()
return
def evaluate(args):
model_file = build_model_filename(args)
model = CharacterLanguageModel.load(model_file).to(args['device'])
vocab = model.vocab
data = load_data(args['eval_file'], vocab, args['direction'])
criterion = torch.nn.CrossEntropyLoss()
loss = evaluate_epoch(args, vocab, data, model, criterion)
logger.info(
"| best model | loss {:5.2f} | ppl {:8.2f}".format(
loss,
math.exp(loss),
)
)
return
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/classifier.py
================================================
import argparse
import ast
import logging
import os
import random
import re
from enum import Enum
import torch
import torch.nn as nn
from stanza.models.common import loss
from stanza.models.common import utils
from stanza.models.pos.vocab import CharVocab
import stanza.models.classifiers.data as data
from stanza.models.classifiers.trainer import Trainer
from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
from stanza.utils.confusion import format_confusion, confusion_to_accuracy, confusion_to_macro_f1
class Loss(Enum):
CROSS = 1
WEIGHTED_CROSS = 2
LOG_CROSS = 3
FOCAL = 4
class DevScoring(Enum):
ACCURACY = 'ACC'
WEIGHTED_F1 = 'WF'
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')
logging.getLogger('elmoformanylangs').setLevel(logging.WARNING)
DEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt'
DEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt'
DEFAULT_TEST='data/sentiment/en_sst3roots.test.txt'
"""A script for training and testing classifier models, especially on the SST.
If you run the script with no arguments, it will start trying to train
a sentiment model.
python3 -m stanza.models.classifier
This requires the sentiment dataset to be in an `extern_data`
directory, such as by symlinking it from somewhere else.
The default model is a CNN where the word vectors are first mapped to
channels with filters of a few different widths, those channels are
maxpooled over the entire sentence, and then the resulting pools have
fully connected layers until they reach the number of classes in the
training data. You can see the defaults in the options below.
https://arxiv.org/abs/1408.5882
(Currently the CNN is the only sentence classifier implemented.)
To train with a more complicated CNN arch:
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 > FC41.out 2>&1 &
You can train models with word vectors other than the default word2vec. For example:
nohup python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --max_epochs 200 --filter_channels 1000 --fc_shapes 200,100 --base_name FC21_google > FC21_google.out 2>&1 &
A model trained on the 5 class dataset can be tested on the 2 class dataset with a command line like this:
python3 -u -m stanza.models.classifier --no_train --load_name saved_models/classifier/sst_en_ewt_FS_3_4_5_C_1000_FC_400_100_classifier.E0165-ACC41.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}"
python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0189-ACC45.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}"
A model trained on the 3 class dataset can be tested on the 2 class dataset with a command line like this:
python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_3C_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0101-ACC68.94.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 2:1}"
To train models on combined 3 class datasets:
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file data/sentiment/en_sstplus.train.txt --dev_file data/sentiment/en_sst3roots.dev.txt --test_file data/sentiment/en_sst3roots.test.txt > FC41_3class.out 2>&1 &
This tests that model:
python3 -u -m stanza.models.classifier --no_train --load_name en_sstplus.pt --test_file data/sentiment/en_sst3roots.test.txt
Here is an example for training a model in a different language:
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 &
This uses more data, although that wound up being worse for the German model:
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt,data/sentiment/de_scare.train.txt,data/sentiment/de_usage.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 &
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_chinese --train_file data/sentiment/zh_ren.train.txt --dev_file data/sentiment/zh_ren.dev.txt --test_file data/sentiment/zh_ren.test.txt --shorthand zh_ren --wordvec_type fasttext --extra_wordvec_method SUM --wordvec_pretrain_file ../stanza_resources/zh-hans/pretrain/gsdsimp.pt > zh_ren.out 2>&1 &
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --save_name vi_vsfc.pt --train_file data/sentiment/vi_vsfc.train.json --dev_file data/sentiment/vi_vsfc.dev.json --test_file data/sentiment/vi_vsfc.test.json --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --extra_wordvec_method SUM --dev_eval_scoring WEIGHTED_F1 > vi_vsfc.out 2>&1 &
python3 -u -m stanza.models.classifier --no_train --test_file extern_data/sentiment/vietnamese/_UIT-VSFC/test.txt --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --load_name vi_vsfc.pt
"""
def convert_fc_shapes(arg):
"""
Returns a tuple of sizes to use in FC layers.
For examples, converts "100" -> (100,)
"100,200" -> (100,200)
"""
arg = arg.strip()
if not arg:
return ()
arg = ast.literal_eval(arg)
if isinstance(arg, int):
return (arg,)
if isinstance(arg, tuple):
return arg
return tuple(arg)
# For the most part, these values are for the constituency parser.
# Only the WD for adadelta is originally for sentiment
# Also LR for adadelta and madgrad
# madgrad learning rate experiment on sstplus
# note that the hyperparameters are not cross-validated in tandem, so
# later changes may make some earlier experiments slightly out of date
# LR
# 0.01 failed to converge
# 0.004 failed to converge
# 0.003 0.5572
# 0.002 failed to converge
# 0.001 0.6857
# 0.0008 0.6799
# 0.0005 0.6849
# 0.00025 0.6749
# 0.0001 0.6746
# 0.00001 0.6536
# 0.000001 0.6267
# LR 0.001 produced the best model, but it does occasionally fail to
# converge to a working model, so we set the default to 0.0005 instead
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0005, "sgd": 0.001 }
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
DEFAULT_LEARNING_RHO = 0.9
DEFAULT_MOMENTUM = { "madgrad": 0.9, "sgd": 0.9 }
DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.0001, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6 }
def build_argparse():
"""
Build the argparse for the classifier.
Refactored so that other utility scripts can use the same parser if needed.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--train', dest='train', default=True, action='store_true', help='Train the model (default)')
parser.add_argument('--no_train', dest='train', action='store_false', help="Don't train the model")
parser.add_argument('--shorthand', type=str, default='en_ewt', help="Treebank shorthand, eg 'en' for English")
parser.add_argument('--load_name', type=str, default=None, help='Name for loading an existing model')
parser.add_argument('--save_dir', type=str, default='saved_models/classifier', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{bert_finetuning}_{classifier_type}_classifier.pt", help='Name for saving the model')
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--save_intermediate_models', default=False, action='store_true',
help='Save all intermediate models - this can be a lot!')
parser.add_argument('--train_file', type=str, default=DEFAULT_TRAIN, help='Input file(s) to train a model from. Each line is an example. Should go . Comma separated list.')
parser.add_argument('--dev_file', type=str, default=DEFAULT_DEV, help='Input file(s) to use as the dev set.')
parser.add_argument('--test_file', type=str, default=DEFAULT_TEST, help='Input file(s) to use as the test set.')
parser.add_argument('--output_predictions', default=False, action='store_true', help='Output predictions when running the test set')
parser.add_argument('--max_epochs', type=int, default=100)
parser.add_argument('--tick', type=int, default=50)
parser.add_argument('--model_type', type=lambda x: ModelType[x.upper()], default=ModelType.CNN,
help='Model type to use. Options: %s' % " ".join(x.name for x in ModelType))
parser.add_argument('--filter_sizes', default=(3,4,5), type=ast.literal_eval, help='Filter sizes for the layer after the word vectors')
parser.add_argument('--filter_channels', default=1000, type=ast.literal_eval, help='Number of channels for layers after the word vectors. Int for same number of channels (scaled by width) for each filter, or tuple/list for exact lengths for each filter')
parser.add_argument('--fc_shapes', default="400,100", type=convert_fc_shapes, help='Extra fully connected layers to put after the initial filters. If set to blank, will FC directly from the max pooling to the output layer.')
parser.add_argument('--dropout', default=0.5, type=float, help='Dropout value to use')
parser.add_argument('--batch_size', default=50, type=int, help='Batch size when training')
parser.add_argument('--batch_single_item', default=200, type=int, help='Items of this size go in their own batch')
parser.add_argument('--dev_eval_batches', default=2000, type=int, help='Run the dev set after this many train batches. Set to 0 to only do it once per epoch')
parser.add_argument('--dev_eval_scoring', type=lambda x: DevScoring[x.upper()], default=DevScoring.WEIGHTED_F1,
help=('Scoring method to use for choosing the best model. Options: %s' %
" ".join(x.name for x in DevScoring)))
parser.add_argument('--weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')
parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate to use in the optimizer')
parser.add_argument('--momentum', default=None, type=float, help='Momentum to use in the optimizer')
parser.add_argument('--optim', default='adadelta', choices=['adadelta', 'madgrad', 'sgd'], help='Optimizer type: SGD, Adadelta, or madgrad. Highly recommend to install madgrad and use that')
parser.add_argument('--test_remap_labels', default=None, type=ast.literal_eval,
help='Map of which label each classifier label should map to. For example, "{0:0, 1:0, 3:1, 4:1}" to map a 5 class sentiment test to a 2 class. Any labels not mapped will be considered wrong')
parser.add_argument('--forgive_unmapped_labels', dest='forgive_unmapped_labels', default=True, action='store_true',
help='When remapping labels, such as from 5 class to 2 class, pick a different label if the first guess is not remapped.')
parser.add_argument('--no_forgive_unmapped_labels', dest='forgive_unmapped_labels', action='store_false',
help="When remapping labels, such as from 5 class to 2 class, DON'T pick a different label if the first guess is not remapped.")
parser.add_argument('--loss', type=lambda x: Loss[x.upper()], default=Loss.CROSS,
help="Whether to use regular cross entropy or scale it by 1/log(quantity)")
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')
parser.add_argument('--min_train_len', type=int, default=0,
help="Filter sentences less than this length")
parser.add_argument('--pretrain_max_vocab', type=int, default=-1)
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--wordvec_raw_file', type=str, default=None, help='Exact name of the raw wordvec file to read')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')
parser.add_argument('--wordvec_type', type=lambda x: WVType[x.upper()], default='word2vec', help='Different vector types have different options, such as google 300d replacing numbers with #')
parser.add_argument('--extra_wordvec_dim', type=int, default=0, help="Extra dim of word vectors - will be trained")
parser.add_argument('--extra_wordvec_method', type=lambda x: ExtraVectors[x.upper()], default='sum', help='How to train extra dimensions of word vectors, if at all')
parser.add_argument('--extra_wordvec_max_norm', type=float, default=None, help="Max norm for initializing the extra vectors")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--charlm_projection', type=int, default=None, help="Project the charlm values to this dimension")
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
parser.add_argument('--elmo_model', default='extern_data/manyelmo/english', help='Directory with elmo model')
parser.add_argument('--use_elmo', dest='use_elmo', default=False, action='store_true', help='Use an elmo model as a source of parameters')
parser.add_argument('--elmo_projection', type=int, default=None, help='Project elmo to this many dimensions')
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_finetune', default=False, action='store_true', help="Finetune the Bert model")
parser.add_argument('--bert_learning_rate', default=0.01, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')
parser.add_argument('--bilstm', dest='bilstm', action='store_true', default=True, help="Use a bilstm after the inputs, before the convs. Using bilstm is about as accurate and significantly faster (because of dim reduction) than going straight to the filters")
parser.add_argument('--no_bilstm', dest='bilstm', action='store_false', help="Don't use a bilstm after the inputs, before the convs.")
# somewhere between 200-300 seems to be the sweet spot for a couple datasets:
# dev set macro f1 scores on 3 class problems
# note that these were only run once each
# more trials might narrow down which ones works best
# es_tass2020:
# 150 0.5580
# 200 0.5629
# 250 0.5586
# 300 0.5642 <---
# 400 0.5525
# 500 0.5579
# 750 0.5585
# en_sstplus:
# 150 0.6816
# 200 0.6721
# 250 0.6915 <---
# 300 0.6824
# 400 0.6757
# 500 0.6770
# 750 0.6781
# de_sb10k
# 150 0.6745
# 200 0.6798 <---
# 250 0.6459
# 300 0.6665
# 400 0.6521
# 500 0.6584
# 750 0.6447
parser.add_argument('--bilstm_hidden_dim', type=int, default=300, help="Dimension of the bilstm to use")
parser.add_argument('--maxpool_width', type=int, default=1, help="Width of the maxpool kernel to use")
parser.add_argument('--no_constituency_backprop', dest='constituency_backprop', default=True, action='store_false', help="When using a constituency parser, backprop into the parser's weights if True")
parser.add_argument('--constituency_model', type=str, default="/home/john/stanza_resources/it/constituency/vit_bert.pt", help="Which constituency model to use. TODO: make this more user friendly")
parser.add_argument('--constituency_batch_norm', default=False, action='store_true', help='Add a LayerNorm between the output of the parser and the classifier layers')
parser.add_argument('--constituency_node_attn', default=False, action='store_true', help='True means to make an attn layer out of the tree, with the words as key and nodes as query')
parser.add_argument('--no_constituency_node_attn', dest='constituency_node_attn', action='store_false', help='True means to make an attn layer out of the tree, with the words as key and nodes as query')
parser.add_argument('--constituency_top_layer', dest='constituency_top_layer', default=False, action='store_true', help='True means use the top (ROOT) layer of the constituents. Otherwise, the next layer down (S, usually) will be used')
parser.add_argument('--no_constituency_top_layer', dest='constituency_top_layer', action='store_false', help='True means use the top (ROOT) layer of the constituents. Otherwise, the next layer down (S, usually) will be used')
parser.add_argument('--constituency_all_words', default=False, action='store_true', help='Use all word positions in the constituency classifier')
parser.add_argument('--no_constituency_all_words', dest='constituency_all_words', default=False, action='store_false', help='Use the start and end word embeddings as inputs to the constituency classifier')
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
parser.add_argument('--seed', default=None, type=int, help='Random seed for model')
add_peft_args(parser)
utils.add_device_args(parser)
return parser
def build_model_filename(args):
shape = "FS_%s" % "_".join([str(x) for x in args.filter_sizes])
shape = shape + "_C_%d_" % args.filter_channels
if args.fc_shapes:
shape = shape + "_FC_%s_" % "_".join([str(x) for x in args.fc_shapes])
model_save_file = utils.standard_model_file_name(vars(args), "classifier", shape=shape, classifier_type=args.model_type.name)
logger.info("Expanded save_name: %s", model_save_file)
return model_save_file
def parse_args(args=None):
"""
Add arguments for building the classifier.
Parses command line args and returns the result.
"""
parser = build_argparse()
args = parser.parse_args(args)
resolve_peft_args(args, tlogger)
if args.wandb_name:
args.wandb = True
args.optim = args.optim.lower()
if args.weight_decay is None:
args.weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim, None)
if args.momentum is None:
args.momentum = DEFAULT_MOMENTUM.get(args.optim, None)
if args.learning_rate is None:
args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim, None)
return args
def dataset_predictions(model, dataset):
model.eval()
index_label_map = {x: y for (x, y) in enumerate(model.labels)}
dataset_lengths = data.sort_dataset_by_len(dataset, keep_index=True)
predictions = []
o_idx = []
for length in dataset_lengths.keys():
batch = dataset_lengths[length]
output = model([x[0] for x in batch])
for i in range(len(batch)):
predicted = torch.argmax(output[i])
predicted_label = index_label_map[predicted.item()]
predictions.append(predicted_label)
o_idx.append(batch[i][1])
predictions = utils.unsort(predictions, o_idx)
return predictions
def confusion_dataset(predictions, dataset, labels):
"""
Returns a confusion matrix
First key: gold
Second key: predicted
so: confusion_matrix[gold][predicted]
"""
confusion_matrix = {}
for label in labels:
confusion_matrix[label] = {}
for predicted_label, datum in zip(predictions, dataset):
expected_label = datum.sentiment
confusion_matrix[expected_label][predicted_label] = confusion_matrix[expected_label].get(predicted_label, 0) + 1
return confusion_matrix
def score_dataset(model, dataset, label_map=None,
remap_labels=None, forgive_unmapped_labels=False):
"""
remap_labels: a dict from old label to new label to use when
testing a classifier on a dataset with a simpler label set.
For example, a model trained on 5 class sentiment can be tested
on a binary distribution with {"0": "0", "1": "0", "3": "1", "4": "1"}
forgive_unmapped_labels says the following: in the case that the
model predicts "2" in the above example for remap_labels, instead
treat the model's prediction as whichever label it gave the
highest score
"""
model.eval()
if label_map is None:
label_map = {x: y for (y, x) in enumerate(model.labels)}
correct = 0
dataset_lengths = data.sort_dataset_by_len(dataset)
for length in dataset_lengths.keys():
# TODO: possibly break this up into smaller batches
batch = dataset_lengths[length]
expected_labels = [label_map[x.sentiment] for x in batch]
output = model(batch)
for i in range(len(expected_labels)):
predicted = torch.argmax(output[i])
predicted_label = predicted.item()
if remap_labels:
if predicted_label in remap_labels:
predicted_label = remap_labels[predicted_label]
else:
found = False
if forgive_unmapped_labels:
items = []
for j in range(len(output[i])):
items.append((output[i][j].item(), j))
items.sort(key=lambda x: -x[0])
for _, item in items:
if item in remap_labels:
predicted_label = remap_labels[item]
found = True
break
# if slack guesses allowed, none of the existing
# labels matched, so we count it wrong. if slack
# guesses not allowed, just count it wrong
if not found:
continue
if predicted_label == expected_labels[i]:
correct = correct + 1
return correct
def score_dev_set(model, dev_set, dev_eval_scoring):
predictions = dataset_predictions(model, dev_set)
confusion_matrix = confusion_dataset(predictions, dev_set, model.labels)
logger.info("Dev set confusion matrix:\n{}".format(format_confusion(confusion_matrix, model.labels)))
correct, total = confusion_to_accuracy(confusion_matrix)
macro_f1 = confusion_to_macro_f1(confusion_matrix)
logger.info("Dev set: %d correct of %d examples. Accuracy: %f" %
(correct, len(dev_set), correct / len(dev_set)))
logger.info("Macro f1: {}".format(macro_f1))
accuracy = correct / total
if dev_eval_scoring is DevScoring.ACCURACY:
return accuracy, accuracy, macro_f1
elif dev_eval_scoring is DevScoring.WEIGHTED_F1:
return macro_f1, accuracy, macro_f1
else:
raise ValueError("Unknown scoring method {}".format(dev_eval_scoring))
def intermediate_name(filename, epoch, dev_scoring, score):
"""
Build an informative intermediate checkpoint name from a base name, epoch #, and accuracy
"""
root, ext = os.path.splitext(filename)
return root + ".E{epoch:04d}-{score_type}{acc:05.2f}".format(**{"epoch": epoch, "score_type": dev_scoring.value, "acc": score * 100}) + ext
def log_param_sizes(model):
logger.debug("--- Model parameter sizes ---")
total_size = 0
for name, param in model.named_parameters():
param_size = param.element_size() * param.nelement()
total_size += param_size
logger.debug(" %s %d %d %d", name, param.element_size(), param.nelement(), param_size)
logger.debug(" Total size: %d", total_size)
def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, labels):
tlogger.setLevel(logging.DEBUG)
# TODO: use a (torch) dataloader to possibly speed up the GPU usage
model = trainer.model
optimizer = trainer.optimizer
device = next(model.parameters()).device
logger.info("Current device: %s" % device)
label_map = {x: y for (y, x) in enumerate(labels)}
label_tensors = {x: torch.tensor(y, requires_grad=False, device=device)
for (y, x) in enumerate(labels)}
process_outputs = lambda x: x
if args.loss == Loss.CROSS:
logger.info("Creating CrossEntropyLoss")
loss_function = nn.CrossEntropyLoss()
elif args.loss == Loss.WEIGHTED_CROSS:
logger.info("Creating weighted cross entropy loss w/o log")
loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=False)
elif args.loss == Loss.LOG_CROSS:
logger.info("Creating weighted cross entropy loss w/ log")
loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=True)
elif args.loss == Loss.FOCAL:
try:
from focal_loss.focal_loss import FocalLoss
except ImportError:
raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature")
logger.info("Creating FocalLoss with loss %f", args.loss_focal_gamma)
process_outputs = lambda x: torch.softmax(x, dim=1)
loss_function = FocalLoss(gamma=args.loss_focal_gamma)
else:
raise ValueError("Unknown loss function {}".format(args.loss))
loss_function.to(device)
train_set_by_len = data.sort_dataset_by_len(train_set)
if trainer.global_step > 0:
# We reloaded the model, so let's report its current dev set score
_ = score_dev_set(model, dev_set, args.dev_eval_scoring)
logger.info("Reloaded model for continued training.")
if trainer.best_score is not None:
logger.info("Previous best score: %.5f", trainer.best_score)
log_param_sizes(model)
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
if args.wandb:
import wandb
wandb_name = args.wandb_name if args.wandb_name else "%s_classifier" % args.shorthand
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('accuracy', summary='max')
wandb.run.define_metric('macro_f1', summary='max')
wandb.run.define_metric('epoch_loss', summary='min')
for opt_name, opt in optimizer.items():
current_lr = opt.param_groups[0]['lr']
logger.info("optimizer %s learning rate: %s", opt_name, current_lr)
# if this is a brand new training run, and we're saving all intermediate models, save the start model as well
if args.save_intermediate_models and trainer.epochs_trained == 0:
intermediate_file = intermediate_name(model_file, trainer.epochs_trained, args.dev_eval_scoring, 0.0)
trainer.save(intermediate_file, save_optimizer=False)
for trainer.epochs_trained in range(trainer.epochs_trained, args.max_epochs):
running_loss = 0.0
epoch_loss = 0.0
shuffled_batches = data.shuffle_dataset(train_set_by_len, args.batch_size, args.batch_single_item)
model.train()
logger.info("Starting epoch %d", trainer.epochs_trained)
if args.log_norms:
model.log_norms()
for batch_num, batch in enumerate(shuffled_batches):
# logger.debug("Batch size %d max len %d" % (len(batch), max(len(x.text) for x in batch)))
trainer.global_step += 1
logger.debug("Starting batch: %d step %d", batch_num, trainer.global_step)
batch_labels = torch.stack([label_tensors[x.sentiment] for x in batch])
# zero the parameter gradients
for opt in optimizer.values():
opt.zero_grad()
outputs = model(batch)
outputs = process_outputs(outputs)
batch_loss = loss_function(outputs, batch_labels)
batch_loss.backward()
for opt in optimizer.values():
opt.step()
# print statistics
running_loss += batch_loss.item()
if (batch_num + 1) % args.tick == 0: # print every so many batches
train_loss = running_loss / args.tick
logger.info('[%d, %5d] Average loss: %.3f', trainer.epochs_trained + 1, batch_num + 1, train_loss)
if args.wandb:
wandb.log({'train_loss': train_loss}, step=trainer.global_step)
if args.dev_eval_batches > 0 and (batch_num + 1) % args.dev_eval_batches == 0:
logger.info('---- Interim analysis ----')
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)
if args.wandb:
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step)
if trainer.best_score is None or dev_score > trainer.best_score:
trainer.best_score = dev_score
trainer.save(model_file, save_optimizer=False)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1))
model.train()
if args.log_norms:
trainer.model.log_norms()
epoch_loss += running_loss
running_loss = 0.0
# Add any leftover loss to the epoch_loss
epoch_loss += running_loss
logger.info("Finished epoch %d Total loss %.3f" % (trainer.epochs_trained + 1, epoch_loss))
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)
if args.wandb:
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1, 'epoch_loss': epoch_loss}, step=trainer.global_step)
if checkpoint_file:
trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1)
if args.save_intermediate_models:
intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score)
trainer.save(intermediate_file, save_optimizer=False)
if trainer.best_score is None or dev_score > trainer.best_score:
trainer.best_score = dev_score
trainer.save(model_file, save_optimizer=False)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1))
if args.wandb:
wandb.finish()
def main(args=None):
args = parse_args(args)
seed = utils.set_random_seed(args.seed)
logger.info("Using random seed: %d" % seed)
utils.ensure_dir(args.save_dir)
save_name = build_model_filename(args)
# TODO: maybe the dataset needs to be in a torch data loader in order to
# make cuda operations faster
checkpoint_file = None
if args.train:
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
logger.info("Using training set: %s" % args.train_file)
logger.info("Training set has %d labels" % len(data.dataset_labels(train_set)))
tlogger.setLevel(logging.DEBUG)
tlogger.info("Saving checkpoints: %s", args.checkpoint)
if args.checkpoint:
checkpoint_file = utils.checkpoint_name(args.save_dir, save_name, args.checkpoint_save_name)
tlogger.info("Checkpoint filename: %s", checkpoint_file)
elif not args.load_name:
if save_name:
args.load_name = save_name
else:
raise ValueError("No model provided and not asked to train a model. This makes no sense")
else:
train_set = None
if args.train and checkpoint_file is not None and os.path.exists(checkpoint_file):
trainer = Trainer.load(checkpoint_file, args, load_optimizer=args.train)
elif args.load_name:
trainer = Trainer.load(args.load_name, args, load_optimizer=args.train)
else:
trainer = Trainer.build_new_model(args, train_set)
trainer.model.log_configuration()
if args.train:
utils.log_training_args(args, logger)
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, min_len=None)
logger.info("Using dev set: %s", args.dev_file)
logger.info("Training set has %d items", len(train_set))
logger.info("Dev set has %d items", len(dev_set))
data.check_labels(trainer.model.labels, dev_set)
train_model(trainer, save_name, checkpoint_file, args, train_set, dev_set, trainer.model.labels)
if args.log_norms:
trainer.model.log_norms()
test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
logger.info("Using test set: %s" % args.test_file)
data.check_labels(trainer.model.labels, test_set)
if args.test_remap_labels is None:
predictions = dataset_predictions(trainer.model, test_set)
confusion_matrix = confusion_dataset(predictions, test_set, trainer.model.labels)
if args.output_predictions:
logger.info("List of predictions: %s", predictions)
logger.info("Confusion matrix:\n{}".format(format_confusion(confusion_matrix, trainer.model.labels)))
correct, total = confusion_to_accuracy(confusion_matrix)
logger.info("Macro f1: {}".format(confusion_to_macro_f1(confusion_matrix)))
else:
correct = score_dataset(trainer.model, test_set,
remap_labels=args.test_remap_labels,
forgive_unmapped_labels=args.forgive_unmapped_labels)
total = len(test_set)
logger.info("Test set: %d correct of %d examples. Accuracy: %f" %
(correct, total, correct / total))
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/classifiers/__init__.py
================================================
================================================
FILE: stanza/models/classifiers/base_classifier.py
================================================
from abc import ABC, abstractmethod
import logging
import torch
import torch.nn as nn
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
"""
A base classifier type
Currently, has the ability to process text or other inputs in a manner
suitable for the particular model type.
In other words, the CNNClassifier processes lists of words,
and the ConstituencyClassifier processes trees
"""
logger = logging.getLogger('stanza')
class BaseClassifier(ABC, nn.Module):
@abstractmethod
def extract_sentences(self, doc):
"""
Extract the sentences or the relevant information in the sentences from a document
"""
def preprocess_sentences(self, sentences):
"""
By default, don't do anything
"""
return sentences
def label_sentences(self, sentences, batch_size=None):
"""
Given a list of sentences, return the model's results on that text.
"""
self.eval()
sentences = self.preprocess_sentences(sentences)
if batch_size is None:
intervals = [(0, len(sentences))]
orig_idx = None
else:
sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
intervals = split_into_batches(sentences, batch_size)
labels = []
for interval in intervals:
if interval[1] - interval[0] == 0:
# this can happen for empty text
continue
output = self(sentences[interval[0]:interval[1]])
predicted = torch.argmax(output, dim=1)
labels.extend(predicted.tolist())
if orig_idx:
sentences = unsort(sentences, orig_idx)
labels = unsort(labels, orig_idx)
logger.debug("Found labels")
for (label, sentence) in zip(labels, sentences):
logger.debug((label, sentence))
return labels
================================================
FILE: stanza/models/classifiers/cnn_classifier.py
================================================
import dataclasses
import logging
import math
import os
import random
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import stanza.models.classifiers.data as data
from stanza.models.classifiers.base_classifier import BaseClassifier
from stanza.models.classifiers.config import CNNConfig
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.common.data import get_long_tensor, sort_all
from stanza.models.common.utils import attach_bert_model
from stanza.models.common.vocab import PAD_ID, UNK_ID
"""
The CNN classifier is based on Yoon Kim's work:
https://arxiv.org/abs/1408.5882
Also included are maxpool 2d, conv 2d, and a bilstm, as in
Text Classification Improved by Integrating Bidirectional LSTM
with Two-dimensional Max Pooling
https://aclanthology.org/C16-1329.pdf
The architecture is simple:
- Embedding at the bottom layer
- separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK
- maybe a bilstm layer, as per a command line flag
- Some number of conv2d layers over the embedding
- Maxpool layers over small windows, window size being a parameter
- FC layer to the classification layer
One experiment which was run and found to be a bit of a negative was
putting a layer on top of the pretrain. You would think that might
help, but dev performance went down for each variation of
- trans(emb)
- relu(trans(emb))
- dropout(trans(emb))
- dropout(relu(trans(emb)))
"""
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')
class CNNClassifier(BaseClassifier):
def __init__(self, pretrain, extra_vocab, labels,
charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer, force_bert_saved, peft_name,
args):
"""
pretrain is a pretrained word embedding. should have .emb and .vocab
extra_vocab is a collection of words in the training data to
be used for the delta word embedding, if used. can be set to
None if delta word embedding is not used.
labels is the list of labels we expect in the training data.
Used to derive the number of classes. Saving it in the model
will let us check that test data has the same labels
args is either the complete arguments when training, or the
subset of arguments stored in the model save file
"""
super(CNNClassifier, self).__init__()
self.labels = labels
bert_finetune = args.bert_finetune
use_peft = args.use_peft
force_bert_saved = force_bert_saved or bert_finetune
logger.debug("bert_finetune %s / force_bert_saved %s", bert_finetune, force_bert_saved)
# this may change when loaded in a new Pipeline, so it's not part of the config
self.peft_name = peft_name
# we build a separate config out of the args so that we can easily save it in torch
self.config = CNNConfig(filter_channels = args.filter_channels,
filter_sizes = args.filter_sizes,
fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
wordvec_type = args.wordvec_type,
extra_wordvec_method = args.extra_wordvec_method,
extra_wordvec_dim = args.extra_wordvec_dim,
extra_wordvec_max_norm = args.extra_wordvec_max_norm,
char_lowercase = args.char_lowercase,
charlm_projection = args.charlm_projection,
has_charlm_forward = charmodel_forward is not None,
has_charlm_backward = charmodel_backward is not None,
use_elmo = args.use_elmo,
elmo_projection = args.elmo_projection,
bert_model = args.bert_model,
bert_finetune = bert_finetune,
bert_hidden_layers = args.bert_hidden_layers,
force_bert_saved = force_bert_saved,
use_peft = use_peft,
lora_rank = args.lora_rank,
lora_alpha = args.lora_alpha,
lora_dropout = args.lora_dropout,
lora_modules_to_save = args.lora_modules_to_save,
lora_target_modules = args.lora_target_modules,
bilstm = args.bilstm,
bilstm_hidden_dim = args.bilstm_hidden_dim,
maxpool_width = args.maxpool_width,
model_type = ModelType.CNN)
self.char_lowercase = args.char_lowercase
self.unsaved_modules = []
emb_matrix = pretrain.emb
self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
self.add_unsaved_module('elmo_model', elmo_model)
self.vocab_size = emb_matrix.shape[0]
self.embedding_dim = emb_matrix.shape[1]
self.add_unsaved_module('forward_charlm', charmodel_forward)
if charmodel_forward is not None:
tlogger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
if not charmodel_forward.is_forward_lm:
raise ValueError("Got a backward charlm as a forward charlm!")
self.add_unsaved_module('backward_charlm', charmodel_backward)
if charmodel_backward is not None:
tlogger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
if charmodel_backward.is_forward_lm:
raise ValueError("Got a forward charlm as a backward charlm!")
attach_bert_model(self, bert_model, bert_tokenizer, self.config.use_peft, force_bert_saved)
# The Pretrain has PAD and UNK already (indices 0 and 1), but we
# possibly want to train UNK while freezing the rest of the embedding
# note that the /10.0 operation has to be inside nn.Parameter unless
# you want to spend a long time debugging this
self.unk = nn.Parameter(torch.randn(self.embedding_dim) / np.sqrt(self.embedding_dim) / 10.0)
# replacing NBSP picks up a whole bunch of words for VI
self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
if self.config.extra_wordvec_method is not ExtraVectors.NONE:
if not extra_vocab:
raise ValueError("Should have had extra_vocab set for extra_wordvec_method {}".format(self.config.extra_wordvec_method))
if not args.extra_wordvec_dim:
self.config.extra_wordvec_dim = self.embedding_dim
if self.config.extra_wordvec_method is ExtraVectors.SUM:
if self.config.extra_wordvec_dim != self.embedding_dim:
raise ValueError("extra_wordvec_dim must equal embedding_dim for {}".format(self.config.extra_wordvec_method))
self.extra_vocab = list(extra_vocab)
self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }
# TODO: possibly add regularization specifically on the extra embedding?
# note: it looks like a bug that this doesn't add UNK or PAD, but actually
# those are expected to already be the first two entries
self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),
embedding_dim = self.config.extra_wordvec_dim,
max_norm = self.config.extra_wordvec_max_norm,
padding_idx = 0)
tlogger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
else:
self.extra_vocab = None
self.extra_vocab_map = None
self.config.extra_wordvec_dim = 0
self.extra_embedding = None
# Pytorch is "aware" of the existence of the nn.Modules inside
# an nn.ModuleList in terms of parameters() etc
if self.config.extra_wordvec_method is ExtraVectors.NONE:
total_embedding_dim = self.embedding_dim
elif self.config.extra_wordvec_method is ExtraVectors.SUM:
total_embedding_dim = self.embedding_dim
elif self.config.extra_wordvec_method is ExtraVectors.CONCAT:
total_embedding_dim = self.embedding_dim + self.config.extra_wordvec_dim
else:
raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
if charmodel_forward is not None:
if args.charlm_projection:
self.charmodel_forward_projection = nn.Linear(charmodel_forward.hidden_dim(), args.charlm_projection)
total_embedding_dim += args.charlm_projection
else:
self.charmodel_forward_projection = None
total_embedding_dim += charmodel_forward.hidden_dim()
if charmodel_backward is not None:
if args.charlm_projection:
self.charmodel_backward_projection = nn.Linear(charmodel_backward.hidden_dim(), args.charlm_projection)
total_embedding_dim += args.charlm_projection
else:
self.charmodel_backward_projection = None
total_embedding_dim += charmodel_backward.hidden_dim()
if self.config.use_elmo:
if elmo_model is None:
raise ValueError("Model requires elmo, but elmo_model not passed in")
elmo_dim = elmo_model.sents2elmo([["Test"]])[0].shape[1]
# this mapping will combine 3 layers of elmo to 1 layer of features
self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)
if self.config.elmo_projection:
self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)
total_embedding_dim = total_embedding_dim + self.config.elmo_projection
else:
total_embedding_dim = total_embedding_dim + elmo_dim
if bert_model is not None:
if self.config.bert_hidden_layers:
# The average will be offset by 1/N so that the default zeros
# repressents an average of the N layers
if self.config.bert_hidden_layers > bert_model.config.num_hidden_layers:
# limit ourselves to the number of layers actually available
# note that we can +1 because of the initial embedding layer
self.config.bert_hidden_layers = bert_model.config.num_hidden_layers + 1
self.bert_layer_mix = nn.Linear(self.config.bert_hidden_layers, 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
if bert_tokenizer is None:
raise ValueError("Cannot have a bert model without a tokenizer")
self.bert_dim = self.bert_model.config.hidden_size
total_embedding_dim += self.bert_dim
if self.config.bilstm:
conv_input_dim = self.config.bilstm_hidden_dim * 2
self.bilstm = nn.LSTM(batch_first=True,
input_size=total_embedding_dim,
hidden_size=self.config.bilstm_hidden_dim,
num_layers=2,
bidirectional=True,
dropout=0.2)
else:
conv_input_dim = total_embedding_dim
self.bilstm = None
self.fc_input_size = 0
self.conv_layers = nn.ModuleList()
self.max_window = 0
for filter_idx, filter_size in enumerate(self.config.filter_sizes):
if isinstance(filter_size, int):
self.max_window = max(self.max_window, filter_size)
if isinstance(self.config.filter_channels, int):
filter_channels = self.config.filter_channels
else:
filter_channels = self.config.filter_channels[filter_idx]
fc_delta = filter_channels // self.config.maxpool_width
tlogger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
self.fc_input_size += fc_delta
self.conv_layers.append(nn.Conv2d(in_channels=1,
out_channels=filter_channels,
kernel_size=(filter_size, conv_input_dim)))
elif isinstance(filter_size, tuple) and len(filter_size) == 2:
filter_height, filter_width = filter_size
self.max_window = max(self.max_window, filter_width)
if isinstance(self.config.filter_channels, int):
filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
else:
filter_channels = self.config.filter_channels[filter_idx]
fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
tlogger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
self.fc_input_size += fc_delta
self.conv_layers.append(nn.Conv2d(in_channels=1,
out_channels=filter_channels,
stride=(1, filter_width),
kernel_size=(filter_height, filter_width)))
else:
raise ValueError("Expected int or 2d tuple for conv size")
tlogger.debug("Input dim to FC layers: %d", self.fc_input_size)
self.fc_layers = build_output_layers(self.fc_input_size, self.config.fc_shapes, self.config.num_classes)
self.dropout = nn.Dropout(self.config.dropout)
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
if module is not None and (name in ('forward_charlm', 'backward_charlm') or
(name == 'bert_model' and not self.config.use_peft)):
# if we are using peft, we should not save the transformer directly
# instead, the peft parameters only will be saved later
for _, parameter in module.named_parameters():
parameter.requires_grad = False
def is_unsaved_module(self, name):
return name.split('.')[0] in self.unsaved_modules
def log_configuration(self):
"""
Log some essential information about the model configuration to the training logger
"""
tlogger.info("Filter sizes: %s" % str(self.config.filter_sizes))
tlogger.info("Filter channels: %s" % str(self.config.filter_channels))
tlogger.info("Intermediate layers: %s" % str(self.config.fc_shapes))
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMTERS"]
for name, param in self.named_parameters():
if param.requires_grad and name.split(".")[0] not in ('forward_charlm', 'backward_charlm'):
lines.append("%s %.6g" % (name, torch.norm(param).item()))
logger.info("\n".join(lines))
def build_char_reps(self, inputs, max_phrase_len, charlm, projection, begin_paddings, device):
char_reps = charlm.build_char_representation(inputs)
if projection is not None:
char_reps = [projection(x) for x in char_reps]
char_inputs = torch.zeros((len(inputs), max_phrase_len, char_reps[0].shape[-1]), device=device)
for idx, rep in enumerate(char_reps):
start = begin_paddings[idx]
end = start + rep.shape[0]
char_inputs[idx, start:end, :] = rep
return char_inputs
def extract_bert_embeddings(self, inputs, max_phrase_len, begin_paddings, device):
bert_embeddings = extract_bert_embeddings(self.config.bert_model, self.bert_tokenizer, self.bert_model, inputs, device,
keep_endpoints=False,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.config.bert_finetune,
peft_name=self.peft_name)
if self.bert_layer_mix is not None:
# add the average so that the default behavior is to
# take an average of the N layers, and anything else
# other than that needs to be learned
bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
bert_inputs = torch.zeros((len(inputs), max_phrase_len, bert_embeddings[0].shape[-1]), device=device)
for idx, rep in enumerate(bert_embeddings):
start = begin_paddings[idx]
end = start + rep.shape[0]
bert_inputs[idx, start:end, :] = rep
return bert_inputs
def forward(self, inputs):
# assume all pieces are on the same device
device = next(self.parameters()).device
vocab_map = self.vocab_map
def map_word(word):
idx = vocab_map.get(word, None)
if idx is not None:
return idx
if word[-1] == "'":
idx = vocab_map.get(word[:-1], None)
if idx is not None:
return idx
return vocab_map.get(word.lower(), UNK_ID)
inputs = [x.text if isinstance(x, SentimentDatum) else x for x in inputs]
# we will pad each phrase so either it matches the longest
# conv or the longest phrase in the input, whichever is longer
max_phrase_len = max(len(x) for x in inputs)
if self.max_window > max_phrase_len:
max_phrase_len = self.max_window
batch_indices = []
batch_unknowns = []
extra_batch_indices = []
begin_paddings = []
end_paddings = []
elmo_batch_words = []
for phrase in inputs:
# we use random at training time to try to learn different
# positions of padding. at test time, though, we want to
# have consistent results, so we set that to 0 begin_pad
if self.training:
begin_pad_width = random.randint(0, max_phrase_len - len(phrase))
else:
begin_pad_width = 0
end_pad_width = max_phrase_len - begin_pad_width - len(phrase)
begin_paddings.append(begin_pad_width)
end_paddings.append(end_pad_width)
# the initial lists are the length of the begin padding
sentence_indices = [PAD_ID] * begin_pad_width
sentence_indices.extend([map_word(x) for x in phrase])
sentence_indices.extend([PAD_ID] * end_pad_width)
# the "unknowns" will be the locations of the unknown words.
# these locations will get the specially trained unknown vector
# TODO: split UNK based on part of speech? might be an interesting experiment
sentence_unknowns = [idx for idx, word in enumerate(sentence_indices) if word == UNK_ID]
batch_indices.append(sentence_indices)
batch_unknowns.append(sentence_unknowns)
if self.extra_vocab:
extra_sentence_indices = [PAD_ID] * begin_pad_width
for word in phrase:
if word in self.extra_vocab_map:
# the extra vocab is initialized from the
# words in the training set, which means there
# would be no unknown words. to occasionally
# train the extra vocab's unknown words, we
# replace 1% of the words with UNK
# we don't do that for the original embedding
# on the assumption that there may be some
# unknown words in the training set anyway
# TODO: maybe train unk for the original embedding?
if self.training and random.random() < 0.01:
extra_sentence_indices.append(UNK_ID)
else:
extra_sentence_indices.append(self.extra_vocab_map[word])
else:
extra_sentence_indices.append(UNK_ID)
extra_sentence_indices.extend([PAD_ID] * end_pad_width)
extra_batch_indices.append(extra_sentence_indices)
if self.config.use_elmo:
elmo_phrase_words = [""] * begin_pad_width
for word in phrase:
elmo_phrase_words.append(word)
elmo_phrase_words.extend([""] * end_pad_width)
elmo_batch_words.append(elmo_phrase_words)
# creating a single large list with all the indices lets us
# create a single tensor, which is much faster than creating
# many tiny tensors
# we can convert this to the input to the CNN
# it is padded at one or both ends so that it is now num_phrases x max_len x emb_size
# there are two ways in which this padding is suboptimal
# the first is that for short sentences, smaller windows will
# be padded to the point that some windows are entirely pad
# the second is that a sentence S will have more or less padding
# depending on what other sentences are in its batch
# we assume these effects are pretty minimal
batch_indices = torch.tensor(batch_indices, requires_grad=False, device=device)
input_vectors = self.embedding(batch_indices)
# we use the random unk so that we are not necessarily
# learning to match 0s for unk
for phrase_num, sentence_unknowns in enumerate(batch_unknowns):
input_vectors[phrase_num][sentence_unknowns] = self.unk
if self.extra_vocab:
extra_batch_indices = torch.tensor(extra_batch_indices, requires_grad=False, device=device)
extra_input_vectors = self.extra_embedding(extra_batch_indices)
if self.config.extra_wordvec_method is ExtraVectors.CONCAT:
all_inputs = [input_vectors, extra_input_vectors]
elif self.config.extra_wordvec_method is ExtraVectors.SUM:
all_inputs = [input_vectors + extra_input_vectors]
else:
raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
else:
all_inputs = [input_vectors]
if self.forward_charlm is not None:
char_reps_forward = self.build_char_reps(inputs, max_phrase_len, self.forward_charlm, self.charmodel_forward_projection, begin_paddings, device)
all_inputs.append(char_reps_forward)
if self.backward_charlm is not None:
char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
all_inputs.append(char_reps_backward)
if self.config.use_elmo:
# this will be N arrays of 3xMx1024 where M is the number of words
# and N is the number of sentences (and 1024 is actually the number of weights)
elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)
elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]
# elmo_tensor will now be Nx3xMx1024
elmo_tensor = torch.stack(elmo_tensors)
# Nx1024xMx3
elmo_tensor = torch.transpose(elmo_tensor, 1, 3)
# NxMx1024x3
elmo_tensor = torch.transpose(elmo_tensor, 1, 2)
# NxMx1024x1
elmo_tensor = self.elmo_combine_layers(elmo_tensor)
# NxMx1024
elmo_tensor = elmo_tensor.squeeze(3)
if self.config.elmo_projection:
elmo_tensor = self.elmo_projection(elmo_tensor)
all_inputs.append(elmo_tensor)
if self.bert_model is not None:
bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)
all_inputs.append(bert_embeddings)
# still works even if there's just one item
input_vectors = torch.cat(all_inputs, dim=2)
if self.config.bilstm:
input_vectors, _ = self.bilstm(self.dropout(input_vectors))
# reshape to fit the input tensors
x = input_vectors.unsqueeze(1)
conv_outs = []
for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
if isinstance(filter_size, int):
conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
conv_outs.append(conv_out)
else:
conv_out = conv(x).transpose(2, 3).flatten(1, 2)
conv_out = self.dropout(F.relu(conv_out))
conv_outs.append(conv_out)
pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
pooled = torch.cat(pool_outs, dim=1)
previous_layer = pooled
for fc in self.fc_layers[:-1]:
previous_layer = self.dropout(F.relu(fc(previous_layer)))
out = self.fc_layers[-1](previous_layer)
# note that we return the raw logits rather than use a softmax
# https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4
return out
def get_params(self, skip_modules=True):
model_state = self.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
for k in skipped:
del model_state[k]
config = dataclasses.asdict(self.config)
config['wordvec_type'] = config['wordvec_type'].name
config['extra_wordvec_method'] = config['extra_wordvec_method'].name
config['model_type'] = config['model_type'].name
params = {
'model': model_state,
'config': config,
'labels': self.labels,
'extra_vocab': self.extra_vocab,
}
if self.config.use_peft:
# Hide import so that peft dependency is optional
from peft import get_peft_model_state_dict
params["bert_lora"] = get_peft_model_state_dict(self.bert_model, adapter_name=self.peft_name)
return params
def preprocess_data(self, sentences):
sentences = [data.update_text(s, self.config.wordvec_type) for s in sentences]
return sentences
def extract_sentences(self, doc):
# TODO: tokens or words better here?
return [[token.text for token in sentence.tokens] for sentence in doc.sentences]
================================================
FILE: stanza/models/classifiers/config.py
================================================
from dataclasses import dataclass
from typing import List, Union
# TODO: perhaps put the enums in this file
from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType
@dataclass
class CNNConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
filter_channels: Union[int, tuple]
filter_sizes: tuple
fc_shapes: tuple
dropout: float
num_classes: int
wordvec_type: WVType
extra_wordvec_method: ExtraVectors
extra_wordvec_dim: int
extra_wordvec_max_norm: float
char_lowercase: bool
charlm_projection: int
has_charlm_forward: bool
has_charlm_backward: bool
use_elmo: bool
elmo_projection: int
bert_model: str
bert_finetune: bool
bert_hidden_layers: int
force_bert_saved: bool
use_peft: bool
lora_rank: int
lora_alpha: float
lora_dropout: float
lora_modules_to_save: List
lora_target_modules: List
bilstm: bool
bilstm_hidden_dim: int
maxpool_width: int
model_type: ModelType
@dataclass
class ConstituencyConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
fc_shapes: tuple
dropout: float
num_classes: int
constituency_backprop: bool
constituency_batch_norm: bool
constituency_node_attn: bool
constituency_top_layer: bool
constituency_all_words: bool
model_type: ModelType
================================================
FILE: stanza/models/classifiers/constituency_classifier.py
================================================
"""
A classifier that uses a constituency parser for the base embeddings
"""
import dataclasses
import logging
from types import SimpleNamespace
import torch
import torch.nn as nn
import torch.nn.functional as F
from stanza.models.classifiers.base_classifier import BaseClassifier
from stanza.models.classifiers.config import ConstituencyConfig
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import ModelType, build_output_layers
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')
class ConstituencyClassifier(BaseClassifier):
def __init__(self, tree_embedding, labels, args):
super(ConstituencyClassifier, self).__init__()
self.labels = labels
# we build a separate config out of the args so that we can easily save it in torch
self.config = ConstituencyConfig(fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
constituency_backprop = args.constituency_backprop,
constituency_batch_norm = args.constituency_batch_norm,
constituency_node_attn = args.constituency_node_attn,
constituency_top_layer = args.constituency_top_layer,
constituency_all_words = args.constituency_all_words,
model_type = ModelType.CONSTITUENCY)
self.tree_embedding = tree_embedding
self.fc_layers = build_output_layers(self.tree_embedding.output_size, self.config.fc_shapes, self.config.num_classes)
self.dropout = nn.Dropout(self.config.dropout)
def is_unsaved_module(self, name):
return False
def log_configuration(self):
tlogger.info("Backprop into parser: %s", self.config.constituency_backprop)
tlogger.info("Batch norm: %s", self.config.constituency_batch_norm)
tlogger.info("Word positions used: %s", "all words" if self.config.constituency_all_words else "start and end words")
tlogger.info("Attention over nodes: %s", self.config.constituency_node_attn)
tlogger.info("Intermediate layers: %s", self.config.fc_shapes)
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMTERS"]
lines.extend(["tree_embedding." + x for x in self.tree_embedding.get_norms()])
for name, param in self.named_parameters():
if param.requires_grad and not name.startswith('tree_embedding.'):
lines.append("%s %.6g" % (name, torch.norm(param).item()))
logger.info("\n".join(lines))
def forward(self, inputs):
inputs = [x.constituency if isinstance(x, SentimentDatum) else x for x in inputs]
embedding = self.tree_embedding.embed_trees(inputs)
previous_layer = torch.stack([torch.max(x, dim=0)[0] for x in embedding], dim=0)
previous_layer = self.dropout(previous_layer)
for fc in self.fc_layers[:-1]:
# relu cause many neuron die
previous_layer = self.dropout(F.gelu(fc(previous_layer)))
out = self.fc_layers[-1](previous_layer)
return out
def get_params(self, skip_modules=True):
model_state = self.state_dict()
# skip all of the constituency parameters here -
# we will add them by calling the model's get_params()
skipped = [k for k in model_state.keys() if k.startswith("tree_embedding.")]
for k in skipped:
del model_state[k]
tree_embedding = self.tree_embedding.get_params(skip_modules)
config = dataclasses.asdict(self.config)
config['model_type'] = config['model_type'].name
params = {
'model': model_state,
'tree_embedding': tree_embedding,
'config': config,
'labels': self.labels,
}
return params
def extract_sentences(self, doc):
return [sentence.constituency for sentence in doc.sentences]
================================================
FILE: stanza/models/classifiers/data.py
================================================
"""Stanza models classifier data functions."""
import collections
from collections import namedtuple
import logging
import json
import random
import re
from typing import List
from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
import stanza.models.constituency.tree_reader as tree_reader
logger = logging.getLogger('stanza')
class SentimentDatum:
def __init__(self, sentiment, text, constituency=None):
self.sentiment = sentiment
self.text = text
self.constituency = constituency
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, SentimentDatum):
return False
return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency
def __str__(self):
return str(self._asdict())
def _asdict(self):
if self.constituency is None:
return {'sentiment': self.sentiment, 'text': self.text}
else:
return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}
def update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:
"""
Process a line of text (with tokenization provided as whitespace)
into a list of strings.
"""
# stanford sentiment dataset has a lot of random - and /
# remove those characters and flatten the newly created sublists into one list each time
sentence = [y for x in sentence for y in x.split("-") if y]
sentence = [y for x in sentence for y in x.split("/") if y]
sentence = [x.strip() for x in sentence]
sentence = [x for x in sentence if x]
if sentence == []:
# removed too much
sentence = ["-"]
# our current word vectors are all entirely lowercased
sentence = [word.lower() for word in sentence]
if wordvec_type == WVType.WORD2VEC:
return sentence
elif wordvec_type == WVType.GOOGLE:
new_sentence = []
for word in sentence:
if word != '0' and word != '1':
word = re.sub('[0-9]', '#', word)
new_sentence.append(word)
return new_sentence
elif wordvec_type == WVType.FASTTEXT:
return sentence
elif wordvec_type == WVType.OTHER:
return sentence
else:
raise ValueError("Unknown wordvec_type {}".format(wordvec_type))
def read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:
"""
returns a list where the values of the list are
label, [token...]
"""
lines = []
for filename in str(dataset).split(","):
with open(filename, encoding="utf-8") as fin:
new_lines = json.load(fin)
new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]
lines.extend(new_lines)
# TODO: maybe do this processing later, once the model is built.
# then move the processing into the model so we can use
# overloading to potentially make future model types
lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]
if min_len:
lines = [x for x in lines if len(x.text) >= min_len]
return lines
def dataset_labels(dataset):
"""
Returns a sorted list of label name
"""
labels = set([x.sentiment for x in dataset])
if all(re.match("^[0-9]+$", label) for label in labels):
# if all of the labels are integers, sort numerically
# maybe not super important, but it would be nicer than having
# 10 before 2
labels = [str(x) for x in sorted(map(int, list(labels)))]
else:
labels = sorted(list(labels))
return labels
def dataset_vocab(dataset):
vocab = set()
for line in dataset:
for word in line.text:
vocab.add(word)
vocab = [PAD, UNK] + list(vocab)
if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
raise ValueError("Unexpected values for PAD and UNK!")
return vocab
def sort_dataset_by_len(dataset, keep_index=False):
"""
returns a dict mapping length -> list of items of that length
an OrderedDict is used so that the mapping is sorted from smallest to largest
"""
sorted_dataset = collections.OrderedDict()
lengths = sorted(list(set(len(x.text) for x in dataset)))
for l in lengths:
sorted_dataset[l] = []
for item_idx, item in enumerate(dataset):
if keep_index:
sorted_dataset[len(item.text)].append((item, item_idx))
else:
sorted_dataset[len(item.text)].append(item)
return sorted_dataset
def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
"""
Given a dataset sorted by len, sorts within each length to make
chunks of roughly the same size. Returns all items as a single list.
"""
dataset = []
for l in sorted_dataset.keys():
items = list(sorted_dataset[l])
random.shuffle(items)
dataset.extend(items)
batches = []
next_batch = []
for item in dataset:
if batch_single_item > 0 and len(item.text) >= batch_single_item:
batches.append([item])
else:
next_batch.append(item)
if len(next_batch) >= batch_size:
batches.append(next_batch)
next_batch = []
if len(next_batch) > 0:
batches.append(next_batch)
random.shuffle(batches)
return batches
def check_labels(labels, dataset):
"""
Check that all of the labels in the dataset are in the known labels.
Actually, unknown labels could be acceptable if we just treat the model as always wrong.
However, this is a good sanity check to make sure the datasets match
"""
new_labels = dataset_labels(dataset)
not_found = [i for i in new_labels if i not in labels]
if not_found:
raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
================================================
FILE: stanza/models/classifiers/iterate_test.py
================================================
"""Iterate test."""
import argparse
import glob
import logging
import stanza.models.classifier as classifier
import stanza.models.classifiers.cnn_classifier as cnn_classifier
from stanza.models.common import utils
from stanza.utils.confusion import format_confusion, confusion_to_accuracy
"""
A script for running the same test file on several different classifiers.
For each one, it will output the accuracy and, if possible, the confusion matrix.
Includes the arguments for pretrain, which allows for passing in a
different directory for the pretrain file.
Example command line:
python3 -m stanza.models.classifiers.iterate_test --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt --glob "saved_models/classifier/FC41_3class_en_ewt_FS*ACC66*"
"""
logger = logging.getLogger('stanza')
def parse_args():
"""Add and parse arguments."""
parser = classifier.build_argparse()
parser.add_argument('--glob', type=str, default='saved_models/classifier/*classifier*pt', help='Model file(s) to test.')
args = parser.parse_args()
return args
args = parse_args()
seed = utils.set_random_seed(args.seed)
model_files = []
for glob_piece in args.glob.split():
model_files.extend(glob.glob(glob_piece))
model_files = sorted(set(model_files))
test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
logger.info("Using test set: %s" % args.test_file)
device = None
for load_name in model_files:
args.load_name = load_name
model = classifier.load_model(args)
logger.info("Testing %s" % load_name)
model = cnn_classifier.load(load_name, pretrain)
if device is None:
device = next(model.parameters()).device
logger.info("Current device: %s" % device)
labels = model.labels
classifier.check_labels(labels, test_set)
confusion = classifier.confusion_dataset(model, test_set, device=device)
correct, total = confusion_to_accuracy(confusion)
logger.info(" Results: %d correct of %d examples. Accuracy: %f" % (correct, total, correct / total))
logger.info("Confusion matrix:\n{}".format(format_confusion(confusion, model.labels)))
================================================
FILE: stanza/models/classifiers/trainer.py
================================================
"""
Organizes the model itself and its optimizer in one place
Saving the optimizer allows for easy restarting of training
"""
import logging
import os
import torch
import torch.optim as optim
from types import SimpleNamespace
import stanza.models.classifiers.data as data
import stanza.models.classifiers.cnn_classifier as cnn_classifier
import stanza.models.classifiers.constituency_classifier as constituency_classifier
from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.common.pretrain import Pretrain
from stanza.models.common.utils import get_split_optimizer
from stanza.models.constituency.tree_embedding import TreeEmbedding
from pickle import UnpicklingError
import warnings
logger = logging.getLogger('stanza')
class Trainer:
"""
Stores a constituency model and its optimizer
"""
def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
self.model = model
self.optimizer = optimizer
# we keep track of position in the learning so that we can
# checkpoint & restart if needed without restarting the epoch count
self.epochs_trained = epochs_trained
self.global_step = global_step
# save the best dev score so that when reloading a checkpoint
# of a model, we know how far we got
self.best_score = best_score
def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
"""
save the current model, optimizer, and other state to filename
epochs_trained can be passed as a parameter to handle saving at the end of an epoch
"""
if epochs_trained is None:
epochs_trained = self.epochs_trained
save_dir = os.path.split(filename)[0]
os.makedirs(save_dir, exist_ok=True)
model_params = self.model.get_params(skip_modules)
params = {
'params': model_params,
'epochs_trained': epochs_trained,
'global_step': self.global_step,
'best_score': self.best_score,
}
if save_optimizer and self.optimizer is not None:
params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
@staticmethod
def load(filename, args, foundation_cache=None, load_optimizer=False):
if not os.path.exists(filename):
if args.save_dir is None:
raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename))
elif os.path.exists(os.path.join(args.save_dir, filename)):
filename = os.path.join(args.save_dir, filename)
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
try:
# TODO: can remove the try/except once the new version is out
#checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except UnpicklingError as e:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
except BaseException:
logger.exception("Cannot load model from {}".format(filename))
raise
logger.debug("Loaded model {}".format(filename))
epochs_trained = checkpoint.get('epochs_trained', 0)
global_step = checkpoint.get('global_step', 0)
best_score = checkpoint.get('best_score', None)
# TODO: can remove this block once all models are retrained
if 'params' not in checkpoint:
model_params = {
'model': checkpoint['model'],
'config': checkpoint['config'],
'labels': checkpoint['labels'],
'extra_vocab': checkpoint['extra_vocab'],
}
else:
model_params = checkpoint['params']
# TODO: this can be removed once v1.10.0 is out
if isinstance(model_params['config'], SimpleNamespace):
model_params['config'] = vars(model_params['config'])
# TODO: these isinstance can go away after 1.10.0
model_type = model_params['config']['model_type']
if isinstance(model_type, str):
model_type = ModelType[model_type]
model_params['config']['model_type'] = model_type
if model_type == ModelType.CNN:
# TODO: these updates are only necessary during the
# transition to the @dataclass version of the config
# Once those are all saved, it is no longer necessary
# to patch existing models (since they will all be patched)
if 'has_charlm_forward' not in model_params['config']:
model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
if 'has_charlm_backward' not in model_params['config']:
model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
model_params['config'][argname] = model_params['config'].get(argname, None)
# TODO: these isinstance can go away after 1.10.0
if isinstance(model_params['config']['wordvec_type'], str):
model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
if isinstance(model_params['config']['extra_wordvec_method'], str):
model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
model_params['config'] = CNNConfig(**model_params['config'])
pretrain = Trainer.load_pretrain(args, foundation_cache)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
if model_params['config'].has_charlm_forward:
charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
else:
charmodel_forward = None
if model_params['config'].has_charlm_backward:
charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
else:
charmodel_backward = None
bert_model = model_params['config'].bert_model
# TODO: can get rid of the getattr after rebuilding all models
use_peft = getattr(model_params['config'], 'use_peft', False)
force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)
peft_name = None
if use_peft:
# if loading a peft model, we first load the base transformer
# the CNNClassifier code wraps the transformer in peft
# after creating the CNNClassifier with the peft wrapper,
# we *then* load the weights
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache)
bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)
elif force_bert_saved:
bert_model, bert_tokenizer = load_bert(bert_model)
else:
bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=model_params['extra_vocab'],
labels=model_params['labels'],
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
force_bert_saved=force_bert_saved,
peft_name=peft_name,
args=model_params['config'])
elif model_type == ModelType.CONSTITUENCY:
# the constituency version doesn't have a peft feature yet
use_peft = False
pretrain_args = {
'wordvec_pretrain_file': args.wordvec_pretrain_file,
'charlm_forward_file': args.charlm_forward_file,
'charlm_backward_file': args.charlm_backward_file,
}
# TODO: integrate with peft for the constituency version
tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
model_params['config'] = ConstituencyConfig(**model_params['config'])
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
labels=model_params['labels'],
args=model_params['config'])
else:
raise ValueError("Unknown model type {}".format(model_type))
model.load_state_dict(model_params['model'], strict=False)
model = model.to(args.device)
logger.debug("-- MODEL CONFIG --")
for k in model.config.__dict__:
logger.debug(" --{}: {}".format(k, model.config.__dict__[k]))
logger.debug("-- MODEL LABELS --")
logger.debug(" {}".format(" ".join(model.labels)))
optimizer = None
if load_optimizer:
optimizer = Trainer.build_optimizer(model, args)
if checkpoint.get('optimizer_state_dict', None) is not None:
for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():
optimizer[opt_name].load_state_dict(opt_state_dict)
else:
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
return trainer
def load_pretrain(args, foundation_cache):
if args.wordvec_pretrain_file:
pretrain_file = args.wordvec_pretrain_file
elif args.wordvec_type:
pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())
else:
raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file")
logger.debug("Looking for pretrained vectors in {}".format(pretrain_file))
if os.path.exists(pretrain_file):
return load_pretrain(pretrain_file, foundation_cache)
elif args.wordvec_raw_file:
vec_file = args.wordvec_raw_file
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
else:
vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)
logger.debug("Embedding shape: %s" % str(pretrain.emb.shape))
return pretrain
@staticmethod
def build_new_model(args, train_set):
"""
Load pretrained pieces and then build a new model
"""
if train_set is None:
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
labels = data.dataset_labels(train_set)
if args.model_type == ModelType.CNN:
pretrain = Trainer.load_pretrain(args, foundation_cache=None)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
peft_name = None
bert_model, bert_tokenizer = load_bert(args.bert_model)
use_peft = getattr(args, "use_peft", False)
if use_peft:
peft_name = "sentiment"
bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)
extra_vocab = data.dataset_vocab(train_set)
force_bert_saved = args.bert_finetune
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=extra_vocab,
labels=labels,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
force_bert_saved=force_bert_saved,
peft_name=peft_name,
args=args)
model = model.to(args.device)
elif args.model_type == ModelType.CONSTITUENCY:
# this passes flags such as "constituency_backprop" from
# the classifier to the TreeEmbedding as the "backprop" flag
parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") }
parser_args.update({
"wordvec_pretrain_file": args.wordvec_pretrain_file,
"charlm_forward_file": args.charlm_forward_file,
"charlm_backward_file": args.charlm_backward_file,
"bert_model": args.bert_model,
# we found that finetuning from the classifier output
# all the way to the bert layers caused the bert model
# to go astray
# could make this an option... but it is much less accurate
# with the Bert finetuning
# noting that the constituency parser itself works better
# after finetuning, of course
"bert_finetune": False,
"stage1_bert_finetune": False,
})
logger.info("Building constituency classifier using %s as the base model" % args.constituency_model)
tree_embedding = TreeEmbedding.from_parser_file(parser_args)
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
labels=labels,
args=args)
model = model.to(args.device)
else:
raise ValueError("Unhandled model type {}".format(args.model_type))
optimizer = Trainer.build_optimizer(model, args)
return Trainer(model, optimizer)
@staticmethod
def build_optimizer(model, args):
return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)
================================================
FILE: stanza/models/classifiers/utils.py
================================================
from enum import Enum
from torch import nn
"""
Defines some methods which may occur in multiple model types
"""
# NLP machines:
# word2vec are in
# /u/nlp/data/stanfordnlp/model_production/stanfordnlp/extern_data/word2vec
# google vectors are in
# /scr/nlp/data/wordvectors/en/google/GoogleNews-vectors-negative300.txt
class WVType(Enum):
WORD2VEC = 1
GOOGLE = 2
FASTTEXT = 3
OTHER = 4
class ExtraVectors(Enum):
NONE = 1
CONCAT = 2
SUM = 3
class ModelType(Enum):
CNN = 1
CONSTITUENCY = 2
def build_output_layers(fc_input_size, fc_shapes, num_classes):
"""
Build a sequence of fully connected layers to go from the final conv layer to num_classes
Returns an nn.ModuleList
"""
fc_layers = []
previous_layer_size = fc_input_size
for shape in fc_shapes:
fc_layers.append(nn.Linear(previous_layer_size, shape))
previous_layer_size = shape
fc_layers.append(nn.Linear(previous_layer_size, num_classes))
return nn.ModuleList(fc_layers)
================================================
FILE: stanza/models/common/__init__.py
================================================
================================================
FILE: stanza/models/common/beam.py
================================================
from __future__ import division
import torch
import stanza.models.common.seq2seq_constant as constant
r"""
Adapted and modified from the OpenNMT project.
Class for managing the internals of the beam search process.
hyp1-hyp1---hyp1 -hyp1
\ /
hyp2 \-hyp2 /-hyp2hyp2
/ \
hyp3-hyp3---hyp3 -hyp3
========================
Takes care of beams, back pointers, and scores.
"""
# TORCH COMPATIBILITY
#
# Here we special case trunc division
# torch < 1.8.0 has no rounding_model='trunc' argument for torch.div
# however, there were several versions in a row where // would loudly
# proclaim it was buggy, and users complained about that
# this hopefully maintains compatibility for torch
try:
a = torch.tensor([1.])
b = torch.tensor([2.])
c = torch.div(a, b, rounding_mode='trunc')
def trunc_division(a, b):
return torch.div(a, b, rounding_mode='trunc')
except TypeError:
def trunc_division(a, b):
return a // b
class Beam(object):
def __init__(self, size, device=None):
self.size = size
self.done = False
# The score for each translation on the beam.
self.scores = torch.zeros(size, dtype=torch.float32, device=device)
self.allScores = []
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [torch.zeros(size, dtype=torch.int64, device=device).fill_(constant.PAD_ID)]
self.nextYs[0][0] = constant.SOS_ID
# The copy indices for each time
self.copy = []
def get_current_state(self):
"Get the outputs for the current timestep."
return self.nextYs[-1]
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
def advance(self, wordLk, copy_indices=None):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `copy_indices` - copy indices (K x ctx_len)
Returns: True if beam search is complete.
"""
if self.done:
return True
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
else:
# first step, expand from the first position
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.allScores.append(self.scores)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
# bestScoreId is the integer ids, and numWords is the integer length.
# Need to do integer division
prevK = trunc_division(bestScoresId, numWords)
self.prevKs.append(prevK)
self.nextYs.append(bestScoresId - prevK * numWords)
if copy_indices is not None:
self.copy.append(copy_indices.index_select(0, prevK))
# End condition is when top-of-beam is EOS.
if self.nextYs[-1][0] == constant.EOS_ID:
self.done = True
self.allScores.append(self.scores)
return self.done
def sort_best(self):
return torch.sort(self.scores, 0, True)
def get_best(self):
"Get the score of the best in the beam."
scores, ids = self.sortBest()
return scores[1], ids[1]
def get_hyp(self, k):
"""
Walk back to construct the full hypothesis.
Parameters:
* `k` - the position in the beam to construct.
Returns: The hypothesis
"""
hyp = []
cpy = []
for j in range(len(self.prevKs) - 1, -1, -1):
hyp.append(self.nextYs[j+1][k])
if len(self.copy) > 0:
cpy.append(self.copy[j][k])
k = self.prevKs[j][k]
hyp = hyp[::-1]
cpy = cpy[::-1]
# postprocess: if cpy index is not -1, use cpy index instead of hyp word
for i,cidx in enumerate(cpy):
if cidx >= 0:
hyp[i] = -(cidx+1) # make index 1-based and flip it for token generation
return hyp
================================================
FILE: stanza/models/common/bert_embedding.py
================================================
import math
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
logger = logging.getLogger('stanza')
BERT_ARGS = {
"vinai/phobert-base": { "use_fast": True },
"vinai/phobert-large": { "use_fast": True },
}
class TextTooLongError(ValueError):
"""
A text was too long for the underlying model (possibly BERT)
"""
def __init__(self, length, max_len, line_num, text):
super().__init__("Found a text of length %d (possibly after tokenizing). Maximum handled length is %d Error occurred at line %d" % (length, max_len, line_num))
self.line_num = line_num
self.text = text
def update_max_length(model_name, tokenizer):
if model_name in ('hf-internal-testing/tiny-bert',
'google/muril-base-cased',
'google/muril-large-cased',
'airesearch/wangchanberta-base-att-spm-uncased',
'camembert/camembert-large',
'hfl/chinese-electra-180g-large-discriminator',
'hfl/chinese-macbert-large',
'NYTK/electra-small-discriminator-hungarian'):
tokenizer.model_max_length = 512
def load_tokenizer(model_name, tokenizer_kwargs=None, local_files_only=False):
if model_name:
# note that use_fast is the default
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
bert_args = BERT_ARGS.get(model_name, dict())
if not model_name.startswith("vinai/phobert"):
bert_args["add_prefix_space"] = True
if tokenizer_kwargs:
bert_args.update(tokenizer_kwargs)
bert_args['local_files_only'] = local_files_only
bert_tokenizer = AutoTokenizer.from_pretrained(model_name, **bert_args)
update_max_length(model_name, bert_tokenizer)
if model_name == 'princeton-nlp/Sheared-LLaMA-1.3B':
bert_tokenizer.pad_token = bert_tokenizer.eos_token
logger.debug("Tokenizer does not have a pad_token - setting to %s (%s)", bert_tokenizer.pad_token, bert_tokenizer.eos_token)
return bert_tokenizer
return None
def load_bert(model_name, tokenizer_kwargs=None, local_files_only=False):
if model_name:
# such as: "vinai/phobert-base"
try:
from transformers import AutoModel
except ImportError:
raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
bert_model = AutoModel.from_pretrained(model_name, local_files_only=local_files_only)
bert_tokenizer = load_tokenizer(model_name, tokenizer_kwargs=tokenizer_kwargs, local_files_only=local_files_only)
return bert_model, bert_tokenizer
return None, None
def tokenize_manual(model_name, sent, tokenizer):
"""
Tokenize a sentence manually, using for checking long sentences and PHOBert.
"""
#replace \xa0 or whatever the space character is by _ since PhoBERT expects _ between syllables
tokenized = [word.replace("\xa0","_").replace(" ", "_") for word in sent] if model_name.startswith("vinai/phobert") else [word.replace("\xa0"," ") for word in sent]
#concatenate to a sentence
sentence = ' '.join(tokenized)
#tokenize using AutoTokenizer PhoBERT
tokenized = tokenizer.tokenize(sentence)
#convert tokens to ids
sent_ids = tokenizer.convert_tokens_to_ids(tokenized)
#add start and end tokens to sent_ids
tokenized_sent = [tokenizer.bos_token_id] + sent_ids + [tokenizer.eos_token_id]
return tokenized, tokenized_sent
def filter_data(model_name, data, tokenizer = None, log_level=logging.DEBUG):
"""
Filter out the (NER, POS) data that is too long for BERT model.
"""
if tokenizer is None:
tokenizer = load_tokenizer(model_name)
filtered_data = []
#eliminate all the sentences that are too long for bert model
for sent in data:
sentence = [word if isinstance(word, str) else word[0] for word in sent]
_, tokenized_sent = tokenize_manual(model_name, sentence, tokenizer)
if len(tokenized_sent) > tokenizer.model_max_length - 2:
continue
filtered_data.append(sent)
logger.log(log_level, "Eliminated %d of %d datapoints because their length is over maximum size of BERT model.", (len(data)-len(filtered_data)), len(data))
return filtered_data
def needs_length_filter(model_name):
"""
TODO: we were lazy and didn't implement any form of length fudging for models other than bert/roberta/electra
"""
if 'bart' in model_name or 'xlnet' in model_name:
return True
if model_name.startswith("vinai/phobert"):
return True
return False
def cloned_feature(feature, num_layers, detach=True):
"""
Clone & detach the feature, keeping the last N layers (or averaging -2,-3,-4 if not specified)
averaging 3 of the last 4 layers worked well for non-VI languages
"""
# in most cases, need to call with features.hidden_states
# bartpho is different - it has features.decoder_hidden_states
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
if num_layers is None:
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
else:
feature = torch.stack(feature[-num_layers:], axis=3)
if detach:
return feature.clone().detach()
else:
return feature
def extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
"""
Handles vi-bart. May need testing before using on other bart
https://github.com/VinAIResearch/BARTpho
"""
processed = [] # final product, returns the list of list of word representation
sentences = [" ".join([word.replace(" ", "_") for word in sentence]) for sentence in data]
tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True)
input_ids = tokenized['input_ids'].to(device)
attention_mask = tokenized['attention_mask'].to(device)
for i in range(int(math.ceil(len(sentences)/128))):
start_sentence = i * 128
end_sentence = min(start_sentence + 128, len(sentences))
input_ids = input_ids[start_sentence:end_sentence]
attention_mask = attention_mask[start_sentence:end_sentence]
if detach:
with torch.no_grad():
features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
else:
features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
for feature, sentence in zip(features, data):
# +2 for the endpoints
feature = feature[:len(sentence)+2]
if not keep_endpoints:
feature = feature[1:-1]
processed.append(feature)
return processed
def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
"""
Extract transformer embeddings using a method specifically for phobert
Since phobert doesn't have the is_split_into_words / tokenized.word_ids(batch_index=0)
capability, we instead look for @@ to denote a continued token.
data: list of list of string (the text tokens)
"""
processed = [] # final product, returns the list of list of word representation
tokenized_sents = [] # list of sentences, each is a torch tensor with start and end token
list_tokenized = [] # list of tokenized sentences from phobert
for idx, sent in enumerate(data):
tokenized, tokenized_sent = tokenize_manual(model_name, sent, tokenizer)
#add tokenized to list_tokenzied for later checking
list_tokenized.append(tokenized)
if len(tokenized_sent) > tokenizer.model_max_length:
logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(tokenized_sent), data[idx])
raise TextTooLongError(len(tokenized_sent), tokenizer.model_max_length, idx, " ".join(data[idx]))
#add to tokenized_sents
tokenized_sents.append(torch.tensor(tokenized_sent).detach())
processed_sent = []
processed.append(processed_sent)
# done loading bert emb
size = len(tokenized_sents)
#padding the inputs
tokenized_sents_padded = torch.nn.utils.rnn.pad_sequence(tokenized_sents,batch_first=True,padding_value=tokenizer.pad_token_id)
features = []
# Feed into PhoBERT 128 at a time in a batch fashion. In testing, the loop was
# run only 1 time as the batch size for the outer model was less than that
# (30 for conparser, for example)
for i in range(int(math.ceil(size/128))):
padded_input = tokenized_sents_padded[128*i:128*i+128]
start_sentence = i * 128
end_sentence = start_sentence + padded_input.shape[0]
attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device)
for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]):
attention_mask[sent_idx, :len(sent)] = 1
if detach:
with torch.no_grad():
# TODO: is the clone().detach() necessary?
feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True)
features += cloned_feature(feature.hidden_states, num_layers, detach)
else:
feature = model(padded_input.to(device), attention_mask=attention_mask, output_hidden_states=True)
features += cloned_feature(feature.hidden_states, num_layers, detach)
assert len(features)==size
assert len(features)==len(processed)
#process the output
#only take the vector of the last word piece of a word/ you can do other methods such as first word piece or averaging.
# idx2+1 compensates for the start token at the start of a sentence
offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if (idx2 > 0 and not list_tokenized[idx][idx2-1].endswith("@@")) or (idx2==0)]
for idx, sent in enumerate(processed)]
if keep_endpoints:
# [0] and [-1] grab the start and end representations as well
offsets = [[0] + off + [-1] for off in offsets]
processed = [feature[offset] for feature, offset in zip(features, offsets)]
# This is a list of tensors
# Each tensor holds the representation of a sentence extracted from phobert
return processed
BAD_TOKENIZERS = ('bert-base-german-cased',
# the dbmdz tokenizers turn one or more types of characters into empty words
# for example, from PoSTWITA:
# ewww — in viaggio Roma
# the character which may not be rendering properly is 0xFE4FA
# https://github.com/dbmdz/berts/issues/48
'dbmdz/bert-base-german-cased',
'dbmdz/bert-base-italian-xxl-cased',
'dbmdz/bert-base-italian-cased',
'dbmdz/electra-base-italian-xxl-cased-discriminator',
# each of these (perhaps using similar tokenizers?)
# does not digest the script-flip-mark \u200f
'avichr/heBERT',
'onlplab/alephbert-base',
'imvladikon/alephbertgimmel-base-512',
# these indonesian models fail on a sentence in the Indonesian GSD dataset:
# 'Tak', 'dapat', 'disangkal', 'jika', '\u200e', 'kemenangan', ...
# weirdly some other indonesian models (even by the same group) don't have that problem
'cahya/bert-base-indonesian-1.5G',
'indolem/indobert-base-uncased',
'google/muril-base-cased',
'l3cube-pune/marathi-roberta')
def fix_blank_tokens(tokenizer, data):
"""Patch bert tokenizers with missing characters
There is an issue that some tokenizers (so far the German ones identified above)
tokenize soft hyphens or other unknown characters into nothing
If an entire word is tokenized as a soft hyphen, this means the tokenizer
simply vaporizes that word. The result is we're missing an embedding for
an entire word we wanted to use.
The solution we take here is to look for any words which get vaporized
in such a manner, eg `len(token) == 2`, and replace it with a regular "-"
Actually, recently we have found that even the Bert / Electra tokenizer
can do this in the case of "words" which are one special character long,
so the easiest thing to do is just always run this function
"""
new_data = []
for sentence in data:
tokenized = tokenizer(sentence, is_split_into_words=False).input_ids
new_sentence = [word if len(token) > 2 else "-" for word, token in zip(sentence, tokenized)]
new_data.append(new_sentence)
return new_data
def extract_llama_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
# will calculate attention masks ourselves later
tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)
list_offsets = []
for idx in range(len(data)):
converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
list_offsets.append(converted_offsets)
if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
raise ValueError("OOPS, hit None when preparing to use transformer at idx {}\ndata[idx]: {}\nlist_offsets[idx]: {}\ntokenizer output: {}".format(idx, data[idx], list_offsets[idx], tokenized))
features = []
for i in range(int(math.ceil(len(data)/128))):
id_rows = [id_row + [tokenizer.eos_token_id] for id_row in tokenized['input_ids'][128*i:128*i+128]]
max_id_len = max(len(x) for x in id_rows)
attention_tensor = torch.zeros((len(id_rows), max_id_len), dtype=torch.long, device=device)
for idx, id_row in enumerate(id_rows):
attention_tensor[idx, :len(id_row)] = 1
if len(id_row) < max_id_len:
# actually this value doesn't matter... autoregressive
id_row.extend([0] * (max_id_len - len(id_row)))
id_tensor = torch.tensor(id_rows, device=device)
if detach:
with torch.no_grad():
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
else:
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
processed = []
#process the output
if not keep_endpoints:
#remove the bos and eos tokens
list_offsets = [sent[1:-1] for sent in list_offsets]
for feature, offsets in zip(features, list_offsets):
new_sent = feature[offsets]
processed.append(new_sent)
return processed
def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
# using attention masks makes contextual embeddings much more useful for downstream tasks
tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)
#tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
list_offsets = [[None] * (len(sentence)+2) for sentence in data]
for idx in range(len(data)):
offsets = tokenized.word_ids(batch_index=idx)
list_offsets[idx][0] = 0
for pos, offset in enumerate(offsets):
if offset is None:
break
# this uses the last token piece for any offset by overwriting the previous value
# this will be one token earlier
# we will add a to the start of each sentence for the endpoints
list_offsets[idx][offset+1] = pos + 1
list_offsets[idx][-1] = list_offsets[idx][-2] + 1
if any(x is None for x in list_offsets[idx]):
raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))
if len(offsets) > tokenizer.model_max_length - 2:
logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(offsets), data[idx])
raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
features = []
for i in range(int(math.ceil(len(data)/128))):
# TODO: find a suitable representation for attention masks for xlnet
# xlnet base on WSJ:
# sep_token_id at beginning, cls_token_id at end: 0.9441
# bos_token_id at beginning, eos_token_id at end: 0.9463
# bos_token_id at beginning, sep_token_id at end: 0.9459
# bos_token_id at beginning, cls_token_id at end: 0.9457
# bos_token_id at beginning, sep/cls at end: 0.9454
# use the xlnet tokenization with words at end,
# begin token is last pad, end token is sep, no mask: 0.9463
# same, but with masks: 0.9440
input_ids = [[tokenizer.bos_token_id] + x[:-2] + [tokenizer.eos_token_id] for x in tokenized['input_ids'][128*i:128*i+128]]
max_len = max(len(x) for x in input_ids)
attention_mask = torch.zeros(len(input_ids), max_len, dtype=torch.long, device=device)
for idx, input_row in enumerate(input_ids):
attention_mask[idx, :len(input_row)] = 1
if len(input_row) < max_len:
input_row.extend([tokenizer.pad_token_id] * (max_len - len(input_row)))
if detach:
with torch.no_grad():
id_tensor = torch.tensor(input_ids, device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
# feature = feature[2]
features += cloned_feature(feature.hidden_states, num_layers, detach)
else:
id_tensor = torch.tensor(input_ids, device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
# feature = feature[2]
features += cloned_feature(feature.hidden_states, num_layers, detach)
processed = []
#process the output
if not keep_endpoints:
#remove the bos and eos tokens
list_offsets = [sent[1:-1] for sent in list_offsets]
for feature, offsets in zip(features, list_offsets):
new_sent = feature[offsets]
processed.append(new_sent)
return processed
def build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device):
"""
Extract an embedding from the given transformer for a certain attention mask and tokens range
In the event that the tokens are longer than the max length
supported by the model, the range is split up into overlapping
sections and the overlapping pieces are connected. No idea if
this is actually any good, but at least it returns something
instead of horribly failing
TODO: at least two upgrades are very relevant
1) cut off some overlap at the end as well
2) use this on the phobert, bart, and xln versions as well
"""
if attention_tensor.shape[1] <= tokenizer.model_max_length:
features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)
features = cloned_feature(features.hidden_states, num_layers, detach)
return features
slices = []
slice_len = max(tokenizer.model_max_length - 20, tokenizer.model_max_length // 2)
prefix_len = tokenizer.model_max_length - slice_len
if slice_len < 5:
raise RuntimeError("Really tiny tokenizer!")
remaining_attention = attention_tensor
remaining_ids = id_tensor
while True:
attention_slice = remaining_attention[:, :tokenizer.model_max_length]
id_slice = remaining_ids[:, :tokenizer.model_max_length]
features = model(id_slice, attention_mask=attention_slice, output_hidden_states=True)
features = cloned_feature(features.hidden_states, num_layers, detach)
if len(slices) > 0:
features = features[:, prefix_len:, :]
slices.append(features)
if remaining_attention.shape[1] <= tokenizer.model_max_length:
break
remaining_attention = remaining_attention[:, slice_len:]
remaining_ids = remaining_ids[:, slice_len:]
slices = torch.cat(slices, axis=1)
return slices
def convert_to_position_list(sentence, offsets):
"""
Convert a transformers-tokenized sentence's offsets to a list of word to position
"""
# +2 for the beginning and end
list_offsets = [None] * (len(sentence) + 2)
for pos, offset in enumerate(offsets):
if offset is None:
continue
# this uses the last token piece for any offset by overwriting the previous value
list_offsets[offset+1] = pos
list_offsets[0] = 0
for offset in list_offsets[-2::-1]:
# count backwards in case the last position was
# a word or character that got erased by the tokenizer
# this loop should eventually find something...
# after all, we just set the first one to be 0
if offset is not None:
list_offsets[-1] = offset + 1
break
return list_offsets
def extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach):
#add add_prefix_space = True for RoBerTa-- error if not
# using attention masks makes contextual embeddings much more useful for downstream tasks
tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
list_offsets = []
for idx in range(len(data)):
converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
list_offsets.append(converted_offsets)
#if list_offsets[idx][-1] > tokenizer.model_max_length - 1:
# logger.error("Invalid size, max size: %d, got %d.\nTokens: %s\nTokenized: %s", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])
# raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
# at least one of the tokens in the data is composed entirely of characters the tokenizer doesn't know about
# one possible approach would be to retokenize only those sentences
# however, in that case the attention mask might be of a different length,
# as would the token ids, and it would be a pain to fix those
# easiest to just retokenize the whole thing, hopefully a rare event
data = fix_blank_tokens(tokenizer, data)
tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
list_offsets = []
for idx in range(len(data)):
converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
list_offsets.append(converted_offsets)
if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
raise ValueError("OOPS, hit None when preparing to use transformer at idx {}\ndata[idx]: {}\nlist_offsets[idx]: {}\ntokenizer output: {}".format(idx, data[idx], list_offsets[idx], tokenized))
features = []
for i in range(int(math.ceil(len(data)/128))):
attention_tensor = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
if detach:
with torch.no_grad():
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
else:
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
processed = []
#process the output
if not keep_endpoints:
#remove the bos and eos tokens
list_offsets = [sent[1:-1] for sent in list_offsets]
for feature, offsets in zip(features, list_offsets):
new_sent = feature[offsets]
processed.append(new_sent)
return processed
def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None, detach=True, peft_name=None):
"""
Extract transformer embeddings using a generic roberta extraction
data: list of list of string (the text tokens)
num_layers: how many to return. If None, the average of -2, -3, -4 is returned
"""
# TODO: can maybe cache this value for a model and save some time
# TODO: too bad it isn't thread safe, but then again, who does?
if peft_name is None:
if model._hf_peft_config_loaded:
model.disable_adapters()
else:
model.enable_adapters()
model.set_adapter(peft_name)
if model_name.startswith("vinai/phobert"):
return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
if 'bart' in model_name:
# this should work with "vinai/bartpho-word"
# not sure this works with any other Bart
return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
if isinstance(data, tuple):
data = list(data)
if "xlnet" in model_name:
return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
if "LLaMA" in model_name:
return extract_llama_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
return extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
================================================
FILE: stanza/models/common/biaffine.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class PairwiseBilinear(nn.Module):
''' A bilinear module that deals with broadcasting for efficient memory usage.
Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
Output: tensor of size (N x L1 x L2 x O)'''
def __init__(self, input1_size, input2_size, output_size, bias=True):
super().__init__()
self.input1_size = input1_size
self.input2_size = input2_size
self.output_size = output_size
self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else 0
def forward(self, input1, input2):
input1_size = list(input1.size())
input2_size = list(input2.size())
output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]
# ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)
intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size))
# (N x L2 x D2) -> (N x D2 x L2)
input2 = input2.transpose(1, 2)
# (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)
output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
# (N x (L1 x O) x L2) -> (N x L1 x L2 x O)
output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
return output
class BiaffineScorer(nn.Module):
def __init__(self, input1_size, input2_size, output_size):
super().__init__()
self.W_bilin = nn.Bilinear(input1_size + 1, input2_size + 1, output_size)
self.W_bilin.weight.data.zero_()
self.W_bilin.bias.data.zero_()
def forward(self, input1, input2):
input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
return self.W_bilin(input1, input2)
class PairwiseBiaffineScorer(nn.Module):
def __init__(self, input1_size, input2_size, output_size):
super().__init__()
self.W_bilin = PairwiseBilinear(input1_size + 1, input2_size + 1, output_size)
self.W_bilin.weight.data.zero_()
self.W_bilin.bias.data.zero_()
def forward(self, input1, input2):
input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
return self.W_bilin(input1, input2)
class DeepBiaffineScorer(nn.Module):
def __init__(self, input1_size, input2_size, hidden_size, output_size, hidden_func=F.relu, dropout=0, pairwise=True):
super().__init__()
self.W1 = nn.Linear(input1_size, hidden_size)
self.W2 = nn.Linear(input2_size, hidden_size)
self.hidden_func = hidden_func
if pairwise:
self.scorer = PairwiseBiaffineScorer(hidden_size, hidden_size, output_size)
else:
self.scorer = BiaffineScorer(hidden_size, hidden_size, output_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input1, input2):
return self.scorer(self.dropout(self.hidden_func(self.W1(input1))), self.dropout(self.hidden_func(self.W2(input2))))
if __name__ == "__main__":
x1 = torch.randn(3,4)
x2 = torch.randn(3,5)
scorer = DeepBiaffineScorer(4, 5, 6, 7)
print(scorer(x1, x2))
================================================
FILE: stanza/models/common/build_short_name_to_treebank.py
================================================
import glob
import os
from stanza.models.common.constant import treebank_to_short_name, UnknownLanguageError, treebank_special_cases
from stanza.utils import default_paths
paths = default_paths.get_default_paths()
udbase = paths["UDBASE"]
directories = glob.glob(udbase + "/UD_*")
directories.sort()
output_name = os.path.join(os.path.split(__file__)[0], "short_name_to_treebank.py")
ud_names = [os.path.split(ud_path)[1] for ud_path in directories]
short_names = []
# check that all languages are known in the language map
# use that language map to come up with a shortname for these treebanks
for directory, ud_name in zip(directories, ud_names):
try:
short_names.append(treebank_to_short_name(ud_name))
except UnknownLanguageError as e:
raise UnknownLanguageError("Could not find language short name for dataset %s, path %s" % (ud_name, directory)) from e
for directory, ud_name in zip(directories, ud_names):
if ud_name.startswith("UD_Norwegian"):
if ud_name not in treebank_special_cases:
raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
if ud_name.startswith("UD_Chinese"):
if ud_name not in treebank_special_cases:
raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
max_len = max(len(x) for x in short_names) + 8
line_format = " %-" + str(max_len) + "s '%s',\n"
print("Writing to %s" % output_name)
with open(output_name, "w") as fout:
fout.write("# This module is autogenerated by build_short_name_to_treebank.py\n")
fout.write("# Please do not edit\n")
fout.write("\n")
fout.write("SHORT_NAMES = {\n")
for short_name, ud_name in zip(short_names, ud_names):
fout.write(line_format % ("'" + short_name + "':", ud_name))
if short_name.startswith("zh_"):
short_name = "zh-hans_" + short_name[3:]
fout.write(line_format % ("'" + short_name + "':", ud_name))
elif short_name.startswith("zh-hans_") or short_name.startswith("zh-hant_"):
short_name = "zh_" + short_name[8:]
fout.write(line_format % ("'" + short_name + "':", ud_name))
elif short_name == 'nb_bokmaal':
short_name = 'no_bokmaal'
fout.write(line_format % ("'" + short_name + "':", ud_name))
fout.write("}\n")
fout.write("""
def short_name_to_treebank(short_name):
return SHORT_NAMES[short_name]
""")
max_len = max(len(x) for x in ud_names) + 5
line_format = " %-" + str(max_len) + "s '%s',\n"
fout.write("CANONICAL_NAMES = {\n")
for ud_name in ud_names:
fout.write(line_format % ("'" + ud_name.lower() + "':", ud_name))
fout.write("}\n")
fout.write("""
def canonical_treebank_name(ud_name):
if ud_name in SHORT_NAMES:
return SHORT_NAMES[ud_name]
return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
""")
================================================
FILE: stanza/models/common/char_model.py
================================================
"""
Based on
@inproceedings{akbik-etal-2018-contextual,
title = "Contextual String Embeddings for Sequence Labeling",
author = "Akbik, Alan and
Blythe, Duncan and
Vollgraf, Roland",
booktitle = "Proceedings of the 27th International Conference on Computational Linguistics",
month = aug,
year = "2018",
address = "Santa Fe, New Mexico, USA",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/C18-1139",
pages = "1638--1649",
}
"""
from collections import Counter
from operator import itemgetter
import os
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pack_padded_sequence, PackedSequence
from stanza.models.common.data import get_long_tensor
from stanza.models.common.packed_lstm import PackedLSTM
from stanza.models.common.utils import open_read_text, tensor_unsort, unsort
from stanza.models.common.dropout import SequenceUnitDropout
from stanza.models.common.vocab import UNK_ID, CharVocab
class CharacterModel(nn.Module):
def __init__(self, args, vocab, pad=False, bidirectional=False, attention=True):
super().__init__()
self.args = args
self.pad = pad
self.num_dir = 2 if bidirectional else 1
self.attn = attention
# char embeddings
self.char_emb = nn.Embedding(len(vocab['char']), self.args['char_emb_dim'], padding_idx=0)
if self.attn:
self.char_attn = nn.Linear(self.num_dir * self.args['char_hidden_dim'], 1, bias=False)
self.char_attn.weight.data.zero_()
# modules
self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
dropout=0 if self.args['char_num_layers'] == 1 else args['dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=bidirectional)
self.charlstm_h_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
self.charlstm_c_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
self.dropout = nn.Dropout(args['dropout'])
def forward(self, chars, chars_mask, word_orig_idx, sentlens, wordlens):
embs = self.dropout(self.char_emb(chars))
batch_size = embs.size(0)
embs = pack_padded_sequence(embs, wordlens, batch_first=True)
output = self.charlstm(embs, wordlens, hx=(\
self.charlstm_h_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(), \
self.charlstm_c_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous()))
# apply attention, otherwise take final states
if self.attn:
char_reps = output[0]
weights = torch.sigmoid(self.char_attn(self.dropout(char_reps.data)))
char_reps = PackedSequence(char_reps.data * weights, char_reps.batch_sizes)
char_reps, _ = pad_packed_sequence(char_reps, batch_first=True)
res = char_reps.sum(1)
else:
h, c = output[1]
res = h[-2:].transpose(0,1).contiguous().view(batch_size, -1)
# recover character order and word separation
res = tensor_unsort(res, word_orig_idx)
res = pack_sequence(res.split(sentlens))
if self.pad:
res = pad_packed_sequence(res, batch_first=True)[0]
return res
def build_charlm_vocab(path, cutoff=0):
"""
Build a vocab for a CharacterLanguageModel
Requires a large amount of memory, but only need to build once
here we need some trick to deal with excessively large files
for each file we accumulate the counter of characters, and
at the end we simply pass a list of chars to the vocab builder
"""
counter = Counter()
if os.path.isdir(path):
filenames = sorted(os.listdir(path))
else:
filenames = [os.path.split(path)[1]]
path = os.path.split(path)[0]
for filename in filenames:
filename = os.path.join(path, filename)
with open_read_text(filename) as fin:
for line in fin:
counter.update(list(line))
if len(counter) == 0:
raise ValueError("Training data was empty!")
# remove infrequent characters from vocab
for k in list(counter.keys()):
if counter[k] < cutoff:
del counter[k]
# a singleton list of all characters
data = [sorted([x[0] for x in counter.most_common()])]
if len(data[0]) == 0:
raise ValueError("All characters in the training data were less frequent than --cutoff!")
vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
return vocab
CHARLM_START = "\n"
CHARLM_END = " "
class CharacterLanguageModel(nn.Module):
def __init__(self, args, vocab, pad=False, is_forward_lm=True):
super().__init__()
self.args = args
self.vocab = vocab
self.is_forward_lm = is_forward_lm
self.pad = pad
self.finetune = True # always finetune unless otherwise specified
# char embeddings
self.char_emb = nn.Embedding(len(self.vocab['char']), self.args['char_emb_dim'], padding_idx=None) # we use space as padding, so padding_idx is not necessary
# modules
self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
dropout=0 if self.args['char_num_layers'] == 1 else args['char_dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=False)
self.charlstm_h_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
self.charlstm_c_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
# decoder
self.decoder = nn.Linear(self.args['char_hidden_dim'], len(self.vocab['char']))
self.dropout = nn.Dropout(args['char_dropout'])
self.char_dropout = SequenceUnitDropout(args.get('char_unit_dropout', 0), UNK_ID)
def forward(self, chars, charlens, hidden=None):
chars = self.char_dropout(chars)
embs = self.dropout(self.char_emb(chars))
batch_size = embs.size(0)
embs = pack_padded_sequence(embs, charlens, batch_first=True)
if hidden is None:
hidden = (self.charlstm_h_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(),
self.charlstm_c_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous())
output, hidden = self.charlstm(embs, charlens, hx=hidden)
output = self.dropout(pad_packed_sequence(output, batch_first=True)[0])
decoded = self.decoder(output)
return output, hidden, decoded
def get_representation(self, chars, charoffsets, charlens, char_orig_idx):
with torch.no_grad():
output, _, _ = self.forward(chars, charlens)
res = [output[i, offsets] for i, offsets in enumerate(charoffsets)]
res = unsort(res, char_orig_idx)
res = pack_sequence(res)
if self.pad:
res = pad_packed_sequence(res, batch_first=True)[0]
return res
def per_char_representation(self, words):
device = next(self.parameters()).device
vocab = self.char_vocab()
all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)]
all_data.sort(key=itemgetter(1), reverse=True)
chars = [x[0] for x in all_data]
char_lens = [x[1] for x in all_data]
char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
with torch.no_grad():
output, _, _ = self.forward(char_tensor, char_lens)
output = [x[:y, :] for x, y in zip(output, char_lens)]
output = unsort(output, [x[2] for x in all_data])
return output
def build_char_representation(self, sentences):
"""
Return values from this charlm for a list of list of words
input: [[str]]
K sentences, each of length Ki (can be different for each sentence)
output: [tensor(Ki x dim)]
list of tensors, each one with shape Ki by the dim of the character model
Values are taken from the last character in a word for each word.
The words are effectively treated as if they are whitespace separated
(which may actually be somewhat inaccurate for languages such as Chinese or for MWT)
"""
forward = self.is_forward_lm
vocab = self.char_vocab()
device = next(self.parameters()).device
all_data = []
for idx, words in enumerate(sentences):
if not forward:
words = [x[::-1] for x in reversed(words)]
chars = [CHARLM_START]
offsets = []
for w in words:
chars.extend(w)
chars.append(CHARLM_END)
offsets.append(len(chars) - 1)
if not forward:
offsets.reverse()
chars = vocab.map(chars)
all_data.append((chars, offsets, len(chars), len(all_data)))
all_data.sort(key=itemgetter(2), reverse=True)
chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
# TODO: can this be faster?
chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
with torch.no_grad():
output, _, _ = self.forward(chars, char_lens)
res = [output[i, offsets] for i, offsets in enumerate(char_offsets)]
res = unsort(res, orig_idx)
return res
def hidden_dim(self):
return self.args['char_hidden_dim']
def char_vocab(self):
return self.vocab['char']
def train(self, mode=True):
"""
Override the default train() function, so that when self.finetune == False, the training mode
won't be impacted by the parent models' status change.
"""
if not mode: # eval() is always allowed, regardless of finetune status
super().train(mode)
else:
if self.finetune: # only set to training mode in finetune status
super().train(mode)
def full_state(self):
state = {
'vocab': self.vocab['char'].state_dict(),
'args': self.args,
'state_dict': self.state_dict(),
'pad': self.pad,
'is_forward_lm': self.is_forward_lm
}
return state
def save(self, filename):
os.makedirs(os.path.split(filename)[0], exist_ok=True)
state = self.full_state()
torch.save(state, filename, _use_new_zipfile_serialization=False)
@classmethod
def from_full_state(cls, state, finetune=False):
vocab = {'char': CharVocab.load_state_dict(state['vocab'])}
model = cls(state['args'], vocab, state['pad'], state['is_forward_lm'])
model.load_state_dict(state['state_dict'])
model.eval()
model.finetune = finetune # set finetune status
return model
@classmethod
def load(cls, filename, finetune=False):
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
# allow saving just the Model object,
# and allow for old charlms to still work
if 'state_dict' in state:
return cls.from_full_state(state, finetune)
return cls.from_full_state(state['model'], finetune)
class CharacterLanguageModelWordAdapter(nn.Module):
"""
Adapts a character model to return embeddings for each character in a word
"""
def __init__(self, charlms):
super().__init__()
self.charlms = charlms
def forward(self, words, wrap=True):
if wrap:
words = [CHARLM_START + x + CHARLM_END for x in words]
padded_reps = []
for charlm in self.charlms:
rep = charlm.per_char_representation(words)
padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device)
for idx, row in enumerate(rep):
padded_rep[idx, :row.shape[0], :] = row
padded_reps.append(padded_rep)
padded_rep = torch.cat(padded_reps, dim=2)
return padded_rep
def hidden_dim(self):
return sum(charlm.hidden_dim() for charlm in self.charlms)
class CharacterLanguageModelTrainer():
def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0):
self.model = model
self.params = params
self.optimizer = optimizer
self.criterion = criterion
self.scheduler = scheduler
self.epoch = epoch
self.global_step = global_step
def save(self, filename, full=True):
os.makedirs(os.path.split(filename)[0], exist_ok=True)
state = {
'model': self.model.full_state(),
'epoch': self.epoch,
'global_step': self.global_step,
}
if full and self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if full and self.criterion is not None:
state['criterion'] = self.criterion.state_dict()
if full and self.scheduler is not None:
state['scheduler'] = self.scheduler.state_dict()
torch.save(state, filename, _use_new_zipfile_serialization=False)
@classmethod
def from_new_model(cls, args, vocab):
model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)
model = model.to(args['device'])
params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args['anneal'], patience=args['patience'])
return cls(model, params, optimizer, criterion, scheduler)
@classmethod
def load(cls, args, filename, finetune=False):
"""
Load the model along with any other saved state for training
Note that you MUST set finetune=True if planning to continue training
Otherwise the only benefit you will get will be a warm GPU
"""
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
model = CharacterLanguageModel.from_full_state(state['model'], finetune)
model = model.to(args['device'])
params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
if 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
criterion = torch.nn.CrossEntropyLoss()
if 'criterion' in state: criterion.load_state_dict(state['criterion'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args['anneal'], patience=args['patience'])
if 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
epoch = state.get('epoch', 1)
global_step = state.get('global_step', 0)
return cls(model, params, optimizer, criterion, scheduler, epoch, global_step)
================================================
FILE: stanza/models/common/chuliu_edmonds.py
================================================
# Adapted from Tim's code here: https://github.com/tdozat/Parser-v3/blob/master/scripts/chuliu_edmonds.py
import numpy as np
def tarjan(tree):
"""Finds the cycles in a dependency graph
The input should be a numpy array of integers,
where in the standard use case,
tree[i] is the head of node i.
tree[0] == 0 to represent the root
so for example, for the English sentence "This is a test",
the input is
[0 4 4 4 0]
"Arthritis makes my hip hurt"
[0 2 0 4 2 2]
The return is a list of cycles, where in cycle has True if the
node at that index is participating in the cycle.
So, for example, the previous examples both return empty lists,
whereas an input of
np.array([0, 3, 1, 2])
has an output of
[np.array([False, True, True, True])]
"""
indices = -np.ones_like(tree)
lowlinks = -np.ones_like(tree)
onstack = np.zeros_like(tree, dtype=bool)
stack = list()
_index = [0]
cycles = []
#-------------------------------------------------------------
def maybe_pop_cycle(i):
if lowlinks[i] == indices[i]:
# There's a cycle!
cycle = np.zeros_like(indices, dtype=bool)
while stack[-1] != i:
j = stack.pop()
onstack[j] = False
cycle[j] = True
stack.pop()
onstack[i] = False
cycle[i] = True
if cycle.sum() > 1:
cycles.append(cycle)
def initialize_strong_connect(i):
_index[0] += 1
index = _index[-1]
indices[i] = lowlinks[i] = index - 1
stack.append(i)
onstack[i] = True
def strong_connect(i):
# this ridiculous atrocity is because somehow people keep
# coming up with graphs which overflow python's call stack
# so instead we make our own call stack and turn the recursion
# into a loop
# see for example
# https://github.com/stanfordnlp/stanza/issues/962
# https://github.com/spraakbanken/sparv-pipeline/issues/166
# in an ideal world this block of code would look like this
# initialize_strong_connect(i)
# dependents = iter(np.where(np.equal(tree, i))[0])
# for j in dependents:
# if indices[j] == -1:
# strong_connect(j)
# lowlinks[i] = min(lowlinks[i], lowlinks[j])
# elif onstack[j]:
# lowlinks[i] = min(lowlinks[i], indices[j])
#
# maybe_pop_cycle(i)
call_stack = [(i, None, None)]
while len(call_stack) > 0:
i, dependents_iterator, j = call_stack.pop()
if dependents_iterator is None: # first time getting here for this i
initialize_strong_connect(i)
dependents_iterator = iter(np.where(np.equal(tree, i))[0])
else: # been here before. j was the dependent we were just considering
lowlinks[i] = min(lowlinks[i], lowlinks[j])
for j in dependents_iterator:
if indices[j] == -1:
# have to remember where we were...
# put the current iterator & its state on the "call stack"
# we will come back to it later
call_stack.append((i, dependents_iterator, j))
# also, this is what we do next...
call_stack.append((j, None, None))
# this will break this iterator for now
# the next time through, we will continue progressing this iterator
break
elif onstack[j]:
lowlinks[i] = min(lowlinks[i], indices[j])
else:
# this is an intended use of for/else
# please stop filing git issues on obscure language features
# we finished iterating without a break
# and can finally resolve any possible cycles
maybe_pop_cycle(i)
# at this point, there are two cases:
#
# we iterated all the way through an iterator (the else in the for/else)
# and have resolved any possible cycles. can then proceed to the previous
# iterator we were considering (or finish, if there are no others)
# OR
# we have hit a break in the iteration over the dependents
# for a node
# and we need to dig deeper into the graph and resolve the dependent's dependents
# before we can continue the previous node
#
# either way, we check to see if there are unfinished subtrees
# when that is finally done, we can return
#-------------------------------------------------------------
for i in range(len(tree)):
if indices[i] == -1:
strong_connect(i)
return cycles
def process_cycle(tree, cycle, scores):
"""
Build a subproblem with one cycle broken
"""
# indices of cycle in original tree; (c) in t
cycle_locs = np.where(cycle)[0]
# heads of cycle in original tree; (c) in t
cycle_subtree = tree[cycle]
# scores of cycle in original tree; (c) in R
cycle_scores = scores[cycle, cycle_subtree]
# total score of cycle; () in R
cycle_score = cycle_scores.sum()
# locations of noncycle; (t) in [0,1]
noncycle = np.logical_not(cycle)
# indices of noncycle in original tree; (n) in t
noncycle_locs = np.where(noncycle)[0]
#print(cycle_locs, noncycle_locs)
# scores of cycle's potential heads; (c x n) - (c) + () -> (n x c) in R
metanode_head_scores = scores[cycle][:,noncycle] - cycle_scores[:,None] + cycle_score
# scores of cycle's potential dependents; (n x c) in R
metanode_dep_scores = scores[noncycle][:,cycle]
# best noncycle head for each cycle dependent; (n) in c
metanode_heads = np.argmax(metanode_head_scores, axis=0)
# best cycle head for each noncycle dependent; (n) in c
metanode_deps = np.argmax(metanode_dep_scores, axis=1)
# scores of noncycle graph; (n x n) in R
subscores = scores[noncycle][:,noncycle]
# pad to contracted graph; (n+1 x n+1) in R
subscores = np.pad(subscores, ( (0,1) , (0,1) ), 'constant')
# set the contracted graph scores of cycle's potential heads; (c x n)[:, (n) in n] in R -> (n) in R
subscores[-1, :-1] = metanode_head_scores[metanode_heads, np.arange(len(noncycle_locs))]
# set the contracted graph scores of cycle's potential dependents; (n x c)[(n) in n] in R-> (n) in R
subscores[:-1,-1] = metanode_dep_scores[np.arange(len(noncycle_locs)), metanode_deps]
return subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps
def expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps):
"""
Given a partially solved tree with a cycle and a solved subproblem
for the cycle, build a larger solution without the cycle
"""
# head of the cycle; () in n
#print(contracted_tree)
cycle_head = contracted_tree[-1]
# fixed tree: (n) in n+1
contracted_tree = contracted_tree[:-1]
# initialize new tree; (t) in 0
new_tree = -np.ones_like(tree)
#print(0, new_tree)
# fixed tree with no heads coming from the cycle: (n) in [0,1]
contracted_subtree = contracted_tree < len(contracted_tree)
# add the nodes to the new tree (t)[(n)[(n) in [0,1]] in t] in t = (n)[(n)[(n) in [0,1]] in n] in t
new_tree[noncycle_locs[contracted_subtree]] = noncycle_locs[contracted_tree[contracted_subtree]]
#print(1, new_tree)
# fixed tree with heads coming from the cycle: (n) in [0,1]
contracted_subtree = np.logical_not(contracted_subtree)
# add the nodes to the tree (t)[(n)[(n) in [0,1]] in t] in t = (c)[(n)[(n) in [0,1]] in c] in t
new_tree[noncycle_locs[contracted_subtree]] = cycle_locs[metanode_deps[contracted_subtree]]
#print(2, new_tree)
# add the old cycle to the tree; (t)[(c) in t] in t = (t)[(c) in t] in t
new_tree[cycle_locs] = tree[cycle_locs]
#print(3, new_tree)
# root of the cycle; (n)[() in n] in c = () in c
cycle_root = metanode_heads[cycle_head]
# add the root of the cycle to the new tree; (t)[(c)[() in c] in t] = (c)[() in c]
new_tree[cycle_locs[cycle_root]] = noncycle_locs[cycle_head]
#print(4, new_tree)
return new_tree
def prepare_scores(scores):
"""
Alter the scores matrix to avoid self loops and handle the root
"""
# prevent self-loops, set up the root location
np.fill_diagonal(scores, -float('inf')) # prevent self-loops
scores[0] = -float('inf')
scores[0,0] = 0
def chuliu_edmonds(scores):
subtree_stack = []
prepare_scores(scores)
tree = np.argmax(scores, axis=1)
cycles = tarjan(tree)
#print(scores)
#print(cycles)
# recursive implementation:
#if cycles:
# # t = len(tree); c = len(cycle); n = len(noncycle)
# # cycles.pop(): locations of cycle; (t) in [0,1]
# subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
# # MST with contraction; (n+1) in n+1
# contracted_tree = chuliu_edmonds(subscores)
# tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
# unfortunately, while the recursion is simpler to understand, it can get too deep for python's stack limit
# so instead we make our own recursion, with blackjack and (you know how it goes)
while cycles:
# t = len(tree); c = len(cycle); n = len(noncycle)
# cycles.pop(): locations of cycle; (t) in [0,1]
subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
subtree_stack.append((tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps))
scores = subscores
prepare_scores(scores)
tree = np.argmax(scores, axis=1)
cycles = tarjan(tree)
while len(subtree_stack) > 0:
contracted_tree = tree
(tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps) = subtree_stack.pop()
tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
return tree
#===============================================================
def chuliu_edmonds_one_root(scores):
"""
Return the results of the dependency tree search, but with exactly one link to root (0)
scores is a numpy array, with scores[x][y] should be the cost for assigning y to be the head of x
Here we reweight the root arcs so as to ensure that the picker only ever chooses one root.
See for example
https://aclanthology.org/2021.emnlp-main.823/
A Root of a Problem: Optimizing Single-Root Dependency Parsing
Miloš Stanojević, Shay B. Cohen
"""
# we fiddle the scores to prevent double root arcs
# we therefore copy the array so it doesn't get messed up at the source
scores = scores.copy()
scores = scores.astype(np.float64)
min_score = scores[np.isfinite(scores)].min()
scores[:, 0] = scores[:, 0] + (min_score * scores.shape[0])
tree = chuliu_edmonds(scores)
# +1 because we cut off the first column of the tree
roots_to_try = np.where(np.equal(tree[1:], 0))[0]+1
assert len(roots_to_try) == 1, "Rescaling by the lowest score should have prevented using multiple root edges"
return tree
================================================
FILE: stanza/models/common/constant.py
================================================
"""
Global constants.
These language codes mirror UD language codes when possible
"""
import re
class UnknownLanguageError(ValueError):
pass
# tuples in a list so we can assert that the langcodes are all unique
# When applicable, we favor the UD decision over any other possible
# language code or language name
# An example of this is sab -> Bokota, instead of bgd in ISO 693-3
# ISO 639-1 is out of date, but many of the UD datasets are labeled
# using the two letter abbreviations, so we add those for non-UD
# languages in the hopes that we've guessed right if those languages
# are eventually processed
lcode2lang_raw = [
("abq", "Abaza"),
("ab", "Abkhazian"),
("aa", "Afar"),
("af", "Afrikaans"),
("ak", "Akan"),
("akk", "Akkadian"),
("aqz", "Akuntsu"),
("sq", "Albanian"),
("am", "Amharic"),
("grc", "Ancient_Greek"),
("hbo", "Ancient_Hebrew"),
("apu", "Apurina"),
("ar", "Arabic"),
("arz", "Egyptian_Arabic"),
("an", "Aragonese"),
("hy", "Armenian"),
("as", "Assamese"),
("aii", "Assyrian"),
("ast", "Asturian"),
("av", "Avaric"),
("ae", "Avestan"),
("ay", "Aymara"),
("az", "Azerbaijani"),
("bm", "Bambara"),
("ba", "Bashkir"),
("eu", "Basque"),
("bar", "Bavarian"),
("bej", "Beja"),
("be", "Belarusian"),
("bn", "Bengali"),
("bho", "Bhojpuri"),
("bpy", "Bishnupriya_Manipuri"),
("bi", "Bislama"),
("bor", "Bororo"),
("sab", "Bokota"),
("bs", "Bosnian"),
("br", "Breton"),
("bg", "Bulgarian"),
("bxr", "Buryat"),
("yue", "Cantonese"),
("cpg", "Cappadocian"),
("ca", "Catalan"),
("ceb", "Cebuano"),
("km", "Central_Khmer"),
("ch", "Chamorro"),
("ce", "Chechen"),
("ny", "Chichewa"),
("ctn", "Chintang"),
("ckt", "Chukchi"),
("cv", "Chuvash"),
("xcl", "Classical_Armenian"),
("lzh", "Classical_Chinese"),
("cop", "Coptic"),
("kw", "Cornish"),
("co", "Corsican"),
("cr", "Cree"),
("hr", "Croatian"),
("cs", "Czech"),
("da", "Danish"),
("dar", "Dargwa"),
("dv", "Dhivehi"),
("nl", "Dutch"),
("dz", "Dzongkha"),
("egy", "Egyptian"),
("en", "English"),
("myv", "Erzya"),
("eo", "Esperanto"),
("et", "Estonian"),
("ee", "Ewe"),
("ext", "Extremaduran"),
("fo", "Faroese"),
("fj", "Fijian"),
("fi", "Finnish"),
("fon", "Fon"),
("fr", "French"),
("qfn", "Frisian_Dutch"),
("ff", "Fulah"),
("gl", "Galician"),
("lg", "Ganda"),
("ka", "Georgian"),
("de", "German"),
("aln", "Gheg"),
("bbj", "Ghomálá'"),
("got", "Gothic"),
("el", "Greek"),
("kl", "Greenlandic"),
("gub", "Guajajara"),
("gn", "Guarani"),
("gu", "Gujarati"),
("gwi", "Gwichin"),
("ht", "Haitian"),
("ha", "Hausa"),
("he", "Hebrew"),
("hz", "Herero"),
("azz", "Highland_Puebla_Nahuatl"),
("hil", "Hiligaynon"),
("hi", "Hindi"),
("qhe", "Hindi_English"),
("ho", "Hiri_Motu"),
("hit", "Hittite"),
("hu", "Hungarian"),
("is", "Icelandic"),
("io", "Ido"),
("ig", "Igbo"),
("arh", "Ika"),
("ilo", "Ilocano"),
("arc", "Imperial_Aramaic"),
("id", "Indonesian"),
("iu", "Inuktitut"),
("ik", "Inupiaq"),
("ga", "Irish"),
("it", "Italian"),
("ja", "Japanese"),
("jv", "Javanese"),
("urb", "Kaapor"),
("kab", "Kabyle"),
("xnr", "Kangri"),
("kn", "Kannada"),
("kr", "Kanuri"),
("pam", "Kapampangan"),
("krl", "Karelian"),
("arr", "Karo"),
("ks", "Kashmiri"),
("kk", "Kazakh"),
("naq", "Khoekhoe"),
("kfm", "Khunsari"),
("quc", "Kiche"),
("cgg", "Kiga"),
("ki", "Kikuyu"),
("rw", "Kinyarwanda"),
("ky", "Kyrgyz"),
("kv", "Komi"),
("koi", "Komi_Permyak"),
("kpv", "Komi_Zyrian"),
("kg", "Kongo"),
("ko", "Korean"),
("ku", "Kurdish"),
("kmr", "Northern_Kurdish"),
("kj", "Kwanyama"),
("lad", "Ladino"),
("lo", "Lao"),
("ltg", "Latgalian"),
("la", "Latin"),
("lv", "Latvian"),
("lij", "Ligurian"),
("li", "Limburgish"),
("ln", "Lingala"),
("lt", "Lithuanian"),
("liv", "Livonian"),
("olo", "Livvi"),
("nds", "Low_Saxon"),
("lu", "Luba_Katanga"),
("lb", "Luxembourgish"),
("mk", "Macedonian"),
("jaa", "Madi"),
("mag", "Magahi"),
("qaf", "Maghrebi_Arabic_French"),
("mai", "Maithili"),
("mpu", "Makurap"),
("mg", "Malagasy"),
("ms", "Malay"),
("ml", "Malayalam"),
("mt", "Maltese"),
("mjl", "Mandyali"),
("gv", "Manx"),
("mi", "Maori"),
("mr", "Marathi"),
("mh", "Marshallese"),
("mzn", "Mazandarani"),
("gun", "Mbya_Guarani"),
("enm", "Middle_English"),
("frm", "Middle_French"),
("min", "Minangkabau"),
("xmf", "Mingrelian"),
("mwl", "Mirandese"),
("mdf", "Moksha"),
("mn", "Mongolian"),
("mos", "Mossi"),
("myu", "Munduruku"),
("my", "Myanmar"),
("nqo", "N'Ko"),
("nmf", "Naga"),
("nah", "Nahuatl"),
("pcm", "Naija"),
("na", "Nauru"),
("nv", "Navajo"),
("nyq", "Nayini"),
("ng", "Ndonga"),
("nap", "Neapolitan"),
("nrk", "Nenets"),
("ne", "Nepali"),
("new", "Newar"),
("yrl", "Nheengatu"),
("nyn", "Nkore"),
("frr", "North_Frisian"),
("nd", "North_Ndebele"),
("sme", "North_Sami"),
("nso", "Northern_Sotho"),
("gya", "Northwest_Gbaya"),
("nb", "Norwegian_Bokmaal"),
("nn", "Norwegian_Nynorsk"),
("ii", "Nuosu"),
("oc", "Occitan"),
("or", "Odia"),
("oj", "Ojibwa"),
("cu", "Old_Church_Slavonic"),
("orv", "Old_East_Slavic"),
("ang", "Old_English"),
("fro", "Old_French"),
("sga", "Old_Irish"),
("ojp", "Old_Japanese"),
("pro", "Old_Occitan"),
("otk", "Old_Turkish"),
("om", "Oromo"),
("os", "Ossetian"),
("ota", "Ottoman_Turkish"),
("pi", "Pali"),
("ps", "Pashto"),
("pad", "Paumari"),
("fa", "Persian"),
("pay", "Pesh"),
("xpg", "Phrygian"),
("pbv", "Pnar"),
("pl", "Polish"),
("qpm", "Pomak"),
("pnt", "Pontic"),
("pt", "Portuguese"),
("pra", "Prakrit"),
("pa", "Punjabi"),
("qu", "Quechua"),
("rhg", "Rohingya"),
("ro", "Romanian"),
("rm", "Romansh"),
("rn", "Rundi"),
("ru", "Russian"),
("sm", "Samoan"),
("sg", "Sango"),
("sa", "Sanskrit"),
("skr", "Saraiki"),
("sc", "Sardinian"),
("sco", "Scots"),
("gd", "Scottish_Gaelic"),
("sr", "Serbian"),
("wuu", "Shanghainese"),
("sn", "Shona"),
("zh-hans", "Simplified_Chinese"),
("scn", "Sicilian"),
("sd", "Sindhi"),
("si", "Sinhala"),
("sms", "Skolt_Sami"),
("sk", "Slovak"),
("sl", "Slovenian"),
("soj", "Soi"),
("so", "Somali"),
("ckb", "Sorani"),
("ajp", "South_Levantine_Arabic"),
("sdh", "Southern_Kurdish"),
("nr", "South_Ndebele"),
("st", "Southern_Sotho"),
("es", "Spanish"),
("ssp", "Spanish_Sign_Language"),
("su", "Sundanese"),
("sw", "Swahili"),
("ss", "Swati"),
("sv", "Swedish"),
("swl", "Swedish_Sign_Language"),
("gsw", "Swiss_German"),
("syr", "Syriac"),
("tl", "Tagalog"),
("ty", "Tahitian"),
("tg", "Tajik"),
("ta", "Tamil"),
("tt", "Tatar"),
("eme", "Teko"),
("te", "Telugu"),
("qte", "Telugu_English"),
("th", "Thai"),
("bo", "Tibetan"),
("ti", "Tigrinya"),
("to", "Tonga"),
("zh-hant", "Traditional_Chinese"),
("ts", "Tsonga"),
("tn", "Tswana"),
("tpn", "Tupinamba"),
("tr", "Turkish"),
("qti", "Turkish_English"),
("qtd", "Turkish_German"),
("tk", "Turkmen"),
("tw", "Twi"),
("uk", "Ukrainian"),
("xum", "Umbrian"),
("hsb", "Upper_Sorbian"),
("ur", "Urdu"),
("ug", "Uyghur"),
("uz", "Uzbek"),
("ve", "Venda"),
("vep", "Veps"),
("vi", "Vietnamese"),
("vo", "Volapük"),
("wa", "Walloon"),
("war", "Waray"),
("wbp", "Warlpiri"),
("cy", "Welsh"),
("hyw", "Western_Armenian"),
("fy", "Western_Frisian"),
("nhi", "Western_Sierra_Puebla_Nahuatl"),
("wo", "Wolof"),
("xav", "Xavante"),
("xh", "Xhosa"),
("sjo", "Xibe"),
("sah", "Yakut"),
("yi", "Yiddish"),
("yo", "Yoruba"),
("ess", "Yupik"),
("say", "Zaar"),
("zza", "Zazaki"),
("zea", "Zeelandic"),
("za", "Zhuang"),
("zu", "Zulu"),
]
# build the dictionary, checking for duplicate language codes
lcode2lang = {}
for code, language in lcode2lang_raw:
assert code not in lcode2lang
lcode2lang[code] = language
# invert the dictionary, checking for possible duplicate language names
lang2lcode = {}
for code, language in lcode2lang_raw:
assert language not in lang2lcode
lang2lcode[language] = code
# check that nothing got clobbered
assert len(lcode2lang_raw) == len(lcode2lang)
assert len(lcode2lang_raw) == len(lang2lcode)
# some of the two letter langcodes get used elsewhere as three letters
# for example, Wolof is abbreviated "wo" in UD, but "wol" in Masakhane NER
two_to_three_letters_raw = (
("bm", "bam"),
("ee", "ewe"),
("ha", "hau"),
("ig", "ibo"),
("rw", "kin"),
("lg", "lug"),
("ny", "nya"),
("sn", "sna"),
("sw", "swa"),
("tn", "tsn"),
("tw", "twi"),
("wo", "wol"),
("xh", "xho"),
("yo", "yor"),
("zu", "zul"),
# this is a weird case where a 2 letter code was available,
# but UD used the 3 letter code instead
("se", "sme"),
)
for two, three in two_to_three_letters_raw:
if two in lcode2lang:
assert two in lcode2lang
assert three not in lcode2lang
assert three not in lang2lcode
lang2lcode[three] = two
lcode2lang[three] = lcode2lang[two]
elif three in lcode2lang:
assert three in lcode2lang
assert two not in lcode2lang
assert two not in lang2lcode
lang2lcode[two] = three
lcode2lang[two] = lcode2lang[three]
else:
raise AssertionError("Found a proposed alias %s -> %s when neither code was already known" % (two, three))
two_to_three_letters = {
two: three for two, three in two_to_three_letters_raw
}
three_to_two_letters = {
three: two for two, three in two_to_three_letters_raw
}
assert len(two_to_three_letters) == len(two_to_three_letters_raw)
assert len(three_to_two_letters) == len(two_to_three_letters_raw)
# additional useful code to language mapping
# added after dict invert to avoid conflict
lcode2lang['bgd'] = 'Bokota' # ISO 693-3 code, although UD used sab
lcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian
lcode2lang['no'] = 'Norwegian'
lcode2lang['zh'] = 'Simplified_Chinese'
extra_lang_to_lcodes = {
"ab": "Abkhaz",
"gsw": "Alemannic",
"my": "Burmese",
"ckb": "Central_Kurdish",
"ny": "Chewa",
"zh": "Chinese",
"za": "Chuang",
"dv": "Divehi",
"eme": "Emerillon",
"lij": "Genoese",
"ga": "Gaelic",
"ne": "Gorkhali",
"ht": "Haitian_Creole",
"ilo": "Ilokano",
"nr": "isiNdebele",
"xh": "isiXhosa",
"zu": "isiZulu",
"jaa": "Jamamadí",
"kab": "Kabylian",
"kl": "Kalaallisut",
"km": "Khmer",
"ky": "Kirghiz",
"lb": "Letzeburgesch",
"lg": "Luganda",
"jaa": "Madí",
"dv": "Maldivian",
"mjl": "Mandeali",
"skr": "Multani",
"nb": "Norwegian",
"kmr": "Kurmanji",
"ny": "Nyanja",
"sga": "Old_Gaelic",
"or": "Oriya",
"arr": "Ramarama",
"sah": "Sakha",
"nso": "Sepedi",
"tn": "Setswana",
"ii": "Sichuan_Yi",
"si": "Sinhalese",
"ss": "Siswati",
"soj": "Sohi",
"st": "Sesotho",
"ve": "Tshivenda",
"ts": "Xitsonga",
"fy": "West_Frisian",
"zza": "Zaza",
}
for code, language in extra_lang_to_lcodes.items():
assert language not in lang2lcode
assert code in lcode2lang
lang2lcode[language] = code
# treebank names changed from Old Russian to Old East Slavic in 2.8
lang2lcode['Old_Russian'] = 'orv'
# build a lowercase map from language to langcode
langlower2lcode = {}
for k in lang2lcode:
langlower2lcode[k.lower()] = lang2lcode[k]
treebank_special_cases = {
"UD_Chinese-Beginner": "zh-hans_beginner",
"UD_Chinese-GSDSimp": "zh-hans_gsdsimp",
"UD_Chinese-GSD": "zh-hant_gsd",
"UD_Chinese-HK": "zh-hant_hk",
"UD_Chinese-CFL": "zh-hans_cfl",
"UD_Chinese-PatentChar": "zh-hans_patentchar",
"UD_Chinese-PUD": "zh-hant_pud",
"UD_Norwegian-Bokmaal": "nb_bokmaal",
"UD_Norwegian-Nynorsk": "nn_nynorsk",
"UD_Norwegian-NynorskLIA": "nn_nynorsklia",
}
SHORTNAME_RE = re.compile("^[a-z-]+_[a-z0-9-_]+$")
def langcode_to_lang(lcode):
if lcode in lcode2lang:
return lcode2lang[lcode]
elif lcode.lower() in lcode2lang:
return lcode2lang[lcode.lower()]
else:
return lcode
def pretty_langcode_to_lang(lcode):
lang = langcode_to_lang(lcode)
lang = lang.replace("_", " ")
if lang == 'Simplified Chinese':
lang = 'Chinese (Simplified)'
elif lang == 'Traditional Chinese':
lang = 'Chinese (Traditional)'
return lang
def lang_to_langcode(lang):
if lang in lang2lcode:
lcode = lang2lcode[lang]
elif lang.lower() in langlower2lcode:
lcode = langlower2lcode[lang.lower()]
elif lang in lcode2lang:
lcode = lang
elif lang.lower() in lcode2lang:
lcode = lang.lower()
else:
raise UnknownLanguageError("Unable to find language code for %s" % lang)
return lcode
RIGHT_TO_LEFT = set(["ar", "arc", "az", "ckb", "dv", "ff", "he", "ku", "mzn", "nqo", "ps", "fa", "rhg", "sd", "syr", "ur"])
def is_right_to_left(lang):
"""
Covers all the RtL languages we support, as well as many we don't.
If a language is left out, please let us know!
"""
lcode = lang_to_langcode(lang)
return lcode in RIGHT_TO_LEFT
def treebank_to_short_name(treebank):
""" Convert treebank name to short code. """
if treebank in treebank_special_cases:
return treebank_special_cases.get(treebank)
if SHORTNAME_RE.match(treebank):
lang, corpus = treebank.split("_", 1)
lang = lang_to_langcode(lang)
return lang + "_" + corpus
if treebank.startswith('UD_'):
treebank = treebank[3:]
# special case starting with zh in case the input is an already-converted ZH treebank
if treebank.startswith("zh-hans") or treebank.startswith("zh-hant"):
splits = (treebank[:len("zh-hans")], treebank[len("zh-hans")+1:])
else:
splits = treebank.split('-')
if len(splits) == 1:
splits = treebank.split("_", 1)
assert len(splits) == 2, "Unable to process %s" % treebank
lang, corpus = splits
lcode = lang_to_langcode(lang)
short = "{}_{}".format(lcode, corpus.lower())
return short
def treebank_to_langid(treebank):
""" Convert treebank name to langid """
short_name = treebank_to_short_name(treebank)
return short_name.split("_")[0]
================================================
FILE: stanza/models/common/convert_pretrain.py
================================================
"""
A utility script to load a word embedding file from a text file and save it as a .pt
Run it as follows:
python stanza/models/common/convert_pretrain.py <.pt file> <# vectors>
Note that -1 for # of vectors will keep all the vectors.
You probably want to keep fewer than that for most publicly released
embeddings, though, as they can get quite large.
As a concrete example, you can convert a newly downloaded Faroese WV file as follows:
python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/fo_farpahc.pretrain.pt ~/extern_data/wordvec/fasttext/faroese.txt -1
or save part of an Icelandic WV file:
python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/is_icepahc.pretrain.pt ~/extern_data/wordvec/fasttext/icelandic.cc.is.300.vec 150000
Note that if the pretrain already exists, nothing will be changed. It will not overwrite an existing .pt file.
"""
import argparse
import os
import sys
from stanza.models.common import pretrain
def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_pt", default=None, help="Where to write the converted PT file")
parser.add_argument("input_vec", default=None, help="Unconverted vectors file")
parser.add_argument("max_vocab", type=int, default=-1, nargs="?", help="How many vectors to convert. -1 means convert them all")
args = parser.parse_args()
if os.path.exists(args.output_pt):
print("Not overwriting existing pretrain file in %s" % args.output_pt)
if args.input_vec.endswith(".csv"):
pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec)
else:
pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab)
print("Pretrain is of size {}".format(len(pt.vocab)))
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/common/count_ner_coverage.py
================================================
from stanza.models.common import pretrain
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')
parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use')
parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv",
"/home/john/stanza/data/ner/hi_fire2013.dev.csv"])
args = parser.parse_args()
return args
def read_ner(filename):
words = []
for line in open(filename).readlines():
line = line.strip()
if not line:
continue
if line.split("\t")[1] == 'O':
continue
words.append(line.split("\t")[0])
return words
def count_coverage(pretrain, words):
count = 0
for w in words:
if w in pretrain.vocab:
count = count + 1
return count / len(words)
args = parse_args()
pt = pretrain.Pretrain(args.pretrain)
for dataset in args.ners:
words = read_ner(dataset)
print(dataset)
print(count_coverage(pt, words))
print()
================================================
FILE: stanza/models/common/count_pretrain_coverage.py
================================================
"""A simple script to count the fraction of words in a UD dataset which are in a particular pretrain.
For example, this script shows that the word2vec Armenian vectors,
truncated at 250K words, have 75% coverage of the Western Armenian
dataset, whereas the vectors available here have 88% coverage:
https://github.com/ispras-texterra/word-embeddings-eval-hy
"""
from stanza.models.common import pretrain
from stanza.utils.conll import CoNLL
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('treebanks', type=str, nargs='*', help='Which treebanks to run on')
parser.add_argument('--pretrain', type=str, default="/home/john/extern_data/wordvec/glove/armenian.pt", help='Which pretrain to use')
parser.set_defaults(treebanks=["/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu",
"/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"])
args = parser.parse_args()
return args
args = parse_args()
pt = pretrain.Pretrain(args.pretrain)
pt.load()
print("Pretrain stats: {} vectors, {} dim".format(len(pt.vocab), pt.emb[0].shape[0]))
for treebank in args.treebanks:
print(treebank)
found = 0
total = 0
doc = CoNLL.conll2doc(treebank)
for sentence in doc.sentences:
for word in sentence.words:
total = total + 1
if word.text in pt.vocab:
found = found + 1
print (found / total)
================================================
FILE: stanza/models/common/crf.py
================================================
"""
CRF loss and viterbi decoding.
"""
import math
from numbers import Number
import numpy as np
import torch
from torch import nn
import torch.nn.init as init
class CRFLoss(nn.Module):
"""
Calculate log-space crf loss, given unary potentials, a transition matrix
and gold tag sequences.
"""
def __init__(self, num_tag, batch_average=True):
super().__init__()
self._transitions = nn.Parameter(torch.zeros(num_tag, num_tag))
self._batch_average = batch_average # if not batch average, average on all tokens
def forward(self, inputs, masks, tag_indices):
"""
inputs: batch_size x seq_len x num_tags
masks: batch_size x seq_len
tag_indices: batch_size x seq_len
@return:
loss: CRF negative log likelihood on all instances.
transitions: the transition matrix
"""
# TODO: handle and tags
input_bs, input_sl, input_nc = inputs.size()
unary_scores = self.crf_unary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
binary_scores = self.crf_binary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
log_norm = self.crf_log_norm(inputs, masks, tag_indices)
log_likelihood = unary_scores + binary_scores - log_norm # batch_size
loss = torch.sum(-log_likelihood)
if self._batch_average:
loss = loss / input_bs
else:
total = masks.eq(0).sum()
loss = loss / (total + 1e-8)
return loss, self._transitions
def crf_unary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
"""
@return:
unary_scores: batch_size
"""
flat_inputs = inputs.view(input_bs, -1)
flat_tag_indices = tag_indices + torch.arange(input_sl, device=tag_indices.device).long().unsqueeze(0) * input_nc
unary_scores = torch.gather(flat_inputs, 1, flat_tag_indices).view(input_bs, -1)
unary_scores.masked_fill_(masks, 0)
return unary_scores.sum(dim=1)
def crf_binary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
"""
@return:
binary_scores: batch_size
"""
# get number of transitions
nt = tag_indices.size(-1) - 1
start_indices = tag_indices[:, :nt]
end_indices = tag_indices[:, 1:]
# flat matrices
flat_transition_indices = start_indices * input_nc + end_indices
flat_transition_indices = flat_transition_indices.view(-1)
flat_transition_matrix = self._transitions.view(-1)
binary_scores = torch.gather(flat_transition_matrix, 0, flat_transition_indices)\
.view(input_bs, -1)
score_masks = masks[:, 1:]
binary_scores.masked_fill_(score_masks, 0)
return binary_scores.sum(dim=1)
def crf_log_norm(self, inputs, masks, tag_indices):
"""
Calculate the CRF partition in log space for each instance, following:
http://www.cs.columbia.edu/~mcollins/fb.pdf
@return:
log_norm: batch_size
"""
start_inputs = inputs[:,0,:] # bs x nc
rest_inputs = inputs[:,1:,:]
# TODO: technically we need to pay attention to the initial
# value being masked. Currently we do compensate for the
# entire row being masked at the end of the operation
rest_masks = masks[:,1:]
alphas = start_inputs # bs x nc
trans = self._transitions.unsqueeze(0) # 1 x nc x nc
# accumulate alphas in log space
for i in range(rest_inputs.size(1)):
transition_scores = alphas.unsqueeze(2) + trans # bs x nc x nc
new_alphas = rest_inputs[:,i,:] + log_sum_exp(transition_scores, dim=1)
m = rest_masks[:,i].unsqueeze(1).expand_as(new_alphas) # bs x nc, 1 for padding idx
# apply masks
new_alphas.masked_scatter_(m, alphas.masked_select(m))
alphas = new_alphas
log_norm = log_sum_exp(alphas, dim=1)
# if any row was entirely masked, we just turn its log denominator to 0
# eg, the empty summation for the denominator will be 1, and its log will be 0
all_masked = torch.all(masks, dim=1)
log_norm = log_norm * torch.logical_not(all_masked)
return log_norm
def viterbi_decode(scores, transition_params):
"""
Decode a tag sequence with viterbi algorithm.
scores: seq_len x num_tags (numpy array)
transition_params: num_tags x num_tags (numpy array)
@return:
viterbi: a list of tag ids with highest score
viterbi_score: the highest score
"""
trellis = np.zeros_like(scores)
backpointers = np.zeros_like(scores, dtype=np.int32)
trellis[0] = scores[0]
for t in range(1, scores.shape[0]):
v = np.expand_dims(trellis[t-1], 1) + transition_params
trellis[t] = scores[t] + np.max(v, 0)
backpointers[t] = np.argmax(v, 0)
viterbi = [np.argmax(trellis[-1])]
for bp in reversed(backpointers[1:]):
viterbi.append(bp[viterbi[-1]])
viterbi.reverse()
viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score
def log_sum_exp(value, dim=None, keepdim=False):
"""Numerically stable implementation of the operation
value.exp().sum(dim, keepdim).log()
"""
if dim is not None:
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0),
dim=dim, keepdim=keepdim))
else:
m = torch.max(value)
sum_exp = torch.sum(torch.exp(value - m))
if isinstance(sum_exp, Number):
return m + math.log(sum_exp)
else:
return m + torch.log(sum_exp)
================================================
FILE: stanza/models/common/data.py
================================================
"""
Utility functions for data transformations.
"""
import logging
import random
import torch
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import HEAD, ID, UPOS
logger = logging.getLogger('stanza')
def map_to_ids(tokens, vocab):
ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
return ids
def get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID):
""" Convert (list of )+ tokens to a padded LongTensor. """
sizes = []
x = tokens_list
while isinstance(x[0], list):
sizes.append(max(len(y) for y in x))
x = [z for y in x for z in y]
# TODO: pass in a device parameter and put it directly on the relevant device?
# that might be faster than creating it and then moving it
tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id)
for i, s in enumerate(tokens_list):
tokens[i, :len(s)] = torch.LongTensor(s)
return tokens
def get_float_tensor(features_list, batch_size):
if features_list is None or features_list[0] is None:
return None
seq_len = max(len(x) for x in features_list)
feature_len = len(features_list[0][0])
features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_()
for i,f in enumerate(features_list):
features[i,:len(f),:] = torch.FloatTensor(f)
return features
def sort_all(batch, lens):
""" Sort all fields by descending order of lens, and return the original indices. """
if batch == [[]]:
return [[]], []
unsorted_all = [lens] + [range(len(lens))] + list(batch)
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
return sorted_all[2:], sorted_all[1]
def get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5):
"""
Returns X so that if you randomly select X * N sentences, you get 10%
The ratio will be chosen in the assumption that the final dataset
is of size N rather than N + X * N.
should_augment_predicate: returns True if the sentence has some
feature which we may want to change occasionally. for example,
depparse sentences which end in punct
can_augment_predicate: in the depparse sentences example, it is
technically possible for the punct at the end to be the parent
of some other word in the sentence. in that case, the sentence
should not be chosen. should be at least as restrictive as
should_augment_predicate
"""
n_data = len(train_data)
n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data)
n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data)
n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence)
for sentence in train_data)
if n_error > 0:
raise AssertionError("can_augment_predicate allowed sentences not allowed by should_augment_predicate")
if n_can_augment == 0:
logger.warning("Found no sentences which matched can_augment_predicate {}".format(can_augment_predicate))
return 0.0
n_needed = n_data * desired_ratio - (n_data - n_should_augment)
# if we want 10%, for example, and more than 10% already matches, we can skip
if n_needed < 0:
return 0.0
ratio = n_needed / n_can_augment
if ratio > max_ratio:
return max_ratio
return ratio
def should_augment_nopunct_predicate(sentence):
last_word = sentence[-1]
return last_word.get(UPOS, None) == 'PUNCT'
def can_augment_nopunct_predicate(sentence):
"""
Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word
"""
last_word = sentence[-1]
if last_word.get(UPOS, None) != 'PUNCT':
return False
# don't cut off MWT
if len(last_word[ID]) > 1:
return False
if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence):
return False
return True
def augment_punct(train_data, augment_ratio,
should_augment_predicate=should_augment_nopunct_predicate,
can_augment_predicate=can_augment_nopunct_predicate,
keep_original_sentences=True):
"""
Adds extra training data to compensate for some models having all sentences end with PUNCT
Some of the models (for example, UD_Hebrew-HTB) have the flaw that
all of the training sentences end with PUNCT. The model therefore
learns to finish every sentence with punctuation, even if it is
given a sentence with non-punct at the end.
One simple way to fix this is to train on some fraction of training data with punct.
Params:
train_data: list of list of dicts, eg a conll doc
augment_ratio: the fraction to augment. if None, a best guess is made to get to 10%
should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT
can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT
TODO: do this dynamically, as part of the DataLoader or elsewhere?
One complication is the data comes back from the DataLoader as
tensors & indices, so it is much more complicated to manipulate
"""
if len(train_data) == 0:
return []
if augment_ratio is None:
augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate)
if augment_ratio <= 0:
if keep_original_sentences:
return list(train_data)
else:
return []
new_data = []
for sentence in train_data:
if can_augment_predicate(sentence):
if random.random() < augment_ratio and len(sentence) > 1:
# todo: could deep copy the words
# or not deep copy any of this
new_sentence = list(sentence[:-1])
new_data.append(new_sentence)
elif keep_original_sentences:
new_data.append(new_sentence)
return new_data
================================================
FILE: stanza/models/common/doc.py
================================================
"""
Basic data structures
"""
import io
from itertools import repeat
import re
import json
import pickle
import warnings
from enum import Enum
import networkx as nx
from stanza.models.common.stanza_object import StanzaObject
from stanza.models.common.utils import misc_to_space_after, space_after_to_misc, misc_to_space_before, space_before_to_misc
from stanza.models.ner.utils import decode_from_bioes
from stanza.models.constituency import tree_reader
from stanza.models.coref.coref_chain import CorefMention, CorefChain, CorefAttachment
class MWTProcessingType(Enum):
FLATTEN = 0 # flatten the current token into one ID instead of MWT
PROCESS = 1 # process the current token as an MWT and expand it as such
SKIP = 2 # do nothing on this token, simply increment IDs
multi_word_token_id = re.compile(r"([0-9]+)-([0-9]+)")
multi_word_token_misc = re.compile(r".*MWT=Yes.*")
MEXP = 'manual_expansion'
ID = 'id'
TEXT = 'text'
LEMMA = 'lemma'
UPOS = 'upos'
XPOS = 'xpos'
FEATS = 'feats'
HEAD = 'head'
DEPREL = 'deprel'
DEPS = 'deps'
MISC = 'misc'
NER = 'ner'
MULTI_NER = 'multi_ner' # will represent tags from multiple NER models
START_CHAR = 'start_char'
END_CHAR = 'end_char'
TYPE = 'type'
SENTIMENT = 'sentiment'
CONSTITUENCY = 'constituency'
COREF_CHAINS = 'coref_chains'
LINE_NUMBER = 'line_number'
MORPHEMES = 'morphemes'
# field indices when converting the document to conll
FIELD_TO_IDX = {ID: 0, TEXT: 1, LEMMA: 2, UPOS: 3, XPOS: 4, FEATS: 5, HEAD: 6, DEPREL: 7, DEPS: 8, MISC: 9}
FIELD_NUM = len(FIELD_TO_IDX)
DEFAULT_OUTPUT_FIELDS = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, MULTI_NER, MEXP, COREF_CHAINS, MORPHEMES]
NO_OFFSETS_OUTPUT_FIELDS = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, NER, MULTI_NER, MEXP, COREF_CHAINS, MORPHEMES]
class DocJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, CorefMention):
return obj.__dict__
if isinstance(obj, CorefAttachment):
return obj.to_json()
return json.JSONEncoder.default(self, obj)
class Document(StanzaObject):
""" A document class that stores attributes of a document and carries a list of sentences.
"""
def __init__(self, sentences, text=None, comments=None, empty_sentences=None):
""" Construct a document given a list of sentences in the form of lists of CoNLL-U dicts.
Args:
sentences: a list of sentences, which being a list of token entry, in the form of a CoNLL-U dict.
text: the raw text of the document.
comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences
"""
self._sentences = []
self._lang = None
self._text = text
self._num_tokens = 0
self._num_words = 0
self._process_sentences(sentences, comments, empty_sentences)
self._ents = []
self._coref = []
if self._text is not None:
self.build_ents()
self.mark_whitespace()
def mark_whitespace(self):
for sentence in self._sentences:
# TODO: pairwise, once we move to minimum 3.10
for prev_token, next_token in zip(sentence.tokens[:-1], sentence.tokens[1:]):
whitespace = self._text[prev_token.end_char:next_token.start_char]
prev_token.spaces_after = whitespace
for prev_sentence, next_sentence in zip(self._sentences[:-1], self._sentences[1:]):
prev_token = prev_sentence.tokens[-1]
next_token = next_sentence.tokens[0]
whitespace = self._text[prev_token.end_char:next_token.start_char]
prev_token.spaces_after = whitespace
if len(self._sentences) > 0 and len(self._sentences[-1].tokens) > 0:
final_token = self._sentences[-1].tokens[-1]
whitespace = self._text[final_token.end_char:]
final_token.spaces_after = whitespace
if len(self._sentences) > 0 and len(self._sentences[0].tokens) > 0:
first_token = self._sentences[0].tokens[0]
whitespace = self._text[:first_token.start_char]
first_token.spaces_before = whitespace
@property
def lang(self):
""" Access the language of this document """
return self._lang
@lang.setter
def lang(self, value):
""" Set the language of this document """
self._lang = value
@property
def text(self):
""" Access the raw text for this document. """
return self._text
@text.setter
def text(self, value):
""" Set the raw text for this document. """
self._text = value
@property
def sentences(self):
""" Access the list of sentences for this document. """
return self._sentences
@sentences.setter
def sentences(self, value):
""" Set the list of tokens for this document. """
self._sentences = value
@property
def num_tokens(self):
""" Access the number of tokens for this document. """
return self._num_tokens
@num_tokens.setter
def num_tokens(self, value):
""" Set the number of tokens for this document. """
self._num_tokens = value
@property
def num_words(self):
""" Access the number of words for this document. """
return self._num_words
@num_words.setter
def num_words(self, value):
""" Set the number of words for this document. """
self._num_words = value
@property
def ents(self):
""" Access the list of entities in this document. """
return self._ents
@ents.setter
def ents(self, value):
""" Set the list of entities in this document. """
self._ents = value
@property
def entities(self):
""" Access the list of entities. This is just an alias of `ents`. """
return self._ents
@entities.setter
def entities(self, value):
""" Set the list of entities in this document. """
self._ents = value
def _process_sentences(self, sentences, comments=None, empty_sentences=None):
self.sentences = []
if empty_sentences is None:
empty_sentences = repeat([])
for sent_idx, (tokens, empty_words) in enumerate(zip(sentences, empty_sentences)):
try:
sentence = Sentence(tokens, doc=self, empty_words=empty_words)
except IndexError as e:
raise IndexError("Could not process document at sentence %d" % sent_idx) from e
except ValueError as e:
tokens = ["|%s|" % t for t in tokens]
tokens = ", ".join(tokens)
raise ValueError("Could not process document at sentence %d\n Raw tokens: %s" % (sent_idx, tokens)) from e
self.sentences.append(sentence)
begin_idx, end_idx = sentence.tokens[0].start_char, sentence.tokens[-1].end_char
if all((self.text is not None, begin_idx is not None, end_idx is not None)): sentence.text = self.text[begin_idx: end_idx]
sentence.index = sent_idx
self._count_words()
# Add a #text comment to each sentence in a doc if it doesn't already exist
if not comments:
comments = [[] for x in self.sentences]
else:
comments = [list(x) for x in comments]
for sentence, sentence_comments in zip(self.sentences, comments):
# the space after text can occur in treebanks such as the Naija-NSC treebank,
# which extensively uses `# text_en =` and `# text_ortho`
if sentence.text and not any(comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text=") for comment in sentence_comments):
# split/join to handle weird whitespace, especially newlines
sentence_comments.append("# text = " + ' '.join(sentence.text.split()))
elif not sentence.text:
for comment in sentence_comments:
if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
sentence.text = comment.split("=", 1)[-1].strip()
break
for comment in sentence_comments:
sentence.add_comment(comment)
# look for sent_id in the comments
# if it's there, overwrite the sent_idx id from above
for comment in sentence_comments:
if comment.startswith("# sent_id"):
sentence.sent_id = comment.split("=", 1)[-1].strip()
break
else:
# no sent_id found. add a comment with our enumerated id
# setting the sent_id on the sentence will automatically add the comment
sentence.sent_id = str(sentence.index)
# look for speaker in the comments
for comment in sentence_comments:
if comment.startswith("# speaker"):
sentence.speaker = comment.split("=", 1)[-1].strip()
break
else:
sentence.speaker = None
def _count_words(self):
"""
Count the number of tokens and words
"""
self.num_tokens = sum([len(sentence.tokens) for sentence in self.sentences])
self.num_words = sum([len(sentence.words) for sentence in self.sentences])
def get(self, fields, as_sentences=False, from_token=False):
""" Get fields from a list of field names.
If only one field name (string or singleton list) is provided,
return a list of that field; if more than one, return a list of list.
Note that all returned fields are after multi-word expansion.
Args:
fields: name of the fields as a list or a single string
as_sentences: if True, return the fields as a list of sentences; otherwise as a whole list
from_token: if True, get the fields from Token; otherwise from Word
Returns:
All requested fields.
"""
if isinstance(fields, str):
fields = [fields]
assert isinstance(fields, list), "Must provide field names as a list."
assert len(fields) >= 1, "Must have at least one field."
results = []
for sentence in self.sentences:
cursent = []
# decide word or token
if from_token:
units = sentence.tokens
else:
units = sentence.words
for unit in units:
if len(fields) == 1:
cursent += [getattr(unit, fields[0])]
else:
cursent += [[getattr(unit, field) for field in fields]]
# decide whether append the results as a sentence or a whole list
if as_sentences:
results.append(cursent)
else:
results += cursent
return results
def set(self, fields, contents, to_token=False, to_sentence=False):
"""Set fields based on contents. If only one field (string or
singleton list) is provided, then a list of content will be
expected; otherwise a list of list of contents will be expected.
Args:
fields: name of the fields as a list or a single string
contents: field values to set; total length should be equal to number of words/tokens
to_token: if True, set field values to tokens; otherwise to words
"""
if isinstance(fields, str):
fields = [fields]
assert isinstance(fields, (tuple, list)), "Must provide field names as a list."
assert isinstance(contents, (tuple, list)), "Must provide contents as a list (one item per line)."
assert len(fields) >= 1, "Must have at least one field."
assert not to_sentence or not to_token, "Both to_token and to_sentence set to True, which is very confusing"
if to_sentence:
assert len(self.sentences) == len(contents), \
"Contents must have the same length as the sentences"
for sentence, content in zip(self.sentences, contents):
if len(fields) == 1:
setattr(sentence, fields[0], content)
else:
for field, piece in zip(fields, content):
setattr(sentence, field, piece)
else:
assert (to_token and self.num_tokens == len(contents)) or self.num_words == len(contents), \
"Contents must have the same length as the original file."
cidx = 0
for sentence in self.sentences:
# decide word or token
if to_token:
units = sentence.tokens
else:
units = sentence.words
for unit in units:
if len(fields) == 1:
setattr(unit, fields[0], contents[cidx])
else:
for field, content in zip(fields, contents[cidx]):
setattr(unit, field, content)
cidx += 1
def set_mwt_expansions(self, expansions,
fake_dependencies=False,
process_manual_expanded=None):
""" Extend the multi-word tokens annotated by tokenizer. A list of list of expansions
will be expected for each multi-word token. Use `process_manual_expanded` to limit
processing for tokens marked manually expanded:
There are two types of MWT expansions: those with `misc`: `MWT=True`, and those with
`manual_expansion`: True. The latter of which means that it is an expansion which the
user manually specified through a postprocessor; the former means that it is a MWT
which the detector picked out, but needs to be automatically expanded.
process_manual_expanded = None - default; doesn't process manually expanded tokens
= True - process only manually expanded tokens (with `manual_expansion`: True)
= False - process only tokens explicitly tagged as MWT (`misc`: `MWT=True`)
"""
idx_e = 0
for sentence in self.sentences:
idx_w = 0
for token in sentence.tokens:
idx_w += 1
is_multi = (len(token.id) > 1)
is_mwt = (multi_word_token_misc.match(token.misc) if token.misc is not None else None)
is_manual_expansion = token.manual_expansion
perform_mwt_processing = MWTProcessingType.FLATTEN
if (process_manual_expanded and is_manual_expansion):
perform_mwt_processing = MWTProcessingType.PROCESS
elif (process_manual_expanded==False and is_mwt):
perform_mwt_processing = MWTProcessingType.PROCESS
elif (process_manual_expanded==False and is_manual_expansion):
perform_mwt_processing = MWTProcessingType.SKIP
elif (process_manual_expanded==None and (is_mwt or is_multi)):
perform_mwt_processing = MWTProcessingType.PROCESS
if perform_mwt_processing == MWTProcessingType.FLATTEN:
for word in token.words:
token.id = (idx_w, )
# delete dependency information
word.deps = None
word.head, word.deprel = None, None
word.id = idx_w
elif perform_mwt_processing == MWTProcessingType.PROCESS:
expanded = [x for x in expansions[idx_e].split(' ') if len(x) > 0]
# in the event the MWT annotator only split the
# Token into a single Word, we preserve its text
# otherwise the Token's text is different from its
# only Word's text
if len(expanded) == 1:
expanded = [token.text]
idx_e += 1
idx_w_end = idx_w + len(expanded) - 1
if token.misc: # None can happen when using a prebuilt doc
token.misc = None if token.misc == 'MWT=Yes' else '|'.join([x for x in token.misc.split('|') if x != 'MWT=Yes'])
token.id = (idx_w, idx_w_end) if len(expanded) > 1 else (idx_w,)
token.words = []
for i, e_word in enumerate(expanded):
token.words.append(Word(sentence, {ID: idx_w + i, TEXT: e_word}))
idx_w = idx_w_end
elif perform_mwt_processing == MWTProcessingType.SKIP:
token.id = tuple(orig_id + idx_e for orig_id in token.id)
for i in token.words:
i.id += idx_e
idx_w = token.id[-1]
token.manual_expansion = None
# reprocess the words using the new tokens
sentence.words = []
for token in sentence.tokens:
token.sent = sentence
for word in token.words:
word.sent = sentence
word.parent = token
sentence.words.append(word)
if len(token.words) == 1:
word.start_char = token.start_char
word.end_char = token.end_char
elif token.start_char is not None and token.end_char is not None:
search_string = "^%s$" % ("\\s*".join("(%s)" % re.escape(word.text) for word in token.words))
match = re.compile(search_string).match(token.text)
if match:
for word_idx, word in enumerate(token.words):
word.start_char = match.start(word_idx+1) + token.start_char
word.end_char = match.end(word_idx+1) + token.start_char
if fake_dependencies:
sentence.build_fake_dependencies()
else:
sentence.rebuild_dependencies()
self._count_words() # update number of words & tokens
assert idx_e == len(expansions), "{} {}".format(idx_e, len(expansions))
return
def get_mwt_expansions(self, evaluation=False):
""" Get the multi-word tokens. For training, return a list of
(multi-word token, extended multi-word token); otherwise, return a list of
multi-word token only. By default doesn't skip already expanded tokens, but
`skip_already_expanded` will return only tokens marked as MWT.
"""
expansions = []
for sentence in self.sentences:
for token in sentence.tokens:
is_multi = (len(token.id) > 1)
is_mwt = multi_word_token_misc.match(token.misc) if token.misc is not None else None
is_manual_expansion = token.manual_expansion
if (is_multi and not is_manual_expansion) or is_mwt:
src = token.text
dst = ' '.join([word.text for word in token.words])
expansions.append([src, dst])
if evaluation: expansions = [e[0] for e in expansions]
return expansions
def build_ents(self):
""" Build the list of entities by iterating over all words. Return all entities as a list. """
self.ents = []
for s in self.sentences:
s_ents = s.build_ents()
self.ents += s_ents
return self.ents
def sort_features(self):
""" Sort the features on all the words... useful for prototype treebanks, for example """
for sentence in self.sentences:
for word in sentence.words:
if not word.feats:
continue
pieces = word.feats.split("|")
pieces = sorted(pieces, key=str.casefold)
word.feats = "|".join(pieces)
def iter_words(self):
""" An iterator that returns all of the words in this Document. """
for s in self.sentences:
yield from s.words
def iter_tokens(self):
""" An iterator that returns all of the tokens in this Document. """
for s in self.sentences:
yield from s.tokens
def sentence_comments(self):
""" Returns a list of list of comments for the sentences """
return [[comment for comment in sentence.comments] for sentence in self.sentences]
@property
def coref(self):
"""
Access the coref lists of the document
"""
return self._coref
@coref.setter
def coref(self, chains):
""" Set the document's coref lists """
self._coref = chains
self._attach_coref_mentions(chains)
def _attach_coref_mentions(self, chains):
for sentence in self.sentences:
for word in sentence.all_words:
word.coref_chains = []
for chain in chains:
for mention_idx, mention in enumerate(chain.mentions):
sentence = self.sentences[mention.sentence]
if isinstance(mention.start_word, tuple):
attachment = CorefAttachment(chain, True, True, False)
sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment)
else:
for word_idx in range(mention.start_word, mention.end_word):
is_start = word_idx == mention.start_word
is_end = word_idx == mention.end_word - 1
is_representative = mention_idx == chain.representative_index
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
sentence.words[word_idx].coref_chains.append(attachment)
def reindex_sentences(self, start_index):
for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):
sentence.sent_id = str(sent_id)
def to_dict(self):
""" Dumps the whole document into a list of list of dictionary for each token in each sentence in the doc.
"""
return [sentence.to_dict() for sentence in self.sentences]
def __repr__(self):
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
def __format__(self, spec):
if spec and spec[0] in ('c', 'C'):
spec = "{:%s}" % spec
return "\n\n".join(spec.format(s) for s in self.sentences)
else:
return str(self)
def to_serialized(self):
""" Dumps the whole document including text to a byte array containing a list of list of dictionaries for each token in each sentence in the doc.
"""
return pickle.dumps((self.text, self.to_dict(), self.sentence_comments()))
@classmethod
def from_serialized(cls, serialized_string):
""" Create and initialize a new document from a serialized string generated by Document.to_serialized_string():
"""
stuff = pickle.loads(serialized_string)
if not isinstance(stuff, tuple):
raise TypeError("Serialized data was not a tuple when building a Document")
if len(stuff) == 2:
text, sentences = pickle.loads(serialized_string)
doc = cls(sentences, text)
else:
text, sentences, comments = pickle.loads(serialized_string)
doc = cls(sentences, text, comments)
return doc
class Sentence(StanzaObject):
""" A sentence class that stores attributes of a sentence and carries a list of tokens.
"""
def __init__(self, tokens, doc=None, empty_words=None):
""" Construct a sentence given a list of tokens in the form of CoNLL-U dicts.
"""
self._tokens = []
self._words = []
self._dependencies = []
self._text = None
self._ents = []
self._doc = doc
self._constituency = None
self._sentiment = None
# comments are a list of comment lines occurring before the
# sentence in a CoNLL-U file. Can be empty
self._comments = []
self._doc_id = None
# enhanced_dependencies represents the DEPS column
# this is a networkx MultiDiGraph
# with edges from the parent to the dependent
# however, we set it to None until needed, as it is somewhat slow
self._enhanced_dependencies = None
self._process_tokens(tokens)
if empty_words is not None:
self._empty_words = [Word(self, entry) for entry in empty_words]
else:
self._empty_words = []
def _process_tokens(self, tokens):
st, en = -1, -1
self.tokens, self.words = [], []
for i, entry in enumerate(tokens):
if ID not in entry: # manually set a 1-based id for word if not exist
entry[ID] = (i+1, )
if isinstance(entry[ID], int):
entry[ID] = (entry[ID], )
if len(entry.get(ID)) > 1: # if this token is a multi-word token
st, en = entry[ID]
self.tokens.append(Token(self, entry))
else: # else this token is a word
new_word = Word(self, entry)
if len(self.words) > 0 and self.words[-1].id == new_word.id:
# this can happen in the following context:
# a document was created with MWT=Yes to mark that a token should be split
# and then there was an MWT "expansion" with a single word after that token
# we replace the Word in the Token assuming that the expansion token might
# have more information than the Token dict did
# note that a single word MWT like that can be detected with something like
# multi_word_token_misc.match(entry.get(MISC)) if entry.get(MISC, None)
self.words[-1] = new_word
self.tokens[-1].words[-1] = new_word
continue
self.words.append(new_word)
idx = entry.get(ID)[0]
if idx <= en:
self.tokens[-1].words.append(new_word)
else:
self.tokens.append(Token(self, entry, words=[new_word]))
new_word.parent = self.tokens[-1]
# put all of the whitespace annotations (if any) on the Tokens instead of the Words
for token in self.tokens:
token.consolidate_whitespace()
self.rebuild_dependencies()
def has_enhanced_dependencies(self):
"""
Whether or not the enhanced dependencies are part of this sentence
"""
return self._enhanced_dependencies is not None and len(self._enhanced_dependencies) > 0
@property
def enhanced_dependencies(self):
"""
Returns the enhanced_dependencies graph.
Creates an empty one if one currently does not exist.
"""
graph = self._enhanced_dependencies
if graph is None:
graph = nx.MultiDiGraph()
self._enhanced_dependencies = graph
return graph
@property
def index(self):
"""
Access the index of this sentence within the doc.
If multiple docs were processed together,
the sentence index will continue counting across docs.
"""
return self._index
@index.setter
def index(self, value):
""" Set the sentence's index value. """
self._index = value
@property
def id(self):
"""
Access the index of this sentence within the doc.
If multiple docs were processed together,
the sentence index will continue counting across docs.
"""
warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
return self._index
@id.setter
def id(self, value):
""" Set the sentence's index value. """
warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
self._index = value
@property
def sent_id(self):
""" conll-style sent_id Will be set from index if unknown """
return self._sent_id
@sent_id.setter
def sent_id(self, value):
""" Set the sentence's sent_id value. """
self._sent_id = value
sent_id_comment = "# sent_id = " + str(value)
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# sent_id = "):
self._comments[comment_idx] = sent_id_comment
break
else: # this is intended to be a for/else loop
self._comments.append(sent_id_comment)
@property
def speaker(self):
""" conll-style speaker - adopt the EN GUM formatting """
return self._speaker
@speaker.setter
def speaker(self, value):
""" Set the sentence's speaker value. """
self._speaker = value
speaker_comment = "# speaker = " + str(value)
if not value:
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# speaker = "):
self._comments.pop(comment_idx)
break
else:
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# speaker = "):
self._comments[comment_idx] = speaker_comment
break
else: # this is intended to be a for/else loop
self._comments.append(speaker_comment)
@property
def doc_id(self):
""" conll-style doc_id Can be left blank if unknown """
return self._doc_id
@doc_id.setter
def doc_id(self, value):
""" Set the sentence's doc_id value. """
self._doc_id = value
doc_id_comment = "# doc_id = " + str(value)
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# doc_id = "):
self._comments[comment_idx] = doc_id_comment
break
else: # this is intended to be a for/else loop
self._comments.append(doc_id_comment)
@property
def doc(self):
""" Access the parent doc of this span. """
return self._doc
@doc.setter
def doc(self, value):
""" Set the parent doc of this span. """
self._doc = value
@property
def text(self):
""" Access the raw text for this sentence. """
return self._text
@text.setter
def text(self, value):
""" Set the raw text for this sentence. """
self._text = value
@property
def dependencies(self):
""" Access list of dependencies for this sentence. """
return self._dependencies
@dependencies.setter
def dependencies(self, value):
""" Set the list of dependencies for this sentence. """
self._dependencies = value
@property
def tokens(self):
""" Access the list of tokens for this sentence. """
return self._tokens
@tokens.setter
def tokens(self, value):
""" Set the list of tokens for this sentence. """
self._tokens = value
@property
def words(self):
""" Access the list of words for this sentence. """
return self._words
@words.setter
def words(self, value):
""" Set the list of words for this sentence. """
self._words = value
@property
def empty_words(self):
""" Access the list of words for this sentence. """
return self._empty_words
@empty_words.setter
def empty_words(self, value):
""" Set the list of words for this sentence. """
self._empty_words = value
@property
def all_words(self):
""" Access the list of words + empty words for this sentence. """
words = self._words
empty_words = self._empty_words
all_words = sorted(words + empty_words,
key=lambda x:(x.id,) if isinstance(x.id, int) else x.id)
return all_words
@property
def ents(self):
""" Access the list of entities in this sentence. """
return self._ents
@ents.setter
def ents(self, value):
""" Set the list of entities in this sentence. """
self._ents = value
@property
def entities(self):
""" Access the list of entities. This is just an alias of `ents`. """
return self._ents
@entities.setter
def entities(self, value):
""" Set the list of entities in this sentence. """
self._ents = value
def build_ents(self):
""" Build the list of entities by iterating over all tokens. Return all entities as a list.
Note that unlike other attributes, since NER requires raw text, the actual tagging are always
performed at and attached to the `Token`s, instead of `Word`s.
"""
self.ents = []
tags = [w.ner for w in self.tokens]
decoded = decode_from_bioes(tags)
for e in decoded:
ent_tokens = self.tokens[e['start']:e['end']+1]
self.ents.append(Span(tokens=ent_tokens, type=e['type'], doc=self.doc, sent=self))
return self.ents
@property
def sentiment(self):
""" Returns the sentiment value for this sentence """
return self._sentiment
@sentiment.setter
def sentiment(self, value):
""" Set the sentiment value """
self._sentiment = value
sentiment_comment = "# sentiment = " + str(value)
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# sentiment = "):
self._comments[comment_idx] = sentiment_comment
break
else: # this is intended to be a for/else loop
self._comments.append(sentiment_comment)
@property
def constituency(self):
""" Returns the constituency tree for this sentence """
return self._constituency
@constituency.setter
def constituency(self, value):
"""
Set the constituency tree
This incidentally updates the #constituency comment if it already exists,
or otherwise creates a new comment # constituency = ...
"""
self._constituency = value
constituency_comment = "# constituency = " + str(value)
constituency_comment = constituency_comment.replace("\n", "*NL*").replace("\r", "")
for comment_idx, comment in enumerate(self._comments):
if comment.startswith("# constituency = "):
self._comments[comment_idx] = constituency_comment
break
else: # this is intended to be a for/else loop
self._comments.append(constituency_comment)
@property
def comments(self):
""" Returns CoNLL-style comments for this sentence """
return self._comments
def add_comment(self, comment):
""" Adds a single comment to this sentence.
If the comment does not already have # at the start, it will be added.
"""
if not comment.startswith("#"):
comment = "# " + comment
if comment.startswith("# constituency ="):
_, tree_text = comment.split("=", 1)
tree = tree_reader.read_trees(tree_text)
if len(tree) > 1:
raise ValueError("Multiple constituency trees for one sentence: %s" % tree_text)
self._constituency = tree[0]
self._comments = [x for x in self._comments if not x.startswith("# constituency =")]
elif comment.startswith("# sentiment ="):
_, sentiment = comment.split("=", 1)
sentiment = int(sentiment.strip())
self._sentiment = sentiment
self._comments = [x for x in self._comments if not x.startswith("# sentiment =")]
elif comment.startswith("# sent_id ="):
_, sent_id = comment.split("=", 1)
sent_id = sent_id.strip()
self._sent_id = sent_id
self._comments = [x for x in self._comments if not x.startswith("# sent_id =")]
elif comment.startswith("# doc_id ="):
_, doc_id = comment.split("=", 1)
doc_id = doc_id.strip()
self._doc_id = doc_id
self._comments = [x for x in self._comments if not x.startswith("# doc_id =")]
self._comments.append(comment)
def rebuild_dependencies(self):
# rebuild dependencies if there is dependency info
is_complete_dependencies = all(word.head is not None and word.deprel is not None for word in self.words)
is_complete_words = (len(self.words) >= len(self.tokens)) and (len(self.words) == self.words[-1].id)
if is_complete_dependencies and is_complete_words: self.build_dependencies()
def build_dependencies(self):
""" Build the dependency graph for this sentence. Each dependency graph entry is
a list of (head, deprel, word).
"""
self.dependencies = []
for word in self.words:
if word.head == 0:
# make a word for the ROOT
word_entry = {ID: 0, TEXT: "ROOT"}
head = Word(self, word_entry)
else:
# id is index in words list + 1
try:
head = self.words[word.head - 1]
except IndexError as e:
raise IndexError("Word head {} is not a valid word index for word {}".format(word.head, word.id)) from e
if word.head != head.id:
raise ValueError("Dependency tree is incorrectly constructed")
self.dependencies.append((head, word.deprel, word))
def build_fake_dependencies(self):
self.dependencies = []
for word_idx, word in enumerate(self.words):
word.head = word_idx # note that this goes one previous to the index
word.deprel = "root" if word_idx == 0 else "dep"
word.deps = "%d:%s" % (word.head, word.deprel)
self.dependencies.append((word_idx, word.deprel, word))
def print_dependencies(self, file=None):
""" Print the dependencies for this sentence. """
for dep_edge in self.dependencies:
print((dep_edge[2].text, dep_edge[0].id, dep_edge[1]), file=file)
def dependencies_string(self):
""" Dump the dependencies for this sentence into string. """
dep_string = io.StringIO()
self.print_dependencies(file=dep_string)
return dep_string.getvalue().strip()
def get_roots(self):
""" Return a list of root(s) from a sentence """
roots = []
for word in self.words:
if word.head == 0:
roots.append(word)
return roots
def print_tokens(self, file=None):
""" Print the tokens for this sentence. """
for tok in self.tokens:
print(tok.pretty_print(), file=file)
def tokens_string(self):
""" Dump the tokens for this sentence into string. """
toks_string = io.StringIO()
self.print_tokens(file=toks_string)
return toks_string.getvalue().strip()
def print_words(self, file=None):
""" Print the words for this sentence. """
for word in self.words:
print(word.pretty_print(), file=file)
def words_string(self):
""" Dump the words for this sentence into string. """
wrds_string = io.StringIO()
self.print_words(file=wrds_string)
return wrds_string.getvalue().strip()
def to_dict(self):
""" Dumps the sentence into a list of dictionary for each token in the sentence.
"""
ret = []
empty_idx = 0
for token_idx, token in enumerate(self.tokens):
while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
ret.append(self._empty_words[empty_idx].to_dict())
empty_idx += 1
ret += token.to_dict()
for empty_word in self._empty_words[empty_idx:]:
ret.append(empty_word.to_dict())
return ret
def __repr__(self):
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
def __format__(self, spec):
if not spec:
return str(self)
if not spec[0] == 'c' and not spec[0] == 'C':
return str(self)
if "-o" in spec:
fields = NO_OFFSETS_OUTPUT_FIELDS
else:
fields = DEFAULT_OUTPUT_FIELDS
pieces = []
empty_idx = 0
for token_idx, token in enumerate(self.tokens):
while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
pieces.append(self._empty_words[empty_idx].to_conll_text(fields))
empty_idx += 1
pieces.append(token.to_conll_text(fields))
for empty_word in self._empty_words[empty_idx:]:
pieces.append(empty_word.to_conll_text(fields))
if spec[0] == 'c':
return "\n".join(pieces)
elif spec[0] == 'C':
tokens = "\n".join(pieces)
if len(self.comments) > 0:
text = "\n".join(self.comments)
return text + "\n" + tokens
return tokens
def init_from_misc(unit):
"""Create attributes by parsing from the `misc` field.
Also, remove start_char, end_char, and any other values we can set
from the misc field if applicable, so that we don't repeat ourselves
"""
remaining_values = []
for item in unit._misc.split('|'):
key_value = item.split('=', 1)
if len(key_value) == 2:
# some key_value can not be split
key, value = key_value
# start & end char are kept as ints
if key in (START_CHAR, END_CHAR, LINE_NUMBER):
value = int(value)
# set attribute
attr = f'_{key}'
if hasattr(unit, attr):
setattr(unit, attr, value)
continue
elif key == NER:
# special case skipping NER for Words, since there is no Word NER field
continue
remaining_values.append(item)
unit._misc = "|".join(remaining_values)
def dict_to_conll_text(token_dict, id_connector="-"):
token_conll = ['_' for i in range(FIELD_NUM)]
misc = []
if token_dict.get(MISC):
# avoid appending a blank misc entry.
# otherwise the resulting misc field in the conll doc will wind up being blank text
# TODO: potentially need to escape =|\ in the MISC as well
misc.append(token_dict[MISC])
# for other items meant to be in the MISC field,
# we try to operate on those columns in a deterministic order
# so that the output doesn't change based on the order of keys
# in the token_dict
for key in [START_CHAR, END_CHAR, NER]:
if key in token_dict:
misc.append("{}={}".format(key, token_dict[key]))
if COREF_CHAINS in token_dict:
chains = token_dict[COREF_CHAINS]
if len(chains) > 0:
misc_chains = []
for chain in chains:
if chain.is_start and chain.is_end:
coref_position = "unit-"
elif chain.is_start:
coref_position = "start-"
elif chain.is_end:
coref_position = "end-"
else:
coref_position = "middle-"
is_representative = "repr-" if chain.is_representative else ""
misc_chains.append("%s%sid%d" % (coref_position, is_representative, chain.chain.index))
misc.append("{}={}".format(key, ",".join(misc_chains)))
for key in token_dict.keys():
if key == ID:
token_conll[FIELD_TO_IDX[key]] = id_connector.join([str(x) for x in token_dict[key]]) if isinstance(token_dict[key], tuple) else str(token_dict[key])
elif key == FEATS:
feats = token_dict[key]
if feats:
pieces = feats.split("|")
pieces = sorted(pieces, key=str.casefold)
feats = "|".join(pieces)
token_conll[FIELD_TO_IDX[key]] = str(feats)
elif key in FIELD_TO_IDX:
token_conll[FIELD_TO_IDX[key]] = str(token_dict[key])
elif key == LINE_NUMBER:
# skip this when converting back for now
pass
if misc:
token_conll[FIELD_TO_IDX[MISC]] = "|".join(misc)
else:
token_conll[FIELD_TO_IDX[MISC]] = '_'
# when a word (not mwt token) without head is found, we insert dummy head as required by the UD eval script
if '-' not in token_conll[FIELD_TO_IDX[ID]] and '.' not in token_conll[FIELD_TO_IDX[ID]] and HEAD not in token_dict:
token_conll[FIELD_TO_IDX[HEAD]] = str(int(token_dict[ID] if isinstance(token_dict[ID], int) else token_dict[ID][0]) - 1) # evaluation script requires head: int
return "\t".join(token_conll)
class Token(StanzaObject):
""" A token class that stores attributes of a token and carries a list of words. A token corresponds to a unit in the raw
text. In some languages such as English, a token has a one-to-one mapping to a word, while in other languages such as French,
a (multi-word) token might be expanded into multiple words that carry syntactic annotations.
"""
def __init__(self, sentence, token_entry, words=None):
"""
Construct a token given a dictionary format token entry. Optionally link itself to the corresponding words.
The owning sentence must be passed in.
"""
self._id = token_entry.get(ID)
self._text = token_entry.get(TEXT)
if not self._id:
raise ValueError('id not included for the token')
if not self._text:
raise ValueError('text not included for the token')
self._misc = token_entry.get(MISC, None)
self._ner = token_entry.get(NER, None)
self._multi_ner = token_entry.get(MULTI_NER, None)
self._words = words if words is not None else []
self._start_char = token_entry.get(START_CHAR, None)
self._end_char = token_entry.get(END_CHAR, None)
self._sent = sentence
self._mexp = token_entry.get(MEXP, None)
self._spaces_before = ""
self._spaces_after = " "
self._line_number = None
if self._misc is not None:
init_from_misc(self)
@property
def id(self):
""" Access the index of this token. """
return self._id
@id.setter
def id(self, value):
""" Set the token's id value. """
self._id = value
@property
def manual_expansion(self):
""" Access the whether this token was manually expanded. """
return self._mexp
@manual_expansion.setter
def manual_expansion(self, value):
""" Set the whether this token was manually expanded. """
self._mexp = value
@property
def text(self):
""" Access the text of this token. Example: 'The' """
return self._text
@text.setter
def text(self, value):
""" Set the token's text value. Example: 'The' """
self._text = value
@property
def misc(self):
""" Access the miscellaneousness of this token. """
return self._misc
@misc.setter
def misc(self, value):
""" Set the token's miscellaneousness value. """
self._misc = value if self._is_null(value) == False else None
def consolidate_whitespace(self):
"""
Remove whitespace misc annotations from the Words and mark the whitespace on the Tokens
"""
found_after = False
found_before = False
num_words = len(self.words)
for word_idx, word in enumerate(self.words):
misc = word.misc
if not misc:
continue
pieces = misc.split("|")
if word_idx == 0:
if any(piece.startswith("SpacesBefore=") for piece in pieces):
self.spaces_before = misc_to_space_before(misc)
found_before = True
else:
if any(piece.startswith("SpacesBefore=") for piece in pieces):
warnings.warn("Found a SpacesBefore MISC annotation on a Word that was not the first Word in a Token")
if word_idx == num_words - 1:
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
self.spaces_after = misc_to_space_after(misc)
found_after = True
else:
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
unexpected_space_after = misc_to_space_after(misc)
if unexpected_space_after == "":
warnings.warn("Unexpected SpaceAfter=No annotation on a word in the middle of an MWT")
else:
warnings.warn("Unexpected SpacesAfter on a word in the middle on an MWT")
pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
word.misc = "|".join(pieces)
misc = self.misc
if misc:
pieces = misc.split("|")
if any(piece.startswith("SpacesBefore=") for piece in pieces):
spaces_before = misc_to_space_before(misc)
if found_before:
if spaces_before != self.spaces_before:
warnings.warn("Found conflicting SpacesBefore on a token and its word!")
else:
self.spaces_before = spaces_before
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
spaces_after = misc_to_space_after(misc)
if found_after:
if spaces_after != self.spaces_after:
warnings.warn("Found conflicting SpaceAfter / SpacesAfter on a token and its word!")
else:
self.spaces_after = spaces_after
pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
self.misc = "|".join(pieces)
@property
def spaces_before(self):
""" SpacesBefore for the token. Translated from the MISC fields """
return self._spaces_before
@spaces_before.setter
def spaces_before(self, value):
self._spaces_before = value
@property
def spaces_after(self):
""" SpaceAfter or SpacesAfter for the token. Translated from the MISC field """
return self._spaces_after
@spaces_after.setter
def spaces_after(self, value):
self._spaces_after = value
@property
def words(self):
""" Access the list of syntactic words underlying this token. """
return self._words
@words.setter
def words(self, value):
""" Set this token's list of underlying syntactic words. """
self._words = value
for w in self._words:
w.parent = self
@property
def line_number(self):
""" Access the line number from the original document, if set """
return self._line_number
@property
def start_char(self):
""" Access the start character index for this token in the raw text. """
return self._start_char
@property
def end_char(self):
""" Access the end character index for this token in the raw text. """
return self._end_char
@property
def ner(self):
""" Access the NER tag of this token. Example: 'B-ORG'"""
return self._ner
@ner.setter
def ner(self, value):
""" Set the token's NER tag. Example: 'B-ORG'"""
self._ner = value if self._is_null(value) == False else None
@property
def multi_ner(self):
""" Access the MULTI_NER tag of this token. Example: '(B-ORG, B-DISEASE)'"""
return self._multi_ner
@multi_ner.setter
def multi_ner(self, value):
""" Set the token's MULTI_NER tag. Example: '(B-ORG, B-DISEASE)'"""
self._multi_ner = value if self._is_null(value) == False else None
@property
def sent(self):
""" Access the pointer to the sentence that this token belongs to. """
return self._sent
@sent.setter
def sent(self, value):
""" Set the pointer to the sentence that this token belongs to. """
self._sent = value
def __repr__(self):
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
def __format__(self, spec):
if spec == 'C':
return "\n".join(self.to_conll_text(DEFAULT_OUTPUT_FIELDS))
elif spec == 'P':
return self.pretty_print()
else:
return str(self)
def to_conll_text(self, fields=DEFAULT_OUTPUT_FIELDS):
return "\n".join(dict_to_conll_text(x) for x in self.to_dict(fields))
def to_dict(self, fields=DEFAULT_OUTPUT_FIELDS):
""" Dumps the token into a list of dictionary for this token with its extended words
if the token is a multi-word token.
"""
ret = []
if len(self.id) > 1:
token_dict = {}
for field in fields:
if getattr(self, field, None) is not None:
token_dict[field] = getattr(self, field)
if MISC in fields:
spaces_after = self.spaces_after
if spaces_after is not None and spaces_after != ' ':
space_misc = space_after_to_misc(spaces_after)
if token_dict.get(MISC):
token_dict[MISC] = token_dict[MISC] + "|" + space_misc
else:
token_dict[MISC] = space_misc
spaces_before = self.spaces_before
if spaces_before is not None and spaces_before != '':
space_misc = space_before_to_misc(spaces_before)
if token_dict.get(MISC):
token_dict[MISC] = token_dict[MISC] + "|" + space_misc
else:
token_dict[MISC] = space_misc
ret.append(token_dict)
for word in self.words:
word_dict = word.to_dict(fields)
if len(self.id) == 1 and NER in fields and getattr(self, NER) is not None: # propagate NER label to Word if it is a single-word token
word_dict[NER] = getattr(self, NER)
if len(self.id) == 1 and MULTI_NER in fields and getattr(self, MULTI_NER) is not None: # propagate MULTI_NER label to Word if it is a single-word token
word_dict[MULTI_NER] = getattr(self, MULTI_NER)
if len(self.id) == 1 and MISC in fields:
spaces_after = self.spaces_after
if spaces_after is not None and spaces_after != ' ':
space_misc = space_after_to_misc(spaces_after)
if word_dict.get(MISC):
word_dict[MISC] = word_dict[MISC] + "|" + space_misc
else:
word_dict[MISC] = space_misc
spaces_before = self.spaces_before
if spaces_before is not None and spaces_before != '':
space_misc = space_before_to_misc(spaces_before)
if word_dict.get(MISC):
word_dict[MISC] = word_dict[MISC] + "|" + space_misc
else:
word_dict[MISC] = space_misc
ret.append(word_dict)
return ret
def pretty_print(self):
""" Print this token with its extended words in one line. """
return f"<{self.__class__.__name__} id={'-'.join([str(x) for x in self.id])};words=[{', '.join([word.pretty_print() for word in self.words])}]>"
def _is_null(self, value):
return (value is None) or (value == '_')
def is_mwt(self):
return len(self.words) > 1
class Word(StanzaObject):
""" A word class that stores attributes of a word.
"""
def __init__(self, sentence, word_entry):
""" Construct a word given a dictionary format word entry.
"""
self._id = word_entry.get(ID, None)
if isinstance(self._id, tuple):
if len(self._id) == 1:
self._id = self._id[0]
self._text = word_entry.get(TEXT, None)
assert self._id is not None and self._text is not None, 'id and text should be included for the word. {}'.format(word_entry)
self._lemma = word_entry.get(LEMMA, None)
self._upos = word_entry.get(UPOS, None)
self._xpos = word_entry.get(XPOS, None)
self._feats = word_entry.get(FEATS, None)
self._head = word_entry.get(HEAD, None)
self._deprel = word_entry.get(DEPREL, None)
self._misc = word_entry.get(MISC, None)
self._start_char = word_entry.get(START_CHAR, None)
self._end_char = word_entry.get(END_CHAR, None)
self._parent = None
self._sent = sentence
self._mexp = word_entry.get(MEXP, None)
self._coref_chains = None
self._line_number = None
if self._misc is not None:
init_from_misc(self)
# use the setter, which will go up to the sentence and set the
# dependencies on that graph
self.deps = word_entry.get(DEPS, None)
@property
def manual_expansion(self):
""" Access the whether this token was manually expanded. """
return self._mexp
@manual_expansion.setter
def manual_expansion(self, value):
""" Set the whether this token was manually expanded. """
self._mexp = value
@property
def id(self):
""" Access the index of this word. """
return self._id
@id.setter
def id(self, value):
""" Set the word's index value. """
self._id = value
@property
def text(self):
""" Access the text of this word. Example: 'The'"""
return self._text
@text.setter
def text(self, value):
""" Set the word's text value. Example: 'The'"""
self._text = value
@property
def lemma(self):
""" Access the lemma of this word. """
return self._lemma
@lemma.setter
def lemma(self, value):
""" Set the word's lemma value. """
self._lemma = value if self._is_null(value) == False or self._text == '_' else None
@property
def upos(self):
""" Access the universal part-of-speech of this word. Example: 'NOUN'"""
return self._upos
@upos.setter
def upos(self, value):
""" Set the word's universal part-of-speech value. Example: 'NOUN'"""
self._upos = value if self._is_null(value) == False else None
@property
def xpos(self):
""" Access the treebank-specific part-of-speech of this word. Example: 'NNP'"""
return self._xpos
@xpos.setter
def xpos(self, value):
""" Set the word's treebank-specific part-of-speech value. Example: 'NNP'"""
self._xpos = value if self._is_null(value) == False else None
@property
def feats(self):
""" Access the morphological features of this word. Example: 'Gender=Fem'"""
return self._feats
@feats.setter
def feats(self, value):
""" Set this word's morphological features. Example: 'Gender=Fem'"""
self._feats = value if self._is_null(value) == False else None
@property
def head(self):
""" Access the id of the governor of this word. """
return self._head
@head.setter
def head(self, value):
""" Set the word's governor id value. """
self._head = int(value) if self._is_null(value) == False else None
@property
def deprel(self):
""" Access the dependency relation of this word. Example: 'nmod'"""
return self._deprel
@deprel.setter
def deprel(self, value):
""" Set the word's dependency relation value. Example: 'nmod'"""
self._deprel = value if self._is_null(value) == False else None
@property
def deps(self):
""" Access the dependencies of this word. """
graph = self._sent._enhanced_dependencies
if graph is None or not graph.has_node(self.id):
return None
data = []
predecessors = sorted(list(graph.predecessors(self.id)), key=lambda x: x if isinstance(x, tuple) else (x,))
for parent in predecessors:
deps = sorted(list(graph.get_edge_data(parent, self.id)))
for dep in deps:
if isinstance(parent, int):
data.append("%d:%s" % (parent, dep))
else:
data.append("%d.%d:%s" % (parent[0], parent[1], dep))
if not data:
return None
return "|".join(data)
@deps.setter
def deps(self, value):
""" Set the word's dependencies value. """
graph = self._sent._enhanced_dependencies
# if we don't have a graph, and we aren't trying to set any actual
# dependencies, we can save the time of doing anything else
if graph is None and value is None:
return
if graph is None:
graph = nx.MultiDiGraph()
self._sent._enhanced_dependencies = graph
# need to make a new list: cannot iterate and delete at the same time
if graph.has_node(self.id):
in_edges = list(graph.in_edges(self.id))
graph.remove_edges_from(in_edges)
if value is None:
return
if isinstance(value, str):
value = value.split("|")
if all(isinstance(x, str) for x in value):
value = [x.split(":", maxsplit=1) for x in value]
for parent, dep in value:
# we have to match the format of the IDs. since the IDs
# of the words are int if they aren't empty words, we need
# to convert single int IDs into int instead of tuple
parent = tuple(map(int, parent.split(".", maxsplit=1)))
if len(parent) == 1:
parent = parent[0]
graph.add_edge(parent, self.id, dep)
@property
def misc(self):
""" Access the miscellaneousness of this word. """
return self._misc
@misc.setter
def misc(self, value):
""" Set the word's miscellaneousness value. """
self._misc = value if self._is_null(value) == False else None
@property
def line_number(self):
""" Access the line number from the original document, if set """
return self._line_number
@property
def start_char(self):
""" Access the start character index for this token in the raw text. """
return self._start_char
@start_char.setter
def start_char(self, value):
self._start_char = value
@property
def end_char(self):
""" Access the end character index for this token in the raw text. """
return self._end_char
@end_char.setter
def end_char(self, value):
self._end_char = value
@property
def parent(self):
""" Access the parent token of this word. In the case of a multi-word token, a token can be the parent of
multiple words. Note that this should return a reference to the parent token object.
"""
return self._parent
@parent.setter
def parent(self, value):
""" Set this word's parent token. In the case of a multi-word token, a token can be the parent of
multiple words. Note that value here should be a reference to the parent token object.
"""
self._parent = value
@property
def pos(self):
""" Access the universal part-of-speech of this word. Example: 'NOUN'"""
return self._upos
@pos.setter
def pos(self, value):
""" Set the word's universal part-of-speech value. Example: 'NOUN'"""
self._upos = value if self._is_null(value) == False else None
@property
def coref_chains(self):
"""
coref_chains points to a list of CorefChain namedtuple, which has a list of mentions and a representative mention.
Useful for disambiguating words such as "him" (in languages where coref is available)
Theoretically it is possible for multiple corefs to occur at the same word. For example,
"Chris Manning's NLP Group"
could have "Chris Manning" and "Chris Manning's NLP Group" as overlapping entities
"""
return self._coref_chains
@coref_chains.setter
def coref_chains(self, chain):
""" Set the backref for the coref chains """
self._coref_chains = chain
@property
def sent(self):
""" Access the pointer to the sentence that this word belongs to. """
return self._sent
@sent.setter
def sent(self, value):
""" Set the pointer to the sentence that this word belongs to. """
self._sent = value
def __repr__(self):
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
def __format__(self, spec):
if spec == 'C':
return self.to_conll_text(DEFAULT_OUTPUT_FIELDS)
elif spec == 'P':
return self.pretty_print()
else:
return str(self)
def to_conll_text(self, fields=DEFAULT_OUTPUT_FIELDS):
"""
Turn a word into a conll representation (10 column tab separated)
"""
token_dict = self.to_dict(fields)
return dict_to_conll_text(token_dict, '.')
def to_dict(self, fields=DEFAULT_OUTPUT_FIELDS):
""" Dumps the word into a dictionary.
"""
word_dict = {}
for field in fields:
if getattr(self, field, None) is not None:
word_dict[field] = getattr(self, field)
return word_dict
def pretty_print(self):
""" Print the word in one line. """
features = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL]
feature_str = ";".join(["{}={}".format(k, getattr(self, k)) for k in features if getattr(self, k) is not None])
return f"<{self.__class__.__name__} {feature_str}>"
def _is_null(self, value):
return (value is None) or (value == '_')
class Span(StanzaObject):
""" A span class that stores attributes of a textual span. A span can be typed.
A range of objects (e.g., entity mentions) can be represented as spans.
"""
def __init__(self, span_entry=None, tokens=None, type=None, doc=None, sent=None):
""" Construct a span given a span entry or a list of tokens. A valid reference to a doc
must be provided to construct a span (otherwise the text of the span cannot be initialized).
"""
assert span_entry is not None or (tokens is not None and type is not None), \
'Either a span_entry or a token list needs to be provided to construct a span.'
assert doc is not None, 'A parent doc must be provided to construct a span.'
self._text, self._type, self._start_char, self._end_char = [None] * 4
self._tokens = []
self._words = []
self._doc = doc
self._sent = sent
if span_entry is not None:
self.init_from_entry(span_entry)
if tokens is not None:
self.init_from_tokens(tokens, type)
def init_from_entry(self, span_entry):
self.text = span_entry.get(TEXT, None)
self.type = span_entry.get(TYPE, None)
self.start_char = span_entry.get(START_CHAR, None)
self.end_char = span_entry.get(END_CHAR, None)
def init_from_tokens(self, tokens, type):
assert isinstance(tokens, list), 'Tokens must be provided as a list to construct a span.'
assert len(tokens) > 0, "Tokens of a span cannot be an empty list."
self.tokens = tokens
self.type = type
# load start and end char offsets from tokens
self.start_char = self.tokens[0].start_char
self.end_char = self.tokens[-1].end_char
if self.doc is not None and self.doc.text is not None:
self.text = self.doc.text[self.start_char:self.end_char]
elif tokens[0].sent is tokens[-1].sent:
sentence = tokens[0].sent
if tokens[-1].end_char is not None and tokens[0].start_char is not None and sentence.tokens[0].start_char is not None:
text_start = tokens[0].start_char - sentence.tokens[0].start_char
text_end = tokens[-1].end_char - sentence.tokens[0].start_char
self.text = sentence.text[text_start:text_end]
else:
text = []
for token in tokens:
text.append(token.text)
text.append(token.spaces_after)
self.text = "".join(text[:-1])
else:
# TODO: do any spans ever cross sentences?
raise RuntimeError("Document text does not exist, and the span tested crosses two sentences, so it is impossible to extract the entity text!")
# collect the words of the span following tokens
self.words = [w for t in tokens for w in t.words]
# set the sentence back-pointer to point to the sentence of the first token
self.sent = tokens[0].sent
@property
def doc(self):
""" Access the parent doc of this span. """
return self._doc
@doc.setter
def doc(self, value):
""" Set the parent doc of this span. """
self._doc = value
@property
def text(self):
""" Access the text of this span. Example: 'Stanford University'"""
return self._text
@text.setter
def text(self, value):
""" Set the span's text value. Example: 'Stanford University'"""
self._text = value
@property
def tokens(self):
""" Access reference to a list of tokens that correspond to this span. """
return self._tokens
@tokens.setter
def tokens(self, value):
""" Set the span's list of tokens. """
self._tokens = value
@property
def words(self):
""" Access reference to a list of words that correspond to this span. """
return self._words
@words.setter
def words(self, value):
""" Set the span's list of words. """
self._words = value
@property
def type(self):
""" Access the type of this span. Example: 'PERSON'"""
return self._type
@type.setter
def type(self, value):
""" Set the type of this span. """
self._type = value
@property
def start_char(self):
""" Access the start character offset of this span. """
return self._start_char
@start_char.setter
def start_char(self, value):
""" Set the start character offset of this span. """
self._start_char = value
@property
def end_char(self):
""" Access the end character offset of this span. """
return self._end_char
@end_char.setter
def end_char(self, value):
""" Set the end character offset of this span. """
self._end_char = value
@property
def sent(self):
""" Access the pointer to the sentence that this span belongs to. """
return self._sent
@sent.setter
def sent(self, value):
""" Set the pointer to the sentence that this span belongs to. """
self._sent = value
def to_dict(self):
""" Dumps the span into a dictionary. """
attrs = ['text', 'type', 'start_char', 'end_char']
span_dict = dict([(attr_name, getattr(self, attr_name)) for attr_name in attrs])
return span_dict
def __repr__(self):
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
def pretty_print(self):
""" Print the span in one line. """
span_dict = self.to_dict()
feature_str = ";".join(["{}={}".format(k,v) for k,v in span_dict.items()])
return f"<{self.__class__.__name__} {feature_str}>"
================================================
FILE: stanza/models/common/dropout.py
================================================
import torch
import torch.nn as nn
class WordDropout(nn.Module):
""" A word dropout layer that's designed for embedded inputs (e.g., any inputs to an LSTM layer).
Given a batch of embedded inputs, this layer randomly set some of them to be a replacement state.
Note that this layer assumes the last dimension of the input to be the hidden dimension of a unit.
"""
def __init__(self, dropprob):
super().__init__()
self.dropprob = dropprob
def forward(self, x, replacement=None):
if not self.training or self.dropprob == 0:
return x
masksize = [y for y in x.size()]
masksize[-1] = 1
dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
res = x.masked_fill(dropmask, 0)
if replacement is not None:
res = res + dropmask.float() * replacement
return res
def extra_repr(self):
return 'p={}'.format(self.dropprob)
class LockedDropout(nn.Module):
"""
A variant of dropout layer that consistently drops out the same parameters over time. Also known as the variational dropout.
This implementation was modified from the LockedDropout implementation in the flair library (https://github.com/zalandoresearch/flair).
"""
def __init__(self, dropprob, batch_first=True):
super().__init__()
self.dropprob = dropprob
self.batch_first = batch_first
def forward(self, x):
if not self.training or self.dropprob == 0:
return x
if not self.batch_first:
m = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
else:
m = x.new_empty(x.size(0), 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
mask = m.div(1 - self.dropprob).expand_as(x)
return mask * x
def extra_repr(self):
return 'p={}'.format(self.dropprob)
class SequenceUnitDropout(nn.Module):
""" A unit dropout layer that's designed for input of sequence units (e.g., word sequence, char sequence, etc.).
Given a sequence of unit indices, this layer randomly set some of them to be a replacement id (usually set to be ).
"""
def __init__(self, dropprob, replacement_id):
super().__init__()
self.dropprob = dropprob
self.replacement_id = replacement_id
def forward(self, x):
""" :param: x must be a LongTensor of unit indices. """
if not self.training or self.dropprob == 0:
return x
masksize = [y for y in x.size()]
dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
res = x.masked_fill(dropmask, self.replacement_id)
return res
def extra_repr(self):
return 'p={}, replacement_id={}'.format(self.dropprob, self.replacement_id)
================================================
FILE: stanza/models/common/exceptions.py
================================================
"""
A couple more specific FileNotFoundError exceptions
The idea being, the caller can catch it and report a more useful error resolution
"""
import errno
class ForwardCharlmNotFoundError(FileNotFoundError):
def __init__(self, msg, filename):
super().__init__(errno.ENOENT, msg, filename)
class BackwardCharlmNotFoundError(FileNotFoundError):
def __init__(self, msg, filename):
super().__init__(errno.ENOENT, msg, filename)
================================================
FILE: stanza/models/common/foundation_cache.py
================================================
"""
Keeps BERT, charlm, word embedings in a cache to save memory
"""
from collections import namedtuple
from copy import deepcopy
import logging
import threading
from stanza.models.common import bert_embedding
from stanza.models.common.char_model import CharacterLanguageModel
from stanza.models.common.pretrain import Pretrain
logger = logging.getLogger('stanza')
BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])
class FoundationCache:
def __init__(self, other=None, local_files_only=False):
if other is None:
self.bert = {}
self.charlms = {}
self.pretrains = {}
# future proof the module by using a lock for the glorious day
# when the GIL is finally gone
self.lock = threading.Lock()
else:
self.bert = other.bert
self.charlms = other.charlms
self.pretrains = other.pretrains
self.lock = other.lock
self.local_files_only=local_files_only
def load_bert(self, transformer_name, local_files_only=None):
m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
return m, t
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
"""
Load a transformer only once
Uses a lock for thread safety
"""
if transformer_name is None:
return None, None, None
with self.lock:
if transformer_name not in self.bert:
if local_files_only is None:
local_files_only = self.local_files_only
model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
self.bert[transformer_name] = BertRecord(model, tokenizer, {})
else:
logger.debug("Reusing bert %s", transformer_name)
bert_record = self.bert[transformer_name]
if not peft_name:
return bert_record.model, bert_record.tokenizer, None
if peft_name not in bert_record.peft_ids:
bert_record.peft_ids[peft_name] = 0
else:
bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1
peft_name = "%s_%d" % (peft_name, bert_record.peft_ids[peft_name])
return bert_record.model, bert_record.tokenizer, peft_name
def load_charlm(self, filename):
if not filename:
return None
with self.lock:
if filename not in self.charlms:
logger.debug("Loading charlm from %s", filename)
self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)
else:
logger.debug("Reusing charlm from %s", filename)
return self.charlms[filename]
def load_pretrain(self, filename):
"""
Load a pretrained word embedding only once
Uses a lock for thread safety
"""
if filename is None:
return None
with self.lock:
if filename not in self.pretrains:
logger.debug("Loading pretrain %s", filename)
self.pretrains[filename] = Pretrain(filename)
else:
logger.debug("Reusing pretrain %s", filename)
return self.pretrains[filename]
class NoTransformerFoundationCache(FoundationCache):
"""
Uses the underlying FoundationCache, but hiding the transformer.
Useful for when loading a downstream model such as POS which has a
finetuned transformer, and we don't want the transformer reused
since it will then have the finetuned weights for other models
which don't want them
"""
def load_bert(self, transformer_name, local_files_only=None):
return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
def load_bert(model_name, foundation_cache=None, local_files_only=None):
"""
Load a bert, possibly using a foundation cache, ignoring the cache if None
"""
if foundation_cache is None:
return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
else:
return foundation_cache.load_bert(model_name, local_files_only=local_files_only)
def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
if foundation_cache is None:
m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
return m, t, peft_name
return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)
def load_charlm(charlm_file, foundation_cache=None, finetune=False):
if not charlm_file:
return None
if finetune:
# can't use the cache in the case of a model which will be finetuned
# and the numbers will be different for other users of the model
return CharacterLanguageModel.load(charlm_file, finetune=True)
if foundation_cache is not None:
return foundation_cache.load_charlm(charlm_file)
logger.debug("Loading charlm from %s", charlm_file)
return CharacterLanguageModel.load(charlm_file, finetune=False)
def load_pretrain(filename, foundation_cache=None):
if not filename:
return None
if foundation_cache is not None:
return foundation_cache.load_pretrain(filename)
logger.debug("Loading pretrain from %s", filename)
return Pretrain(filename)
================================================
FILE: stanza/models/common/hlstm.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
from stanza.models.common.packed_lstm import PackedLSTM
class HLSTMCell(nn.modules.rnn.RNNCellBase):
"""
A Highway LSTM Cell as proposed in Zhang et al. (2018) Highway Long Short-Term Memory RNNs for
Distant Speech Recognition.
"""
def __init__(self, input_size, hidden_size, bias=True):
super(HLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# LSTM parameters
self.Wi = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
self.Wf = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
self.Wo = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
self.Wg = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
# highway gate parameters
self.gate = nn.Linear(input_size + 2 * hidden_size, hidden_size, bias=bias)
def forward(self, input, c_l_minus_one=None, hx=None):
self.check_forward_input(input)
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
hx = (hx, hx)
if c_l_minus_one is None:
c_l_minus_one = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
self.check_forward_hidden(input, hx[0], '[0]')
self.check_forward_hidden(input, hx[1], '[1]')
self.check_forward_hidden(input, c_l_minus_one, 'c_l_minus_one')
# vanilla LSTM computation
rec_input = torch.cat([input, hx[0]], 1)
i = F.sigmoid(self.Wi(rec_input))
f = F.sigmoid(self.Wf(rec_input))
o = F.sigmoid(self.Wo(rec_input))
g = F.tanh(self.Wg(rec_input))
# highway gates
gate = F.sigmoid(self.gate(torch.cat([c_l_minus_one, hx[1], input], 1)))
c = gate * c_l_minus_one + f * hx[1] + i * g
h = o * F.tanh(c)
return h, c
# Highway LSTM network, does NOT use the HLSTMCell above
class HighwayLSTM(nn.Module):
"""
A Highway LSTM network, as used in the original Tensorflow version of the Dozat parser. Note that this
is independent from the HLSTMCell above.
"""
def __init__(self, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0, bidirectional=False, rec_dropout=0, highway_func=None, pad=False):
super(HighwayLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.dropout_state = {}
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self.highway_func = highway_func
self.pad = pad
self.lstm = nn.ModuleList()
self.highway = nn.ModuleList()
self.gate = nn.ModuleList()
self.drop = nn.Dropout(dropout, inplace=True)
in_size = input_size
for l in range(num_layers):
self.lstm.append(PackedLSTM(in_size, hidden_size, num_layers=1, bias=bias,
batch_first=batch_first, dropout=0, bidirectional=bidirectional, rec_dropout=rec_dropout))
self.highway.append(nn.Linear(in_size, hidden_size * self.num_directions))
self.gate.append(nn.Linear(in_size, hidden_size * self.num_directions))
self.highway[-1].bias.data.zero_()
self.gate[-1].bias.data.zero_()
in_size = hidden_size * self.num_directions
def forward(self, input, seqlens, hx=None):
highway_func = (lambda x: x) if self.highway_func is None else self.highway_func
hs = []
cs = []
if not isinstance(input, PackedSequence):
input = pack_padded_sequence(input, seqlens, batch_first=self.batch_first)
for l in range(self.num_layers):
if l > 0:
input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None
h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx)
hs.append(ht)
cs.append(ct)
input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
if self.pad:
input = pad_packed_sequence(input, batch_first=self.batch_first)[0]
return input, (torch.cat(hs, 0), torch.cat(cs, 0))
if __name__ == "__main__":
T = 10
bidir = True
num_dir = 2 if bidir else 1
rnn = HighwayLSTM(10, 20, num_layers=2, bidirectional=True)
input = torch.randn(T, 3, 10)
hx = torch.randn(2 * num_dir, 3, 20)
cx = torch.randn(2 * num_dir, 3, 20)
output = rnn(input, (hx, cx))
print(output)
================================================
FILE: stanza/models/common/large_margin_loss.py
================================================
"""
LargeMarginInSoftmax, from the article
@inproceedings{kobayashi2019bmvc,
title={Large Margin In Softmax Cross-Entropy Loss},
author={Takumi Kobayashi},
booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
year={2019}
}
implementation from
https://github.com/tk1980/LargeMarginInSoftmax
There is no license specifically chosen; they just ask people to cite the paper if the work is useful.
"""
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss):
r"""
This combines the Softmax Cross-Entropy Loss (nn.CrossEntropyLoss) and the large-margin inducing
regularization proposed in
T. Kobayashi, "Large-Margin In Softmax Cross-Entropy Loss." In BMVC2019.
This loss function inherits the parameters from nn.CrossEntropyLoss except for `reg_lambda` and `deg_logit`.
Args:
reg_lambda (float, optional): a regularization parameter. (default: 0.3)
deg_logit (bool, optional): underestimate (degrade) the target logit by -1 or not. (default: False)
If True, it realizes the method that incorporates the modified loss into ours
as described in the above paper (Table 4).
"""
def __init__(self, reg_lambda=0.3, deg_logit=None,
weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average,
ignore_index=ignore_index, reduce=reduce, reduction=reduction)
self.reg_lambda = reg_lambda
self.deg_logit = deg_logit
def forward(self, input, target):
N = input.size(0) # number of samples
C = input.size(1) # number of classes
Mask = torch.zeros_like(input, requires_grad=False)
Mask[range(N),target] = 1
if self.deg_logit is not None:
input = input - self.deg_logit * Mask
loss = F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
X = input - 1.e6 * Mask # [N x C], excluding the target class
reg = 0.5 * ((F.softmax(X, dim=1) - 1.0/(C-1)) * F.log_softmax(X, dim=1) * (1.0-Mask)).sum(dim=1)
if self.reduction == 'sum':
reg = reg.sum()
elif self.reduction == 'mean':
reg = reg.mean()
elif self.reduction == 'none':
reg = reg
return loss + self.reg_lambda * reg
================================================
FILE: stanza/models/common/loss.py
================================================
"""
Different loss functions.
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import stanza.models.common.seq2seq_constant as constant
logger = logging.getLogger('stanza')
def SequenceLoss(vocab_size):
weight = torch.ones(vocab_size)
weight[constant.PAD_ID] = 0
crit = nn.NLLLoss(weight)
return crit
def weighted_cross_entropy_loss(labels, log_dampened=False):
"""
Either return a loss function which reweights all examples so the
classes have the same effective weight, or dampened reweighting
using log() so that the biggest class has some priority
"""
if isinstance(labels, list):
all_labels = np.array(labels)
_, weights = np.unique(labels, return_counts=True)
weights = weights / float(np.sum(weights))
weights = np.sum(weights) / weights
if log_dampened:
weights = 1 + np.log(weights)
logger.debug("Reweighting cross entropy by {}".format(weights))
loss = nn.CrossEntropyLoss(
weight=torch.from_numpy(weights).type('torch.FloatTensor')
)
return loss
class FocalLoss(nn.Module):
"""
Uses the model's assessment of how likely the correct answer is
to weight the loss for a each error
multi-category focal loss, in other words
from "Focal Loss for Dense Object Detection"
https://arxiv.org/abs/1708.02002
"""
def __init__(self, reduction='mean', gamma=2.0):
super().__init__()
if reduction not in ('sum', 'none', 'mean'):
raise ValueError("Unknown reduction: %s" % reduction)
self.reduction = reduction
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
self.gamma = gamma
def forward(self, inputs, targets):
"""
Weight the loss using the models assessment of the correct answer
inputs: [N, C]
targets: [N]
"""
if len(inputs.shape) == 2 and len(targets.shape) == 1:
if inputs.shape[0] != targets.shape[0]:
raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
elif len(inputs.shape) == 1 and len(targets.shape) == 0:
raise NotImplementedError("This would be a reasonable thing to implement, but we haven't done it yet")
else:
raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
raw_loss = self.ce_loss(inputs, targets)
assert len(raw_loss.shape) == 1 and raw_loss.shape[0] == inputs.shape[0]
# https://www.tutorialexample.com/implement-focal-loss-for-multi-label-classification-in-pytorch-pytorch-tutorial/
final_loss = raw_loss * ((1 - torch.exp(-raw_loss)) ** self.gamma)
assert len(final_loss.shape) == 1 and final_loss.shape[0] == inputs.shape[0]
if self.reduction == 'sum':
return final_loss.sum()
elif self.reduction == 'mean':
return final_loss.mean()
elif self.reduction == 'none':
return final_loss
raise AssertionError("unknown reduction! how did this happen??")
class MixLoss(nn.Module):
"""
A mixture of SequenceLoss and CrossEntropyLoss.
Loss = SequenceLoss + alpha * CELoss
"""
def __init__(self, vocab_size, alpha):
super().__init__()
self.seq_loss = SequenceLoss(vocab_size)
self.ce_loss = nn.CrossEntropyLoss()
assert alpha >= 0
self.alpha = alpha
def forward(self, seq_inputs, seq_targets, class_inputs, class_targets):
sl = self.seq_loss(seq_inputs, seq_targets)
cel = self.ce_loss(class_inputs, class_targets)
loss = sl + self.alpha * cel
return loss
class MaxEntropySequenceLoss(nn.Module):
"""
A max entropy loss that encourage the model to have large entropy,
therefore giving more diverse outputs.
Loss = NLLLoss + alpha * EntropyLoss
"""
def __init__(self, vocab_size, alpha):
super().__init__()
weight = torch.ones(vocab_size)
weight[constant.PAD_ID] = 0
self.nll = nn.NLLLoss(weight)
self.alpha = alpha
def forward(self, inputs, targets):
"""
inputs: [N, C]
targets: [N]
"""
assert inputs.size(0) == targets.size(0)
nll_loss = self.nll(inputs, targets)
# entropy loss
mask = targets.eq(constant.PAD_ID).unsqueeze(1).expand_as(inputs)
masked_inputs = inputs.clone().masked_fill_(mask, 0.0)
p = torch.exp(masked_inputs)
ent_loss = p.mul(masked_inputs).sum() / inputs.size(0) # average over minibatch
loss = nll_loss + self.alpha * ent_loss
return loss
================================================
FILE: stanza/models/common/maxout_linear.py
================================================
"""
A layer which implements maxout from the "Maxout Networks" paper
https://arxiv.org/pdf/1302.4389v4.pdf
Goodfellow, Warde-Farley, Mirza, Courville, Bengio
or a simpler explanation here:
https://stats.stackexchange.com/questions/129698/what-is-maxout-in-neural-network/298705#298705
The implementation here:
for k layers of maxout, in -> out channels, we make a single linear
map of size in -> out*k
then we reshape the end to be (..., k, out)
and return the max over the k layers
"""
import torch
import torch.nn as nn
class MaxoutLinear(nn.Module):
def __init__(self, in_channels, out_channels, maxout_k):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.maxout_k = maxout_k
self.linear = nn.Linear(in_channels, out_channels * maxout_k)
def forward(self, inputs):
"""
Use the oversized linear as the repeated linear, then take the max
One large linear map makes the implementation simpler and easier for pytorch to make parallel
"""
outputs = self.linear(inputs)
outputs = outputs.view(*outputs.shape[:-1], self.maxout_k, self.out_channels)
outputs = torch.max(outputs, dim=-2)[0]
return outputs
================================================
FILE: stanza/models/common/packed_lstm.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
class PackedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
super().__init__()
self.batch_first = batch_first
self.pad = pad
if rec_dropout == 0:
# use the fast, native LSTM implementation
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
else:
self.lstm = LSTMwRecDropout(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, rec_dropout=rec_dropout)
def forward(self, input, lengths, hx=None):
if not isinstance(input, PackedSequence):
input = pack_padded_sequence(input, lengths, batch_first=self.batch_first)
res = self.lstm(input, hx)
if self.pad:
res = (pad_packed_sequence(res[0], batch_first=self.batch_first)[0], res[1])
return res
class LSTMwRecDropout(nn.Module):
""" An LSTM implementation that supports recurrent dropout """
def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
super().__init__()
self.batch_first = batch_first
self.pad = pad
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout = dropout
self.drop = nn.Dropout(dropout, inplace=True)
self.rec_drop = nn.Dropout(rec_dropout, inplace=True)
self.num_directions = 2 if bidirectional else 1
self.cells = nn.ModuleList()
for l in range(num_layers):
in_size = input_size if l == 0 else self.num_directions * hidden_size
for d in range(self.num_directions):
self.cells.append(nn.LSTMCell(in_size, hidden_size, bias=bias))
def forward(self, input, hx=None):
def rnn_loop(x, batch_sizes, cell, inits, reverse=False):
# RNN loop for one layer in one direction with recurrent dropout
# Assumes input is PackedSequence, returns PackedSequence as well
batch_size = batch_sizes[0].item()
states = [list(init.split([1] * batch_size)) for init in inits]
h_drop_mask = x.new_ones(batch_size, self.hidden_size)
h_drop_mask = self.rec_drop(h_drop_mask)
resh = []
if not reverse:
st = 0
for bs in batch_sizes:
s1 = cell(x[st:st+bs], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
resh.append(s1[0])
for j in range(bs):
states[0][j] = s1[0][j].unsqueeze(0)
states[1][j] = s1[1][j].unsqueeze(0)
st += bs
else:
en = x.size(0)
for i in range(batch_sizes.size(0)-1, -1, -1):
bs = batch_sizes[i]
s1 = cell(x[en-bs:en], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
resh.append(s1[0])
for j in range(bs):
states[0][j] = s1[0][j].unsqueeze(0)
states[1][j] = s1[1][j].unsqueeze(0)
en -= bs
resh = list(reversed(resh))
return torch.cat(resh, 0), tuple(torch.cat(s, 0) for s in states)
all_states = [[], []]
inputdata, batch_sizes = input.data, input.batch_sizes
for l in range(self.num_layers):
new_input = []
if self.dropout > 0 and l > 0:
inputdata = self.drop(inputdata)
for d in range(self.num_directions):
idx = l * self.num_directions + d
cell = self.cells[idx]
out, states = rnn_loop(inputdata, batch_sizes, cell, (hx[i][idx] for i in range(2)) if hx is not None else (input.data.new_zeros(input.batch_sizes[0].item(), self.hidden_size, requires_grad=False) for _ in range(2)), reverse=(d == 1))
new_input.append(out)
all_states[0].append(states[0].unsqueeze(0))
all_states[1].append(states[1].unsqueeze(0))
if self.num_directions > 1:
# concatenate both directions
inputdata = torch.cat(new_input, 1)
else:
inputdata = new_input[0]
input = PackedSequence(inputdata, batch_sizes)
return input, tuple(torch.cat(x, 0) for x in all_states)
================================================
FILE: stanza/models/common/peft_config.py
================================================
"""
Set a few common flags for peft uage
"""
TRANSFORMER_LORA_RANK = {}
DEFAULT_LORA_RANK = 64
TRANSFORMER_LORA_ALPHA = {}
DEFAULT_LORA_ALPHA = 128
TRANSFORMER_LORA_DROPOUT = {}
DEFAULT_LORA_DROPOUT = 0.1
TRANSFORMER_LORA_TARGETS = {
"princeton-nlp/Sheared-LLaMA-1.3B": "self_attn.k_proj,self_attn.v_proj,self_attn.o_proj,mlp.gate_proj,mlp.up_proj,mlp.down_proj"
}
DEFAULT_LORA_TARGETS = "query,value,output.dense,intermediate.dense"
TRANSFORMER_LORA_SAVE = {}
DEFAULT_LORA_SAVE = ""
def add_peft_args(parser):
"""
Add common default flags to an argparse
"""
parser.add_argument('--lora_rank', type=int, default=None, help="Rank of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_RANK)
parser.add_argument('--lora_alpha', type=int, default=None, help="Alpha of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_ALPHA)
parser.add_argument('--lora_dropout', type=float, default=None, help="Dropout for the LoRA approximation. Default will be %s or a model-specific parameter" % DEFAULT_LORA_DROPOUT)
parser.add_argument('--lora_target_modules', type=str, default=None, help="Comma separated list of LoRA targets. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_TARGETS)
parser.add_argument('--lora_modules_to_save', type=str, default=None, help="Comma separated list of modules to save (eg, fully tune) when using LoRA. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_SAVE)
parser.add_argument('--use_peft', default=False, action='store_true', help="Finetune Bert using peft")
def pop_peft_args(args):
"""
Pop all of the peft-related arguments from a given dict
Useful for making sure a model loaded from disk is recreated with
the right shapes, for example
"""
args.pop("lora_rank", None)
args.pop("lora_alpha", None)
args.pop("lora_dropout", None)
args.pop("lora_target_modules", None)
args.pop("lora_modules_to_save", None)
args.pop("use_peft", None)
def resolve_peft_args(args, logger, check_bert_finetune=True):
if not hasattr(args, 'bert_model'):
return
if args.lora_rank is None:
args.lora_rank = TRANSFORMER_LORA_RANK.get(args.bert_model, DEFAULT_LORA_RANK)
if args.lora_alpha is None:
args.lora_alpha = TRANSFORMER_LORA_ALPHA.get(args.bert_model, DEFAULT_LORA_ALPHA)
if args.lora_dropout is None:
args.lora_dropout = TRANSFORMER_LORA_DROPOUT.get(args.bert_model, DEFAULT_LORA_DROPOUT)
if args.lora_target_modules is None:
args.lora_target_modules = TRANSFORMER_LORA_TARGETS.get(args.bert_model, DEFAULT_LORA_TARGETS)
if not args.lora_target_modules.strip():
args.lora_target_modules = []
else:
args.lora_target_modules = args.lora_target_modules.split(",")
if args.lora_modules_to_save is None:
args.lora_modules_to_save = TRANSFORMER_LORA_SAVE.get(args.bert_model, DEFAULT_LORA_SAVE)
if not args.lora_modules_to_save.strip():
args.lora_modules_to_save = []
else:
args.lora_modules_to_save = args.lora_modules_to_save.split(",")
if check_bert_finetune and hasattr(args, 'bert_finetune'):
if args.use_peft and not args.bert_finetune:
logger.info("--use_peft set. setting --bert_finetune as well")
args.bert_finetune = True
def build_peft_config(args, logger):
# Hide import so that the peft dependency is optional
from peft import LoraConfig
logger.debug("Creating lora adapter with rank %d and alpha %d", args['lora_rank'], args['lora_alpha'])
peft_config = LoraConfig(inference_mode=False,
r=args['lora_rank'],
target_modules=args['lora_target_modules'],
lora_alpha=args['lora_alpha'],
lora_dropout=args['lora_dropout'],
modules_to_save=args['lora_modules_to_save'],
bias="none")
return peft_config
def build_peft_wrapper(bert_model, args, logger, adapter_name="default"):
# Hide import so that the peft dependency is optional
from peft import get_peft_model
peft_config = build_peft_config(args, logger)
pefted = get_peft_model(bert_model, peft_config, adapter_name=adapter_name)
# apparently get_peft_model doesn't actually mark that
# peft configs are loaded, making it impossible to turn off (or on)
# the peft adapter later
bert_model._hf_peft_config_loaded = True
pefted._hf_peft_config_loaded = True
pefted.set_adapter(adapter_name)
return pefted
def load_peft_wrapper(bert_model, lora_params, args, logger, adapter_name):
peft_config = build_peft_config(args, logger)
try:
bert_model.load_adapter(adapter_name=adapter_name, peft_config=peft_config, adapter_state_dict=lora_params)
except (ValueError, TypeError) as _:
from peft import set_peft_model_state_dict
# this can happen if the adapter already exists...
# in that case, try setting the adapter weights?
set_peft_model_state_dict(bert_model, lora_params, adapter_name=adapter_name)
bert_model.set_adapter(adapter_name)
return bert_model
================================================
FILE: stanza/models/common/pretrain.py
================================================
"""
Supports for pretrained data.
"""
import csv
import os
import re
import lzma
import logging
import numpy as np
import torch
from .vocab import BaseVocab, VOCAB_PREFIX, UNK_ID
from stanza.models.common.utils import open_read_binary, open_read_text
from stanza.resources.common import DEFAULT_MODEL_DIR
from pickle import UnpicklingError
import warnings
logger = logging.getLogger('stanza')
class PretrainedWordVocab(BaseVocab):
def build_vocab(self):
self._id2unit = VOCAB_PREFIX + self.data
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
def normalize_unit(self, unit):
unit = super().normalize_unit(unit)
if unit:
unit = unit.replace(" ","\xa0")
return unit
class Pretrain:
""" A loader and saver for pretrained embeddings. """
def __init__(self, filename=None, vec_filename=None, max_vocab=-1, save_to_file=True, csv_filename=None):
self.filename = filename
self._vec_filename = vec_filename
self._csv_filename = csv_filename
self._max_vocab = max_vocab
self._save_to_file = save_to_file
def __len__(self):
return len(self.vocab)
@property
def vocab(self):
if not hasattr(self, '_vocab'):
self.load()
return self._vocab
@property
def emb(self):
if not hasattr(self, '_emb'):
self.load()
return self._emb
def load(self):
if self.filename is not None and os.path.exists(self.filename):
try:
# TODO: after making the next release, remove the weights_only=False version
try:
data = torch.load(self.filename, lambda storage, loc: storage, weights_only=True)
except UnpicklingError:
data = torch.load(self.filename, lambda storage, loc: storage, weights_only=False)
warnings.warn("The saved pretrain has an old format using numpy.ndarray instead of torch to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained embedding using this version ASAP.")
logger.debug("Loaded pretrain from {}".format(self.filename))
if not isinstance(data, dict):
raise RuntimeError("File {} exists but is not a stanza pretrain file. It is not a dict, whereas a Stanza pretrain should have a dict with 'emb' and 'vocab'".format(self.filename))
if 'emb' not in data or 'vocab' not in data:
raise RuntimeError("File {} exists but is not a stanza pretrain file. A Stanza pretrain file should have 'emb' and 'vocab' fields in its state dict".format(self.filename))
self._vocab = PretrainedWordVocab.load_state_dict(data['vocab'])
self._emb = data['emb']
if isinstance(self._emb, np.ndarray):
self._emb = torch.from_numpy(self._emb)
return
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as e:
if not self._vec_filename and not self._csv_filename:
raise
logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e))
vocab, emb = self.read_pretrain()
else:
if not self._vec_filename and not self._csv_filename:
raise FileNotFoundError("Pretrained file {} does not exist, and no text/xz file was provided".format(self.filename))
if self.filename is not None:
logger.info("Pretrained filename %s specified, but file does not exist. Attempting to load from text file" % self.filename)
vocab, emb = self.read_pretrain()
self._vocab = vocab
self._emb = emb
if self._save_to_file:
# save to file
assert self.filename is not None, "Filename must be provided to save pretrained vector to file."
self.save(self.filename)
def save(self, filename):
directory, _ = os.path.split(filename)
if directory:
os.makedirs(directory, exist_ok=True)
# should not infinite loop since the load function sets _vocab and _emb before trying to save
data = {'vocab': self.vocab.state_dict(), 'emb': self.emb}
try:
torch.save(data, filename, _use_new_zipfile_serialization=False)
logger.info("Saved pretrained vocab and vectors to {}".format(filename))
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as e:
logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e))
def write_text(self, filename, header=False):
"""
Write the vocab & values to a text file
"""
with open(filename, "w") as fout:
if header:
word_dim = self.emb[0].shape[0]
fout.write("%d %d\n" % (len(self.vocab), word_dim))
for word_idx, word in enumerate(self.vocab):
row = self.emb[word_idx].to("cpu")
fout.write(word)
fout.write(" ")
fout.write(" ".join(["%.6f" % x.item() for x in row]))
fout.write("\n")
def read_pretrain(self):
# load from pretrained filename
if self._vec_filename is not None:
words, emb, failed = self.read_from_file(self._vec_filename, self._max_vocab)
elif self._csv_filename is not None:
words, emb = self.read_from_csv(self._csv_filename)
else:
raise RuntimeError("Vector file is not provided.")
if len(emb) - len(VOCAB_PREFIX) != len(words):
raise RuntimeError("Loaded number of vectors does not match number of words.")
# Use a fixed vocab size
if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words) + len(VOCAB_PREFIX):
words = words[:self._max_vocab - len(VOCAB_PREFIX)]
emb = emb[:self._max_vocab]
vocab = PretrainedWordVocab(words)
return vocab, emb
@staticmethod
def read_from_csv(filename):
"""
Read vectors from CSV
Skips the first row
"""
logger.info("Reading pretrained vectors from csv file %s ...", filename)
with open_read_text(filename) as fin:
csv_reader = csv.reader(fin)
# the header of the thai csv vector file we have is just the number of columns
# so we read past the first line
for line in csv_reader:
break
lines = [line for line in csv_reader]
rows = len(lines)
cols = len(lines[0]) - 1
emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
for i, line in enumerate(lines):
emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
words = [line[0].replace(' ', '\xa0') for line in lines]
return words, emb
@staticmethod
def read_from_file(filename, max_vocab=None):
"""
Open a vector file using the provided function and read from it.
"""
logger.info("Reading pretrained vectors from %s ...", filename)
# some vector files, such as Google News, use tabs
tab_space_pattern = re.compile(r"[ \t]+")
first = True
cols = None
lines = []
failed = 0
unk_line = None
with open_read_binary(filename) as f:
for i, line in enumerate(f):
try:
line = line.decode()
except UnicodeDecodeError:
failed += 1
continue
line = line.rstrip()
if not line:
continue
pieces = tab_space_pattern.split(line)
if first:
# the first line contains the number of word vectors and the dimensionality
# note that a 1d embedding with a number as the first entry
# will fail to read properly. we ignore that case
first = False
if len(pieces) == 2:
cols = int(pieces[1])
continue
if pieces[0] == '':
if unk_line is not None:
logger.error("More than one line in the pretrain! Keeping the most recent one")
else:
logger.debug("Found an unk line while reading the pretrain")
unk_line = pieces
else:
if not max_vocab or max_vocab < 0 or len(lines) < max_vocab:
lines.append(pieces)
if cols is None:
# another failure case: all words have spaces in them
cols = min(len(x) for x in lines) - 1
rows = len(lines)
emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
if unk_line is not None:
emb[UNK_ID] = torch.tensor([float(x) for x in unk_line[-cols:]], dtype=torch.float32)
for i, line in enumerate(lines):
emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
# if there were word pieces separated with spaces, rejoin them with nbsp instead
# this way, the normalize_unit method in vocab.py can find the word at test time
words = ['\xa0'.join(line[:-cols]) for line in lines]
if failed > 0:
logger.info("Failed to read %d lines from embedding", failed)
return words, emb, failed
def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):
"""
When training a model, look in a few different places for a .pt file
If a specific argument was passsed in, prefer that location
Otherwise, check in a few places:
saved_models/{model}/{shorthand}.pretrain.pt
saved_models/{model}/{shorthand}_pretrain.pt
~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt
"""
if wordvec_pretrain_file:
return wordvec_pretrain_file
default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))
if os.path.exists(default_pretrain_file):
logger.debug("Found existing .pt file in %s" % default_pretrain_file)
return default_pretrain_file
else:
logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file)
pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))
if os.path.exists(pretrain_file):
logger.debug("Found existing .pt file in %s" % pretrain_file)
return pretrain_file
else:
logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
if shorthand.find("_") >= 0:
# try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example
pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))
if os.path.exists(pretrain_file):
logger.debug("Found existing .pt file in %s" % pretrain_file)
return pretrain_file
else:
logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
# if we can't find it anywhere, just return the first location searched...
# maybe we'll get lucky and the original .txt file can be found
return default_pretrain_file
if __name__ == '__main__':
with open('test.txt', 'w') as fout:
fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n')
# 1st load: save to pt file
pretrain = Pretrain('test.pt', 'test.txt')
print(pretrain.emb)
# verify pt file
x = torch.load('test.pt', weights_only=True)
print(x)
# 2nd load: load saved pt file
pretrain = Pretrain('test.pt', 'test.txt')
print(pretrain.emb)
================================================
FILE: stanza/models/common/relative_attn.py
================================================
import logging
import torch
from torch import nn
import torch.nn.functional as F
logger = logging.getLogger('stanza')
class RelativeAttention(nn.Module):
def __init__(self, d_model, num_heads, window=8, dropout=0.2, reverse=False, d_output=None, fudge_output=False, num_sinks=0):
super().__init__()
if d_output is None:
d_output = d_model
d_head, remainder = divmod(d_output, num_heads)
if remainder:
if fudge_output:
d_head = d_head + 1
logger.debug("Relative attn: %d %% %d != 0, updating d_output to %d", d_output, num_heads, num_heads * d_head)
d_output = num_heads * d_head
else:
raise ValueError("incompatible `d_model` and `num_heads`")
self.window = window
self.num_sinks = num_sinks
self.d_model = d_model
self.d_head = d_head
self.num_heads = num_heads
self.d_output = d_output
self.key = nn.Linear(d_model, d_output)
# the bias for query all gets trained to 0 anyway
self.query = nn.Linear(d_model, d_output, bias=False)
self.value = nn.Linear(d_model, d_output, bias=False)
# initializing value with eye seems to hurt!
#nn.init.eye_(self.value.weight)
self.dropout = nn.Dropout(dropout)
self.position = nn.Parameter(torch.randn(1, 1, d_head, window + num_sinks, 1))
self.register_buffer(
"mask",
torch.tril(torch.ones(window, window), diagonal=-1).unsqueeze(0).unsqueeze(0).unsqueeze(0)
)
self.register_buffer(
"flipped_mask",
torch.flip(self.mask, (-1,))
)
self.reverse = reverse
def forward(self, x, sink=None):
# x.shape == (batch_size, seq_len, d_model)
batch_size, seq_len, d_model = x.shape
if d_model != self.d_model:
raise ValueError("Incompatible input")
if self.reverse:
x = torch.flip(x, (1,))
orig_seq_len = seq_len
if seq_len < self.window:
zeros = torch.zeros((x.shape[0], self.window - seq_len, x.shape[2]), dtype=x.dtype, device=x.device)
x = torch.cat((x, zeros), axis=1)
seq_len = self.window
if self.num_sinks > 0:
# could keep a parameter to train sinks, but as it turns out,
# the position vectors just overlap that parameter space anyway
# generally the model trains the sinks to zero if we do that
if sink is None:
sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)
else:
sink = sink.expand(batch_size, self.num_sinks, d_model)
x = torch.cat((sink, x), axis=1)
# k.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)
k = self.key(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)[:, :, :, self.num_sinks:]
# v.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)
v = self.value(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)
# q.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)
q = self.query(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)
# q.shape = (batch_size, num_heads, d_head, window + num_sinks, seq_len)
q = self.skew_repeat(q)
q = q + self.position
# qk.shape = (batch_size, num_heads, d_head, window + num_sinks, seq_len)
qk = torch.einsum('bndws,bnds->bndws', q, k)
# TODO: fix mask
# mask out the padding spaces at the end
# can only attend to spots that aren't padded
if orig_seq_len < seq_len:
# mask out the part of the sentence which is empty
shorter_mask = self.flipped_mask[:, :, :, :orig_seq_len, -orig_seq_len:]
qk = qk[:, :, :, :(orig_seq_len + self.num_sinks), :orig_seq_len]
qk[:, :, :, -orig_seq_len:, :] = qk[:, :, :, -orig_seq_len:, :].masked_fill(shorter_mask == 1, float("-inf"))
else:
qk[:, :, :, -self.window:, -self.window:] = qk[:, :, :, -self.window:, -self.window:].masked_fill(self.flipped_mask == 1, float("-inf"))
qk = F.softmax(qk, dim=3)
# v.shape = (batch_size, num_heads, d_head, window, seq_len)
v = self.skew_repeat(v)
if orig_seq_len < seq_len:
v = v[:, :, :, :(orig_seq_len + self.num_sinks), :orig_seq_len]
# result.shape = (batch_size, num_heads, d_head, orig_seq_len)
result = torch.einsum('bndws,bndws->bnds', qk, v)
# batch_size, orig_seq_len, d_output
result = result.reshape(batch_size, self.d_output, orig_seq_len).transpose(1, 2)
if self.reverse:
result = torch.flip(result, (1,))
return self.dropout(result)
def skew_repeat(self, q):
"""
q (currently, at least) is num_sinks + seq_len long
and the num_sinks are there to be chopped off the front
then the seq_len remainder is skewed
"""
if self.num_sinks > 0:
q_sink = q[:, :, :, :self.num_sinks]
q_sink = q_sink.unsqueeze(4)
q_sink = q_sink.repeat(1, 1, 1, 1, q.shape[-1] - self.num_sinks)
q = q[:, :, :, self.num_sinks:]
# make stripes that look like this
# (seq_len 5, window 3)
# 1 2 3 4 5
# 1 2 3 4 5
# 1 2 3 4 5
q = q.unsqueeze(4).repeat(1, 1, 1, 1, self.window).transpose(3, 4)
# now the stripes look like
# 1 2 3 4 5
# 0 2 3 4 5
# 0 0 3 4 5
q[:, :, :, :, :self.window] = q[:, :, :, :, :self.window].masked_fill(self.mask == 1, 0)
q_shape = list(q.shape)
q_new_shape = list(q.shape)[:-2] + [-1]
q = q.reshape(q_new_shape)
zeros = torch.zeros_like(q[:, :, :, :1])
zeros = zeros.repeat(1, 1, 1, self.window)
q = torch.cat((q, zeros), axis=-1)
q_new_shape = q_new_shape[:-1] + [self.window, -1]
# now the stripes look like
# 1 2 3 4 5
# 2 3 4 5 0
# 3 4 5 0 0
# q.shape = (batch_size, num_heads, d_head, window, seq_len)
q = q.reshape(q_new_shape)[:, :, :, :, :-1]
if self.num_sinks > 0:
q = torch.cat([q_sink, q], dim=3)
return q
================================================
FILE: stanza/models/common/seq2seq_constant.py
================================================
"""
Constants for seq2seq models.
"""
PAD = ''
PAD_ID = 0
UNK = ''
UNK_ID = 1
SOS = ''
SOS_ID = 2
EOS = ''
EOS_ID = 3
VOCAB_PREFIX = [PAD, UNK, SOS, EOS]
EMB_INIT_RANGE = 1.0
INFINITY_NUMBER = 1e12
================================================
FILE: stanza/models/common/seq2seq_model.py
================================================
"""
The full encoder-decoder model, built on top of the base seq2seq modules.
"""
import logging
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common import utils
from stanza.models.common.seq2seq_modules import LSTMAttention
from stanza.models.common.beam import Beam
from stanza.models.common.seq2seq_constant import UNK_ID
logger = logging.getLogger('stanza')
class Seq2SeqModel(nn.Module):
"""
A complete encoder-decoder model, with optional attention.
A parent class which makes use of the contextual_embedding (such as a charlm)
can make use of unsaved_modules when saving.
"""
def __init__(self, args, emb_matrix=None, contextual_embedding=None):
super().__init__()
self.unsaved_modules = []
self.vocab_size = args['vocab_size']
self.emb_dim = args['emb_dim']
self.hidden_dim = args['hidden_dim']
self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
self.emb_dropout = args.get('emb_dropout', 0.0)
self.dropout = args['dropout']
self.pad_token = constant.PAD_ID
self.max_dec_len = args['max_dec_len']
self.top = args.get('top', 1e10)
self.args = args
self.emb_matrix = emb_matrix
self.add_unsaved_module("contextual_embedding", contextual_embedding)
logger.debug("Building an attentional Seq2Seq model...")
logger.debug("Using a Bi-LSTM encoder")
self.num_directions = 2
self.enc_hidden_dim = self.hidden_dim // 2
self.dec_hidden_dim = self.hidden_dim
self.use_pos = args.get('pos', False)
self.pos_dim = args.get('pos_dim', 0)
self.pos_vocab_size = args.get('pos_vocab_size', 0)
self.pos_dropout = args.get('pos_dropout', 0)
self.edit = args.get('edit', False)
self.num_edit = args.get('num_edit', 0)
self.copy = args.get('copy', False)
self.emb_drop = nn.Dropout(self.emb_dropout)
self.drop = nn.Dropout(self.dropout)
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
self.input_dim = self.emb_dim
if self.contextual_embedding is not None:
self.input_dim += self.contextual_embedding.hidden_dim()
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
batch_first=True, attn_type=self.args['attn_type'])
self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
if self.use_pos and self.pos_dim > 0:
logger.debug("Using POS in encoder")
self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
self.pos_drop = nn.Dropout(self.pos_dropout)
if self.edit:
edit_hidden = self.hidden_dim//2
self.edit_clf = nn.Sequential(
nn.Linear(self.hidden_dim, edit_hidden),
nn.ReLU(),
nn.Linear(edit_hidden, self.num_edit))
if self.copy:
self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)
SOS_tensor = torch.LongTensor([constant.SOS_ID])
self.register_buffer('SOS_tensor', SOS_tensor)
self.init_weights()
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def init_weights(self):
# initialize embeddings
init_range = constant.EMB_INIT_RANGE
if self.emb_matrix is not None:
if isinstance(self.emb_matrix, np.ndarray):
self.emb_matrix = torch.from_numpy(self.emb_matrix)
assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
"Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
self.embedding.weight.data.copy_(self.emb_matrix)
else:
self.embedding.weight.data.uniform_(-init_range, init_range)
# decide finetuning
if self.top <= 0:
logger.debug("Do not finetune embedding layer.")
self.embedding.weight.requires_grad = False
elif self.top < self.vocab_size:
logger.debug("Finetune top {} embeddings.".format(self.top))
self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
else:
logger.debug("Finetune all embeddings.")
# initialize pos embeddings
if self.use_pos:
self.pos_embedding.weight.data.uniform_(-init_range, init_range)
def zero_state(self, inputs):
batch_size = inputs.size(0)
device = self.SOS_tensor.device
h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
return h0, c0
def encode(self, enc_inputs, lens):
""" Encode source sequence. """
h0, c0 = self.zero_state(enc_inputs)
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
hn = torch.cat((hn[-1], hn[-2]), 1)
cn = torch.cat((cn[-1], cn[-2]), 1)
return h_in, (hn, cn)
def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
""" Decode a step, based on context encoding and source context states."""
dec_hidden = (hn, cn)
decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
if self.copy:
h_out, dec_hidden, log_attn = decoder_output
else:
h_out, dec_hidden = decoder_output
h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
decoder_logits = self.dec2vocab(h_out_reshape)
decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
log_probs = self.get_log_prob(decoder_logits)
if self.copy:
copy_logit = self.copy_gate(h_out)
if self.use_pos:
# can't copy the UPOS
log_attn = log_attn[:, :, 1:]
# renormalize
log_attn = torch.log_softmax(log_attn, -1)
# calculate copy probability for each word in the vocab
log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
# scatter logsumexp
mx = log_copy_prob.max(-1, keepdim=True)[0]
log_copy_prob = log_copy_prob - mx
# here we make space in the log probs for vocab items
# which might be copied from the encoder side, but which
# were not known at training time
# note that such an item cannot possibly be predicted by
# the model as a raw output token
# however, the copy gate might score high on copying a
# previously unknown vocab item
copy_prob = torch.exp(log_copy_prob)
copied_vocab_shape = list(log_probs.size())
if torch.max(src) >= copied_vocab_shape[-1]:
copied_vocab_shape[-1] = torch.max(src) + 1
copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
# fill in the copy tensor with the copy probs of each character
# the rest of the copy tensor will be filled with -largenumber
copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
zero_mask = (copied_vocab_prob == 0)
log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)
# combine with normal vocab probability
log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
if log_probs.shape[-1] < copied_vocab_shape[-1]:
# for previously unknown vocab items which are in the encoder,
# we reuse the UNK_ID prediction
# this gives a baseline number which we can combine with
# the copy gate prediction
# technically this makes log_probs no longer represent
# a probability distribution when looking at unknown vocab
# this is probably not a serious problem
# an example of this usage is in the Lemmatizer, such as a
# plural word in English with the character "ã" in it instead of "a"
# if "ã" is not known in the training data, the lemmatizer would
# ordinarily be unable to output it, and thus the seq2seq model
# would have no chance to depluralize "ãntennae" -> "ãntenna"
# however, if we temporarily add "ã" to the encoder vocab,
# then let the copy gate accept that letter, we find the Lemmatizer
# seq2seq model will want to copy that particular vocab item
# this allows the Lemmatizer to produce "ã" instead of requiring
# that it produces UNK, then going back to the input text to
# figure out which UNK it intended to produce
new_log_probs = log_probs.new_zeros(copied_vocab_shape)
new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
log_probs = new_log_probs
log_probs = log_probs + log_nocopy_prob
log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)
if never_decode_unk:
log_probs[:, :, UNK_ID] = float("-inf")
return log_probs, dec_hidden
def embed(self, src, src_mask, pos, raw):
embed_src = src.clone()
embed_src[embed_src >= self.vocab_size] = UNK_ID
enc_inputs = self.emb_drop(self.embedding(embed_src))
batch_size = enc_inputs.size(0)
if self.use_pos:
assert pos is not None, "Missing POS input for seq2seq lemmatizer."
pos_inputs = self.pos_drop(self.pos_embedding(pos))
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
pos_src_mask = src_mask.new_zeros([batch_size, 1])
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
if raw is not None and self.contextual_embedding is not None:
raw_inputs = self.contextual_embedding(raw)
if self.use_pos:
raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
return enc_inputs, batch_size, src_lens, src_mask
def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
# prepare for encoder/decoder
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
dec_inputs = self.emb_drop(self.embedding(tgt_in))
log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
return log_probs, edit_logits
def get_log_prob(self, logits):
logits_reshape = logits.view(-1, self.vocab_size)
log_probs = F.log_softmax(logits_reshape, dim=1)
if logits.dim() == 2:
return log_probs
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))
def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
""" Predict with greedy decoding. """
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# greedy decode by step
dec_inputs = self.embedding(self.SOS_tensor)
dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))
done = [False for _ in range(batch_size)]
total_done = 0
max_len = 0
output_seqs = [[] for _ in range(batch_size)]
while total_done < batch_size and max_len < self.max_dec_len:
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
assert log_probs.size(1) == 1, "Output must have 1-step of output."
_, preds = log_probs.squeeze(1).max(1, keepdim=True)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs = preds.clone()
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs) # update decoder inputs
max_len += 1
for i in range(batch_size):
if not done[i]:
token = preds.data[i][0].item()
if token == constant.EOS_ID:
done[i] = True
total_done += 1
else:
output_seqs[i].append(token)
return output_seqs, edit_logits
def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
""" Predict with beam search. """
if beam_size == 1:
return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# (1) encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# (2) set up beam
with torch.no_grad():
h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
src_mask = src_mask.repeat(beam_size, 1)
# repeat decoder hidden states
hn = hn.data.repeat(beam_size, 1)
cn = cn.data.repeat(beam_size, 1)
device = self.SOS_tensor.device
beam = [Beam(beam_size, device) for _ in range(batch_size)]
def update_state(states, idx, positions, beam_size):
""" Select the states according to back pointers. """
for e in states:
br, d = e.size()
s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
s.data.copy_(s.data.index_select(0, positions))
# (3) main loop
for i in range(self.max_dec_len):
dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs)
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]
# advance each beam
done = []
for b in range(batch_size):
is_done = beam[b].advance(log_probs.data[b])
if is_done:
done += [b]
# update beam state
update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)
if len(done) == batch_size:
break
# back trace and find hypothesis
all_hyp, all_scores = [], []
for b in range(batch_size):
scores, ks = beam[b].sort_best()
all_scores += [scores[0]]
k = ks[0]
hyp = beam[b].get_hyp(k)
hyp = utils.prune_hyp(hyp)
hyp = [i.item() for i in hyp]
all_hyp += [hyp]
return all_hyp, edit_logits
================================================
FILE: stanza/models/common/seq2seq_modules.py
================================================
"""
Pytorch implementation of basic sequence to Sequence modules.
"""
import logging
import torch
import torch.nn as nn
import math
import numpy as np
import stanza.models.common.seq2seq_constant as constant
logger = logging.getLogger('stanza')
class BasicAttention(nn.Module):
"""
A basic MLP attention layer.
"""
def __init__(self, dim):
super(BasicAttention, self).__init__()
self.linear_in = nn.Linear(dim, dim, bias=False)
self.linear_c = nn.Linear(dim, dim)
self.linear_v = nn.Linear(dim, 1, bias=False)
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
self.tanh = nn.Tanh()
self.sm = nn.Softmax(dim=1)
def forward(self, input, context, mask=None, attn_only=False):
"""
input: batch x dim
context: batch x sourceL x dim
"""
batch_size = context.size(0)
source_len = context.size(1)
dim = context.size(2)
target = self.linear_in(input) # batch x dim
source = self.linear_c(context.contiguous().view(-1, dim)).view(batch_size, source_len, dim)
attn = target.unsqueeze(1).expand_as(context) + source
attn = self.tanh(attn) # batch x sourceL x dim
attn = self.linear_v(attn.view(-1, dim)).view(batch_size, source_len)
if mask is not None:
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
attn = self.sm(attn)
if attn_only:
return attn
weighted_context = torch.bmm(attn.unsqueeze(1), context).squeeze(1)
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
class SoftDotAttention(nn.Module):
"""Soft Dot Attention.
Ref: http://www.aclweb.org/anthology/D15-1166
Adapted from PyTorch OPEN NMT.
"""
def __init__(self, dim):
"""Initialize layer."""
super(SoftDotAttention, self).__init__()
self.linear_in = nn.Linear(dim, dim, bias=False)
self.sm = nn.Softmax(dim=1)
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
self.tanh = nn.Tanh()
self.mask = None
def forward(self, input, context, mask=None, attn_only=False, return_logattn=False):
"""Propagate input through the network.
input: batch x dim
context: batch x sourceL x dim
"""
target = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, target).squeeze(2) # batch x sourceL
if mask is not None:
# sett the padding attention logits to -inf
assert mask.size() == attn.size(), "Mask size must match the attention size!"
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
if return_logattn:
attn = torch.log_softmax(attn, 1)
attn_w = torch.exp(attn)
else:
attn = self.sm(attn)
attn_w = attn
if attn_only:
return attn
attn3 = attn_w.view(attn_w.size(0), 1, attn_w.size(1)) # batch x 1 x sourceL
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
class LinearAttention(nn.Module):
""" A linear attention form, inspired by BiDAF:
a = W (u; v; u o v)
"""
def __init__(self, dim):
super(LinearAttention, self).__init__()
self.linear = nn.Linear(dim*3, 1, bias=False)
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
self.sm = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
self.mask = None
def forward(self, input, context, mask=None, attn_only=False):
"""
input: batch x dim
context: batch x sourceL x dim
"""
batch_size = context.size(0)
source_len = context.size(1)
dim = context.size(2)
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
v = context.contiguous().view(-1, dim)
attn_in = torch.cat((u, v, u.mul(v)), 1)
attn = self.linear(attn_in).view(batch_size, source_len)
if mask is not None:
# sett the padding attention logits to -inf
assert mask.size() == attn.size(), "Mask size must match the attention size!"
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
attn = self.sm(attn)
if attn_only:
return attn
attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
class DeepAttention(nn.Module):
""" A deep attention form, invented by Robert:
u = ReLU(Wx)
v = ReLU(Wy)
a = V.(u o v)
"""
def __init__(self, dim):
super(DeepAttention, self).__init__()
self.linear_in = nn.Linear(dim, dim, bias=False)
self.linear_v = nn.Linear(dim, 1, bias=False)
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
self.relu = nn.ReLU()
self.sm = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
self.mask = None
def forward(self, input, context, mask=None, attn_only=False):
"""
input: batch x dim
context: batch x sourceL x dim
"""
batch_size = context.size(0)
source_len = context.size(1)
dim = context.size(2)
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
u = self.relu(self.linear_in(u))
v = self.relu(self.linear_in(context.contiguous().view(-1, dim)))
attn = self.linear_v(u.mul(v)).view(batch_size, source_len)
if mask is not None:
# sett the padding attention logits to -inf
assert mask.size() == attn.size(), "Mask size must match the attention size!"
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
attn = self.sm(attn)
if attn_only:
return attn
attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
class LSTMAttention(nn.Module):
r"""A long short-term memory (LSTM) cell with attention."""
def __init__(self, input_size, hidden_size, batch_first=True, attn_type='soft'):
"""Initialize params."""
super(LSTMAttention, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.batch_first = batch_first
self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
if attn_type == 'soft':
self.attention_layer = SoftDotAttention(hidden_size)
elif attn_type == 'mlp':
self.attention_layer = BasicAttention(hidden_size)
elif attn_type == 'linear':
self.attention_layer = LinearAttention(hidden_size)
elif attn_type == 'deep':
self.attention_layer = DeepAttention(hidden_size)
else:
raise Exception("Unsupported LSTM attention type: {}".format(attn_type))
logger.debug("Using {} attention for LSTM.".format(attn_type))
def forward(self, input, hidden, ctx, ctx_mask=None, return_logattn=False):
"""Propagate input through the network."""
if self.batch_first:
input = input.transpose(0,1)
output = []
attn = []
steps = range(input.size(0))
for i in steps:
hidden = self.lstm_cell(input[i], hidden)
hy, cy = hidden
h_tilde, alpha = self.attention_layer(hy, ctx, mask=ctx_mask, return_logattn=return_logattn)
output.append(h_tilde)
attn.append(alpha)
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
if self.batch_first:
output = output.transpose(0,1)
if return_logattn:
attn = torch.stack(attn, 0)
if self.batch_first:
attn = attn.transpose(0, 1)
return output, hidden, attn
return output, hidden
================================================
FILE: stanza/models/common/seq2seq_utils.py
================================================
"""
Utils for seq2seq models.
"""
from collections import Counter
import random
import json
import torch
import stanza.models.common.seq2seq_constant as constant
# torch utils
def get_optimizer(name, parameters, lr):
if name == 'sgd':
return torch.optim.SGD(parameters, lr=lr)
elif name == 'adagrad':
return torch.optim.Adagrad(parameters, lr=lr)
elif name == 'adam':
return torch.optim.Adam(parameters) # use default lr
elif name == 'adamax':
return torch.optim.Adamax(parameters) # use default lr
else:
raise Exception("Unsupported optimizer: {}".format(name))
def change_lr(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
def flatten_indices(seq_lens, width):
flat = []
for i, l in enumerate(seq_lens):
for j in range(l):
flat.append(i * width + j)
return flat
def keep_partial_grad(grad, topk):
"""
Keep only the topk rows of grads.
"""
assert topk < grad.size(0)
grad.data[topk:].zero_()
return grad
# other utils
def save_config(config, path, verbose=True):
with open(path, 'w') as outfile:
json.dump(config, outfile, indent=2)
if verbose:
print("Config saved to file {}".format(path))
return config
def load_config(path, verbose=True):
with open(path) as f:
config = json.load(f)
if verbose:
print("Config loaded from file {}".format(path))
return config
def unmap_with_copy(indices, src_tokens, vocab):
"""
Unmap a list of list of indices, by optionally copying from src_tokens.
"""
result = []
for ind, tokens in zip(indices, src_tokens):
words = []
for idx in ind:
if idx >= 0:
words.append(vocab.id2word[idx])
else:
idx = -idx - 1 # flip and minus 1
words.append(tokens[idx])
result += [words]
return result
def prune_decoded_seqs(seqs):
"""
Prune decoded sequences after EOS token.
"""
out = []
for s in seqs:
if constant.EOS in s:
idx = s.index(constant.EOS_TOKEN)
out += [s[:idx]]
else:
out += [s]
return out
def prune_hyp(hyp):
"""
Prune a decoded hypothesis
"""
if constant.EOS_ID in hyp:
idx = hyp.index(constant.EOS_ID)
return hyp[:idx]
else:
return hyp
def prune(data_list, lens):
assert len(data_list) == len(lens)
nl = []
for d, l in zip(data_list, lens):
nl.append(d[:l])
return nl
def sort(packed, ref, reverse=True):
"""
Sort a series of packed list, according to a ref list.
Also return the original index before the sort.
"""
assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
packed = [ref] + [range(len(ref))] + list(packed)
sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
return tuple(sorted_packed[1:])
def unsort(sorted_list, oidx):
"""
Unsort a sorted list, based on the original idx.
"""
assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
return unsorted
================================================
FILE: stanza/models/common/short_name_to_treebank.py
================================================
# This module is autogenerated by build_short_name_to_treebank.py
# Please do not edit
SHORT_NAMES = {
'abq_atb': 'UD_Abaza-ATB',
'ab_abnc': 'UD_Abkhaz-AbNC',
'af_afribooms': 'UD_Afrikaans-AfriBooms',
'akk_pisandub': 'UD_Akkadian-PISANDUB',
'akk_riao': 'UD_Akkadian-RIAO',
'aqz_tudet': 'UD_Akuntsu-TuDeT',
'sq_staf': 'UD_Albanian-STAF',
'sq_tsa': 'UD_Albanian-TSA',
'gsw_divital': 'UD_Alemannic-DIVITAL',
'gsw_uzh': 'UD_Alemannic-UZH',
'am_att': 'UD_Amharic-ATT',
'grc_proiel': 'UD_Ancient_Greek-PROIEL',
'grc_ptnk': 'UD_Ancient_Greek-PTNK',
'grc_perseus': 'UD_Ancient_Greek-Perseus',
'hbo_ptnk': 'UD_Ancient_Hebrew-PTNK',
'apu_ufpa': 'UD_Apurina-UFPA',
'ar_nyuad': 'UD_Arabic-NYUAD',
'ar_padt': 'UD_Arabic-PADT',
'ar_pud': 'UD_Arabic-PUD',
'hy_armtdp': 'UD_Armenian-ArmTDP',
'hy_bsut': 'UD_Armenian-BSUT',
'aii_as': 'UD_Assyrian-AS',
'az_tuecl': 'UD_Azerbaijani-TueCL',
'bm_crb': 'UD_Bambara-CRB',
'eu_bdt': 'UD_Basque-BDT',
'bar_maibaam': 'UD_Bavarian-MaiBaam',
'bej_autogramm': 'UD_Beja-Autogramm',
'be_hse': 'UD_Belarusian-HSE',
'bn_bru': 'UD_Bengali-BRU',
'bho_bhtb': 'UD_Bhojpuri-BHTB',
'sab_chibergis': 'UD_Bokota-ChibErgIS',
'bor_bdt': 'UD_Bororo-BDT',
'br_keb': 'UD_Breton-KEB',
'bg_btb': 'UD_Bulgarian-BTB',
'bxr_bdt': 'UD_Buryat-BDT',
'yue_hk': 'UD_Cantonese-HK',
'cpg_amgic': 'UD_Cappadocian-AMGiC',
'cpg_tuecl': 'UD_Cappadocian-TueCL',
'ca_ancora': 'UD_Catalan-AnCora',
'ceb_gja': 'UD_Cebuano-GJA',
'ckb_mukri': 'UD_Central_Kurdish-Mukri',
'zh-hans_beginner': 'UD_Chinese-Beginner',
'zh_beginner': 'UD_Chinese-Beginner',
'zh-hans_cfl': 'UD_Chinese-CFL',
'zh_cfl': 'UD_Chinese-CFL',
'zh-hant_gsd': 'UD_Chinese-GSD',
'zh_gsd': 'UD_Chinese-GSD',
'zh-hans_gsdsimp': 'UD_Chinese-GSDSimp',
'zh_gsdsimp': 'UD_Chinese-GSDSimp',
'zh-hant_hk': 'UD_Chinese-HK',
'zh_hk': 'UD_Chinese-HK',
'zh-hant_pud': 'UD_Chinese-PUD',
'zh_pud': 'UD_Chinese-PUD',
'zh-hans_patentchar': 'UD_Chinese-PatentChar',
'zh_patentchar': 'UD_Chinese-PatentChar',
'ctn_ctntb': 'UD_Chintang-CTNTB',
'ckt_hse': 'UD_Chukchi-HSE',
'xcl_caval': 'UD_Classical_Armenian-CAVaL',
'lzh_kyoto': 'UD_Classical_Chinese-Kyoto',
'lzh_tuecl': 'UD_Classical_Chinese-TueCL',
'cop_bohairic': 'UD_Coptic-Bohairic',
'cop_scriptorium': 'UD_Coptic-Scriptorium',
'hr_set': 'UD_Croatian-SET',
'cs_cac': 'UD_Czech-CAC',
'cs_cltt': 'UD_Czech-CLTT',
'cs_fictree': 'UD_Czech-FicTree',
'cs_pdtc': 'UD_Czech-PDTC',
'cs_pud': 'UD_Czech-PUD',
'cs_poetry': 'UD_Czech-Poetry',
'da_ddt': 'UD_Danish-DDT',
'nl_alpino': 'UD_Dutch-Alpino',
'nl_lassysmall': 'UD_Dutch-LassySmall',
'egy_ujaen': 'UD_Egyptian-UJaen',
'en_atis': 'UD_English-Atis',
'en_childes': 'UD_English-CHILDES',
'en_ctetex': 'UD_English-CTeTex',
'en_eslspok': 'UD_English-ESLSpok',
'en_ewt': 'UD_English-EWT',
'en_gentle': 'UD_English-GENTLE',
'en_gum': 'UD_English-GUM',
'en_gumreddit': 'UD_English-GUMReddit',
'en_lines': 'UD_English-LinES',
'en_littleprince': 'UD_English-LittlePrince',
'en_pud': 'UD_English-PUD',
'en_partut': 'UD_English-ParTUT',
'en_pronouns': 'UD_English-Pronouns',
'myv_jr': 'UD_Erzya-JR',
'eo_cairo': 'UD_Esperanto-Cairo',
'eo_prago': 'UD_Esperanto-Prago',
'et_edt': 'UD_Estonian-EDT',
'et_ewt': 'UD_Estonian-EWT',
'fo_farpahc': 'UD_Faroese-FarPaHC',
'fo_oft': 'UD_Faroese-OFT',
'fi_ftb': 'UD_Finnish-FTB',
'fi_ood': 'UD_Finnish-OOD',
'fi_pud': 'UD_Finnish-PUD',
'fi_tdt': 'UD_Finnish-TDT',
'fr_alts': 'UD_French-ALTS',
'fr_fqb': 'UD_French-FQB',
'fr_gsd': 'UD_French-GSD',
'fr_pud': 'UD_French-PUD',
'fr_partut': 'UD_French-ParTUT',
'fr_parisstories': 'UD_French-ParisStories',
'fr_poitevindivital': 'UD_French-PoitevinDIVITAL',
'fr_rhapsodie': 'UD_French-Rhapsodie',
'fr_sequoia': 'UD_French-Sequoia',
'qfn_fame': 'UD_Frisian_Dutch-Fame',
'gl_ctg': 'UD_Galician-CTG',
'gl_pud': 'UD_Galician-PUD',
'gl_treegal': 'UD_Galician-TreeGal',
'ka_glc': 'UD_Georgian-GLC',
'ka_gnc': 'UD_Georgian-GNC',
'de_gsd': 'UD_German-GSD',
'de_hdt': 'UD_German-HDT',
'de_lit': 'UD_German-LIT',
'de_pud': 'UD_German-PUD',
'aln_gps': 'UD_Gheg-GPS',
'got_proiel': 'UD_Gothic-PROIEL',
'el_cretan': 'UD_Greek-Cretan',
'el_gdt': 'UD_Greek-GDT',
'el_gud': 'UD_Greek-GUD',
'el_lesbian': 'UD_Greek-Lesbian',
'el_messinian': 'UD_Greek-Messinian',
'gub_tudet': 'UD_Guajajara-TuDeT',
'gn_oldtudet': 'UD_Guarani-OldTuDeT',
'gu_gujtb': 'UD_Gujarati-GujTB',
'gwi_tuecl': 'UD_Gwichin-TueCL',
'ht_adolphe': 'UD_Haitian_Creole-Adolphe',
'ht_autogramm': 'UD_Haitian_Creole-Autogramm',
'ha_northernautogramm': 'UD_Hausa-NorthernAutogramm',
'ha_southernautogramm': 'UD_Hausa-SouthernAutogramm',
'ha_westernautogramm': 'UD_Hausa-WesternAutogramm',
'he_htb': 'UD_Hebrew-HTB',
'he_iahltknesset': 'UD_Hebrew-IAHLTknesset',
'he_iahltwiki': 'UD_Hebrew-IAHLTwiki',
'azz_itml': 'UD_Highland_Puebla_Nahuatl-ITML',
'hi_hdtb': 'UD_Hindi-HDTB',
'hi_pud': 'UD_Hindi-PUD',
'hit_hittb': 'UD_Hittite-HitTB',
'hu_szeged': 'UD_Hungarian-Szeged',
'is_gc': 'UD_Icelandic-GC',
'is_icepahc': 'UD_Icelandic-IcePaHC',
'is_modern': 'UD_Icelandic-Modern',
'is_pud': 'UD_Icelandic-PUD',
'arh_chibergis': 'UD_Ika-ChibErgIS',
'id_csui': 'UD_Indonesian-CSUI',
'id_gsd': 'UD_Indonesian-GSD',
'id_pud': 'UD_Indonesian-PUD',
'ga_cadhan': 'UD_Irish-Cadhan',
'ga_idt': 'UD_Irish-IDT',
'ga_twittirish': 'UD_Irish-TwittIrish',
'it_isdt': 'UD_Italian-ISDT',
'it_kiparlaforest': 'UD_Italian-KIParlaForest',
'it_markit': 'UD_Italian-MarkIT',
'it_old': 'UD_Italian-Old',
'it_pud': 'UD_Italian-PUD',
'it_partut': 'UD_Italian-ParTUT',
'it_parlamint': 'UD_Italian-ParlaMint',
'it_postwita': 'UD_Italian-PoSTWITA',
'it_twittiro': 'UD_Italian-TWITTIRO',
'it_vit': 'UD_Italian-VIT',
'it_valico': 'UD_Italian-Valico',
'ja_bccwj': 'UD_Japanese-BCCWJ',
'ja_bccwjluw': 'UD_Japanese-BCCWJLUW',
'ja_gsd': 'UD_Japanese-GSD',
'ja_gsdluw': 'UD_Japanese-GSDLUW',
'ja_pud': 'UD_Japanese-PUD',
'ja_pudluw': 'UD_Japanese-PUDLUW',
'jv_csui': 'UD_Javanese-CSUI',
'urb_tudet': 'UD_Kaapor-TuDeT',
'xnr_kdtb': 'UD_Kangri-KDTB',
'krl_kkpp': 'UD_Karelian-KKPP',
'arr_tudet': 'UD_Karo-TuDeT',
'kk_ktb': 'UD_Kazakh-KTB',
'naq_kdt': 'UD_Khoekhoe-KDT',
'kfm_aha': 'UD_Khunsari-AHA',
'quc_iu': 'UD_Kiche-IU',
'koi_uh': 'UD_Komi_Permyak-UH',
'kpv_ikdp': 'UD_Komi_Zyrian-IKDP',
'kpv_lattice': 'UD_Komi_Zyrian-Lattice',
'ko_gsd': 'UD_Korean-GSD',
'ko_ksl': 'UD_Korean-KSL',
'ko_kaist': 'UD_Korean-Kaist',
'ko_littleprince': 'UD_Korean-LittlePrince',
'ko_pud': 'UD_Korean-PUD',
'ky_ktmu': 'UD_Kyrgyz-KTMU',
'ky_tuecl': 'UD_Kyrgyz-TueCL',
'ltg_cairo': 'UD_Latgalian-Cairo',
'la_circse': 'UD_Latin-CIRCSE',
'la_ittb': 'UD_Latin-ITTB',
'la_llct': 'UD_Latin-LLCT',
'la_proiel': 'UD_Latin-PROIEL',
'la_perseus': 'UD_Latin-Perseus',
'la_udante': 'UD_Latin-UDante',
'lv_cairo': 'UD_Latvian-Cairo',
'lv_lvtb': 'UD_Latvian-LVTB',
'lij_glt': 'UD_Ligurian-GLT',
'lt_alksnis': 'UD_Lithuanian-ALKSNIS',
'lt_hse': 'UD_Lithuanian-HSE',
'olo_kkpp': 'UD_Livvi-KKPP',
'nds_lsdc': 'UD_Low_Saxon-LSDC',
'lb_luxbank': 'UD_Luxembourgish-LuxBank',
'mk_mtb': 'UD_Macedonian-MTB',
'jaa_jarawara': 'UD_Madi-Jarawara',
'qaf_arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
'mpu_tudet': 'UD_Makurap-TuDeT',
'ml_ufal': 'UD_Malayalam-UFAL',
'mt_mudt': 'UD_Maltese-MUDT',
'gv_cadhan': 'UD_Manx-Cadhan',
'mr_ufal': 'UD_Marathi-UFAL',
'gun_dooley': 'UD_Mbya_Guarani-Dooley',
'gun_thomas': 'UD_Mbya_Guarani-Thomas',
'frm_altm': 'UD_Middle_French-ALTM',
'frm_profiterole': 'UD_Middle_French-PROFITEROLE',
'mdf_jr': 'UD_Moksha-JR',
'myu_tudet': 'UD_Munduruku-TuDeT',
'nmf_suansu': 'UD_Naga-Suansu',
'pcm_nsc': 'UD_Naija-NSC',
'nyq_aha': 'UD_Nayini-AHA',
'nap_rb': 'UD_Neapolitan-RB',
'nrk_tundra': 'UD_Nenets-Tundra',
'yrl_complin': 'UD_Nheengatu-CompLin',
'sme_giella': 'UD_North_Sami-Giella',
'kmr_kurmanji': 'UD_Northern_Kurdish-Kurmanji',
'gya_autogramm': 'UD_Northwest_Gbaya-Autogramm',
'nb_bokmaal': 'UD_Norwegian-Bokmaal',
'no_bokmaal': 'UD_Norwegian-Bokmaal',
'nn_nynorsk': 'UD_Norwegian-Nynorsk',
'oc_ttb': 'UD_Occitan-TTB',
'or_odtb': 'UD_Odia-ODTB',
'cu_proiel': 'UD_Old_Church_Slavonic-PROIEL',
'orv_birchbark': 'UD_Old_East_Slavic-Birchbark',
'orv_rnc': 'UD_Old_East_Slavic-RNC',
'orv_ruthenian': 'UD_Old_East_Slavic-Ruthenian',
'orv_torot': 'UD_Old_East_Slavic-TOROT',
'ang_cairo': 'UD_Old_English-Cairo',
'fro_altm': 'UD_Old_French-ALTM',
'fro_profiterole': 'UD_Old_French-PROFITEROLE',
'sga_dipsgg': 'UD_Old_Irish-DipSGG',
'sga_dipwbg': 'UD_Old_Irish-DipWBG',
'pro_corag': 'UD_Old_Occitan-CorAG',
'otk_clausal': 'UD_Old_Turkish-Clausal',
'ota_boun': 'UD_Ottoman_Turkish-BOUN',
'ota_dudu': 'UD_Ottoman_Turkish-DUDU',
'ps_sikaram': 'UD_Pashto-Sikaram',
'pad_tuecl': 'UD_Paumari-TueCL',
'fa_perdt': 'UD_Persian-PerDT',
'fa_seraji': 'UD_Persian-Seraji',
'pay_chibergis': 'UD_Pesh-ChibErgIS',
'xpg_kul': 'UD_Phrygian-KUL',
'pl_lfg': 'UD_Polish-LFG',
'pl_mpdt': 'UD_Polish-MPDT',
'pl_pdb': 'UD_Polish-PDB',
'pl_pud': 'UD_Polish-PUD',
'qpm_philotis': 'UD_Pomak-Philotis',
'pt_bosque': 'UD_Portuguese-Bosque',
'pt_cintil': 'UD_Portuguese-CINTIL',
'pt_dantestocks': 'UD_Portuguese-DANTEStocks',
'pt_gsd': 'UD_Portuguese-GSD',
'pt_pud': 'UD_Portuguese-PUD',
'pt_petrogold': 'UD_Portuguese-PetroGold',
'pt_porttinari': 'UD_Portuguese-Porttinari',
'ro_art': 'UD_Romanian-ArT',
'ro_moldoro': 'UD_Romanian-MolDoRo',
'ro_nonstandard': 'UD_Romanian-Nonstandard',
'ro_rrt': 'UD_Romanian-RRT',
'ro_simonero': 'UD_Romanian-SiMoNERo',
'ro_tuecl': 'UD_Romanian-TueCL',
'ru_gsd': 'UD_Russian-GSD',
'ru_pud': 'UD_Russian-PUD',
'ru_poetry': 'UD_Russian-Poetry',
'ru_syntagrus': 'UD_Russian-SynTagRus',
'ru_taiga': 'UD_Russian-Taiga',
'sa_ufal': 'UD_Sanskrit-UFAL',
'sa_vedic': 'UD_Sanskrit-Vedic',
'gd_arcosg': 'UD_Scottish_Gaelic-ARCOSG',
'sr_set': 'UD_Serbian-SET',
'wuu_shud': 'UD_Shanghainese-ShUD',
'scn_stb': 'UD_Sicilian-STB',
'sd_isra': 'UD_Sindhi-Isra',
'si_stb': 'UD_Sinhala-STB',
'sms_giellagas': 'UD_Skolt_Sami-Giellagas',
'sk_snk': 'UD_Slovak-SNK',
'sl_ssj': 'UD_Slovenian-SSJ',
'sl_sst': 'UD_Slovenian-SST',
'soj_aha': 'UD_Soi-AHA',
'ajp_madar': 'UD_South_Levantine_Arabic-MADAR',
'sdh_garrusi': 'UD_Southern_Kurdish-Garrusi',
'es_ancora': 'UD_Spanish-AnCora',
'es_coser': 'UD_Spanish-COSER',
'es_gsd': 'UD_Spanish-GSD',
'es_pud': 'UD_Spanish-PUD',
'ssp_lse': 'UD_Spanish_Sign_Language-LSE',
'sv_lines': 'UD_Swedish-LinES',
'sv_old': 'UD_Swedish-Old',
'sv_pud': 'UD_Swedish-PUD',
'sv_swell': 'UD_Swedish-SweLL',
'sv_talbanken': 'UD_Swedish-Talbanken',
'swl_sslc': 'UD_Swedish_Sign_Language-SSLC',
'tl_trg': 'UD_Tagalog-TRG',
'tl_ugnayan': 'UD_Tagalog-Ugnayan',
'ta_mwtt': 'UD_Tamil-MWTT',
'ta_ttb': 'UD_Tamil-TTB',
'tt_nmctt': 'UD_Tatar-NMCTT',
'eme_tudet': 'UD_Teko-TuDeT',
'te_mtg': 'UD_Telugu-MTG',
'qte_tect': 'UD_Telugu_English-TECT',
'th_pud': 'UD_Thai-PUD',
'th_tud': 'UD_Thai-TUD',
'tn_popapolelo': 'UD_Tswana-Popapolelo',
'tpn_tudet': 'UD_Tupinamba-TuDeT',
'tr_atis': 'UD_Turkish-Atis',
'tr_boun': 'UD_Turkish-BOUN',
'tr_framenet': 'UD_Turkish-FrameNet',
'tr_gb': 'UD_Turkish-GB',
'tr_imst': 'UD_Turkish-IMST',
'tr_kenet': 'UD_Turkish-Kenet',
'tr_pud': 'UD_Turkish-PUD',
'tr_penn': 'UD_Turkish-Penn',
'tr_tourism': 'UD_Turkish-Tourism',
'tr_tuecl': 'UD_Turkish-TueCL',
'qti_butr': 'UD_Turkish_English-BUTR',
'qtd_sagt': 'UD_Turkish_German-SAGT',
'uk_iu': 'UD_Ukrainian-IU',
'uk_parlamint': 'UD_Ukrainian-ParlaMint',
'xum_ikuvina': 'UD_Umbrian-IKUVINA',
'hsb_ufal': 'UD_Upper_Sorbian-UFAL',
'ur_udtb': 'UD_Urdu-UDTB',
'ug_udt': 'UD_Uyghur-UDT',
'uz_tuecl': 'UD_Uzbek-TueCL',
'uz_ut': 'UD_Uzbek-UT',
'uz_uzudt': 'UD_Uzbek-UzUDT',
'vep_vwt': 'UD_Veps-VWT',
'vi_tuecl': 'UD_Vietnamese-TueCL',
'vi_vtb': 'UD_Vietnamese-VTB',
'wbp_ufal': 'UD_Warlpiri-UFAL',
'cy_ccg': 'UD_Welsh-CCG',
'hyw_armtdp': 'UD_Western_Armenian-ArmTDP',
'nhi_itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
'wo_wtb': 'UD_Wolof-WTB',
'xav_xdt': 'UD_Xavante-XDT',
'sjo_xdt': 'UD_Xibe-XDT',
'sah_yktdt': 'UD_Yakut-YKTDT',
'yi_yitb': 'UD_Yiddish-YiTB',
'yo_ytb': 'UD_Yoruba-YTB',
'ess_sli': 'UD_Yupik-SLI',
'say_autogramm': 'UD_Zaar-Autogramm',
}
def short_name_to_treebank(short_name):
return SHORT_NAMES[short_name]
CANONICAL_NAMES = {
'ud_abaza-atb': 'UD_Abaza-ATB',
'ud_abkhaz-abnc': 'UD_Abkhaz-AbNC',
'ud_afrikaans-afribooms': 'UD_Afrikaans-AfriBooms',
'ud_akkadian-pisandub': 'UD_Akkadian-PISANDUB',
'ud_akkadian-riao': 'UD_Akkadian-RIAO',
'ud_akuntsu-tudet': 'UD_Akuntsu-TuDeT',
'ud_albanian-staf': 'UD_Albanian-STAF',
'ud_albanian-tsa': 'UD_Albanian-TSA',
'ud_alemannic-divital': 'UD_Alemannic-DIVITAL',
'ud_alemannic-uzh': 'UD_Alemannic-UZH',
'ud_amharic-att': 'UD_Amharic-ATT',
'ud_ancient_greek-proiel': 'UD_Ancient_Greek-PROIEL',
'ud_ancient_greek-ptnk': 'UD_Ancient_Greek-PTNK',
'ud_ancient_greek-perseus': 'UD_Ancient_Greek-Perseus',
'ud_ancient_hebrew-ptnk': 'UD_Ancient_Hebrew-PTNK',
'ud_apurina-ufpa': 'UD_Apurina-UFPA',
'ud_arabic-nyuad': 'UD_Arabic-NYUAD',
'ud_arabic-padt': 'UD_Arabic-PADT',
'ud_arabic-pud': 'UD_Arabic-PUD',
'ud_armenian-armtdp': 'UD_Armenian-ArmTDP',
'ud_armenian-bsut': 'UD_Armenian-BSUT',
'ud_assyrian-as': 'UD_Assyrian-AS',
'ud_azerbaijani-tuecl': 'UD_Azerbaijani-TueCL',
'ud_bambara-crb': 'UD_Bambara-CRB',
'ud_basque-bdt': 'UD_Basque-BDT',
'ud_bavarian-maibaam': 'UD_Bavarian-MaiBaam',
'ud_beja-autogramm': 'UD_Beja-Autogramm',
'ud_belarusian-hse': 'UD_Belarusian-HSE',
'ud_bengali-bru': 'UD_Bengali-BRU',
'ud_bhojpuri-bhtb': 'UD_Bhojpuri-BHTB',
'ud_bokota-chibergis': 'UD_Bokota-ChibErgIS',
'ud_bororo-bdt': 'UD_Bororo-BDT',
'ud_breton-keb': 'UD_Breton-KEB',
'ud_bulgarian-btb': 'UD_Bulgarian-BTB',
'ud_buryat-bdt': 'UD_Buryat-BDT',
'ud_cantonese-hk': 'UD_Cantonese-HK',
'ud_cappadocian-amgic': 'UD_Cappadocian-AMGiC',
'ud_cappadocian-tuecl': 'UD_Cappadocian-TueCL',
'ud_catalan-ancora': 'UD_Catalan-AnCora',
'ud_cebuano-gja': 'UD_Cebuano-GJA',
'ud_central_kurdish-mukri': 'UD_Central_Kurdish-Mukri',
'ud_chinese-beginner': 'UD_Chinese-Beginner',
'ud_chinese-cfl': 'UD_Chinese-CFL',
'ud_chinese-gsd': 'UD_Chinese-GSD',
'ud_chinese-gsdsimp': 'UD_Chinese-GSDSimp',
'ud_chinese-hk': 'UD_Chinese-HK',
'ud_chinese-pud': 'UD_Chinese-PUD',
'ud_chinese-patentchar': 'UD_Chinese-PatentChar',
'ud_chintang-ctntb': 'UD_Chintang-CTNTB',
'ud_chukchi-hse': 'UD_Chukchi-HSE',
'ud_classical_armenian-caval': 'UD_Classical_Armenian-CAVaL',
'ud_classical_chinese-kyoto': 'UD_Classical_Chinese-Kyoto',
'ud_classical_chinese-tuecl': 'UD_Classical_Chinese-TueCL',
'ud_coptic-bohairic': 'UD_Coptic-Bohairic',
'ud_coptic-scriptorium': 'UD_Coptic-Scriptorium',
'ud_croatian-set': 'UD_Croatian-SET',
'ud_czech-cac': 'UD_Czech-CAC',
'ud_czech-cltt': 'UD_Czech-CLTT',
'ud_czech-fictree': 'UD_Czech-FicTree',
'ud_czech-pdtc': 'UD_Czech-PDTC',
'ud_czech-pud': 'UD_Czech-PUD',
'ud_czech-poetry': 'UD_Czech-Poetry',
'ud_danish-ddt': 'UD_Danish-DDT',
'ud_dutch-alpino': 'UD_Dutch-Alpino',
'ud_dutch-lassysmall': 'UD_Dutch-LassySmall',
'ud_egyptian-ujaen': 'UD_Egyptian-UJaen',
'ud_english-atis': 'UD_English-Atis',
'ud_english-childes': 'UD_English-CHILDES',
'ud_english-ctetex': 'UD_English-CTeTex',
'ud_english-eslspok': 'UD_English-ESLSpok',
'ud_english-ewt': 'UD_English-EWT',
'ud_english-gentle': 'UD_English-GENTLE',
'ud_english-gum': 'UD_English-GUM',
'ud_english-gumreddit': 'UD_English-GUMReddit',
'ud_english-lines': 'UD_English-LinES',
'ud_english-littleprince': 'UD_English-LittlePrince',
'ud_english-pud': 'UD_English-PUD',
'ud_english-partut': 'UD_English-ParTUT',
'ud_english-pronouns': 'UD_English-Pronouns',
'ud_erzya-jr': 'UD_Erzya-JR',
'ud_esperanto-cairo': 'UD_Esperanto-Cairo',
'ud_esperanto-prago': 'UD_Esperanto-Prago',
'ud_estonian-edt': 'UD_Estonian-EDT',
'ud_estonian-ewt': 'UD_Estonian-EWT',
'ud_faroese-farpahc': 'UD_Faroese-FarPaHC',
'ud_faroese-oft': 'UD_Faroese-OFT',
'ud_finnish-ftb': 'UD_Finnish-FTB',
'ud_finnish-ood': 'UD_Finnish-OOD',
'ud_finnish-pud': 'UD_Finnish-PUD',
'ud_finnish-tdt': 'UD_Finnish-TDT',
'ud_french-alts': 'UD_French-ALTS',
'ud_french-fqb': 'UD_French-FQB',
'ud_french-gsd': 'UD_French-GSD',
'ud_french-pud': 'UD_French-PUD',
'ud_french-partut': 'UD_French-ParTUT',
'ud_french-parisstories': 'UD_French-ParisStories',
'ud_french-poitevindivital': 'UD_French-PoitevinDIVITAL',
'ud_french-rhapsodie': 'UD_French-Rhapsodie',
'ud_french-sequoia': 'UD_French-Sequoia',
'ud_frisian_dutch-fame': 'UD_Frisian_Dutch-Fame',
'ud_galician-ctg': 'UD_Galician-CTG',
'ud_galician-pud': 'UD_Galician-PUD',
'ud_galician-treegal': 'UD_Galician-TreeGal',
'ud_georgian-glc': 'UD_Georgian-GLC',
'ud_georgian-gnc': 'UD_Georgian-GNC',
'ud_german-gsd': 'UD_German-GSD',
'ud_german-hdt': 'UD_German-HDT',
'ud_german-lit': 'UD_German-LIT',
'ud_german-pud': 'UD_German-PUD',
'ud_gheg-gps': 'UD_Gheg-GPS',
'ud_gothic-proiel': 'UD_Gothic-PROIEL',
'ud_greek-cretan': 'UD_Greek-Cretan',
'ud_greek-gdt': 'UD_Greek-GDT',
'ud_greek-gud': 'UD_Greek-GUD',
'ud_greek-lesbian': 'UD_Greek-Lesbian',
'ud_greek-messinian': 'UD_Greek-Messinian',
'ud_guajajara-tudet': 'UD_Guajajara-TuDeT',
'ud_guarani-oldtudet': 'UD_Guarani-OldTuDeT',
'ud_gujarati-gujtb': 'UD_Gujarati-GujTB',
'ud_gwichin-tuecl': 'UD_Gwichin-TueCL',
'ud_haitian_creole-adolphe': 'UD_Haitian_Creole-Adolphe',
'ud_haitian_creole-autogramm': 'UD_Haitian_Creole-Autogramm',
'ud_hausa-northernautogramm': 'UD_Hausa-NorthernAutogramm',
'ud_hausa-southernautogramm': 'UD_Hausa-SouthernAutogramm',
'ud_hausa-westernautogramm': 'UD_Hausa-WesternAutogramm',
'ud_hebrew-htb': 'UD_Hebrew-HTB',
'ud_hebrew-iahltknesset': 'UD_Hebrew-IAHLTknesset',
'ud_hebrew-iahltwiki': 'UD_Hebrew-IAHLTwiki',
'ud_highland_puebla_nahuatl-itml': 'UD_Highland_Puebla_Nahuatl-ITML',
'ud_hindi-hdtb': 'UD_Hindi-HDTB',
'ud_hindi-pud': 'UD_Hindi-PUD',
'ud_hittite-hittb': 'UD_Hittite-HitTB',
'ud_hungarian-szeged': 'UD_Hungarian-Szeged',
'ud_icelandic-gc': 'UD_Icelandic-GC',
'ud_icelandic-icepahc': 'UD_Icelandic-IcePaHC',
'ud_icelandic-modern': 'UD_Icelandic-Modern',
'ud_icelandic-pud': 'UD_Icelandic-PUD',
'ud_ika-chibergis': 'UD_Ika-ChibErgIS',
'ud_indonesian-csui': 'UD_Indonesian-CSUI',
'ud_indonesian-gsd': 'UD_Indonesian-GSD',
'ud_indonesian-pud': 'UD_Indonesian-PUD',
'ud_irish-cadhan': 'UD_Irish-Cadhan',
'ud_irish-idt': 'UD_Irish-IDT',
'ud_irish-twittirish': 'UD_Irish-TwittIrish',
'ud_italian-isdt': 'UD_Italian-ISDT',
'ud_italian-kiparlaforest': 'UD_Italian-KIParlaForest',
'ud_italian-markit': 'UD_Italian-MarkIT',
'ud_italian-old': 'UD_Italian-Old',
'ud_italian-pud': 'UD_Italian-PUD',
'ud_italian-partut': 'UD_Italian-ParTUT',
'ud_italian-parlamint': 'UD_Italian-ParlaMint',
'ud_italian-postwita': 'UD_Italian-PoSTWITA',
'ud_italian-twittiro': 'UD_Italian-TWITTIRO',
'ud_italian-vit': 'UD_Italian-VIT',
'ud_italian-valico': 'UD_Italian-Valico',
'ud_japanese-bccwj': 'UD_Japanese-BCCWJ',
'ud_japanese-bccwjluw': 'UD_Japanese-BCCWJLUW',
'ud_japanese-gsd': 'UD_Japanese-GSD',
'ud_japanese-gsdluw': 'UD_Japanese-GSDLUW',
'ud_japanese-pud': 'UD_Japanese-PUD',
'ud_japanese-pudluw': 'UD_Japanese-PUDLUW',
'ud_javanese-csui': 'UD_Javanese-CSUI',
'ud_kaapor-tudet': 'UD_Kaapor-TuDeT',
'ud_kangri-kdtb': 'UD_Kangri-KDTB',
'ud_karelian-kkpp': 'UD_Karelian-KKPP',
'ud_karo-tudet': 'UD_Karo-TuDeT',
'ud_kazakh-ktb': 'UD_Kazakh-KTB',
'ud_khoekhoe-kdt': 'UD_Khoekhoe-KDT',
'ud_khunsari-aha': 'UD_Khunsari-AHA',
'ud_kiche-iu': 'UD_Kiche-IU',
'ud_komi_permyak-uh': 'UD_Komi_Permyak-UH',
'ud_komi_zyrian-ikdp': 'UD_Komi_Zyrian-IKDP',
'ud_komi_zyrian-lattice': 'UD_Komi_Zyrian-Lattice',
'ud_korean-gsd': 'UD_Korean-GSD',
'ud_korean-ksl': 'UD_Korean-KSL',
'ud_korean-kaist': 'UD_Korean-Kaist',
'ud_korean-littleprince': 'UD_Korean-LittlePrince',
'ud_korean-pud': 'UD_Korean-PUD',
'ud_kyrgyz-ktmu': 'UD_Kyrgyz-KTMU',
'ud_kyrgyz-tuecl': 'UD_Kyrgyz-TueCL',
'ud_latgalian-cairo': 'UD_Latgalian-Cairo',
'ud_latin-circse': 'UD_Latin-CIRCSE',
'ud_latin-ittb': 'UD_Latin-ITTB',
'ud_latin-llct': 'UD_Latin-LLCT',
'ud_latin-proiel': 'UD_Latin-PROIEL',
'ud_latin-perseus': 'UD_Latin-Perseus',
'ud_latin-udante': 'UD_Latin-UDante',
'ud_latvian-cairo': 'UD_Latvian-Cairo',
'ud_latvian-lvtb': 'UD_Latvian-LVTB',
'ud_ligurian-glt': 'UD_Ligurian-GLT',
'ud_lithuanian-alksnis': 'UD_Lithuanian-ALKSNIS',
'ud_lithuanian-hse': 'UD_Lithuanian-HSE',
'ud_livvi-kkpp': 'UD_Livvi-KKPP',
'ud_low_saxon-lsdc': 'UD_Low_Saxon-LSDC',
'ud_luxembourgish-luxbank': 'UD_Luxembourgish-LuxBank',
'ud_macedonian-mtb': 'UD_Macedonian-MTB',
'ud_madi-jarawara': 'UD_Madi-Jarawara',
'ud_maghrebi_arabic_french-arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
'ud_makurap-tudet': 'UD_Makurap-TuDeT',
'ud_malayalam-ufal': 'UD_Malayalam-UFAL',
'ud_maltese-mudt': 'UD_Maltese-MUDT',
'ud_manx-cadhan': 'UD_Manx-Cadhan',
'ud_marathi-ufal': 'UD_Marathi-UFAL',
'ud_mbya_guarani-dooley': 'UD_Mbya_Guarani-Dooley',
'ud_mbya_guarani-thomas': 'UD_Mbya_Guarani-Thomas',
'ud_middle_french-altm': 'UD_Middle_French-ALTM',
'ud_middle_french-profiterole': 'UD_Middle_French-PROFITEROLE',
'ud_moksha-jr': 'UD_Moksha-JR',
'ud_munduruku-tudet': 'UD_Munduruku-TuDeT',
'ud_naga-suansu': 'UD_Naga-Suansu',
'ud_naija-nsc': 'UD_Naija-NSC',
'ud_nayini-aha': 'UD_Nayini-AHA',
'ud_neapolitan-rb': 'UD_Neapolitan-RB',
'ud_nenets-tundra': 'UD_Nenets-Tundra',
'ud_nheengatu-complin': 'UD_Nheengatu-CompLin',
'ud_north_sami-giella': 'UD_North_Sami-Giella',
'ud_northern_kurdish-kurmanji': 'UD_Northern_Kurdish-Kurmanji',
'ud_northwest_gbaya-autogramm': 'UD_Northwest_Gbaya-Autogramm',
'ud_norwegian-bokmaal': 'UD_Norwegian-Bokmaal',
'ud_norwegian-nynorsk': 'UD_Norwegian-Nynorsk',
'ud_occitan-ttb': 'UD_Occitan-TTB',
'ud_odia-odtb': 'UD_Odia-ODTB',
'ud_old_church_slavonic-proiel': 'UD_Old_Church_Slavonic-PROIEL',
'ud_old_east_slavic-birchbark': 'UD_Old_East_Slavic-Birchbark',
'ud_old_east_slavic-rnc': 'UD_Old_East_Slavic-RNC',
'ud_old_east_slavic-ruthenian': 'UD_Old_East_Slavic-Ruthenian',
'ud_old_east_slavic-torot': 'UD_Old_East_Slavic-TOROT',
'ud_old_english-cairo': 'UD_Old_English-Cairo',
'ud_old_french-altm': 'UD_Old_French-ALTM',
'ud_old_french-profiterole': 'UD_Old_French-PROFITEROLE',
'ud_old_irish-dipsgg': 'UD_Old_Irish-DipSGG',
'ud_old_irish-dipwbg': 'UD_Old_Irish-DipWBG',
'ud_old_occitan-corag': 'UD_Old_Occitan-CorAG',
'ud_old_turkish-clausal': 'UD_Old_Turkish-Clausal',
'ud_ottoman_turkish-boun': 'UD_Ottoman_Turkish-BOUN',
'ud_ottoman_turkish-dudu': 'UD_Ottoman_Turkish-DUDU',
'ud_pashto-sikaram': 'UD_Pashto-Sikaram',
'ud_paumari-tuecl': 'UD_Paumari-TueCL',
'ud_persian-perdt': 'UD_Persian-PerDT',
'ud_persian-seraji': 'UD_Persian-Seraji',
'ud_pesh-chibergis': 'UD_Pesh-ChibErgIS',
'ud_phrygian-kul': 'UD_Phrygian-KUL',
'ud_polish-lfg': 'UD_Polish-LFG',
'ud_polish-mpdt': 'UD_Polish-MPDT',
'ud_polish-pdb': 'UD_Polish-PDB',
'ud_polish-pud': 'UD_Polish-PUD',
'ud_pomak-philotis': 'UD_Pomak-Philotis',
'ud_portuguese-bosque': 'UD_Portuguese-Bosque',
'ud_portuguese-cintil': 'UD_Portuguese-CINTIL',
'ud_portuguese-dantestocks': 'UD_Portuguese-DANTEStocks',
'ud_portuguese-gsd': 'UD_Portuguese-GSD',
'ud_portuguese-pud': 'UD_Portuguese-PUD',
'ud_portuguese-petrogold': 'UD_Portuguese-PetroGold',
'ud_portuguese-porttinari': 'UD_Portuguese-Porttinari',
'ud_romanian-art': 'UD_Romanian-ArT',
'ud_romanian-moldoro': 'UD_Romanian-MolDoRo',
'ud_romanian-nonstandard': 'UD_Romanian-Nonstandard',
'ud_romanian-rrt': 'UD_Romanian-RRT',
'ud_romanian-simonero': 'UD_Romanian-SiMoNERo',
'ud_romanian-tuecl': 'UD_Romanian-TueCL',
'ud_russian-gsd': 'UD_Russian-GSD',
'ud_russian-pud': 'UD_Russian-PUD',
'ud_russian-poetry': 'UD_Russian-Poetry',
'ud_russian-syntagrus': 'UD_Russian-SynTagRus',
'ud_russian-taiga': 'UD_Russian-Taiga',
'ud_sanskrit-ufal': 'UD_Sanskrit-UFAL',
'ud_sanskrit-vedic': 'UD_Sanskrit-Vedic',
'ud_scottish_gaelic-arcosg': 'UD_Scottish_Gaelic-ARCOSG',
'ud_serbian-set': 'UD_Serbian-SET',
'ud_shanghainese-shud': 'UD_Shanghainese-ShUD',
'ud_sicilian-stb': 'UD_Sicilian-STB',
'ud_sindhi-isra': 'UD_Sindhi-Isra',
'ud_sinhala-stb': 'UD_Sinhala-STB',
'ud_skolt_sami-giellagas': 'UD_Skolt_Sami-Giellagas',
'ud_slovak-snk': 'UD_Slovak-SNK',
'ud_slovenian-ssj': 'UD_Slovenian-SSJ',
'ud_slovenian-sst': 'UD_Slovenian-SST',
'ud_soi-aha': 'UD_Soi-AHA',
'ud_south_levantine_arabic-madar': 'UD_South_Levantine_Arabic-MADAR',
'ud_southern_kurdish-garrusi': 'UD_Southern_Kurdish-Garrusi',
'ud_spanish-ancora': 'UD_Spanish-AnCora',
'ud_spanish-coser': 'UD_Spanish-COSER',
'ud_spanish-gsd': 'UD_Spanish-GSD',
'ud_spanish-pud': 'UD_Spanish-PUD',
'ud_spanish_sign_language-lse': 'UD_Spanish_Sign_Language-LSE',
'ud_swedish-lines': 'UD_Swedish-LinES',
'ud_swedish-old': 'UD_Swedish-Old',
'ud_swedish-pud': 'UD_Swedish-PUD',
'ud_swedish-swell': 'UD_Swedish-SweLL',
'ud_swedish-talbanken': 'UD_Swedish-Talbanken',
'ud_swedish_sign_language-sslc': 'UD_Swedish_Sign_Language-SSLC',
'ud_tagalog-trg': 'UD_Tagalog-TRG',
'ud_tagalog-ugnayan': 'UD_Tagalog-Ugnayan',
'ud_tamil-mwtt': 'UD_Tamil-MWTT',
'ud_tamil-ttb': 'UD_Tamil-TTB',
'ud_tatar-nmctt': 'UD_Tatar-NMCTT',
'ud_teko-tudet': 'UD_Teko-TuDeT',
'ud_telugu-mtg': 'UD_Telugu-MTG',
'ud_telugu_english-tect': 'UD_Telugu_English-TECT',
'ud_thai-pud': 'UD_Thai-PUD',
'ud_thai-tud': 'UD_Thai-TUD',
'ud_tswana-popapolelo': 'UD_Tswana-Popapolelo',
'ud_tupinamba-tudet': 'UD_Tupinamba-TuDeT',
'ud_turkish-atis': 'UD_Turkish-Atis',
'ud_turkish-boun': 'UD_Turkish-BOUN',
'ud_turkish-framenet': 'UD_Turkish-FrameNet',
'ud_turkish-gb': 'UD_Turkish-GB',
'ud_turkish-imst': 'UD_Turkish-IMST',
'ud_turkish-kenet': 'UD_Turkish-Kenet',
'ud_turkish-pud': 'UD_Turkish-PUD',
'ud_turkish-penn': 'UD_Turkish-Penn',
'ud_turkish-tourism': 'UD_Turkish-Tourism',
'ud_turkish-tuecl': 'UD_Turkish-TueCL',
'ud_turkish_english-butr': 'UD_Turkish_English-BUTR',
'ud_turkish_german-sagt': 'UD_Turkish_German-SAGT',
'ud_ukrainian-iu': 'UD_Ukrainian-IU',
'ud_ukrainian-parlamint': 'UD_Ukrainian-ParlaMint',
'ud_umbrian-ikuvina': 'UD_Umbrian-IKUVINA',
'ud_upper_sorbian-ufal': 'UD_Upper_Sorbian-UFAL',
'ud_urdu-udtb': 'UD_Urdu-UDTB',
'ud_uyghur-udt': 'UD_Uyghur-UDT',
'ud_uzbek-tuecl': 'UD_Uzbek-TueCL',
'ud_uzbek-ut': 'UD_Uzbek-UT',
'ud_uzbek-uzudt': 'UD_Uzbek-UzUDT',
'ud_veps-vwt': 'UD_Veps-VWT',
'ud_vietnamese-tuecl': 'UD_Vietnamese-TueCL',
'ud_vietnamese-vtb': 'UD_Vietnamese-VTB',
'ud_warlpiri-ufal': 'UD_Warlpiri-UFAL',
'ud_welsh-ccg': 'UD_Welsh-CCG',
'ud_western_armenian-armtdp': 'UD_Western_Armenian-ArmTDP',
'ud_western_sierra_puebla_nahuatl-itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
'ud_wolof-wtb': 'UD_Wolof-WTB',
'ud_xavante-xdt': 'UD_Xavante-XDT',
'ud_xibe-xdt': 'UD_Xibe-XDT',
'ud_yakut-yktdt': 'UD_Yakut-YKTDT',
'ud_yiddish-yitb': 'UD_Yiddish-YiTB',
'ud_yoruba-ytb': 'UD_Yoruba-YTB',
'ud_yupik-sli': 'UD_Yupik-SLI',
'ud_zaar-autogramm': 'UD_Zaar-Autogramm',
}
def canonical_treebank_name(ud_name):
if ud_name in SHORT_NAMES:
return SHORT_NAMES[ud_name]
return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
================================================
FILE: stanza/models/common/stanza_object.py
================================================
def _readonly_setter(self, name):
full_classname = self.__class__.__module__
if full_classname is None:
full_classname = self.__class__.__qualname__
else:
full_classname += '.' + self.__class__.__qualname__
raise ValueError(f'Property "{name}" of "{full_classname}" is read-only.')
class StanzaObject(object):
"""
Base class for all Stanza data objects that allows for some flexibility handling annotations
"""
@classmethod
def add_property(cls, name, default=None, getter=None, setter=None):
"""
Add a property accessible through self.{name} with underlying variable self._{name}.
Optionally setup a setter as well.
"""
if hasattr(cls, name):
raise ValueError(f'Property by the name of {name} already exists in {cls}. Maybe you want to find another name?')
setattr(cls, f'_{name}', default)
if getter is None:
getter = lambda self: getattr(self, f'_{name}')
if setter is None:
setter = lambda self, value: _readonly_setter(self, name)
setattr(cls, name, property(getter, setter))
================================================
FILE: stanza/models/common/trainer.py
================================================
import torch
class Trainer:
def change_lr(self, new_lr):
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
def save(self, filename):
savedict = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict()
}
torch.save(savedict, filename)
def load(self, filename):
savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)
self.model.load_state_dict(savedict['model'])
if self.args['mode'] == 'train':
self.optimizer.load_state_dict(savedict['optimizer'])
================================================
FILE: stanza/models/common/utils.py
================================================
"""
Utility functions.
"""
import argparse
from collections import Counter
from contextlib import contextmanager
import gzip
import json
import logging
import lzma
import os
import random
import re
import sys
import unicodedata
import zipfile
import torch
import torch.nn as nn
import numpy as np
try:
from udtools import udeval
except ImportError:
from udtools.src.udtools import udeval
try:
from udtools.udeval import UDError
except ImportError:
from udtools.src.udtools.udeval import UDError
from stanza.models.common.constant import lcode2lang
import stanza.models.common.seq2seq_constant as constant
from stanza.resources.default_packages import TRANSFORMER_NICKNAMES
logger = logging.getLogger('stanza')
# filenames
def get_wordvec_file(wordvec_dir, shorthand, wordvec_type=None):
""" Lookup the name of the word vectors file, given a directory and the language shorthand.
"""
lcode, tcode = shorthand.split('_', 1)
lang = lcode2lang[lcode]
# locate language folder
word2vec_dir = os.path.join(wordvec_dir, 'word2vec', lang)
fasttext_dir = os.path.join(wordvec_dir, 'fasttext', lang)
lang_dir = None
if wordvec_type is not None:
lang_dir = os.path.join(wordvec_dir, wordvec_type, lang)
if not os.path.exists(lang_dir):
raise FileNotFoundError("Word vector type {} was specified, but directory {} does not exist".format(wordvec_type, lang_dir))
elif os.path.exists(word2vec_dir): # first try word2vec
lang_dir = word2vec_dir
elif os.path.exists(fasttext_dir): # otherwise try fasttext
lang_dir = fasttext_dir
else:
raise FileNotFoundError("Cannot locate word vector directory for language: {} Looked in {} and {}".format(lang, word2vec_dir, fasttext_dir))
# look for wordvec filename in {lang_dir}
filename = os.path.join(lang_dir, '{}.vectors'.format(lcode))
if os.path.exists(filename + ".xz"):
filename = filename + ".xz"
elif os.path.exists(filename + ".txt"):
filename = filename + ".txt"
return filename
@contextmanager
def output_stream(filename=None):
"""
Yields the given file if a file is given, or returns sys.stdout if filename is None
Opens the file in a context manager so it closes nicely
"""
if filename is None:
yield sys.stdout
else:
with open(filename, "w", encoding="utf-8") as fout:
yield fout
@contextmanager
def open_read_text(filename, encoding="utf-8"):
"""
Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular text otherwise.
Use as a context
eg:
with open_read_text(filename) as fin:
do stuff
File will be closed once the context exits
"""
if filename.endswith(".xz"):
with lzma.open(filename, mode='rt', encoding=encoding) as fin:
yield fin
elif filename.endswith(".gz"):
with gzip.open(filename, mode='rt', encoding=encoding) as fin:
yield fin
else:
with open(filename, encoding=encoding) as fin:
yield fin
@contextmanager
def open_read_binary(filename):
"""
Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular binary file otherwise.
If a .zip file is given, it can be read if there is a single file in there
Use as a context
eg:
with open_read_binary(filename) as fin:
do stuff
File will be closed once the context exits
"""
if filename.endswith(".xz"):
with lzma.open(filename, mode='rb') as fin:
yield fin
elif filename.endswith(".gz"):
with gzip.open(filename, mode='rb') as fin:
yield fin
elif filename.endswith(".zip"):
with zipfile.ZipFile(filename) as zin:
input_names = zin.namelist()
if len(input_names) == 0:
raise ValueError("Empty zip archive")
if len(input_names) > 1:
raise ValueError("zip file %s has more than one file in it")
with zin.open(input_names[0]) as fin:
yield fin
else:
with open(filename, mode='rb') as fin:
yield fin
# training schedule
def get_adaptive_eval_interval(cur_dev_size, thres_dev_size, base_interval):
""" Adjust the evaluation interval adaptively.
If cur_dev_size <= thres_dev_size, return base_interval;
else, linearly increase the interval (round to integer times of base interval).
"""
if cur_dev_size <= thres_dev_size:
return base_interval
else:
alpha = round(cur_dev_size / thres_dev_size)
return base_interval * alpha
# ud utils
def ud_scores(gold_conllu_file, system_conllu_file):
def has_readline(f):
return hasattr(f, 'readline') and callable(f.readline)
if has_readline(gold_conllu_file):
try:
gold_ud = udeval.load_conllu(gold_conllu_file, '', {})
except UDError as e:
raise UDError("Could not process gold UD file") from e
else:
try:
gold_ud = udeval.load_conllu_file(gold_conllu_file)
except UDError as e:
raise UDError("Could not read %s" % gold_conllu_file) from e
if has_readline(system_conllu_file):
try:
system_ud = udeval.load_conllu(system_conllu_file, '', {})
except UDError as e:
raise UDError("Could not process system UD file") from e
else:
try:
system_ud = udeval.load_conllu_file(system_conllu_file)
except UDError as e:
raise UDError("Could not read %s" % system_conllu_file) from e
evaluation = udeval.evaluate(gold_ud, system_ud)
return evaluation
def harmonic_mean(a, weights=None):
if any([x == 0 for x in a]):
return 0
else:
assert weights is None or len(weights) == len(a), 'Weights has length {} which is different from that of the array ({}).'.format(len(weights), len(a))
if weights is None:
return len(a) / sum([1/x for x in a])
else:
return sum(weights) / sum(w/x for x, w in zip(a, weights))
# torch utils
def dispatch_optimizer(name, parameters, opt_logger, lr=None, betas=None, eps=None, momentum=None, **extra_args):
extra_logging = ""
if len(extra_args) > 0:
extra_logging = ", " + ", ".join("%s=%s" % (x, y) for x, y in extra_args.items())
if name == 'amsgrad':
opt_logger.debug("Building Adam w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
elif name == 'amsgradw':
opt_logger.debug("Building AdamW w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
elif name == 'sgd':
opt_logger.debug("Building SGD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args)
elif name == 'adagrad':
opt_logger.debug("Building Adagrad with lr=%f%s", lr, extra_logging)
return torch.optim.Adagrad(parameters, lr=lr, **extra_args)
elif name == 'adam':
opt_logger.debug("Building Adam with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
return torch.optim.Adam(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
elif name == 'adamw':
opt_logger.debug("Building AdamW with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
return torch.optim.AdamW(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
elif name == 'adamax':
opt_logger.debug("Building Adamax%s", extra_logging)
return torch.optim.Adamax(parameters, **extra_args) # use default lr
elif name == 'adadelta':
opt_logger.debug("Building Adadelta with lr=%f%s", lr, extra_logging)
return torch.optim.Adadelta(parameters, lr=lr, **extra_args)
elif name == 'adabelief':
try:
from adabelief_pytorch import AdaBelief
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e
opt_logger.debug("Building AdaBelief with lr=%f, eps=%f%s", lr, eps, extra_logging)
# TODO: add weight_decouple and rectify as extra args?
return AdaBelief(parameters, lr=lr, eps=eps, weight_decouple=True, rectify=True, **extra_args)
elif name == 'madgrad':
try:
import madgrad
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e
opt_logger.debug("Building MADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
return madgrad.MADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
elif name == 'mirror_madgrad':
try:
import madgrad
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create mirror_madgrad optimizer. Perhaps the madgrad package is not installed") from e
opt_logger.debug("Building MirrorMADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
return madgrad.MirrorMADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
elif name == 'rmsprop':
opt_logger.debug("Building RMSprop with lr=%f%s", lr, extra_logging)
return torch.optim.RMSprop(parameters, lr=lr, **extra_args)
else:
raise ValueError("Unsupported optimizer: {}".format(name))
def get_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None, opt_logger=None):
opt_logger = opt_logger if opt_logger is not None else logger
base_parameters = [p for n, p in model.named_parameters()
if p.requires_grad and not n.startswith("bert_model.")
and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
parameters = [{'param_group_name': 'base', 'params': base_parameters}]
charlm_parameters = [p for n, p in model.named_parameters()
if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
if not is_peft:
bert_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
# bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
if len(bert_parameters) > 0 and bert_finetune_layers is not None:
num_layers = model.bert_model.config.num_hidden_layers
start_layer = num_layers - bert_finetune_layers
bert_parameters = []
for layer_num in range(start_layer, num_layers):
bert_parameters.extend([param for name, param in model.named_parameters()
if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
if len(bert_parameters) > 0 and bert_learning_rate > 0:
opt_logger.debug("Finetuning %d bert parameters with LR %s and WD %s", len(bert_parameters), lr * bert_learning_rate, bert_weight_decay)
parameters.append({'param_group_name': 'bert', 'params': bert_parameters, 'lr': lr * bert_learning_rate})
if bert_weight_decay is not None:
parameters[-1]['weight_decay'] = bert_weight_decay
else:
# some optimizers seem to train some even with a learning rate of 0...
if bert_learning_rate > 0:
# because PEFT handles what to hand to an optimizer, we don't want to touch that
parameters.append({'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate})
if bert_weight_decay is not None:
parameters[-1]['weight_decay'] = bert_weight_decay
extra_args = {}
if weight_decay is not None:
extra_args["weight_decay"] = weight_decay
return dispatch_optimizer(name, parameters, opt_logger=opt_logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
def get_split_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None):
"""Same as `get_optimizer`, but splits the optimizer for Bert into a separate optimizer"""
base_parameters = [p for n, p in model.named_parameters()
if p.requires_grad and not n.startswith("bert_model.")
and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
parameters = [{'param_group_name': 'base', 'params': base_parameters}]
charlm_parameters = [p for n, p in model.named_parameters()
if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
bert_parameters = None
if not is_peft:
trainable_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
# bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
if len(trainable_parameters) > 0 and bert_finetune_layers is not None:
num_layers = model.bert_model.config.num_hidden_layers
start_layer = num_layers - bert_finetune_layers
trainable_parameters = []
for layer_num in range(start_layer, num_layers):
trainable_parameters.extend([param for name, param in model.named_parameters()
if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
if len(trainable_parameters) > 0:
bert_parameters = [{'param_group_name': 'bert', 'params': trainable_parameters, 'lr': lr * bert_learning_rate}]
else:
# because PEFT handles what to hand to an optimizer, we don't want to touch that
bert_parameters = [{'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate}]
extra_args = {}
if weight_decay is not None:
extra_args["weight_decay"] = weight_decay
optimizers = {
"general_optimizer": dispatch_optimizer(name, parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
}
if bert_parameters is not None and bert_learning_rate > 0.0:
if bert_weight_decay is not None:
extra_args['weight_decay'] = bert_weight_decay
optimizers["bert_optimizer"] = dispatch_optimizer(name, bert_parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
return optimizers
def change_lr(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
def flatten_indices(seq_lens, width):
flat = []
for i, l in enumerate(seq_lens):
for j in range(l):
flat.append(i * width + j)
return flat
def keep_partial_grad(grad, topk):
"""
Keep only the topk rows of grads.
"""
assert topk < grad.size(0)
grad.data[topk:].zero_()
return grad
# other utils
def ensure_dir(d, verbose=True):
if not os.path.exists(d):
if verbose:
logger.info("Directory {} does not exist; creating...".format(d))
# exist_ok: guard against race conditions
os.makedirs(d, exist_ok=True)
def save_config(config, path, verbose=True):
with open(path, 'w') as outfile:
json.dump(config, outfile, indent=2)
if verbose:
print("Config saved to file {}".format(path))
return config
def load_config(path, verbose=True):
with open(path) as f:
config = json.load(f)
if verbose:
print("Config loaded from file {}".format(path))
return config
def print_config(config):
info = "Running with the following configs:\n"
for k,v in config.items():
info += "\t{} : {}\n".format(k, str(v))
logger.info("\n" + info + "\n")
def normalize_text(text):
return unicodedata.normalize('NFD', text)
def unmap_with_copy(indices, src_tokens, vocab):
"""
Unmap a list of list of indices, by optionally copying from src_tokens.
"""
result = []
for ind, tokens in zip(indices, src_tokens):
words = []
for idx in ind:
if idx >= 0:
words.append(vocab.id2word[idx])
else:
idx = -idx - 1 # flip and minus 1
words.append(tokens[idx])
result += [words]
return result
def prune_decoded_seqs(seqs):
"""
Prune decoded sequences after EOS token.
"""
out = []
for s in seqs:
if constant.EOS in s:
idx = s.index(constant.EOS_TOKEN)
out += [s[:idx]]
else:
out += [s]
return out
def prune_hyp(hyp):
"""
Prune a decoded hypothesis
"""
if constant.EOS_ID in hyp:
idx = hyp.index(constant.EOS_ID)
return hyp[:idx]
else:
return hyp
def prune(data_list, lens):
assert len(data_list) == len(lens)
nl = []
for d, l in zip(data_list, lens):
nl.append(d[:l])
return nl
def sort(packed, ref, reverse=True):
"""
Sort a series of packed list, according to a ref list.
Also return the original index before the sort.
"""
assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
packed = [ref] + [range(len(ref))] + list(packed)
sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
return tuple(sorted_packed[1:])
def unsort(sorted_list, oidx):
"""
Unsort a sorted list, based on the original idx.
"""
assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
if len(sorted_list) == 0:
return []
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
return unsorted
def sort_with_indices(data, key=None, reverse=False):
"""
Sort data and return both the data and the original indices.
One useful application is to sort by length, which can be done with key=len
Returns the data as a sorted list, then the indices of the original list.
"""
if not data:
return [], []
if key:
ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)
else:
ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)
result = tuple(zip(*ordered))
return result[1], result[0]
def split_into_batches(data, batch_size):
"""
Returns a list of intervals so that each interval is either <= batch_size or one element long.
Long elements are not dropped from the intervals.
data is a list of lists
batch_size is how long to make each batch
return value is a list of pairs, start_idx end_idx
"""
intervals = []
interval_start = 0
interval_size = 0
for idx, line in enumerate(data):
if len(line) > batch_size:
# guess we'll just hope the model can handle a batch of this size after all
if interval_size > 0:
intervals.append((interval_start, idx))
intervals.append((idx, idx+1))
interval_start = idx+1
interval_size = 0
elif len(line) + interval_size > batch_size:
# this line puts us over batch_size
intervals.append((interval_start, idx))
interval_start = idx
interval_size = len(line)
else:
interval_size = interval_size + len(line)
if interval_size > 0:
# there's some leftover
intervals.append((interval_start, len(data)))
return intervals
def tensor_unsort(sorted_tensor, oidx):
"""
Unsort a sorted tensor on its 0-th dimension, based on the original idx.
"""
assert sorted_tensor.size(0) == len(oidx), "Number of list elements must match with original indices."
backidx = [x[0] for x in sorted(enumerate(oidx), key=lambda x: x[1])]
return sorted_tensor[backidx]
def set_random_seed(seed):
"""
Set a random seed on all of the things which might need it.
torch, np, python random, and torch.cuda
"""
if seed is None:
seed = random.randint(0, 1000000000)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# some of these calls are probably redundant
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
def find_missing_tags(known_tags, test_tags):
if isinstance(known_tags, list) and isinstance(known_tags[0], list):
known_tags = set(x for y in known_tags for x in y)
if isinstance(test_tags, list) and isinstance(test_tags[0], list):
test_tags = sorted(set(x for y in test_tags for x in y))
missing_tags = sorted(x for x in test_tags if x not in known_tags)
return missing_tags
def warn_missing_tags(known_tags, test_tags, test_set_name):
"""
Print a warning if any tags present in the second list are not in the first list.
Can also handle a list of lists.
"""
missing_tags = find_missing_tags(known_tags, test_tags)
if len(missing_tags) > 0:
logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags))
return True
return False
def checkpoint_name(save_dir, save_name, checkpoint_name):
"""
Will return a recommended checkpoint name for the given dir, save_name, optional checkpoint_name
For example, can pass in args['save_dir'], args['save_name'], args['checkpoint_save_name']
"""
if checkpoint_name:
model_dir = os.path.split(checkpoint_name)[0]
if model_dir == save_dir:
return checkpoint_name
return os.path.join(save_dir, checkpoint_name)
model_dir = os.path.split(save_name)[0]
if model_dir != save_dir:
save_name = os.path.join(save_dir, save_name)
if save_name.endswith(".pt"):
return save_name[:-3] + "_checkpoint.pt"
return save_name + "_checkpoint"
def default_device():
"""
Pick a default device based on what's available on this system
"""
if torch.cuda.is_available():
return 'cuda'
return 'cpu'
def add_device_args(parser):
"""
Add args which specify cpu, cuda, or arbitrary device
"""
parser.add_argument('--device', type=str, default=default_device(), help='Which device to run on - use a torch device string name')
parser.add_argument('--cuda', dest='device', action='store_const', const='cuda', help='Run on CUDA')
parser.add_argument('--cpu', dest='device', action='store_const', const='cpu', help='Ignore CUDA and run on CPU')
def load_elmo(elmo_model):
# This import is here so that Elmo integration can be treated
# as an optional feature
import elmoformanylangs
logger.info("Loading elmo: %s" % elmo_model)
elmo_model = elmoformanylangs.Embedder(elmo_model)
return elmo_model
def log_training_args(args, args_logger, name="training"):
"""
For record keeping purposes, log the arguments when training
"""
if isinstance(args, argparse.Namespace):
args = vars(args)
keys = sorted(args.keys())
log_lines = ['%s: %s' % (k, args[k]) for k in keys]
args_logger.info('ARGS USED AT %s TIME:\n%s\n', name.upper(), '\n'.join(log_lines))
def embedding_name(args):
"""
Return the generic name of the biggest embedding used by a model.
Used by POS and depparse, for example.
TODO: Probably will make the transformer names a bit more informative,
such as electra, roberta, etc. Maybe even phobert for VI, for example
"""
embedding = "nocharlm"
if args['wordvec_pretrain_file'] is None and args['wordvec_file'] is None:
embedding = "nopretrain"
if args.get('charlm', True) and (args['charlm_forward_file'] or args['charlm_backward_file']):
embedding = "charlm"
if args['bert_model']:
if args['bert_model'] in TRANSFORMER_NICKNAMES:
embedding = TRANSFORMER_NICKNAMES[args['bert_model']]
else:
embedding = "transformer"
return embedding
def standard_model_file_name(args, model_type, **kwargs):
"""
Returns a model file name based on some common args found in the various models.
The expectation is that the args will have something like
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model")
Then the model shorthand, embedding type, and other args will be
turned into arguments in a format string
"""
embedding = embedding_name(args)
finetune = ""
transformer_lr = ""
if args.get("bert_finetune", False):
finetune = "finetuned"
if "bert_learning_rate" in args:
transformer_lr = "{}".format(args["bert_learning_rate"])
use_peft = "nopeft"
if args.get("bert_finetune", False) and args.get("use_peft", False):
use_peft = "peft"
bert_finetuning = ""
if args.get("bert_finetune", False):
if args.get("use_peft", False):
bert_finetuning = "peft"
else:
bert_finetuning = "ft"
seed = args.get('seed', None)
if seed is None:
seed = ""
else:
seed = str(seed)
format_args = {
"batch_size": args['batch_size'],
"bert_finetuning": bert_finetuning,
"embedding": embedding,
"finetune": finetune,
"peft": use_peft,
"seed": seed,
"shorthand": args['shorthand'],
"transformer_lr": transformer_lr,
}
format_args.update(**kwargs)
model_file = args['save_name'].format(**format_args)
model_file = re.sub("_+", "_", model_file)
model_dir = os.path.split(model_file)[0]
if not os.path.exists(os.path.join(args['save_dir'], model_file)) and os.path.exists(model_file):
return model_file
if model_dir.startswith(args['save_dir']):
return model_file
return os.path.join(args['save_dir'], model_file)
def escape_misc_space(space):
spaces = []
for char in space:
if char == ' ':
spaces.append('\\s')
elif char == '\t':
spaces.append('\\t')
elif char == '\r':
spaces.append('\\r')
elif char == '\n':
spaces.append('\\n')
elif char == '|':
spaces.append('\\p')
elif char == '\\':
spaces.append('\\\\')
elif char == ' ':
spaces.append('\\u00A0')
else:
spaces.append(char)
escaped_space = "".join(spaces)
return escaped_space
def unescape_misc_space(misc_space):
spaces = []
pos = 0
while pos < len(misc_space):
if misc_space[pos:pos+2] == '\\s':
spaces.append(' ')
pos += 2
elif misc_space[pos:pos+2] == '\\t':
spaces.append('\t')
pos += 2
elif misc_space[pos:pos+2] == '\\r':
spaces.append('\r')
pos += 2
elif misc_space[pos:pos+2] == '\\n':
spaces.append('\n')
pos += 2
elif misc_space[pos:pos+2] == '\\p':
spaces.append('|')
pos += 2
elif misc_space[pos:pos+2] == '\\\\':
spaces.append('\\')
pos += 2
elif misc_space[pos:pos+6] == '\\u00A0':
spaces.append(' ')
pos += 6
else:
spaces.append(misc_space[pos])
pos += 1
unescaped_space = "".join(spaces)
return unescaped_space
def space_before_to_misc(space):
"""
Convert whitespace to SpacesBefore specifically for the start of a document.
In general, UD datasets do not have both SpacesAfter on a token and SpacesBefore on the next token.
The space(s) are only marked on one of the tokens.
Only at the very beginning of a document is it necessary to mark what spaces occurred before the actual text,
and the default assumption is that there is no space if there is no SpacesBefore annotation.
"""
if not space:
return ""
escaped_space = escape_misc_space(space)
return "SpacesBefore=%s" % escaped_space
def space_after_to_misc(space):
"""
Convert whitespace back to the escaped format - either SpaceAfter=No or SpacesAfter=...
"""
if not space:
return "SpaceAfter=No"
if space == " ":
return ""
escaped_space = escape_misc_space(space)
return "SpacesAfter=%s" % escaped_space
def misc_to_space_before(misc):
"""
Find any SpacesBefore annotation in the MISC column and turn it into a space value
"""
if not misc:
return ""
pieces = misc.split("|")
for piece in pieces:
if not piece.lower().startswith("spacesbefore="):
continue
misc_space = piece.split("=", maxsplit=1)[1]
return unescape_misc_space(misc_space)
return ""
def misc_to_space_after(misc):
"""
Convert either SpaceAfter=No or the SpacesAfter annotation
see https://universaldependencies.org/misc.html#spacesafter
We compensate for some treebanks using SpaceAfter=\n instead of SpacesAfter=\n
On the way back, though, those annotations will be turned into SpacesAfter
"""
if not misc:
return " "
pieces = misc.split("|")
if any(piece.lower() == "spaceafter=no" for piece in pieces):
return ""
if "SpaceAfter=Yes" in pieces:
# as of UD 2.11, the Cantonese treebank had this as a misc feature
return " "
if "SpaceAfter=No~" in pieces:
# as of UD 2.11, a weird typo in the Russian Taiga dataset
return ""
for piece in pieces:
if piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter="):
misc_space = piece.split("=", maxsplit=1)[1]
return unescape_misc_space(misc_space)
return " "
def log_norms(model):
lines = ["NORMS FOR MODEL PARAMTERS"]
pieces = []
for name, param in model.named_parameters():
if param.requires_grad:
pieces.append((name, "%.6g" % torch.norm(param).item(), "%d" % param.numel()))
name_len = max(len(x[0]) for x in pieces)
norm_len = max(len(x[1]) for x in pieces)
line_format = " %-" + str(name_len) + "s %" + str(norm_len) + "s %s"
for line in pieces:
lines.append(line_format % line)
logger.info("\n".join(lines))
def attach_bert_model(model, bert_model, bert_tokenizer, use_peft, force_bert_saved):
if use_peft:
# we use a peft-specific pathway for saving peft weights
model.add_unsaved_module('bert_model', bert_model)
model.bert_model.train()
elif force_bert_saved:
model.bert_model = bert_model
elif bert_model is not None:
model.add_unsaved_module('bert_model', bert_model)
for _, parameter in bert_model.named_parameters():
parameter.requires_grad = False
else:
model.bert_model = None
model.add_unsaved_module('bert_tokenizer', bert_tokenizer)
def build_save_each_filename(base_filename):
"""
If the given name doesn't have %d in it, add %4d at the end of the filename
This way, there's something to count how many models have been saved
"""
try:
base_filename % 1
except TypeError:
# so models.pt -> models_0001.pt, etc
pieces = os.path.splitext(model_save_each_file)
base_filename = pieces[0] + "_%04d" + pieces[1]
return base_filename
# the constituency parser went through a large suite of experiments to
# optimize which nonlinearity to use
#
# this is on a VI dataset, VLSP_22, using 1/10th of the data as a dev set
# (no released test set at the time of the experiment)
# original non-Bert tagger, with 1 iteration each instead of averaged over 5
# considering the number of experiments and the length of time they would take
#
# Gelu had the highest score, which tracks with other experiments run.
# Note that publicly released models have typically used Relu
# on account of the runtime speed improvement
#
# Anyway, a larger experiment of 5x models on gelu or relu, using the
# Roberta POS tagger and a corpus of silver trees, resulted in 0.8270
# for relu and 0.8248 for gelu. So it is not even clear that
# switching to gelu would be an accuracy improvement.
#
# Gelu: 82.32
# Relu: 82.14
# Mish: 81.95
# Relu6: 81.91
# Silu: 81.90
# ELU: 81.73
# Hardswish: 81.67
# Softsign: 81.63
# Hardtanh: 81.44
# Celu: 81.43
# Selu: 81.17
# TODO: need to redo the prelu experiment with
# possibly different numbers of parameters
# and proper weight decay
# Prelu: 80.95 (terminated early)
# Softplus: 80.94
# Logsigmoid: 80.91
# Hardsigmoid: 79.03
# RReLU: 77.00
# Hardshrink: failed
# Softshrink: failed
NONLINEARITY = {
'none': nn.Identity,
'celu': nn.CELU,
'elu': nn.ELU,
'gelu': nn.GELU,
'glu': nn.GLU,
'hardsigmoid':nn.Hardsigmoid,
'hardshrink': nn.Hardshrink,
'hardswish': nn.Hardswish,
'hardtanh': nn.Hardtanh,
'leaky_relu': nn.LeakyReLU,
'logsigmoid': nn.LogSigmoid,
'mish': nn.Mish,
'prelu': nn.PReLU,
'relu': nn.ReLU,
'relu6': nn.ReLU6,
'rrelu': nn.RReLU,
'selu': nn.SELU,
'silu': nn.SiLU,
'softplus': nn.Softplus,
'softshrink': nn.Softshrink,
'softsign': nn.Softsign,
'tanhshrink': nn.Tanhshrink,
'tanh': nn.Tanh,
}
def build_nonlinearity(nonlinearity):
"""
Look up "nonlinearity" in a map from function name to function, build the appropriate layer.
"""
if nonlinearity is None:
return nn.Identity()
if nonlinearity in NONLINEARITY:
return NONLINEARITY[nonlinearity]()
raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity)
DEFAULT_WORD_CUTOFF = 7
def update_word_cutoff(pt, word_cutoff):
"""
If a word cutoff option wasn't set, pick a word cutoff based on the size of the pretrain
Using a lower word cutoff for the smaller pretrains helps quite a bit on the Abkhaz tagger,
where all we have is a very small PT.
no WV:
ab_abnc dev
UPOS XPOS UFeats AllTags
89.06 62.53 75.21 61.53
ab_abnc test
UPOS XPOS UFeats AllTags
88.96 61.37 74.85 60.29
WV, cutoff 7
ab_abnc dev
UPOS XPOS UFeats AllTags
89.15 62.76 75.43 61.62
ab_abnc test
UPOS XPOS UFeats AllTags
89.64 61.56 75.31 60.88
WV, cutoff 0
ab_abnc
UPOS XPOS UFeats AllTags
90.02 64.81 76.75 64.13
ab_abnc
UPOS XPOS UFeats AllTags
90.19 63.95 76.62 63.59
The results are less compelling for depparse, though:
no WV
ab_abnc dev
UAS LAS CLAS MLAS BLEX
78.85 65.27 57.31 56.27 57.31
ab_abnc test
UAS LAS CLAS MLAS BLEX
78.11 64.22 57.45 56.90 57.45
WV with cutoff 7
ab_abnc dev
UAS LAS CLAS MLAS BLEX
79.49 65.41 57.15 56.38 57.15
ab_abnc test
UAS LAS CLAS MLAS BLEX
77.30 64.41 57.13 56.65 57.13
WV with cutoff 0
ab_abnc dev
UAS LAS CLAS MLAS BLEX
80.04 65.68 56.81 56.04 56.81
ab_abnc test
UAS LAS CLAS MLAS BLEX
77.66 64.86 57.28 57.00 57.28
"""
if word_cutoff is not None:
return word_cutoff
if pt is None:
logger.info('Using 0 as the word cutoff (no pretrain available)')
return 0
if len(pt) < 5000:
word_cutoff = 0
else:
word_cutoff = DEFAULT_WORD_CUTOFF
logger.info('Using %d as the word cutoff based on the size of the pretrain (%d)', word_cutoff, len(pt))
return word_cutoff
QUESTION_RE = re.compile("^[??︖﹖⁇][??︖﹖⁇!!︕﹗‼]+$")
EXCLAM_RE = re.compile("^[!!︕﹗‼][??︖﹖⁇!!︕﹗‼]+$")
def simplify_punct(data):
"""
For the data formats used in the POS and depparse, replace long punct words with simpler forms
replace ?[?!]+ -> ?
replace ![?!]+ -> !
also, include other non-ascii ?!
"""
for sent_idx in range(len(data)):
for tok_idx in range(len(data[sent_idx])):
data[sent_idx][tok_idx][0] = QUESTION_RE.sub("?", data[sent_idx][tok_idx][0])
data[sent_idx][tok_idx][0] = EXCLAM_RE.sub("!", data[sent_idx][tok_idx][0])
return data
================================================
FILE: stanza/models/common/vocab.py
================================================
from copy import copy
from collections import Counter, OrderedDict
from collections.abc import Iterable
import os
import pickle
PAD = ''
PAD_ID = 0
UNK = ''
UNK_ID = 1
EMPTY = ''
EMPTY_ID = 2
ROOT = ''
ROOT_ID = 3
VOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]
VOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)
class BaseVocab:
""" A base class for common vocabulary operations. Each subclass should at least
implement its own build_vocab() function."""
def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False):
self.data = data
self.lang = lang
self.idx = idx
self.cutoff = cutoff
self.lower = lower
if data is not None:
self.build_vocab()
self.state_attrs = ['lang', 'idx', 'cutoff', 'lower', '_unit2id', '_id2unit']
def build_vocab(self):
raise NotImplementedError("This BaseVocab does not have build_vocab implemented. This method should create _id2unit and _unit2id")
def state_dict(self):
""" Returns a dictionary containing all states that are necessary to recover
this vocab. Useful for serialization."""
state = OrderedDict()
for attr in self.state_attrs:
if hasattr(self, attr):
state[attr] = getattr(self, attr)
return state
@classmethod
def load_state_dict(cls, state_dict):
""" Returns a new Vocab instance constructed from a state dict. """
new = cls()
for attr, value in state_dict.items():
setattr(new, attr, value)
return new
def normalize_unit(self, unit):
# be sure to look in subclasses for other normalization being done
# especially PretrainWordVocab
if unit is None:
return unit
if self.lower:
return unit.lower()
return unit
def unit2id(self, unit):
unit = self.normalize_unit(unit)
if unit in self._unit2id:
return self._unit2id[unit]
else:
return self._unit2id[UNK]
def id2unit(self, id):
return self._id2unit[id]
def map(self, units):
return [self.unit2id(x) for x in units]
def unmap(self, ids):
return [self.id2unit(x) for x in ids]
def __str__(self):
lang_str = "(%s)" % self.lang if self.lang else ""
name = str(type(self)) + lang_str
return "<%s: %s>" % (name, self._id2unit)
def __len__(self):
return len(self._id2unit)
def __getitem__(self, key):
if isinstance(key, str):
return self.unit2id(key)
elif isinstance(key, int) or isinstance(key, list):
return self.id2unit(key)
else:
raise TypeError("Vocab key must be one of str, list, or int")
def __contains__(self, key):
return self.normalize_unit(key) in self._unit2id
@property
def size(self):
return len(self)
class DeltaVocab(BaseVocab):
"""
A vocab that starts off with a BaseVocab, then possibly adds more tokens based on the text in the given data
Currently meant only for characters, such as built by MWT or Lemma
Expected data format is either a list of strings, or a list of list of strings
"""
def __init__(self, data, orig_vocab):
self.orig_vocab = orig_vocab
super().__init__(data=data, lang=orig_vocab.lang, idx=orig_vocab.idx, cutoff=orig_vocab.cutoff, lower=orig_vocab.lower)
def build_vocab(self):
if all(isinstance(word, str) for word in self.data):
allchars = "".join(self.data)
else:
allchars = "".join([word for sentence in self.data for word in sentence])
unk = [c for c in allchars if c not in self.orig_vocab._unit2id]
if len(unk) > 0:
unk = sorted(set(unk))
self._id2unit = self.orig_vocab._id2unit + unk
self._unit2id = dict(self.orig_vocab._unit2id)
for c in unk:
self._unit2id[c] = len(self._unit2id)
else:
self._id2unit = self.orig_vocab._id2unit
self._unit2id = self.orig_vocab._unit2id
class CompositeVocab(BaseVocab):
''' Vocabulary class that handles parsing and printing composite values such as
compositional XPOS and universal morphological features (UFeats).
Two key options are `keyed` and `sep`. `sep` specifies the separator used between
different parts of the composite values, which is `|` for UFeats, for example.
If `keyed` is `True`, then the incoming value is treated similarly to UFeats, where
each part is a key/value pair separated by an equal sign (`=`). There are no inherit
order to the keys, and we sort them alphabetically for serialization and deserialization.
Whenever a part is absent, its internal value is a special `` symbol that will
be treated accordingly when generating the output. If `keyed` is `False`, then the parts
are treated as positioned values, and `` is used to pad parts at the end when the
incoming value is not long enough.'''
def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
self.sep = sep
self.keyed = keyed
super().__init__(data, lang, idx=idx)
self.state_attrs += ['sep', 'keyed']
def unit2parts(self, unit):
# unpack parts of a unit
if not self.sep:
parts = [x for x in unit]
else:
parts = unit.split(self.sep)
if self.keyed:
if len(parts) == 1 and parts[0] == '_':
return dict()
parts = [x.split('=') for x in parts]
if any(len(x) != 2 for x in parts):
raise ValueError('Received "%s" for a dictionary which is supposed to be keyed, eg the entries should all be of the form key=value and separated by %s' % (unit, self.sep))
# Just treat multi-valued properties values as one possible value
parts = dict(parts)
elif unit == '_':
parts = []
return parts
def unit2id(self, unit):
parts = self.unit2parts(unit)
if self.keyed:
# treat multi-valued properties as singletons
return [self._unit2id[k].get(parts[k], UNK_ID) if k in parts else EMPTY_ID for k in self._unit2id]
else:
return [self._unit2id[i].get(parts[i], UNK_ID) if i < len(parts) else EMPTY_ID for i in range(len(self._unit2id))]
def id2unit(self, id):
# special case: allow single ids for vocabs with length 1
if len(self._id2unit) == 1 and not isinstance(id, Iterable):
id = (id,)
items = []
for v, k in zip(id, self._id2unit.keys()):
if v == EMPTY_ID: continue
if self.keyed:
items.append("{}={}".format(k, self._id2unit[k][v]))
else:
items.append(self._id2unit[k][v])
if self.sep is not None:
res = self.sep.join(items)
if res == "":
res = "_"
return res
else:
return items
def build_vocab(self):
allunits = [w[self.idx] for sent in self.data for w in sent]
if self.keyed:
self._id2unit = dict()
for u in allunits:
parts = self.unit2parts(u)
for key in parts:
if key not in self._id2unit:
self._id2unit[key] = copy(VOCAB_PREFIX)
# treat multi-valued properties as singletons
if parts[key] not in self._id2unit[key]:
self._id2unit[key].append(parts[key])
# special handle for the case where upos/xpos/ufeats are always empty
if len(self._id2unit) == 0:
self._id2unit['_'] = copy(VOCAB_PREFIX) # use an arbitrary key
else:
self._id2unit = dict()
allparts = [self.unit2parts(u) for u in allunits]
maxlen = max([len(p) for p in allparts])
for parts in allparts:
for i, p in enumerate(parts):
if i not in self._id2unit:
self._id2unit[i] = copy(VOCAB_PREFIX)
if i < len(parts) and p not in self._id2unit[i]:
self._id2unit[i].append(p)
# special handle for the case where upos/xpos/ufeats are always empty
if len(self._id2unit) == 0:
self._id2unit[0] = copy(VOCAB_PREFIX) # use an arbitrary key
self._id2unit = OrderedDict([(k, self._id2unit[k]) for k in sorted(self._id2unit.keys())])
self._unit2id = {k: {w:i for i, w in enumerate(self._id2unit[k])} for k in self._id2unit}
def lens(self):
return [len(self._unit2id[k]) for k in self._unit2id]
def items(self, idx):
return self._id2unit[idx]
def __str__(self):
pieces = ["[" + ",".join(x) + "]" for _, x in self._id2unit.items()]
rep = "<{}:\n {}>".format(type(self), "\n ".join(pieces))
return rep
class BaseMultiVocab:
""" A convenient vocab container that can store multiple BaseVocab instances, and support
safe serialization of all instances via state dicts. Each subclass of this base class
should implement the load_state_dict() function to specify how a saved state dict
should be loaded back."""
def __init__(self, vocab_dict=None):
self._vocabs = OrderedDict()
if vocab_dict is None:
return
# check all values provided must be a subclass of the Vocab base class
assert all([isinstance(v, BaseVocab) for v in vocab_dict.values()])
for k, v in vocab_dict.items():
self._vocabs[k] = v
def __setitem__(self, key, item):
self._vocabs[key] = item
def __getitem__(self, key):
return self._vocabs[key]
def __str__(self):
return "<{}: [{}]>".format(type(self), ", ".join(self._vocabs.keys()))
def __contains__(self, key):
return key in self._vocabs
def keys(self):
return self._vocabs.keys()
def state_dict(self):
""" Build a state dict by iteratively calling state_dict() of all vocabs. """
state = OrderedDict()
for k, v in self._vocabs.items():
state[k] = v.state_dict()
return state
@classmethod
def load_state_dict(cls, state_dict):
""" Construct a MultiVocab by reading from a state dict."""
raise NotImplementedError
class CharVocab(BaseVocab):
def build_vocab(self):
if isinstance(self.data[0][0], (list, tuple)): # general data from DataLoader
counter = Counter([c for sent in self.data for w in sent for c in w[self.idx]])
for k in list(counter.keys()):
if counter[k] < self.cutoff:
del counter[k]
else: # special data from Char LM
counter = Counter([c for sent in self.data for c in sent])
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: (counter[k], k), reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
================================================
FILE: stanza/models/constituency/__init__.py
================================================
================================================
FILE: stanza/models/constituency/base_model.py
================================================
"""
The BaseModel is passed to the transitions so that the transitions
can operate on a parsing state without knowing the exact
representation used in the model.
For example, a SimpleModel simply looks at the top of the various stacks in the state.
A model with LSTM representations for the different transitions may
attach the hidden and output states of the LSTM to the word /
constituent / transition stacks.
Reminder: the parsing state is a list of words to parse, the
transitions used to build a (possibly incomplete) parse, and the
constituent(s) built so far by those transitions. Each of these
components are represented using stacks to improve the efficiency
of operations such as "combine the most recent 4 constituents"
or "turn the next input word into a constituent"
"""
from abc import ABC, abstractmethod
from collections import defaultdict
import logging
import torch
from stanza.models.common import utils
from stanza.models.constituency import transition_sequence
from stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.state import State
from stanza.models.constituency.tree_stack import TreeStack
from stanza.server.parser_eval import ParseResult, ScoredTree
# default unary limit. some treebanks may have longer chains (CTB, for example)
UNARY_LIMIT = 4
logger = logging.getLogger('stanza.constituency.trainer')
class BaseModel(ABC):
"""
This base class defines abstract methods for manipulating a State.
Applying transitions may change important metadata about a State
such as the vectors associated with LSTM hidden states, for example.
The constructor forwards all unused arguments to other classes in the
constructor sequence, so put this before other classes such as nn.Module
"""
def __init__(self, transition_scheme, unary_limit, reverse_sentence, root_labels, *args, **kwargs):
super().__init__(*args, **kwargs) # forwards all unused arguments
self._transition_scheme = transition_scheme
self._unary_limit = unary_limit
self._reverse_sentence = reverse_sentence
self._root_labels = sorted(list(root_labels))
self._is_top_down = (self._transition_scheme is TransitionScheme.TOP_DOWN or
self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY or
self._transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND)
@abstractmethod
def initial_word_queues(self, tagged_word_lists):
"""
For each list of tagged words, builds a TreeStack of word nodes
The word lists should be backwards so that the first word is the last word put on the stack (LIFO)
"""
@abstractmethod
def initial_transitions(self):
"""
Builds an initial transition stack with whatever values need to go into first position
"""
@abstractmethod
def initial_constituents(self):
"""
Builds an initial constituent stack with whatever values need to go into first position
"""
@abstractmethod
def get_word(self, word_node):
"""
Get the word corresponding to this position in the word queue
"""
@abstractmethod
def transform_word_to_constituent(self, state):
"""
Transform the top node of word_queue to something that can push on the constituent stack
"""
@abstractmethod
def dummy_constituent(self, dummy):
"""
When using a dummy node as a sentinel, transform it to something usable by this model
"""
@abstractmethod
def build_constituents(self, labels, children_lists):
"""
Build multiple constituents at once. This gives the opportunity for batching operations
"""
@abstractmethod
def push_constituents(self, constituent_stacks, constituents):
"""
Add a multiple constituents to multiple constituent_stacks
Useful to factor this out in case batching will help
"""
@abstractmethod
def get_top_constituent(self, constituents):
"""
Get the first constituent from the constituent stack
For example, a model might want to remove embeddings and LSTM state vectors
"""
@abstractmethod
def push_transitions(self, transition_stacks, transitions):
"""
Add a multiple transitions to multiple transition_stacks
Useful to factor this out in case batching will help
"""
@abstractmethod
def get_top_transition(self, transitions):
"""
Get the first transition from the transition stack
For example, a model might want to remove transition embeddings before returning the transition
"""
@property
def root_labels(self):
"""
Return ROOT labels for this model. Probably ROOT, TOP, or both
(Danish uses 's', though)
"""
return self._root_labels
def unary_limit(self):
"""
Limit on the number of consecutive unary transitions
"""
return self._unary_limit
def transition_scheme(self):
"""
Transition scheme used - see parse_transitions
"""
return self._transition_scheme
def has_unary_transitions(self):
"""
Whether or not this model uses unary transitions, based on transition_scheme
"""
return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY
@property
def is_top_down(self):
"""
Whether or not this model is TOP_DOWN
"""
return self._is_top_down
@property
def reverse_sentence(self):
"""
Whether or not this model is built to parse backwards
"""
return self._reverse_sentence
def predict(self, states, is_legal=True):
raise NotImplementedError("LSTMModel can predict, but SimpleModel cannot")
def weighted_choice(self, states):
raise NotImplementedError("LSTMModel can weighted_choice, but SimpleModel cannot")
def predict_gold(self, states, is_legal=True):
"""
For each State, return the next item in the gold_sequence
"""
transitions = [y.gold_sequence[y.num_transitions] for y in states]
if is_legal:
for trans, state in zip(transitions, states):
if not trans.is_legal(state, self):
raise RuntimeError("Transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(state.num_transitions, trans, state.gold_tree, state.gold_sequence))
return None, transitions, None
def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
"""
what is passed in should be a list of list of preterminals
"""
word_queues = self.initial_word_queues(preterminal_lists)
# this is the bottom of the TreeStack and will be the same for each State
transitions = self.initial_transitions()
constituents = self.initial_constituents()
states = [State(sentence_length=len(wq)-2, # -2 because it starts and ends with a sentinel
num_opens=0,
word_queue=wq,
gold_tree=None,
gold_sequence=None,
transitions=transitions,
constituents=constituents,
word_position=0,
score=0.0,
broken=False)
for idx, wq in enumerate(word_queues)]
if gold_trees:
states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]
if gold_sequences:
states = [state._replace(gold_sequence=gold_sequence) for gold_sequence, state in zip(gold_sequences, states)]
return states
def initial_state_from_words(self, word_lists):
preterminal_lists = [[Tree(tag, Tree(word)) for word, tag in words]
for words in word_lists]
return self.initial_state_from_preterminals(preterminal_lists, gold_trees=None, gold_sequences=None)
def initial_state_from_gold_trees(self, trees, gold_sequences=None):
preterminal_lists = [[Tree(pt.label, Tree(pt.children[0].label))
for pt in tree.yield_preterminals()]
for tree in trees]
return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees, gold_sequences=gold_sequences)
def build_batch_from_trees(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size trees and turn them into new parsing states
"""
state_batch = []
for _ in range(batch_size):
gold_tree = next(data_iterator, None)
if gold_tree is None:
break
state_batch.append(gold_tree)
if len(state_batch) > 0:
state_batch = self.initial_state_from_gold_trees(state_batch)
return state_batch
def build_batch_from_trees_with_gold_sequence(self, batch_size, data_iterator):
"""
Same as build_batch_from_trees, but use the model parameters to turn the trees into gold sequences and include the sequence
"""
state_batch = self.build_batch_from_trees(batch_size, data_iterator)
if len(state_batch) == 0:
return state_batch
gold_sequences = transition_sequence.build_treebank([state.gold_tree for state in state_batch], self.transition_scheme(), self.reverse_sentence)
state_batch = [state._replace(gold_sequence=sequence) for state, sequence in zip(state_batch, gold_sequences)]
return state_batch
def build_batch_from_tagged_words(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
Expects a list of list of (word, tag)
"""
state_batch = []
for _ in range(batch_size):
sentence = next(data_iterator, None)
if sentence is None:
break
state_batch.append(sentence)
if len(state_batch) > 0:
state_batch = self.initial_state_from_words(state_batch)
return state_batch
def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
"""
Repeat transitions to build a list of trees from the input batches.
The data_iterator should be anything which returns the data for a parse task via next()
build_batch_fn is a function that turns that data into State objects
This will be called to generate batches of size batch_size until the data is exhausted
The return is a list of tuples: (gold_tree, [(predicted, score) ...])
gold_tree will be left blank if the data did not include gold trees
if keep_scores is true, the score will be the sum of the values
returned by the model for each transition
transition_choice: which method of the model to use for choosing the next transition
predict for predicting the transition based on the model
predict_gold to just extract the gold transition from the sequence
"""
treebank = []
treebank_indices = []
state_batch = build_batch_fn(batch_size, data_iterator)
# used to track which indices we are currently parsing
# since the parses get finished at different times, this will let us unsort after
batch_indices = list(range(len(state_batch)))
horizon_iterator = iter([])
if keep_constituents:
constituents = defaultdict(list)
while len(state_batch) > 0:
pred_scores, transitions, scores = transition_choice(state_batch)
if keep_scores and scores is not None:
state_batch = [state._replace(score=state.score + score) for state, score in zip(state_batch, scores)]
state_batch = self.bulk_apply(state_batch, transitions)
if keep_constituents:
for t_idx, transition in enumerate(transitions):
if isinstance(transition, CloseConstituent):
# constituents is a TreeStack with information on how to build the next state of the LSTM or attn
# constituents.value is the TreeStack node
# constituents.value.value is the Constituent itself (with the tree and the embedding)
constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)
remove = set()
for idx, state in enumerate(state_batch):
if state.broken:
# TODO: make a fake tree with the appropriate words at least?
# something like the X-tree CoreNLP does
#gold_tree = state.gold_tree
#treebank.append(ParseResult(gold_tree, [], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))
#treebank_indices.append(batch_indices[idx])
remove.add(idx)
elif state.finished(self):
predicted_tree = state.get_tree(self)
if self.reverse_sentence:
predicted_tree = predicted_tree.reverse()
gold_tree = state.gold_tree
treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, state.score)], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))
treebank_indices.append(batch_indices[idx])
remove.add(idx)
if len(remove) > 0:
state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
for _ in range(batch_size - len(state_batch)):
horizon_state = next(horizon_iterator, None)
if not horizon_state:
horizon_batch = build_batch_fn(batch_size, data_iterator)
if len(horizon_batch) == 0:
break
horizon_iterator = iter(horizon_batch)
horizon_state = next(horizon_iterator, None)
state_batch.append(horizon_state)
batch_indices.append(len(treebank) + len(state_batch))
treebank = utils.unsort(treebank, treebank_indices)
return treebank
def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
"""
Given an iterator over the data and a method for building batches, returns a list of parse trees.
no_grad() is so that gradients aren't kept, which makes the model
run faster and use less memory at inference time
"""
with torch.no_grad():
return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True):
"""
Return a ParseResult for each tree in the trees list
The transitions run will be the transitions represented by the tree
The output layers will be available in result.state for each result
keep_state=True as a default here as a method which keeps the grad
is likely to want to keep the resulting state as well
"""
if batch_size is None:
# TODO: refactor?
batch_size = self.args['eval_batch_size']
tree_iterator = iter(trees)
treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores)
return treebank
def parse_tagged_words(self, words, batch_size):
"""
This parses tagged words and returns a list of trees.
`parse_tagged_words` is useful at Pipeline time -
it takes words & tags and processes that into trees.
The tagged words should be represented:
one list per sentence
each sentence is a list of (word, tag)
The return value is a list of ParseTree objects
"""
logger.debug("Processing %d sentences", len(words))
self.eval()
sentence_iterator = iter(words)
treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
results = [t.predictions[0].tree for t in treebank]
return results
def bulk_apply(self, state_batch, transitions, fail=False):
"""
Apply the given list of Transitions to the given list of States, using the model as a reference
model: SimpleModel, LSTMModel, or any other form of model
state_batch: list of States
transitions: list of transitions, one per state
fail: throw an exception on a failed transition, as opposed to skipping the tree
"""
word_positions = []
constituents = []
new_constituents = []
valid_state_indices = []
callbacks = defaultdict(list)
state_batch = list(state_batch)
for idx, (state, transition) in enumerate(zip(state_batch, transitions)):
if not transition:
error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(state.gold_tree, state.to_string(self))
if fail:
raise ValueError(error)
else:
logger.error(error)
state_batch[idx] = state._replace(broken=True)
continue
if state.num_transitions >= len(state.word_queue) * 20:
# too many transitions
# x20 is somewhat empirically chosen based on certain
# treebanks having deep unary structures, especially early
# on when the model is fumbling around
if state.gold_tree:
error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(state.gold_tree, state.to_string(self))
else:
error = "Went infinite!:\nFinal state:\n{}".format(state.to_string(self))
if fail:
raise ValueError(error)
else:
logger.error(error)
state_batch[idx] = state._replace(broken=True)
continue
wq, c, nc, callback = transition.update_state(state, self)
word_positions.append(wq)
constituents.append(c)
new_constituents.append(nc)
valid_state_indices.append(idx)
if callback:
# not `idx` in case something was broken
callbacks[callback].append(len(new_constituents)-1)
for key, idxs in callbacks.items():
data = [new_constituents[x] for x in idxs]
callback_constituents = key.build_constituents(self, data)
for idx, constituent in zip(idxs, callback_constituents):
new_constituents[idx] = constituent
transitions = [trans for state, trans in zip(state_batch, transitions) if not state.broken]
if len(transitions) > 0:
new_transitions = self.push_transitions([state.transitions for state in state_batch if not state.broken], transitions)
new_constituents = self.push_constituents(constituents, new_constituents)
else:
new_transitions = []
new_constituents = []
for state, transition, word_position, transition_stack, constituents, state_idx in zip(state_batch, transitions, word_positions, new_transitions, new_constituents, valid_state_indices):
state_batch[state_idx] = state._replace(num_opens=state.num_opens + transition.delta_opens(),
word_position=word_position,
transitions=transition_stack,
constituents=constituents)
return state_batch
class SimpleModel(BaseModel):
"""
This model allows pushing and popping with no extra data
This class is primarily used for testing various operations which
don't need the NN's weights
Also, for rebuilding trees from transitions when verifying the
transitions in situations where the NN state is not relevant,
as this class will be faster than using the NN
"""
def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, unary_limit=UNARY_LIMIT, reverse_sentence=False, root_labels=("ROOT",)):
super().__init__(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse_sentence, root_labels=root_labels)
def initial_word_queues(self, tagged_word_lists):
word_queues = []
for tagged_words in tagged_word_lists:
word_queue = [None]
word_queue += [tag_node for tag_node in tagged_words]
word_queue.append(None)
if self.reverse_sentence:
word_queue.reverse()
word_queues.append(word_queue)
return word_queues
def initial_transitions(self):
return TreeStack(value=None, parent=None, length=1)
def initial_constituents(self):
return TreeStack(value=None, parent=None, length=1)
def get_word(self, word_node):
return word_node
def transform_word_to_constituent(self, state):
return state.get_word(state.word_position)
def dummy_constituent(self, dummy):
return dummy
def build_constituents(self, labels, children_lists):
constituents = []
for label, children in zip(labels, children_lists):
if isinstance(label, str):
label = (label,)
for value in reversed(label):
children = Tree(label=value, children=children)
constituents.append(children)
return constituents
def push_constituents(self, constituent_stacks, constituents):
return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)]
def get_top_constituent(self, constituents):
return constituents.value
def push_transitions(self, transition_stacks, transitions):
return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)]
def get_top_transition(self, transitions):
return transitions.value
================================================
FILE: stanza/models/constituency/base_trainer.py
================================================
from enum import Enum
import logging
import os
import torch
from pickle import UnpicklingError
import warnings
logger = logging.getLogger('stanza')
class ModelType(Enum):
LSTM = 1
ENSEMBLE = 2
class BaseTrainer:
def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
# keeping track of the epochs trained will be useful
# for adjusting the learning scheme
self.epochs_trained = epochs_trained
self.batches_trained = batches_trained
self.best_f1 = best_f1
self.best_epoch = best_epoch
self.first_optimizer = first_optimizer
def save(self, filename, save_optimizer=True):
params = self.model.get_params()
checkpoint = {
'params': params,
'epochs_trained': self.epochs_trained,
'batches_trained': self.batches_trained,
'best_f1': self.best_f1,
'best_epoch': self.best_epoch,
'model_type': self.model_type.name,
'first_optimizer': self.first_optimizer,
}
checkpoint["bert_lora"] = self.get_peft_params()
if save_optimizer and self.optimizer is not None:
checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
torch.save(checkpoint, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to %s", filename)
def log_norms(self):
self.model.log_norms()
def log_shapes(self):
self.model.log_shapes()
@property
def transitions(self):
return self.model.transitions
@property
def root_labels(self):
return self.model.root_labels
@property
def device(self):
return next(self.model.parameters()).device
def train(self):
return self.model.train()
def eval(self):
return self.model.eval()
# TODO: make ABC with methods such as model_from_params?
# TODO: if we save the type in the checkpoint, use that here to figure out which to load
@staticmethod
def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_name=None):
"""
Load back a model and possibly its optimizer.
"""
# hide the import here to avoid circular imports
from stanza.models.constituency.ensemble import EnsembleTrainer
from stanza.models.constituency.trainer import Trainer
if not os.path.exists(filename):
if args.get('save_dir', None) is None:
raise FileNotFoundError("Cannot find model in {} and args['save_dir'] is None".format(filename))
elif os.path.exists(os.path.join(args['save_dir'], filename)):
filename = os.path.join(args['save_dir'], filename)
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename)))
try:
# TODO: currently cannot switch this to weights_only=True
# without in some way changing the model to save enums in
# a safe manner, probably by converting to int
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except UnpicklingError as e:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
warnings.warn("The saved constituency parser has an old format using Enum, set, unsanitized Transitions, etc. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the constituency parser using this version ASAP.")
except BaseException:
logger.exception("Cannot load model from %s", filename)
raise
logger.debug("Loaded model from %s", filename)
params = checkpoint['params']
if 'model_type' not in checkpoint:
# old models will have this trait
# TODO: can remove this after 1.10
checkpoint['model_type'] = ModelType.LSTM
if isinstance(checkpoint['model_type'], str):
checkpoint['model_type'] = ModelType[checkpoint['model_type']]
if checkpoint['model_type'] == ModelType.LSTM:
clazz = Trainer
elif checkpoint['model_type'] == ModelType.ENSEMBLE:
clazz = EnsembleTrainer
else:
raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
model = clazz.model_from_params(params, checkpoint.get('bert_lora', None), args, foundation_cache, peft_name)
epochs_trained = checkpoint['epochs_trained']
batches_trained = checkpoint.get('batches_trained', 0)
best_f1 = checkpoint['best_f1']
best_epoch = checkpoint['best_epoch']
if 'first_optimizer' not in checkpoint:
# this will only apply to old (LSTM) Trainers
# EnsembleTrainers will always have this value saved
# so here we can compensate by looking at the old training statistics...
# we use params['config'] here instead of model.args
# because the args might have a different training
# mechanism, but in order to reload the optimizer, we need
# to match the optimizer we build with the one that was
# used at training time
build_simple_adadelta = params['config']['multistage'] and epochs_trained < params['config']['epochs'] // 2
checkpoint['first_optimizer'] = build_simple_adadelta
first_optimizer = checkpoint['first_optimizer']
if load_optimizer:
optimizer = clazz.load_optimizer(model, checkpoint, first_optimizer, filename)
scheduler = clazz.load_scheduler(model, optimizer, checkpoint, first_optimizer)
else:
optimizer = None
scheduler = None
if checkpoint['model_type'] == ModelType.LSTM:
logger.debug("-- MODEL CONFIG --")
for k in model.args.keys():
logger.debug(" --%s: %s", k, model.args[k])
return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
elif checkpoint['model_type'] == ModelType.ENSEMBLE:
return EnsembleTrainer(ensemble=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
else:
raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
================================================
FILE: stanza/models/constituency/dynamic_oracle.py
================================================
from collections import namedtuple
import numpy as np
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
RepairEnum = namedtuple("RepairEnum", "name value is_correct")
def score_candidates_single_block(model, state, candidates, candidate_idx):
"""
score candidate fixed sequences by summing up the transition scores of the most important block
the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks
"""
scores = []
# could bulkify this if we wanted
for candidate in candidates:
current_state = [state]
for block in candidate[1:candidate_idx]:
for transition in block:
current_state = model.bulk_apply(current_state, [transition])
score = 0.0
for transition in candidate[candidate_idx]:
predictions = model.forward(current_state)
t_idx = model.transition_map[transition]
score += predictions[0, t_idx].cpu().item()
current_state = model.bulk_apply(current_state, [transition])
scores.append(score)
best_idx = np.argmax(scores)
best_candidate = [x for block in candidates[best_idx] for x in block]
return scores, best_idx, best_candidate
def score_candidates(model, state, candidates):
"""
score candidate fixed sequences by summing up the transition scores of the most important block
the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks
actually, using either this or just scoring a single block doesn't really help
eg, score_candidates_single_block(candidate_idx=2)
it still winds up being slightly better for accuracy to simply
revert to teacher forcing for ambiguous transition errors
"""
scores = []
# could bulkify this if we wanted
for candidate in candidates:
current_state = [state]
score = 0.0
for block in candidate[1:]:
for transition in block:
predictions = model.forward(current_state)
t_idx = model.transition_map[transition]
score += predictions[0, t_idx].cpu().item()
current_state = model.bulk_apply(current_state, [transition])
scores.append(score)
best_idx = np.argmax(scores)
best_candidate = [x for block in candidates[best_idx] for x in block]
return scores, best_idx, best_candidate
def advance_past_constituents(gold_sequence, cur_index):
"""
Advance cur_index through gold_sequence until we have seen 1 more Close than Open
The index returned is the index of the Close which occurred after all the stuff
"""
count = 0
while cur_index < len(gold_sequence):
if isinstance(gold_sequence[cur_index], OpenConstituent):
count = count + 1
elif isinstance(gold_sequence[cur_index], CloseConstituent):
count = count - 1
if count == -1: return cur_index
cur_index = cur_index + 1
return None
def find_previous_open(gold_sequence, cur_index):
"""
Go backwards from cur_index to find the open which opens the previous block of stuff.
Return None if it can't be found.
"""
count = 0
cur_index = cur_index - 1
while cur_index >= 0:
if isinstance(gold_sequence[cur_index], OpenConstituent):
count = count + 1
if count > 0:
return cur_index
elif isinstance(gold_sequence[cur_index], CloseConstituent):
count = count - 1
cur_index = cur_index - 1
return None
def find_in_order_constituent_end(gold_sequence, cur_index):
"""
Advance cur_index through gold_sequence until the next block has ended
This is different from advance_past_constituents in that it will
also return when there is a Shift when count == 0. That way, we
return the first block of things we know attach to the left
"""
count = 0
saw_shift = False
while cur_index < len(gold_sequence):
if isinstance(gold_sequence[cur_index], OpenConstituent):
count = count + 1
elif isinstance(gold_sequence[cur_index], CloseConstituent):
count = count - 1
if count == -1: return cur_index
elif isinstance(gold_sequence[cur_index], Shift):
if saw_shift and count == 0:
return cur_index
else:
saw_shift = True
cur_index = cur_index + 1
return None
class DynamicOracle():
def __init__(self, root_labels, oracle_level, repair_types, additional_levels, deactivated_levels):
self.root_labels = root_labels
# default oracle_level will be the UNKNOWN repair type (which each oracle should have)
# transitions after that as experimental or ambiguous, not to be used by default
self.oracle_level = oracle_level if oracle_level is not None else repair_types.UNKNOWN.value
self.repair_types = repair_types
self.additional_levels = set()
if additional_levels:
self.additional_levels = set([repair_types[x.upper()] for x in additional_levels.split(",")])
self.deactivated_levels = set()
if deactivated_levels:
self.deactivated_levels = set([repair_types[x.upper()] for x in deactivated_levels.split(",")])
def fix_error(self, pred_transition, model, state):
"""
Return which error has been made, if any, along with an updated transition list
We assume the transition sequence builds a correct tree, meaning
that there will always be a CloseConstituent sometime after an
OpenConstituent, for example
"""
gold_transition = state.gold_sequence[state.num_transitions]
if gold_transition == pred_transition:
return self.repair_types.CORRECT, None
for repair_type in self.repair_types:
if repair_type.fn is None:
continue
if self.oracle_level is not None and repair_type.value > self.oracle_level and repair_type not in self.additional_levels and not repair_type.debug:
continue
if repair_type in self.deactivated_levels:
continue
repair = repair_type.fn(gold_transition, pred_transition, state.gold_sequence, state.num_transitions, self.root_labels, model, state)
if repair is None:
continue
if isinstance(repair, tuple) and len(repair) == 2:
return repair
# TODO: could update all of the returns to be tuples of length 2
if repair is not None:
return repair_type, repair
return self.repair_types.UNKNOWN, None
================================================
FILE: stanza/models/constituency/ensemble.py
================================================
"""
Prototype of ensembling N models together on the same dataset
The main inference method is to run the normal transition sequence,
but sum the scores for the N models and use that to choose the highest
scoring transition
Example of how to run it to build a silver dataset
(or just parse a text file in general):
# first, use this tool to build a saved ensemble
python3 stanza/models/constituency/ensemble.py
saved_models/constituency/wsj_inorder_?.pt
--save_name saved_models/constituency/en_ensemble.pt
# then use the ensemble directly as a model in constituency_parser.py
python3 stanza/models/constituency_parser.py
--save_name saved_models/constituency/en_ensemble.pt
--mode parse_text
--tokenized_file /nlp/scr/horatio/en_silver/en_split_100
--predict_file /nlp/scr/horatio/en_silver/en_split_100.inorder.mrg
--retag_package en_combined_bert
--lang en
then, ideally, run a second time with a set of topdown models,
then take the trees which match from the files
"""
import argparse
import copy
import logging
import os
import torch
import torch.nn as nn
from stanza.models.common import utils
from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
from stanza.models.constituency.state import MultiState
from stanza.models.constituency.trainer import Trainer
from stanza.models.constituency.utils import build_optimizer, build_scheduler
from stanza.server.parser_eval import ParseResult, ScoredTree
logger = logging.getLogger('stanza.constituency.trainer')
class Ensemble(nn.Module):
def __init__(self, args, filenames=None, models=None, foundation_cache=None):
"""
Loads each model in filenames
If foundation_cache is None, we build one on our own,
as the expectation is the models will reuse modules
such as pretrain, charlm, bert
"""
super().__init__()
self.args = args
if filenames:
if models:
raise ValueError("both filenames and models set when making the Ensemble")
if foundation_cache is None:
foundation_cache = FoundationCache()
if isinstance(filenames, str):
filenames = [filenames]
logger.info("Models used for ensemble:\n %s", "\n ".join(filenames))
models = [Trainer.load(filename, args, load_optimizer=False, foundation_cache=foundation_cache).model for filename in filenames]
elif not models:
raise ValueError("filenames and models both not set!")
self.models = nn.ModuleList(models)
for model_idx, model in enumerate(self.models):
if self.models[0].transition_scheme() != model.transition_scheme():
raise ValueError("Models {} and {} are incompatible. {} vs {}".format(filenames[0], filenames[model_idx], self.models[0].transition_scheme(), model.transition_scheme()))
if self.models[0].transitions != model.transitions:
raise ValueError(f"Models {filenames[0]} and {filenames[model_idx]} are incompatible: different transitions\n{filenames[0]}:\n{self.models[0].transitions}\n{filenames[model_idx]}:\n{model.transitions}")
if self.models[0].constituents != model.constituents:
raise ValueError("Models %s and %s are incompatible: different constituents" % (filenames[0], filenames[model_idx]))
if self.models[0].root_labels != model.root_labels:
raise ValueError("Models %s and %s are incompatible: different root_labels" % (filenames[0], filenames[model_idx]))
if self.models[0].uses_xpos() != model.uses_xpos():
raise ValueError("Models %s and %s are incompatible: different uses_xpos" % (filenames[0], filenames[model_idx]))
if self.models[0].reverse_sentence != model.reverse_sentence:
raise ValueError("Models %s and %s are incompatible: different reverse_sentence" % (filenames[0], filenames[model_idx]))
self._reverse_sentence = self.models[0].reverse_sentence
# submodels are not trained (so far)
self.detach_submodels()
logger.debug("Number of models in the Ensemble: %d", len(self.models))
self.register_parameter('weighted_sum', torch.nn.Parameter(torch.zeros(len(self.models), len(self.transitions), requires_grad=True)))
def detach_submodels(self):
# submodels are not trained (so far)
for model in self.models:
for _, parameter in model.named_parameters():
parameter.requires_grad = False
def train(self, mode=True):
super().train(mode)
if mode:
# peft has a weird interaction where it turns requires_grad back on
# even if it was previously off
self.detach_submodels()
@property
def transitions(self):
return self.models[0].transitions
@property
def root_labels(self):
return self.models[0].root_labels
@property
def device(self):
return next(self.parameters()).device
def unary_limit(self):
"""
Limit on the number of consecutive unary transitions
"""
return min(m.unary_limit() for m in self.models)
def transition_scheme(self):
return self.models[0].transition_scheme()
def has_unary_transitions(self):
return self.models[0].has_unary_transitions()
@property
def is_top_down(self):
return self.models[0].is_top_down
@property
def reverse_sentence(self):
return self._reverse_sentence
@property
def retag_method(self):
# TODO: make the method an enum
return self.models[0].args['retag_method']
def uses_xpos(self):
return self.models[0].uses_xpos()
def get_top_constituent(self, constituents):
return self.models[0].get_top_constituent(constituents)
def get_top_transition(self, transitions):
return self.models[0].get_top_transition(transitions)
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMETERS"]
for name, param in self.named_parameters():
if param.requires_grad and not name.startswith("models."):
zeros = torch.sum(param.abs() < 0.000001).item()
norm = "%.6g" % torch.norm(param).item()
lines.append("%s %s %d %d" % (name, norm, zeros, param.nelement()))
for model_idx, model in enumerate(self.models):
sublines = model.get_norms()
if len(sublines) > 0:
lines.append(" ---- MODEL %d ----" % model_idx)
lines.extend(sublines)
logger.info("\n".join(lines))
def log_shapes(self):
lines = ["NORMS FOR MODEL PARAMETERS"]
for name, param in self.named_parameters():
if param.requires_grad:
lines.append("{} {}".format(name, param.shape))
logger.info("\n".join(lines))
def get_params(self):
model_state = self.state_dict()
# don't save the children in the base params
model_state = {k: v for k, v in model_state.items() if not k.startswith("models.")}
return {
"base_params": model_state,
"children_params": [x.get_params() for x in self.models]
}
def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
state_batch = [model.initial_state_from_preterminals(preterminal_lists, gold_trees, gold_sequences) for model in self.models]
state_batch = list(zip(*state_batch))
state_batch = [MultiState(states, gold_tree, gold_sequence, 0.0)
for states, gold_tree, gold_sequence in zip(state_batch, gold_trees, gold_sequences)]
return state_batch
def build_batch_from_tagged_words(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
Expects a list of list of (word, tag)
"""
state_batch = []
for _ in range(batch_size):
sentence = next(data_iterator, None)
if sentence is None:
break
state_batch.append(sentence)
if len(state_batch) > 0:
state_batch = [model.initial_state_from_words(state_batch) for model in self.models]
state_batch = list(zip(*state_batch))
state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
return state_batch
def build_batch_from_trees(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size trees and turn them into N lists of parsing states
"""
state_batch = []
for _ in range(batch_size):
gold_tree = next(data_iterator, None)
if gold_tree is None:
break
state_batch.append(gold_tree)
if len(state_batch) > 0:
state_batch = [model.initial_state_from_gold_trees(state_batch) for model in self.models]
state_batch = list(zip(*state_batch))
state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
return state_batch
def predict(self, states, is_legal=True):
states = list(zip(*[x.states for x in states]))
predictions = [model.forward(state_batch) for model, state_batch in zip(self.models, states)]
# batch X num transitions X num models
predictions = torch.stack(predictions, dim=2)
flat_predictions = torch.einsum("BTM,MT->BT", predictions, self.weighted_sum)
predictions = torch.sum(predictions, dim=2) + flat_predictions
model = self.models[0]
# TODO: possibly refactor with lstm_model.predict
pred_max = torch.argmax(predictions, dim=1)
scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
pred_max = pred_max.detach().cpu()
pred_trans = [model.transitions[pred_max[idx]] for idx in range(len(states[0]))]
if is_legal:
for idx, (state, trans) in enumerate(zip(states[0], pred_trans)):
if not trans.is_legal(state, model):
_, indices = predictions[idx, :].sort(descending=True)
for index in indices:
if model.transitions[index].is_legal(state, model):
pred_trans[idx] = model.transitions[index]
scores[idx] = predictions[idx, index]
break
else: # yeah, else on a for loop, deal with it
pred_trans[idx] = None
scores[idx] = None
return predictions, pred_trans, scores.squeeze(1)
def bulk_apply(self, state_batch, transitions, fail=False):
new_states = []
states = list(zip(*[x.states for x in state_batch]))
states = [x.bulk_apply(y, transitions, fail=fail) for x, y in zip(self.models, states)]
states = list(zip(*states))
state_batch = [x._replace(states=y) for x, y in zip(state_batch, states)]
return state_batch
def parse_tagged_words(self, words, batch_size):
"""
This parses tagged words and returns a list of trees.
`parse_tagged_words` is useful at Pipeline time -
it takes words & tags and processes that into trees.
The tagged words should be represented:
one list per sentence
each sentence is a list of (word, tag)
The return value is a list of ParseTree objects
TODO: this really ought to be refactored with base_model
"""
logger.debug("Processing %d sentences", len(words))
self.eval()
sentence_iterator = iter(words)
treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
results = [t.predictions[0].tree for t in treebank]
return results
def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
"""
Repeat transitions to build a list of trees from the input batches.
The data_iterator should be anything which returns the data for a parse task via next()
build_batch_fn is a function that turns that data into State objects
This will be called to generate batches of size batch_size until the data is exhausted
The return is a list of tuples: (gold_tree, [(predicted, score) ...])
gold_tree will be left blank if the data did not include gold trees
currently score is always 1.0, but the interface may be expanded
to get a score from the result of the parsing
transition_choice: which method of the model to use for
choosing the next transition
TODO: refactor with base_model
"""
treebank = []
treebank_indices = []
# this will produce tuples of states
# batch size lists of num models tuples
state_batch = build_batch_fn(batch_size, data_iterator)
batch_indices = list(range(len(state_batch)))
horizon_iterator = iter([])
if keep_constituents:
constituents = defaultdict(list)
while len(state_batch) > 0:
pred_scores, transitions, scores = transition_choice(state_batch)
# num models lists of batch size states
state_batch = self.bulk_apply(state_batch, transitions)
remove = set()
for idx, states in enumerate(state_batch):
if states.finished(self):
predicted_tree = states.get_tree(self)
if self.reverse_sentence:
predicted_tree = predicted_tree.reverse()
gold_tree = states.gold_tree
# TODO: could easily store the score here
# not sure what it means to store the state,
# since each model is tracking its own state
treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, None)], None, None))
treebank_indices.append(batch_indices[idx])
remove.add(idx)
if len(remove) > 0:
state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
for _ in range(batch_size - len(state_batch)):
horizon_state = next(horizon_iterator, None)
if not horizon_state:
horizon_batch = build_batch_fn(batch_size, data_iterator)
if len(horizon_batch) == 0:
break
horizon_iterator = iter(horizon_batch)
horizon_state = next(horizon_iterator, None)
state_batch.append(horizon_state)
batch_indices.append(len(treebank) + len(state_batch))
treebank = utils.unsort(treebank, treebank_indices)
return treebank
def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
with torch.no_grad():
return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
class EnsembleTrainer(BaseTrainer):
"""
Stores a list of constituency models, useful for combining their results into one stronger model
"""
def __init__(self, ensemble, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
super().__init__(ensemble, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
@staticmethod
def from_files(args, filenames, foundation_cache=None):
ensemble = Ensemble(args, filenames, foundation_cache=foundation_cache)
ensemble = ensemble.to(args.get('device', None))
return EnsembleTrainer(ensemble)
def get_peft_params(self):
params = []
for model in self.model.models:
if model.args.get('use_peft', False):
from peft import get_peft_model_state_dict
params.append(get_peft_model_state_dict(model.bert_model, adapter_name=model.peft_name))
else:
params.append(None)
return params
@property
def model_type(self):
return ModelType.ENSEMBLE
def log_num_words_known(self, words):
nwk = [m.num_words_known(words) for m in self.model.models]
if all(x == nwk[0] for x in nwk):
logger.info("Number of words in the training set known to each sub-model: %d out of %d", nwk[0], len(words))
else:
logger.info("Number of words in the training set known to the sub-models:\n %s" % "\n ".join(["%d/%d" % (x, len(words)) for x in nwk]))
@staticmethod
def build_optimizer(args, model, first_optimizer):
def fake_named_parameters():
for n, p in model.named_parameters():
if not n.startswith("models."):
yield n, p
# TODO: there has to be a cleaner way to do this, like maybe a "keep" callback
# TODO: if we finetune the underlying models, we will want a series of optimizers
# so that they can have a different learning rate from the ensemble's fields
fake_model = copy.copy(model)
fake_model.named_parameters = fake_named_parameters
optimizer = build_optimizer(args, fake_model, first_optimizer)
return optimizer
@staticmethod
def load_optimizer(model, checkpoint, first_optimizer, filename):
optimizer = EnsembleTrainer.build_optimizer(model.models[0].args, model, first_optimizer)
if checkpoint.get('optimizer_state_dict', None) is not None:
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except ValueError as e:
raise ValueError("Failed to load optimizer from %s" % filename) from e
else:
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
return optimizer
@staticmethod
def load_scheduler(model, optimizer, checkpoint, first_optimizer):
scheduler = build_scheduler(model.models[0].args, optimizer, first_optimizer=first_optimizer)
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return scheduler
@staticmethod
def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
# TODO: no need for the if/else once the models are rebuilt
children_params = params["children_params"] if isinstance(params, dict) else params
base_params = params["base_params"] if isinstance(params, dict) else {}
# TODO: fill in peft_name
if peft_params is None:
peft_params = [None] * len(children_params)
if peft_name is None:
peft_name = [None] * len(children_params)
if len(children_params) != len(peft_params):
raise ValueError("Model file had params length %d and peft params length %d" % (len(params), len(peft_params)))
if len(children_params) != len(peft_name):
raise ValueError("Model file had params length %d and peft name length %d" % (len(params), len(peft_name)))
models = [Trainer.model_from_params(model_param, peft_param, args, foundation_cache, peft_name=pname)
for model_param, peft_param, pname in zip(children_params, peft_params, peft_name)]
ensemble = Ensemble(args, models=models)
ensemble.load_state_dict(base_params, strict=False)
ensemble = ensemble.to(args.get('device', None))
return ensemble
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
utils.add_device_args(parser)
parser.add_argument('--lang', default='en', help='Language to use')
parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
parser.add_argument('--save_name', type=str, default=None, required=True, help='Where to save the combined ensemble')
args = vars(parser.parse_args())
return args
def main(args=None):
args = parse_args(args)
foundation_cache = FoundationCache()
ensemble = EnsembleTrainer.from_files(args, args['models'], foundation_cache)
ensemble.save(args['save_name'], save_optimizer=False)
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/constituency/error_analysis_in_order.py
================================================
"""
A tool with an initial set of error analysis for in-order parsing.
Analyzes the first error created in the parser
TODO: there are more errors to analyze, and see below for a case where attachment is misidentified as bracket
"""
from enum import Enum
from stanza.models.constituency.dynamic_oracle import advance_past_constituents
from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize
from stanza.models.constituency.transition_sequence import build_sequence
class FirstError(Enum):
NONE = 1
UNKNOWN = 2
WRONG_OPEN_LABEL_NO_CASCADE = 3
WRONG_OPEN_LABEL_CASCADE = 4
WRONG_SUBTREE_NO_CASCADE = 5
WRONG_SUBTREE_CASCADE = 6
EXTRA_ATTACHMENT = 7
MISSING_ATTACHMENT = 8
EXTRA_BRACKET_NO_CASCADE = 9
EXTRA_BRACKET_CASCADE = 10
MISSING_BRACKET_NO_CASCADE = 11
MISSING_BRACKET_CASCADE = 12
def advance_past_unaries(sequence, idx):
while idx + 2 < len(sequence) and isinstance(sequence[idx+1], OpenConstituent) and isinstance(sequence[idx+2], CloseConstituent):
idx += 2
return idx
def check_attachment_error(gold_sequence, pred_sequence, idx, error_type):
# this will find the Close that closes the constituent that
# was just closed in the gold sequence
# hopefully we will have built the same constituent(s)
# that were built after the gold sequence closed
pred_close_idx = advance_past_constituents(pred_sequence, idx)
gold_close_idx = pred_close_idx + 1
#gold_close_idx = find_in_order_constituent_end(gold_sequence, idx+1) # +1 represents, start counting from the Shift
#pred_close_idx = find_in_order_constituent_end(pred_sequence, idx)
if gold_sequence[idx+1:gold_close_idx] != pred_sequence[idx:pred_close_idx]:
return FirstError.UNKNOWN
if (isinstance(gold_sequence[gold_close_idx], CloseConstituent) and
isinstance(pred_sequence[pred_close_idx], CloseConstituent) and
isinstance(pred_sequence[pred_close_idx+1], CloseConstituent)):
#print(gold_sequence)
#print(gold_close_idx)
#print(pred_sequence)
#print(pred_close_idx)
#print("{:P}".format(gold_tree))
#print("{:P}".format(pred_tree))
#print("=================")
return error_type
return None
def analyze_tree(gold_tree, pred_tree):
if gold_tree == pred_tree:
return FirstError.NONE
gold_sequence = build_sequence(gold_tree, TransitionScheme.IN_ORDER)
pred_sequence = build_sequence(pred_tree, TransitionScheme.IN_ORDER)
for idx, (gold_trans, pred_trans) in enumerate(zip(gold_sequence, pred_sequence)):
if gold_trans != pred_trans:
break
else:
# guess only the tags were different?
return FirstError.NONE
if isinstance(gold_trans, CloseConstituent) and isinstance(pred_trans, Shift) and isinstance(gold_sequence[idx + 1], Shift):
# perhaps this is an attachment error
# we can see if the exact same sequence of moved constituent was built
error = check_attachment_error(gold_sequence, pred_sequence, idx, FirstError.EXTRA_ATTACHMENT)
if error is not None:
return error
if isinstance(pred_trans, CloseConstituent) and isinstance(gold_trans, Shift) and isinstance(pred_sequence[idx + 1], Shift):
# perhaps this is an attachment error
# we can see if the exact same sequence of moved constituent was built
error = check_attachment_error(pred_sequence, gold_sequence, idx, FirstError.MISSING_ATTACHMENT)
if error is not None:
return error
if isinstance(gold_trans, OpenConstituent) and isinstance(pred_trans, OpenConstituent):
gold_close_idx = advance_past_constituents(gold_sequence, idx+1)
gold_unary_idx = advance_past_unaries(gold_sequence, gold_close_idx)
pred_close_idx = advance_past_constituents(pred_sequence, idx+1)
pred_unary_idx = advance_past_unaries(pred_sequence, pred_close_idx)
if gold_sequence[idx+1:gold_close_idx] != pred_sequence[idx+1:pred_close_idx]:
# maybe the internal structure is the same?
# actually, if the number of shifts inside is the same,
# then the words shifted were the same,
# so the internal structure is different but the parser
# is getting back on track after closing
if (sum(isinstance(gt, Shift) for gt in gold_sequence[idx+1:gold_close_idx]) ==
sum(isinstance(pt, Shift) for pt in pred_sequence[idx+1:pred_close_idx])):
if gold_sequence[gold_unary_idx:] == pred_sequence[pred_unary_idx:]:
return FirstError.WRONG_SUBTREE_NO_CASCADE
else:
return FirstError.WRONG_SUBTREE_CASCADE
return FirstError.UNKNOWN
# at this point, everything is the same aside from the open being a different label
if gold_sequence[gold_unary_idx:] == pred_sequence[pred_unary_idx:]:
return FirstError.WRONG_OPEN_LABEL_NO_CASCADE
else:
return FirstError.WRONG_OPEN_LABEL_CASCADE
if isinstance(gold_trans, Shift) and isinstance(pred_trans, OpenConstituent):
# This could be a case of an extra bracket inserted into the tree
# We will search for the end of the new bracket, then check if
# all the children were properly constructed the way the gold sequence wanted to,
# aside from the extra bracket
# TODO: this is also capturing what are effectively attachment
# errors in the case of nested nodes (S over S) where a node
# at the start should have been connected to the below node
# gold:
# (ROOT
# (S
# (S
# (`` ``)
# (NP (PRP$ Our) (NN balance) (NNS sheets))
# (VP
# (VBP look)
# (SBAR
# (IN like)
# (S
# (NP (PRP they))
# (VP
# (VBD came)
# (PP
# (IN from)
# (NP
# (NP (NNP Alice) (POS 's))
# (NN wonderland)))))))
# (, ,)
# ('' ''))
# (NP (NNP Mr.) (NNP Fromstein))
# (VP (VBD said))
# (. .)))
#
# pred:
# (ROOT
# (S
# (`` ``)
# (S
# (NP (PRP$ Our) (NN balance) (NNS sheets))
# (VP
# (VBP look)
# (SBAR
# (IN like)
# (S
# (NP (PRP they))
# (VP
# (VBD came)
# (PP
# (IN from)
# (NP
# (NP (NNP Alice) (POS 's))
# (NN wonderland))))))))
# (, ,)
# ('' '')
# (NP (NNP Mr.) (NNP Fromstein))
# (VP (VBD said))
# (. .)))
pred_close_idx = advance_past_constituents(pred_sequence, idx+1)
pred_unary_idx = advance_past_unaries(pred_sequence, pred_close_idx + 1)
if gold_sequence[idx:pred_close_idx-1] == pred_sequence[idx+1:pred_close_idx]:
#print(gold_sequence)
#print(pred_sequence)
#print(idx, pred_close_idx)
#print("{:P}".format(gold_tree))
#print("{:P}".format(pred_tree))
#print("=================")
gold_unary_idx = advance_past_unaries(gold_sequence, pred_close_idx - 1)
if pred_sequence[pred_unary_idx:] == gold_sequence[gold_unary_idx:]:
return FirstError.EXTRA_BRACKET_NO_CASCADE
else:
return FirstError.EXTRA_BRACKET_CASCADE
if isinstance(pred_trans, Shift) and isinstance(gold_trans, OpenConstituent):
# presumably this has attachment errors as well, similarly to EXTRA_BRACKET
gold_close_idx = advance_past_constituents(gold_sequence, idx+1)
gold_unary_idx = advance_past_unaries(gold_sequence, gold_close_idx + 1)
if pred_sequence[idx:gold_close_idx-1] == gold_sequence[idx+1:gold_close_idx]:
#print(gold_sequence)
#print(pred_sequence)
#print(idx, gold_close_idx)
#print("{:P}".format(gold_tree))
#print("{:P}".format(pred_tree))
#print("=================")
pred_unary_idx = advance_past_unaries(pred_sequence, gold_close_idx - 1)
if pred_sequence[pred_unary_idx:] == gold_sequence[gold_unary_idx:]:
return FirstError.MISSING_BRACKET_NO_CASCADE
else:
return FirstError.MISSING_BRACKET_CASCADE
return FirstError.UNKNOWN
================================================
FILE: stanza/models/constituency/evaluate_treebanks.py
================================================
"""
Read multiple treebanks, score the results.
Reports the k-best score if multiple predicted treebanks are given.
"""
import argparse
from stanza.models.constituency import tree_reader
from stanza.server.parser_eval import EvaluateParser, ParseResult
def main():
parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')
parser.add_argument('gold', type=str, help='Which file to load as the gold trees')
parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical')
args = parser.parse_args()
print("Loading gold treebank: " + args.gold)
gold = tree_reader.read_treebank(args.gold)
print("Loading predicted treebanks: " + args.pred)
pred = [tree_reader.read_treebank(x) for x in args.pred]
full_results = [ParseResult(parses[0], [*parses[1:]])
for parses in zip(gold, *pred)]
if len(pred) <= 1:
kbest = None
else:
kbest = len(pred)
with EvaluateParser(kbest=kbest) as evaluator:
response = evaluator.process(full_results)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/constituency/in_order_compound_oracle.py
================================================
from enum import Enum
from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, DynamicOracle
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, CompoundUnary, Finalize
def fix_missing_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
A CompoundUnary transition was missed after a Shift, but the sequence was continued correctly otherwise
"""
if not isinstance(gold_transition, CompoundUnary):
return None
if pred_transition != gold_sequence[gold_index + 1]:
return None
if isinstance(pred_transition, Finalize):
# this can happen if the entire tree is a single word
# but it can't be fixed if it means the parser missed the ROOT transition
return None
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:]
def fix_wrong_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CompoundUnary):
return None
if not isinstance(pred_transition, CompoundUnary):
return None
assert gold_transition != pred_transition
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
def fix_spurious_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if isinstance(gold_transition, CompoundUnary):
return None
if not isinstance(pred_transition, CompoundUnary):
return None
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:]
def fix_open_shift_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a missed Open constituent where we predicted a Shift and the next transition was a Shift
In fact, the subsequent transition MUST be a Shift with this transition scheme
"""
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
#if not isinstance(gold_sequence[gold_index+1], Shift):
# return None
assert isinstance(gold_sequence[gold_index+1], Shift)
# close_index represents the Close for the missing Open
close_index = advance_past_constituents(gold_sequence, gold_index+1)
assert close_index is not None
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
def fix_open_open_two_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if gold_transition == pred_transition:
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
if isinstance(gold_sequence[block_end], Shift):
# this is a multiple subtrees version of this error
# we are only skipping the two subtrees errors for now
return None
# no fix is possible, so we just return here
return RepairType.OPEN_OPEN_TWO_SUBTREES_ERROR, None
def fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three):
if gold_transition == pred_transition:
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
if not isinstance(gold_sequence[block_end], Shift):
# this is a multiple subtrees version of this error
# we are only skipping the two subtrees errors for now
return None
next_block_end = find_in_order_constituent_end(gold_sequence, block_end+1)
if exactly_three and isinstance(gold_sequence[next_block_end], Shift):
# for exactly three subtrees,
# we can put back the missing open transition
# and now we have no recall error, only precision error
# for more than three, we separate that out as an ambiguous choice
return None
elif not exactly_three and isinstance(gold_sequence[next_block_end], CloseConstituent):
# this is ambiguous, but we can still try this fix
return None
# at this point, we build a new sequence with the origin constituent inserted
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
def fix_open_open_three_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=True)
def fix_open_open_many_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=False)
def fix_open_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Find the closed bracket, reopen it
The Open we just missed must be forgotten - it cannot be reopened
"""
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, CloseConstituent):
return None
# find the appropriate Open so we can reopen it
open_idx = find_previous_open(gold_sequence, gold_index)
# actually, if the Close is legal, this can't happen
# but it might happen in a unit test which doesn't check legality
if open_idx is None:
return None
# also, since we are punting on the missed Open, we need to skip
# the Close which would have closed it
close_idx = advance_past_constituents(gold_sequence, gold_index+1)
return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index+1:close_idx] + gold_sequence[close_idx+1:]
def fix_shift_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Find the closed bracket, reopen it
"""
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, CloseConstituent):
return None
# don't do this at the start or immediately after opening
if gold_index == 0 or isinstance(gold_sequence[gold_index - 1], OpenConstituent):
return None
open_idx = find_previous_open(gold_sequence, gold_index)
assert open_idx is not None
return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index:]
def fix_shift_open_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
bracket_end = find_in_order_constituent_end(gold_sequence, gold_index)
assert bracket_end is not None
if isinstance(gold_sequence[bracket_end], Shift):
# this is an ambiguous error
# multiple possible places to end the wrong constituent
return None
assert isinstance(gold_sequence[bracket_end], CloseConstituent)
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
def fix_close_shift_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if not isinstance(gold_sequence[gold_index+1], Shift):
return None
bracket_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
assert bracket_end is not None
if isinstance(gold_sequence[bracket_end], Shift):
# this is an ambiguous error
# multiple possible places to end the wrong constituent
return None
assert isinstance(gold_sequence[bracket_end], CloseConstituent)
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
class RepairType(Enum):
"""
Keep track of which repair is used, if any, on an incorrect transition
Effects of different repair types:
no oracle: 0.9251 0.9226
+missing_unary: 0.9246 0.9214
+wrong_unary: 0.9236 0.9213
+spurious_unary: 0.9247 0.9229
+open_shift_error: 0.9258 0.9226
+open_open_two_subtrees: 0.9256 0.9215 # nothing changes with this one...
+open_open_three_subtrees: 0.9256 0.9226
+open_open_many_subtrees: 0.9257 0.9234
+shift_close: 0.9267 0.9250
+shift_open: 0.9273 0.9247
+close_shift: 0.9266 0.9229
+open_close: 0.9267 0.9256
"""
def __new__(cls, fn, correct=False, debug=False):
"""
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
"""
value = len(cls.__members__)
obj = object.__new__(cls)
obj._value_ = value + 1
obj.fn = fn
obj.correct = correct
obj.debug = debug
return obj
@property
def is_correct(self):
return self.correct
# The correct sequence went Shift - Unary - Stuff
# but the CompoundUnary was missed and Stuff predicted
# so now we just proceed as if nothing happened
# note that CompoundUnary happens immediately after a Shift
# complicated nodes are created with single Open transitions
MISSING_UNARY_ERROR = (fix_missing_unary_error,)
# Predicted a wrong CompoundUnary. No way to fix this, so just keep going
WRONG_UNARY_ERROR = (fix_wrong_unary_error,)
# The correct sequence went Shift - Stuff
# but instead we predicted a CompoundUnary
# again, we just keep going
SPURIOUS_UNARY_ERROR = (fix_spurious_unary_error,)
# Were supposed to open a new constituent,
# but instead shifted an item onto the stack
#
# The missed Open cannot be recovered
#
# One could ask, is it possible to open a bigger constituent later,
# but if the constituent patterns go
# X (good open) Y (missed open) Z
# when we eventually close Y and Z, because of the missed Open,
# it is guaranteed to capture X as well
# since it will grab constituents until one left of the previous Open before Y
#
# Therefore, in this case, we must simply forget about this Open (recall error)
OPEN_SHIFT_ERROR = (fix_open_shift_error,)
# With this transition scheme, it is not possible to fix the following pattern:
# T1 O_x T2 C -> T1 O_y T2 C
# seeing as how there are no unary transitions
# so whatever precision & recall errors are caused by substituting O_x -> O_y
# (which could include multiple transitions)
# those errors are unfixable in any way
OPEN_OPEN_TWO_SUBTREES_ERROR = (fix_open_open_two_subtrees_error,)
# With this transition scheme, a three subtree branch with a wrong Open
# has a non-ambiguous fix
# T1 O_x T2 T3 C -> T1 O_y T2 T3 C
# this can become
# T1 O_y T2 C O_x T3 C
# now there are precision errors from the incorrectly added transition(s),
# but the correctly replaced transitions are unambiguous
OPEN_OPEN_THREE_SUBTREES_ERROR = (fix_open_open_three_subtrees_error,)
# We were supposed to shift a new item onto the stack,
# but instead we closed the previous constituent
# This causes a precision error, but we can avoid the recall error
# by immediately reopening the closed constituent.
SHIFT_CLOSE_ERROR = (fix_shift_close_error,)
# We opened a new constituent instead of shifting
# In the event that the next constituent ends with a close,
# rather than building another new constituent,
# then there is no ambiguity
SHIFT_OPEN_UNAMBIGUOUS_ERROR = (fix_shift_open_unambiguous_error,)
# Suppose we were supposed to Close, then Shift
# but instead we just did a Shift
# Similar to shift_open_unambiguous, we now have an opened
# constituent which shouldn't be there
# We can scroll past the next constituent created to see
# if the outer constituents close at that point
# If so, we can close this constituent as well in an unambiguous manner
# TODO: analyze the case where we were supposed to Close, Open
# but instead did a Shift
CLOSE_SHIFT_UNAMBIGUOUS_ERROR = (fix_close_shift_unambiguous_error,)
# Supposed to open a new constituent,
# instead closed an existing constituent
#
# X (good open) Y (open -> close) Z
#
# the constituent that should contain Y, Z is unfortunately lost
# since now the stack has
#
# XY ...
#
# furthermore, there is now a precision error for the extra XY
# constituent that should not exist
# however, what we can do to minimize further errors is
# to at least reopen the label between X and Y
OPEN_CLOSE_ERROR = (fix_open_close_error,)
# this is ambiguous, but we can still try the same fix as three_subtrees (see above)
OPEN_OPEN_MANY_SUBTREES_ERROR = (fix_open_open_many_subtrees_error,)
CORRECT = (None, True)
UNKNOWN = None
class InOrderCompoundOracle(DynamicOracle):
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
================================================
FILE: stanza/models/constituency/in_order_oracle.py
================================================
from enum import Enum
from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, score_candidates, DynamicOracle, RepairEnum
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
def fix_wrong_open_root_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
If there is an open/open error specifically at the ROOT, close the wrong open and try again
"""
if gold_transition == pred_transition:
return None
if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent) and gold_transition.top_label in root_labels:
return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
return None
def fix_wrong_open_unary_chain(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a wrong open/open in a unary chain by removing the skipped unary transitions
Only applies is the wrong pred transition is a transition found higher up in the unary chain
"""
# useful to have this check here in case the call is made independently in a unit test
if gold_transition == pred_transition:
return None
if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent):
cur_index = gold_index + 1 # This is now a Close if we are in this particular context
while cur_index + 1 < len(gold_sequence) and isinstance(gold_sequence[cur_index], CloseConstituent) and isinstance(gold_sequence[cur_index+1], OpenConstituent):
cur_index = cur_index + 1 # advance to the next Open
if gold_sequence[cur_index] == pred_transition:
return gold_sequence[:gold_index] + gold_sequence[cur_index:]
cur_index = cur_index + 1 # advance to the next Close
return None
def fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two):
if gold_transition == pred_transition:
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
# if Close, the gold was a unary
return None
assert not isinstance(gold_sequence[gold_index+1], OpenConstituent)
assert isinstance(gold_sequence[gold_index+1], Shift)
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
assert block_end is not None
if more_than_two and isinstance(gold_sequence[block_end], CloseConstituent):
return None
if not more_than_two and isinstance(gold_sequence[block_end], Shift):
return None
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
def fix_wrong_open_two_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=False)
def fix_wrong_open_multiple_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=True)
def advance_past_unaries(gold_sequence, cur_index):
while cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index], OpenConstituent) and isinstance(gold_sequence[cur_index+1], CloseConstituent):
cur_index += 2
return cur_index
def fix_wrong_open_stuff_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a wrong open/open when there is an intervening constituent and then the guessed NT
This happens when the correct pattern is
stuff_1 NT_X stuff_2 close NT_Y ...
and instead of guessing the gold transition NT_X,
the prediction was NT_Y
"""
if gold_transition == pred_transition:
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
# TODO: Here we could advance past unary transitions while
# watching for hitting pred_transition. However, that is an open
# question... is it better to try to keep such an Open as part of
# the sequence, or is it better to skip them and attach the inner
# nodes to the upper level
stuff_start = gold_index + 1
if not isinstance(gold_sequence[stuff_start], Shift):
return None
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
if stuff_end is None:
return None
# at this point, stuff_end points to the Close which occurred after stuff_2
# also, stuff_start points to the first transition which makes stuff_2, the Shift
cur_index = stuff_end + 1
while isinstance(gold_sequence[cur_index], OpenConstituent):
if gold_sequence[cur_index] == pred_transition:
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index+1:]
# this was an OpenConstituent, but not the OpenConstituent we guessed
# maybe there's a unary transition which lets us try again
if cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index + 1], CloseConstituent):
cur_index = cur_index + 2
else:
break
# oh well, none of this worked
return None
def fix_wrong_open_general(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a general wrong open/open transition by accepting the open and continuing
A couple other open/open patterns have already been carved out
TODO: negative checks for the previous patterns, in case we turn those off
"""
if gold_transition == pred_transition:
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
# If the top is a ROOT, then replacing it with a non-ROOT creates an illegal
# transition sequence. The ROOT case was already handled elsewhere anyway
if gold_transition.top_label in root_labels:
return None
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
def fix_missed_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a missed unary which is followed by an otherwise correct transition
(also handles multiple missed unary transitions)
"""
if gold_transition == pred_transition:
return None
cur_index = gold_index
cur_index = advance_past_unaries(gold_sequence, cur_index)
if gold_sequence[cur_index] == pred_transition:
return gold_sequence[:gold_index] + gold_sequence[cur_index:]
return None
def fix_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix an Open replaced with a Shift
Suppose we were supposed to guess NT_X and instead did S
We derive the repair as follows.
For simplicity, assume the open is not a unary for now
Since we know an Open was legal, there must be stuff
stuff NT_X
Shift is also legal, so there must be other stuff and a previous Open
stuff_1 NT_Y stuff_2 NT_X
After the NT_X which we missed, there was a bunch of stuff and a close for NT_X
stuff_1 NT_Y stuff_2 NT_X stuff_3 C
There could be more stuff here which can be saved...
stuff_1 NT_Y stuff_2 NT_X stuff_3 C stuff_4 C
stuff_1 NT_Y stuff_2 NT_X stuff_3 C C
"""
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
cur_index = gold_index
cur_index = advance_past_unaries(gold_sequence, cur_index)
if not isinstance(gold_sequence[cur_index], OpenConstituent):
return None
if gold_sequence[cur_index].top_label in root_labels:
return None
# cur_index now points to the NT_X we missed (not counting unaries)
stuff_start = cur_index + 1
# can't be a Close, since we just went past an Open and checked for unaries
# can't be an Open, since two Open in a row is illegal
assert isinstance(gold_sequence[stuff_start], Shift)
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
# stuff_end is now the Close which ends NT_X
cur_index = stuff_end + 1
if cur_index >= len(gold_sequence):
return None
if isinstance(gold_sequence[cur_index], OpenConstituent):
cur_index = advance_past_unaries(gold_sequence, cur_index)
if cur_index >= len(gold_sequence):
return None
if isinstance(gold_sequence[cur_index], OpenConstituent):
# an Open here signifies that there was a bracket containing X underneath Y
# TODO: perhaps try to salvage something out of that situation?
return None
# the repair starts with the sequence up through the error,
# then stuff_3, which includes the error
# skip the Close for the missed NT_X
# then finish the sequence with any potential stuff_4, the next Close, and everything else
repair = gold_sequence[:gold_index] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
return repair
def fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix an Open replaced with a Close
Call the Open NT_X
Open legal, so there must be stuff:
stuff NT_X
Close legal, so there must be something to close:
stuff_1 NT_Y stuff_2 NT_X
The incorrect close makes the following brackets:
(Y stuff_1 stuff_2)
We were supposed to build
(Y stuff_1 (X stuff_2 ...) (possibly more stuff))
The simplest fix here is to reopen Y at this point.
One issue might be if there is another bracket which encloses X underneath Y
So, for example, the tree was supposed to be
(Y stuff_1 (Z (X stuff_2 stuff_3) stuff_4))
The pattern for this case is
stuff_1 NT_Y stuff_2 NY_X stuff_3 close NT_Z stuff_4 close close
"""
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, CloseConstituent):
return None
cur_index = advance_past_unaries(gold_sequence, gold_index)
if cur_index >= len(gold_sequence):
return None
if not isinstance(gold_sequence[cur_index], OpenConstituent):
return None
if gold_sequence[cur_index].top_label in root_labels:
return None
prev_open_index = find_previous_open(gold_sequence, gold_index)
if prev_open_index is None:
return None
prev_open = gold_sequence[prev_open_index]
# prev_open is now NT_Y from above
stuff_start = cur_index + 1
assert isinstance(gold_sequence[stuff_start], Shift)
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
# stuff_end is now the Close which ends NT_X
# stuff_start:stuff_end is the stuff_3 block above
cur_index = stuff_end + 1
if cur_index >= len(gold_sequence):
return None
# if there are unary transitions here, we want to skip those.
# those are unary transitions on X and cannot be recovered, since X is gone
cur_index = advance_past_unaries(gold_sequence, cur_index)
# now there is a certain failure case which has to be accounted for.
# specifically, if there is a new non-terminal which opens
# immediately after X closes, it is encompassing X in a way that
# cannot be recovered now that part of X is stuck under Y.
# The two choices at this point would be to eliminate the new
# transition or just reject the tree from the repair
# For now, we reject the tree
if isinstance(gold_sequence[cur_index], OpenConstituent):
return None
repair = gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
return repair
def fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
This fixes Shift replaced with a Close transition.
This error occurs in the following pattern:
stuff_1 NT_X stuff... shift
Instead of shift, you close the NT_X
The easiest fix here is to just restore the NT_X.
"""
if not isinstance(pred_transition, CloseConstituent):
return None
# this fix can also be applied if there were unaries on the
# previous constituent. we just skip those until the Shift
cur_index = gold_index
if isinstance(gold_transition, OpenConstituent):
cur_index = advance_past_unaries(gold_sequence, cur_index)
if not isinstance(gold_sequence[cur_index], Shift):
return None
prev_open_index = find_previous_open(gold_sequence, gold_index)
if prev_open_index is None:
return None
prev_open = gold_sequence[prev_open_index]
# prev_open is now NT_X from above
return gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[cur_index:]
def fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if len(gold_sequence) < gold_index + 3:
return None
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
return None
open_index = advance_past_unaries(gold_sequence, gold_index+1)
if not isinstance(gold_sequence[open_index], OpenConstituent):
return None
if not isinstance(gold_sequence[open_index+1], Shift):
return None
# check that the next operation was to open a *different* constituent
# from the one we just closed
prev_open_index = find_previous_open(gold_sequence, gold_index)
if prev_open_index is None:
return None
prev_open = gold_sequence[prev_open_index]
if gold_sequence[open_index] == prev_open:
return None
# check that the following stuff is a single bracket, not multiple brackets
end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
return None
elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
return None
# if closing at the end of the next blocks,
# instead of closing after the first block ends,
# we go to the end of the last block
if late:
end_index = advance_past_constituents(gold_sequence, open_index+1)
return gold_sequence[:gold_index] + gold_sequence[open_index+1:end_index] + gold_sequence[gold_index:open_index+1] + gold_sequence[end_index:]
def fix_close_open_shift_unambiguous_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
def fix_close_open_shift_ambiguous_bracket_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
def fix_close_open_shift_ambiguous_bracket_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
def fix_close_open_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if len(gold_sequence) < gold_index + 3:
return None
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
return None
open_index = advance_past_unaries(gold_sequence, gold_index+1)
if not isinstance(gold_sequence[open_index], OpenConstituent):
return None
if not isinstance(gold_sequence[open_index+1], Shift):
return None
# check that the next operation was to open a *different* constituent
# from the one we just closed
prev_open_index = find_previous_open(gold_sequence, gold_index)
if prev_open_index is None:
return None
prev_open = gold_sequence[prev_open_index]
if gold_sequence[open_index] == prev_open:
return None
# alright, at long last we have:
# a close that was missed
# a non-nested open that was missed
end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
candidates = []
candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
while isinstance(gold_sequence[end_index], Shift):
end_index = find_in_order_constituent_end(gold_sequence, end_index+1)
candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
scores, best_idx, best_candidate = score_candidates(model, state, candidates)
if len(candidates) == 1:
return RepairType.CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET, best_candidate
if best_idx == len(candidates) - 1:
best_idx = -1
repair_type = RepairEnum(name=RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.name,
value="%d.%d" % (RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
is_correct=False)
return repair_type, best_candidate
def fix_close_open_shift_nested(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Fix a Close X..Open X..Shift pattern where both the Close and Open were skipped.
Here the pattern we are trying to fix is
stuff_A open_X stuff_B *close* open_X shift...
replaced with
stuff_A open_X stuff_B shift...
the missed close & open means a missed recall error for (X A B)
whereas the previous open_X can still get the outer bracket
"""
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if len(gold_sequence) < gold_index + 3:
return None
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
return None
# handle the sequence:
# stuff_A open_X stuff_B close open_Y close open_X shift
open_index = advance_past_unaries(gold_sequence, gold_index+1)
if not isinstance(gold_sequence[open_index], OpenConstituent):
return None
if not isinstance(gold_sequence[open_index+1], Shift):
return None
# check that the next operation was to open the same constituent
# we just closed
prev_open_index = find_previous_open(gold_sequence, gold_index)
if prev_open_index is None:
return None
prev_open = gold_sequence[prev_open_index]
if gold_sequence[open_index] != prev_open:
return None
return gold_sequence[:gold_index] + gold_sequence[open_index+1:]
def fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
"""
Repair Close/Shift -> Shift by moving the Close to after the next block is created
"""
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if len(gold_sequence) < gold_index + 2:
return None
start_index = gold_index + 1
start_index = advance_past_unaries(gold_sequence, start_index)
if len(gold_sequence) < start_index + 2:
return None
if not isinstance(gold_sequence[start_index], Shift):
return None
end_index = find_in_order_constituent_end(gold_sequence, start_index)
if end_index is None:
return None
# if this *isn't* a close, we don't allow it in the unambiguous case
# that case seems to be ambiguous...
# stuff_1 close stuff_2 stuff_3
# if you would normally start building stuff_3,
# it is not clear if you want to close at the end of
# stuff_2 or build stuff_3 instead.
if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
return None
elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
return None
# close at the end of the brackets, rather than once the first bracket is finished
if late:
end_index = advance_past_constituents(gold_sequence, start_index)
return gold_sequence[:gold_index] + gold_sequence[start_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
def fix_close_shift_shift_unambiguous(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
def fix_close_shift_shift_ambiguous_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
def fix_close_shift_shift_ambiguous_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
def fix_close_shift_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
if len(gold_sequence) < gold_index + 2:
return None
start_index = gold_index + 1
start_index = advance_past_unaries(gold_sequence, start_index)
if len(gold_sequence) < start_index + 2:
return None
if not isinstance(gold_sequence[start_index], Shift):
return None
# now we know that the gold pattern was
# Close (unaries) Shift
# and instead the model predicted Shift
candidates = []
current_index = start_index
while isinstance(gold_sequence[current_index], Shift):
current_index = find_in_order_constituent_end(gold_sequence, current_index)
assert current_index is not None
candidates.append((gold_sequence[:gold_index], gold_sequence[start_index:current_index], [CloseConstituent()], gold_sequence[current_index:]))
scores, best_idx, best_candidate = score_candidates(model, state, candidates)
if len(candidates) == 1:
return RepairType.CLOSE_SHIFT_SHIFT, best_candidate
if best_idx == len(candidates) - 1:
best_idx = -1
repair_type = RepairEnum(name=RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.name,
value="%d.%d" % (RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
is_correct=False)
#print(best_idx, len(candidates), repair_type)
return repair_type, best_candidate
def ambiguous_shift_open_unary_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
def ambiguous_shift_open_early_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
# Find when the current block ends,
# either via a Shift or a Close
end_index = find_in_order_constituent_end(gold_sequence, gold_index)
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
def ambiguous_shift_open_late_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
end_index = advance_past_constituents(gold_sequence, gold_index)
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
def ambiguous_shift_open_predicted_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
unary_candidate = (gold_sequence[:gold_index], [pred_transition], [CloseConstituent()], gold_sequence[gold_index:])
early_index = find_in_order_constituent_end(gold_sequence, gold_index)
early_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:early_index], [CloseConstituent()], gold_sequence[early_index:])
late_index = advance_past_constituents(gold_sequence, gold_index)
if early_index == late_index:
candidates = [unary_candidate, early_candidate]
scores, best_idx, best_candidate = score_candidates(model, state, candidates)
if best_idx == 0:
return_label = "U"
else:
return_label = "S"
else:
late_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:late_index], [CloseConstituent()], gold_sequence[late_index:])
candidates = [unary_candidate, early_candidate, late_candidate]
scores, best_idx, best_candidate = score_candidates(model, state, candidates)
if best_idx == 0:
return_label = "U"
elif best_idx == 1:
return_label = "E"
else:
return_label = "L"
repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_PREDICTED_CLOSE.name,
value="%d.%s" % (RepairType.SHIFT_OPEN_PREDICTED_CLOSE.value, return_label),
is_correct=False)
return repair_type, best_candidate
def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
return RepairType.OTHER_CLOSE_SHIFT, None
def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_CLOSE_OPEN, None
def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_OPEN_OPEN, None
def report_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
return RepairType.OTHER_OPEN_SHIFT, None
def report_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, CloseConstituent):
return None
return RepairType.OTHER_OPEN_CLOSE, None
def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_SHIFT_OPEN, None
class RepairType(Enum):
"""
Keep track of which repair is used, if any, on an incorrect transition
Statistics on English w/ no charlm, no transformer,
eg word vectors only, best model as of January 2024
unambiguous transitions only:
oracle scheme dev test
no oracle 0.9245 0.9226
+wrong_open_root 0.9244 0.9224
+wrong_unary_chain 0.9243 0.9237
+wrong_open_unary 0.9249 0.9223
+wrong_open_general 0.9251 0.9215
+missed_unary 0.9248 0.9215
+open_shift 0.9243 0.9216
+open_close 0.9254 0.9217
+shift_close 0.9261 0.9238
+close_shift_nested 0.9253 0.9250
Redoing the wrong_open_general, which seemed to hurt test scores:
wrong_open_two_subtrees - L4 0.9244 0.9220
every else w/o ambiguous open/open fix 0.9259 0.9241
everything w/ open_two_subtrees 0.9261 0.9246
w/ ambiguous open_three_subtrees 0.9264 0.9243
Testing three different possible repairs for shift-open:
w/ ambiguous open_three_subtrees 0.9264 0.9243
immediate close (unary) 0.9267 0.9246
close after first bracket 0.9265 0.9256
close after last bracket 0.9264 0.9240
Testing three possible repairs for close-open-shift/shift
w/ ambiguous open_three_subtrees 0.9264 0.9243
unambiguous c-o-s/shift 0.9265 0.9246
ambiguous c-o-s/shift closed early 0.9262 0.9246
ambiguous c-o-s/shift closed late 0.9259 0.9245
Testing three possible repairs for close-shift/shift
w/ ambiguous open_three_subtrees 0.9264 0.9243
unambiguous c-s/shift 0.9253 0.9239
ambiguous c-s/shift closed early 0.9259 0.9235
ambiguous c-s/shift closed late 0.9252 0.9241
ambiguous c-s/shift predicted 0.9264 0.9243
--------------------------------------------------------
Running ID experiments to verify some of the above findings
no charlm or bert, only 200 epochs
Comparing wrong_open fixes
w/ ambiguous open_two_subtrees 0.8448 0.8335
w/ ambiguous open_three_subtrees 0.8424 0.8336
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.8448 0.8360
ambiguous c-s/shift closed early 0.8425 0.8352
ambiguous c-s/shift closed late 0.8452 0.8334
--------------------------------------------------------
Running ID experiments to verify some of the above findings
bert + peft, only 200 epochs
Comparing wrong_open fixes
w/o ambiguous open/open fix 0.8923 0.8834
w/ ambiguous open_two_subtrees 0.8908 0.8828
w/ ambiguous open_three_subtrees 0.8901 0.8801
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.8921 0.8825
ambiguous c-s/shift closed early 0.8924 0.8841
ambiguous c-s/shift closed late 0.8921 0.8806
ambiguous c-s/shift predicted 0.8923 0.8835
--------------------------------------------------------
Running DE experiments to verify some of the above findings
bert + peft, only 200 epochs
Comparing wrong_open fixes
w/o ambiguous open/open fix 0.9576 0.9402
w/ ambiguous open_two_subtrees 0.9570 0.9410
w/ ambiguous open_three_subtrees 0.9569 0.9412
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.9566 0.9408
ambiguous c-s/shift closed early 0.9564 0.9394
ambiguous c-s/shift closed late 0.9572 0.9408
ambiguous c-s/shift predicted 0.9571 0.9404
--------------------------------------------------------
Running IT experiments to verify some of the above findings
bert + peft, only 200 epochs
Comparing wrong_open fixes
w/o ambiguous open/open fix 0.8380 0.8361
w/ ambiguous open_two_subtrees 0.8377 0.8351
w/ ambiguous open_three_subtrees 0.8381 0.8368
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.8376 0.8392
ambiguous c-s/shift closed early 0.8363 0.8359
ambiguous c-s/shift closed late 0.8365 0.8383
ambiguous c-s/shift predicted 0.8379 0.8371
--------------------------------------------------------
Running ZH experiments to verify some of the above findings
bert + peft, only 200 epochs
Comparing wrong_open fixes
w/o ambiguous open/open fix 0.9160 0.9143
w/ ambiguous open_two_subtrees 0.9145 0.9144
w/ ambiguous open_three_subtrees 0.9146 0.9142
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.9155 0.9146
ambiguous c-s/shift closed early 0.9145 0.9153
ambiguous c-s/shift closed late 0.9138 0.9140
ambiguous c-s/shift predicted 0.9154 0.9144
--------------------------------------------------------
Running VI experiments to verify some of the above findings
bert + peft, only 200 epochs
Comparing wrong_open fixes
w/o ambiguous open/open fix 0.8282 0.7668
w/ ambiguous open_two_subtrees 0.8272 0.7670
w/ ambiguous open_three_subtrees 0.8282 0.7668
Testing three possible repairs for close-shift/shift
unambiguous c-s/shift 0.8285 0.7683
ambiguous c-s/shift closed early 0.8276 0.7678
ambiguous c-s/shift closed late 0.8278 0.7668
ambiguous c-s/shift predicted 0.8270 0.7668
--------------------------------------------------------
Testing a combination of ambiguous vs predicted transitions
ambiguous
EN: (no CSS_U) 0.9258 0.9252
ZH: (no CSS_U) 0.9153 0.9145
predicted
EN: (no CSS_U) 0.9264 0.9241
ZH: (no CSS_U) 0.9145 0.9141
"""
def __new__(cls, fn, correct=False, debug=False):
"""
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
correct: this represents a correct transition
debug: always run this, as it just counts statistics
"""
value = len(cls.__members__)
obj = object.__new__(cls)
obj._value_ = value + 1
obj.fn = fn
obj.correct = correct
obj.debug = debug
return obj
@property
def is_correct(self):
return self.correct
# The first section is a sequence of repairs when the parser
# should have chosen NTx but instead chose NTy
# Blocks of transitions which can be abstracted away to be
# anything will be represented as S1, S2, etc... S for stuff
# We carve out an exception for a wrong open at the root
# The only possble transtions at this point are to close
# the error and try again with the root
WRONG_OPEN_ROOT_ERROR = (fix_wrong_open_root_error,)
# The simplest form of such an error is when there is a sequence
# of unary transitions and the parser chose a wrong parent.
# Remember that a unary transition is represented by a pair
# of transitions, NTx, Close.
# In this case, the correct sequence was
# S1 NTx Close NTy Close NTz ...
# but the parser chose NTy, NTz, etc
# The repair in this case is to simply discard the unchosen
# unary transitions and continue
WRONG_OPEN_UNARY_CHAIN = (fix_wrong_open_unary_chain,)
# Similar to the UNARY_CHAIN error, but in this case there is a
# bunch of stuff (one or more constituents built) between the
# missed open transition and the close transition
WRONG_OPEN_STUFF_UNARY = (fix_wrong_open_stuff_unary,)
# If the correct sequence is
# T1 O_x T2 C
# and instead we predicted
# T1 O_y ...
# this can be fixed with a unary transition after
# T1 O_y T2 C O_x C
# note that this is technically ambiguous
# could have done
# T1 O_x C O_y T2 C
# but doing this should be easier for the parser to detect (untested)
# also this way the same code paths can be used for two subtrees
# and for multiple subtrees
WRONG_OPEN_TWO_SUBTREES = (fix_wrong_open_two_subtrees,)
# If the gold transition is an Open because it is part of
# a unary transition, and the following transition is a
# correct Shift or Close, we can just skip past the unary.
MISSED_UNARY = (fix_missed_unary,)
# Open -> Shift errors which don't just represent a unary
# generally represent a missing bracket which cannot be
# recovered using the in-order mechanism. Dropping the
# missing transition is generally the only fix.
# (This means removing the corresponding Close)
# One could theoretically create a new transition which
# grabs two constituents, though
OPEN_SHIFT = (fix_open_shift,)
# Open -> Close is a rather drastic break in the
# potential structure of the tree. We can no longer
# recover the missed Open, and we might not be able
# to recover other following missed Opens as well.
# In most cases, the only thing to do is reopen the
# incorrectly closed outer bracket and keep going.
OPEN_CLOSE = (fix_open_close,)
# Similar to the Open -> Close error, but at least
# in this case we are just introducing one wrong bracket
# rather than also breaking some existing brackets.
# The fix here is to reopen the closed bracket.
SHIFT_CLOSE = (fix_shift_close,)
# Specifically fixes an error where bracket X is
# closed and then immediately opened to build a
# new X bracket. In this case, the simplest fix
# will be to skip both the close and the new open
# and continue from there.
CLOSE_OPEN_SHIFT_NESTED = (fix_close_open_shift_nested,)
# Fix an error where the correct sequence was to Close X, Open Y,
# then continue building,
# but instead the model did a Shift in place of C_X O_Y
# The damage here is a recall error for the missed X and
# a precision error for the incorrectly opened X
# However, the Y can actually be recovered - whenever we finally
# close X, we can then open Y
# One form of that is unambiguous, that of
# T_A O_X T_B C O_Y T_C C
# with only one subtree after the O_Y
# In that case, the Close that would have closed Y
# is the only place for the missing close of X
# So we can produce the following:
# T_A O_X T_B T_C C O_Y C
CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET = (fix_close_open_shift_unambiguous_bracket,)
# Similarly to WRONG_OPEN_TWO_SUBTREES, if the correct sequence is
# T1 O_x T2 T3 C
# and instead we predicted
# T1 O_y ...
# this can be fixed by closing O_y in any number of places
# T1 O_y T2 C O_x T3 C
# T1 O_y T2 C T3 O_x C
# Either solution is a single precision error,
# but keeps the O_x subtree correct
# This is an ambiguous transition - we can experiment with different fixes
WRONG_OPEN_MULTIPLE_SUBTREES = (fix_wrong_open_multiple_subtrees,)
CORRECT = (None, True)
UNKNOWN = None
# If the model is supposed to build a block after a Close
# operation, attach that block to the piece to the left
# a couple different variations on this were tried
# we tried attaching all constituents to the
# bracket which should have been closed
# we tried attaching exactly one constituent
# and we tried attaching only if there was
# exactly one following constituent
# none of these improved f1. for example, on the VI dataset, we
# lost 0.15 F1 with the exactly one following constituent version
# it might be worthwhile double checking some of the other
# versions to make sure those also fail, though
CLOSE_SHIFT_SHIFT = (fix_close_shift_shift_unambiguous,)
# In the ambiguous close-shift/shift case, this closes the surrounding bracket
# (which should have already been closed)
# as soon as the next constituent is built
# this turns
# (A (B s1 s2) s3 s4)
# into
# (A (B s1 s2 s3) s4)
CLOSE_SHIFT_SHIFT_AMBIGUOUS_EARLY = (fix_close_shift_shift_ambiguous_early,)
# In the ambiguous close-shift/shift case, this closes the surrounding bracket
# (which should have already been closed)
# when the rest of the constituents in this bracket are built
# this turns
# (A (B s1 s2) s3 s4)
# into
# (A (B s1 s2 s3 s4))
CLOSE_SHIFT_SHIFT_AMBIGUOUS_LATE = (fix_close_shift_shift_ambiguous_late,)
# For the close-shift/shift errors which are ambiguous,
# this uses the model's predictions to guess which block
# to put the close after
CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_shift_shift_ambiguous_predicted,)
# If a sequence should have gone Close - Open - Shift,
# and instead we went Shift,
# we need to close the previous bracket
# If it is ambiguous
# such as Close - Open - Shift - Shift
# close the bracket ASAP
# eg, Shift - Close - Open - Shift
CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_EARLY = (fix_close_open_shift_ambiguous_bracket_early,)
# for Close - Open - Shift - Shift
# close the bracket as late as possible
# eg, Shift - Shift - Close - Open
CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_LATE = (fix_close_open_shift_ambiguous_bracket_late,)
# If the sequence should have gone
# Close - Open - Shift
# and instead we predicted a Shift
# in a context where closing the bracket would be ambiguous
# we use the model to predict where the close should actually happen
CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_open_shift_ambiguous_predicted,)
# This particular repair effectively turns the shift -> ambiguous open
# into a unary transition
SHIFT_OPEN_UNARY_CLOSE = (ambiguous_shift_open_unary_close,)
# Fix the shift -> ambiguous open by closing after the first constituent
# This is an ambiguous solution because it could also be closed either
# as a unary transition or with a close at the end of the outer bracket
SHIFT_OPEN_EARLY_CLOSE = (ambiguous_shift_open_early_close,)
# Fix the shift -> ambiguous open by closing after all constituents
# This is an ambiguous solution because it could also be closed either
# as a unary transition or with a close at the end of the first constituent
SHIFT_OPEN_LATE_CLOSE = (ambiguous_shift_open_late_close,)
# Use the model to predict when to close!
# The different options for where to put the Close are put into the model,
# and the highest scoring close is used
SHIFT_OPEN_PREDICTED_CLOSE = (ambiguous_shift_open_predicted_close,)
OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
OTHER_CLOSE_OPEN = (report_close_open, False, True)
OTHER_OPEN_OPEN = (report_open_open, False, True)
OTHER_OPEN_CLOSE = (report_open_close, False, True)
OTHER_OPEN_SHIFT = (report_open_shift, False, True)
OTHER_SHIFT_OPEN = (report_shift_open, False, True)
# any other open transition we get wrong, which hasn't already
# been carved out as an exception above, we just accept the
# incorrect Open and keep going
#
# TODO: check if there is a way to improve this
# it appears to hurt scores simply by existing
# explanation: this is wrong logic
# Suppose the correct sequence had been
# T1 open(NP) T2 T3 close
# Instead we had done
# T1 open(VP) T2 T3 close
# We can recover the missing NP!
# T1 open(VP) T2 close open(NP) T3 close
# Can also recover it as
# T1 open(VP) T2 T3 close open(NP) close
# So this is actually an ambiguous transition
# except in the case of
# T1 open(...) close
# In this case, a unary transition can fix make it so we only have
# a precision error, not also a recall error
# Currently, the approach is to put this after the default fixes
# and use the two & more-than-two versions of the fix above
WRONG_OPEN_GENERAL = (fix_wrong_open_general,)
class InOrderOracle(DynamicOracle):
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
================================================
FILE: stanza/models/constituency/label_attention.py
================================================
import numpy as np
import functools
import sys
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
# publicly available versions alternate between torch.uint8 and torch.bool,
# but that is for older versions of torch anyway
DTYPE = torch.bool
class BatchIndices:
"""
Batch indices container class (used to implement packed batches)
"""
def __init__(self, batch_idxs_np, device):
self.batch_idxs_np = batch_idxs_np
self.batch_idxs_torch = torch.as_tensor(batch_idxs_np, dtype=torch.long, device=device)
self.batch_size = int(1 + np.max(batch_idxs_np))
batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])
self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
#print(f"boundaries_np: {self.boundaries_np}")
#print(f"boundaries_np[1:]: {self.boundaries_np[1:]}")
#print(f"boundaries_np[:-1]: {self.boundaries_np[:-1]}")
self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
#print(f"seq_lens_np: {self.seq_lens_np}")
#print(f"batch_size: {self.batch_size}")
assert len(self.seq_lens_np) == self.batch_size
self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))
class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
@classmethod
def forward(cls, ctx, input, batch_idxs, p=0.5, train=False, inplace=False):
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
ctx.p = p
ctx.train = train
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
if ctx.p > 0 and ctx.train:
ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))
if ctx.p == 1:
ctx.noise.fill_(0)
else:
ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]
output.mul_(ctx.noise)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.p > 0 and ctx.train:
return grad_output.mul(ctx.noise), None, None, None, None
else:
return grad_output, None, None, None, None
#
class FeatureDropout(nn.Module):
"""
Feature-level dropout: takes an input of size len x num_features and drops
each feature with probabibility p. A feature is dropped across the full
portion of the input that corresponds to a single batch element.
"""
def __init__(self, p=0.5, inplace=False):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
def forward(self, input, batch_idxs):
return FeatureDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)
class LayerNormalization(nn.Module):
def __init__(self, d_hid, eps=1e-3, affine=True):
super(LayerNormalization, self).__init__()
self.eps = eps
self.affine = affine
if self.affine:
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
def forward(self, z):
if z.size(-1) == 1:
return z
mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
if self.affine:
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
return ln_out
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, attention_dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.temper = d_model ** 0.5
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v, attn_mask=None):
# q: [batch, slot, feat] or (batch * d_l) x max_len x d_k
# k: [batch, slot, feat] or (batch * d_l) x max_len x d_k
# v: [batch, slot, feat] or (batch * d_l) x max_len x d_v
# q in LAL is (batch * d_l) x 1 x d_k
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len
# in LAL, gives: (batch * d_l) x 1 x max_len
# attention weights from each word to each word, for each label
# in best model (repeated q): attention weights from label (as vector weights) to each word
if attn_mask is not None:
assert attn_mask.size() == attn.size(), \
'Attention mask shape {} mismatch ' \
'with Attention logit tensor shape ' \
'{}.'.format(attn_mask.size(), attn.size())
attn.data.masked_fill_(attn_mask, -float('inf'))
attn = self.softmax(attn)
# Note that this makes the distribution not sum to 1. At some point it
# may be worth researching whether this is the right way to apply
# dropout to the attention.
# Note that the t2t code also applies dropout in this manner
attn = self.dropout(attn)
output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v
# in LAL, gives: (batch * d_l) x 1 x d_v
return output, attn
class MultiHeadAttention(nn.Module):
"""
Multi-head attention module
"""
def __init__(self, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
if not d_positional:
self.partitioned = False
else:
self.partitioned = True
if self.partitioned:
self.d_content = d_model - d_positional
self.d_positional = d_positional
self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))
self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))
init.xavier_normal_(self.w_qs1)
init.xavier_normal_(self.w_ks1)
init.xavier_normal_(self.w_vs1)
init.xavier_normal_(self.w_qs2)
init.xavier_normal_(self.w_ks2)
init.xavier_normal_(self.w_vs2)
else:
self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
init.xavier_normal_(self.w_qs)
init.xavier_normal_(self.w_ks)
init.xavier_normal_(self.w_vs)
self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
self.layer_norm = LayerNormalization(d_model)
if not self.partitioned:
# The lack of a bias term here is consistent with the t2t code, though
# in my experiments I have never observed this making a difference.
self.proj = nn.Linear(n_head*d_v, d_model, bias=False)
else:
self.proj1 = nn.Linear(n_head*(d_v//2), self.d_content, bias=False)
self.proj2 = nn.Linear(n_head*(d_v//2), self.d_positional, bias=False)
self.residual_dropout = FeatureDropout(residual_dropout)
def split_qkv_packed(self, inp, qk_inp=None):
v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model
if qk_inp is None:
qk_inp_repeated = v_inp_repeated
else:
qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))
if not self.partitioned:
q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k
k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k
v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v
else:
q_s = torch.cat([
torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_qs1),
torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_qs2),
], -1)
k_s = torch.cat([
torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_ks1),
torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_ks2),
], -1)
v_s = torch.cat([
torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
], -1)
return q_s, k_s, v_s
def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
# Input is padded representation: n_head x len_inp x d
# Output is packed representation: (n_head * mb_size) x len_padded x d
# (along with masks for the attention and output)
n_head = self.n_head
d_k, d_v = self.d_k, self.d_v
len_padded = batch_idxs.max_len
mb_size = batch_idxs.batch_size
q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
invalid_mask[i, :end-start].fill_(False)
return(
q_padded.view(-1, len_padded, d_k),
k_padded.view(-1, len_padded, d_k),
v_padded.view(-1, len_padded, d_v),
invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1),
(~invalid_mask).repeat(n_head, 1),
)
def combine_v(self, outputs):
# Combine attention information from the different heads
n_head = self.n_head
outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv
if not self.partitioned:
# Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)
# Project back to residual size
outputs = self.proj(outputs)
else:
d_v1 = self.d_v // 2
outputs1 = outputs[:,:,:d_v1]
outputs2 = outputs[:,:,d_v1:]
outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
outputs = torch.cat([
self.proj1(outputs1),
self.proj2(outputs2),
], -1)
return outputs
def forward(self, inp, batch_idxs, qk_inp=None):
residual = inp
# While still using a packed representation, project to obtain the
# query/key/value for each head
q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)
# n_head x len_inp x d_kv
# Switch to padded representation, perform attention, then switch back
q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
# (n_head * batch) x len_padded x d_kv
outputs_padded, attns_padded = self.attention(
q_padded, k_padded, v_padded,
attn_mask=attn_mask,
)
outputs = outputs_padded[output_mask]
# (n_head * len_inp) x d_kv
outputs = self.combine_v(outputs)
# len_inp x d_model
outputs = self.residual_dropout(outputs, batch_idxs)
return self.layer_norm(outputs + residual), attns_padded
#
class PositionwiseFeedForward(nn.Module):
"""
A position-wise feed forward module.
Projects to a higher-dimensional space before applying ReLU, then projects
back.
"""
def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_hid, d_ff)
self.w_2 = nn.Linear(d_ff, d_hid)
self.layer_norm = LayerNormalization(d_hid)
self.relu_dropout = FeatureDropout(relu_dropout)
self.residual_dropout = FeatureDropout(residual_dropout)
self.relu = nn.ReLU()
def forward(self, x, batch_idxs):
residual = x
output = self.w_1(x)
output = self.relu_dropout(self.relu(output), batch_idxs)
output = self.w_2(output)
output = self.residual_dropout(output, batch_idxs)
return self.layer_norm(output + residual)
#
class PartitionedPositionwiseFeedForward(nn.Module):
def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):
super().__init__()
self.d_content = d_hid - d_positional
self.w_1c = nn.Linear(self.d_content, d_ff//2)
self.w_1p = nn.Linear(d_positional, d_ff//2)
self.w_2c = nn.Linear(d_ff//2, self.d_content)
self.w_2p = nn.Linear(d_ff//2, d_positional)
self.layer_norm = LayerNormalization(d_hid)
self.relu_dropout = FeatureDropout(relu_dropout)
self.residual_dropout = FeatureDropout(residual_dropout)
self.relu = nn.ReLU()
def forward(self, x, batch_idxs):
residual = x
xc = x[:, :self.d_content]
xp = x[:, self.d_content:]
outputc = self.w_1c(xc)
outputc = self.relu_dropout(self.relu(outputc), batch_idxs)
outputc = self.w_2c(outputc)
outputp = self.w_1p(xp)
outputp = self.relu_dropout(self.relu(outputp), batch_idxs)
outputp = self.w_2p(outputp)
output = torch.cat([outputc, outputp], -1)
output = self.residual_dropout(output, batch_idxs)
return self.layer_norm(output + residual)
class LabelAttention(nn.Module):
"""
Single-head Attention layer for label-specific representations
"""
def __init__(self, d_model, d_k, d_v, d_l, d_proj, combine_as_self, use_resdrop=True, q_as_matrix=False, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
super(LabelAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
self.d_l = d_l # Number of Labels
self.d_model = d_model # Model Dimensionality
self.d_proj = d_proj # Projection dimension of each label output
self.use_resdrop = use_resdrop # Using Residual Dropout?
self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors
self.combine_as_self = combine_as_self # Using the Combination Method of Self-Attention
if not d_positional:
self.partitioned = False
else:
self.partitioned = True
if self.partitioned:
if d_model <= d_positional:
raise ValueError("Unable to build LabelAttention. d_model %d <= d_positional %d" % (d_model, d_positional))
self.d_content = d_model - d_positional
self.d_positional = d_positional
if self.q_as_matrix:
self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
else:
self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)
if self.q_as_matrix:
self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
else:
self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)
init.xavier_normal_(self.w_qs1)
init.xavier_normal_(self.w_ks1)
init.xavier_normal_(self.w_vs1)
init.xavier_normal_(self.w_qs2)
init.xavier_normal_(self.w_ks2)
init.xavier_normal_(self.w_vs2)
else:
if self.q_as_matrix:
self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
else:
self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)
self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)
init.xavier_normal_(self.w_qs)
init.xavier_normal_(self.w_ks)
init.xavier_normal_(self.w_vs)
self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
if self.combine_as_self:
self.layer_norm = LayerNormalization(d_model)
else:
self.layer_norm = LayerNormalization(self.d_proj)
if not self.partitioned:
# The lack of a bias term here is consistent with the t2t code, though
# in my experiments I have never observed this making a difference.
if self.combine_as_self:
self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)
else:
self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v
else:
if self.combine_as_self:
self.proj1 = nn.Linear(self.d_l*(d_v//2), self.d_content, bias=False)
self.proj2 = nn.Linear(self.d_l*(d_v//2), self.d_positional, bias=False)
else:
self.proj1 = nn.Linear(d_v//2, self.d_content, bias=False)
self.proj2 = nn.Linear(d_v//2, self.d_positional, bias=False)
if not self.combine_as_self:
self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)
self.residual_dropout = FeatureDropout(residual_dropout)
def split_qkv_packed(self, inp, k_inp=None):
len_inp = inp.size(0)
v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model
if k_inp is None:
k_inp_repeated = v_inp_repeated
else:
k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model
if not self.partitioned:
if self.q_as_matrix:
q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k
else:
q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k
k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k
v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v
else:
if self.q_as_matrix:
q_s = torch.cat([
torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_qs1),
torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_qs2),
], -1)
else:
q_s = torch.cat([
self.w_qs1.unsqueeze(1),
self.w_qs2.unsqueeze(1),
], -1)
k_s = torch.cat([
torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_ks1),
torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_ks2),
], -1)
v_s = torch.cat([
torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
], -1)
return q_s, k_s, v_s
def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
# Input is padded representation: n_head x len_inp x d
# Output is packed representation: (n_head * mb_size) x len_padded x d
# (along with masks for the attention and output)
n_head = self.d_l
d_k, d_v = self.d_k, self.d_v
len_padded = batch_idxs.max_len
mb_size = batch_idxs.batch_size
if self.q_as_matrix:
q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
else:
q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
if self.q_as_matrix:
q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
invalid_mask[i, :end-start].fill_(False)
if self.q_as_matrix:
q_padded = q_padded.view(-1, len_padded, d_k)
attn_mask = invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1)
else:
attn_mask = invalid_mask.unsqueeze(1).repeat(n_head, 1, 1)
output_mask = (~invalid_mask).repeat(n_head, 1)
return(
q_padded,
k_padded.view(-1, len_padded, d_k),
v_padded.view(-1, len_padded, d_v),
attn_mask,
output_mask,
)
def combine_v(self, outputs):
# Combine attention information from the different labels
d_l = self.d_l
outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v
if not self.partitioned:
# Switch from d_l x len_inp x d_v to len_inp x d_l x d_v
if self.combine_as_self:
outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)
else:
outputs = torch.transpose(outputs, 0, 1)#.contiguous() #.view(-1, d_l * self.d_v)
# Project back to residual size
outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model
else:
d_v1 = self.d_v // 2
outputs1 = outputs[:,:,:d_v1]
outputs2 = outputs[:,:,d_v1:]
if self.combine_as_self:
outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)
outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)
else:
outputs1 = torch.transpose(outputs1, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
outputs2 = torch.transpose(outputs2, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
outputs = torch.cat([
self.proj1(outputs1),
self.proj2(outputs2),
], -1)#.contiguous()
return outputs
def forward(self, inp, batch_idxs, k_inp=None):
residual = inp # len_inp x d_model
#print()
#print(f"inp.shape: {inp.shape}")
len_inp = inp.size(0)
#print(f"len_inp: {len_inp}")
# While still using a packed representation, project to obtain the
# query/key/value for each head
q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)
# d_l x len_inp x d_k
# q_s is d_l x 1 x d_k
# Switch to padded representation, perform attention, then switch back
q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
# q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv
# q_s is (d_l * batch_size) x 1 x d_kv
outputs_padded, attns_padded = self.attention(
q_padded, k_padded, v_padded,
attn_mask=attn_mask,
)
# outputs_padded: (d_l * batch_size) x max_len x d_kv
# in LAL: (d_l * batch_size) x 1 x d_kv
# on the best model, this is one value vector per label that is repeated max_len times
if not self.q_as_matrix:
outputs_padded = outputs_padded.repeat(1,output_mask.size(-1),1)
outputs = outputs_padded[output_mask]
# outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv
# output_mask: (d_l * batch_size) x max_len
outputs = self.combine_v(outputs)
#print(f"outputs shape: {outputs.shape}")
# outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model
if self.use_resdrop:
if self.combine_as_self:
outputs = self.residual_dropout(outputs, batch_idxs)
else:
outputs = torch.cat([self.residual_dropout(outputs[:,i,:], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)
if self.combine_as_self:
outputs = self.layer_norm(outputs + inp)
else:
for l in range(self.d_l):
outputs[:, l, :] = outputs[:, l, :] + inp
outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj
outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj
outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)
return outputs, attns_padded
#
class LabelAttentionModule(nn.Module):
"""
Label Attention Module for label-specific representations
The module can be used right after the Partitioned Attention, or it can be experimented with for the transition stack
"""
#
def __init__(self,
d_model,
d_input_proj,
d_k,
d_v,
d_l,
d_proj,
combine_as_self,
use_resdrop=True,
q_as_matrix=False,
residual_dropout=0.1,
attention_dropout=0.1,
d_positional=None,
d_ff=2048,
relu_dropout=0.2,
lattn_partitioned=True):
super().__init__()
self.ff_dim = d_proj * d_l
if not lattn_partitioned:
self.d_positional = 0
else:
self.d_positional = d_positional if d_positional else 0
if d_input_proj:
if d_input_proj <= self.d_positional:
raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional))
self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)
d_input = d_input_proj
else:
self.input_projection = None
d_input = d_model
self.label_attention = LabelAttention(d_input,
d_k,
d_v,
d_l,
d_proj,
combine_as_self,
use_resdrop,
q_as_matrix,
residual_dropout,
attention_dropout,
self.d_positional)
if not lattn_partitioned:
self.lal_ff = PositionwiseFeedForward(self.ff_dim,
d_ff,
relu_dropout,
residual_dropout)
else:
self.lal_ff = PartitionedPositionwiseFeedForward(self.ff_dim,
d_ff,
self.d_positional,
relu_dropout,
residual_dropout)
def forward(self, word_embeddings, tagged_word_lists):
if self.input_projection:
if self.d_positional > 0:
word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
sentence[:, -self.d_positional:]), dim=1)
for sentence in word_embeddings]
else:
word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]
# Extract Labeled Representation
packed_len = sum(sentence.shape[0] for sentence in word_embeddings)
batch_idxs = np.zeros(packed_len, dtype=int)
batch_size = len(word_embeddings)
i = 0
sentence_lengths = [0] * batch_size
for sentence_idx, sentence in enumerate(word_embeddings):
sentence_lengths[sentence_idx] = len(sentence)
for word in sentence:
batch_idxs[i] = sentence_idx
i += 1
batch_indices = batch_idxs
batch_idxs = BatchIndices(batch_idxs, word_embeddings[0].device)
new_embeds = []
for sentence_idx, batch in enumerate(word_embeddings):
for word_idx, embed in enumerate(batch):
if word_idx < sentence_lengths[sentence_idx]:
new_embeds.append(embed)
new_word_embeddings = torch.stack(new_embeds)
labeled_representations, _ = self.label_attention(new_word_embeddings, batch_idxs)
labeled_representations = self.lal_ff(labeled_representations, batch_idxs)
final_labeled_representations = [[] for i in range(batch_size)]
for idx, embed in enumerate(labeled_representations):
final_labeled_representations[batch_indices[idx]].append(embed)
for idx, representation in enumerate(final_labeled_representations):
final_labeled_representations[idx] = torch.stack(representation)
return final_labeled_representations
================================================
FILE: stanza/models/constituency/lstm_model.py
================================================
"""
A version of the BaseModel which uses LSTMs to predict the correct next transition
based on the current known state.
The primary purpose of this class is to implement the prediction of the next
transition, which is done by concatenating the output of an LSTM operated over
previous transitions, the words, and the partially built constituents.
A complete processing of a sentence is as follows:
1) Run the input words through an encoder.
The encoder includes some or all of the following:
pretrained word embedding
finetuned word embedding for training set words - "delta_embedding"
POS tag embedding
pretrained charlm representation
BERT or similar large language model representation
attention transformer over the previous inputs
labeled attention transformer over the first attention layer
The encoded input is then put through a bi-lstm, giving a word representation
2) Transitions are put in an embedding, and transitions already used are tracked
in an LSTM
3) Constituents already built are also processed in an LSTM
4) Every transition is chosen by taking the output of the current word position,
the transition LSTM, and the constituent LSTM, and classifying the next
transition
5) Transitions are repeated (with constraints) until the sentence is completed
"""
from collections import namedtuple
import copy
from enum import Enum
import logging
import math
import random
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.common.maxout_linear import MaxoutLinear
from stanza.models.common.relative_attn import RelativeAttention
from stanza.models.common.utils import attach_bert_model, build_nonlinearity, unsort
from stanza.models.common.vocab import PAD_ID, UNK_ID
from stanza.models.constituency.base_model import BaseModel
from stanza.models.constituency.label_attention import LabelAttentionModule
from stanza.models.constituency.lstm_tree_stack import LSTMTreeStack
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule
from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
from stanza.models.constituency.tree_stack import TreeStack
from stanza.models.constituency.utils import initialize_linear
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.constituency.trainer')
WordNode = namedtuple("WordNode", ['value', 'hx'])
# lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents
# tree_hx and tree_cx are the states of the lstm going up the constituents in the case of the tree_lstm combination method
Constituent = namedtuple("Constituent", ['value', 'tree_hx', 'tree_cx'])
# The sentence boundary vectors are marginally useful at best.
# However, they make it much easier to use non-bert layers as input to
# attention layers, as the attention layers work better when they have
# an index 0 to attend to.
class SentenceBoundary(Enum):
NONE = 1
WORDS = 2
EVERYTHING = 3
class StackHistory(Enum):
LSTM = 1
ATTN = 2
# How to compose constituent children into new constituents
# MAX is simply take the max value of the children
# this is surprisingly effective
# for example, a Turkish dataset went from 81-81.5 dev, 75->75.5 test
# BILSTM is the method described in the papers of making an lstm
# out of the constituents
# BILSTM_MAX is the same as BILSTM, but instead of using a Linear
# to reduce the outputs of the lstm, we first take the max
# and then use a linear to reduce the max
# BIGRAM combines pairs of children and then takes the max over those
# ATTN means to put an attention layer over the children nodes
# we then take the max of the children with their attention
#
# Experiments show that MAX is noticeably better than the other options
# On ja_alt, here are a few results after 200 iterations,
# averaged over 5 iterations:
# MAX: 0.8985
# BILSTM: 0.8964
# BILSTM_MAX: 0.8973
# BIGRAM: 0.8982
#
# The MAX method has a linear transform after the max.
# Removing that transform makes the score go down to 0.8982
#
# We tried a few varieties of BILSTM_MAX
# In particular:
# max over LSTM, combining forward & backward using the max: 0.8970
# max over forward & backward separately, then reduce: 0.8970
# max over forward & backward only over 1:-1
# (eg, leave out the node embedding): 0.8969
# same as previous, but split the reduce into 2 pieces: 0.8973
# max over forward & backward separately, then reduce as
# 1/2(F + B) + W(F,B)
# the idea being that this way F and B are guaranteed
# to be represented: 0.8971
#
# BIGRAM is an attempt to mix information from nodes
# when building constituents, but it didn't help
# The first example, just taking pairs and learning
# a transform, went to NaN. Likely the transform
# expanded the embedding too much. Switching it to
# scale the matrix by 0.5 didn't go to Nan, but only
# resulted in 0.8982
#
# A couple varieties of ATTN:
# first an input linear, then attn, then an output linear
# the upside of this would be making the dimension of the attn
# independent from the rest of the model
# however, this caused an expansion in the magnitude of the vectors,
# resulting in NaN for deep enough trees
# adding layernorm or tanh to balance this out resulted in
# disappointing performance
# tanh: 0.8972
# another alternative not tested yet: lower initialization weights
# and enforce that the norms of the matrices are low enough that
# exponential explosion up the layers of the tree doesn't happen
# just an attention layer means hidden_size % reduce_heads == 0
# that is simple enough to enforce by slightly changing hidden_size
# if needed
# appending the embedding for the open state to the start of the
# sequence of children and taking only the content nodes
# was very disappointing: 0.8967
# taking the entire sequence of children including the open state
# embedding resulted in 0.8973
# long story short, this looks like an idea that should work, but it
# doesn't help. suggestions welcome for improving these results
#
# The current TREE_LSTM_CX mechanism uses a word's embedding
# as the hx and a trained embedding over tags as the cx 0.8996
# This worked slightly better than 0s for cx (TREE_LSTM) 0.8992
# A variant of TREE_LSTM which didn't work out:
# nodes are combined with an LSTM
# hx & cx are embeddings of the node type (eg S, NP, etc)
# input is the max over children: 0.8977
# Another variant which didn't work: use the word embedding
# as input to the same LSTM to get hx & cx 0.8985
# Note that although the scores for TREE_LSTM_CX are slightly higher
# than MAX for the JA dataset, the benefit was not as clear for EN,
# so we left the default at MAX.
# For example, on English WSJ, before switching to Bert POS and
# a learned Bert mixing layer, a comparison of 5x models trained
# for 400 iterations got dev scores of:
# TREE_LSTM_CX 0.9589
# MAX 0.9593
#
# UNTIED_MAX has a different reduce_linear for each type of
# constituent in the model. Similar to the different linear
# maps used in the CVG paper from Socher, Bauer, Manning, Ng
# This is implemented as a large CxHxH parameter,
# with num_constituent layers of hidden-hidden transform,
# along with a CxH bias parameter.
# Essentially C Linears stacked on top of each other,
# but in a parameter so that indexing can be done quickly.
# Unfortunately this does not beat out MAX with one combined linear.
# On an experiment on WSJ with all the best settings as of early
# October 2022, such as a Bert model POS tagger:
# MAX 0.9597
# UNTIED_MAX 0.9592
# Furthermore, starting from a finished MAX model and restarting
# by splitting the MAX layer into multiple pieces did not improve.
#
# KEY has a single Key which is used for a facsimile of ATTN
# each incoming subtree has its values weighted by a Query
# then the Key is used to calculate a softmax
# finally, a Value is used to scale the subtrees
# reduce_heads is used to determine the number of heads
# There is an option to use or not use position information
# using a sinusoidal position embedding
# UNTIED_KEY is the same, but has a different key
# for each possible constituent
# On a VI dataset:
# MAX 0.82064
# KEY (pos, 8) 0.81739
# UNTIED_KEY (pos, 8) 0.82046
# UNTIED_KEY (pos, 4) 0.81742
# Attempted to add a linear to mix the attn heads together,
# but that was awful: 0.81567
# Adding two position vectors, one in each direction, did not help:
# UNTIED_KEY (2x pos, 8) 0.8188
# To redo that experiment, double the width of reduce_query and
# reduce_value, then call reduce_position on nhx, flip it,
# and call reduce_position again
# Evidently the experiments to try should be:
# no pos at all
# more heads
class ConstituencyComposition(Enum):
BILSTM = 1
MAX = 2
TREE_LSTM = 3
BILSTM_MAX = 4
BIGRAM = 5
ATTN = 6
TREE_LSTM_CX = 7
UNTIED_MAX = 8
KEY = 9
UNTIED_KEY = 10
class LSTMModel(BaseModel, nn.Module):
def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, force_bert_saved, peft_name, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
"""
pretrain: a Pretrain object
transitions: a list of all possible transitions which will be
used to build trees
constituents: a list of all possible constituents in the treebank
tags: a list of all possible tags in the treebank
words: a list of all known words, used for a delta word embedding.
note that there will be an attempt made to learn UNK words as well,
and tags by themselves may help UNK words
rare_words: a list of rare words, used to occasionally replace with UNK
root_labels: probably ROOT, although apparently some treebanks like TOP or even s
constituent_opens: a list of all possible open nodes which will go on the stack
- this might be different from constituents if there are nodes
which represent multiple constituents at once
args: hidden_size, transition_hidden_size, etc as gotten from
constituency_parser.py
Note that it might look like a hassle to pass all of this in
when it can be collected directly from the trees themselves.
However, that would only work at train time. At eval or
pipeline time we will load the lists from the saved model.
"""
super().__init__(transition_scheme=args['transition_scheme'], unary_limit=unary_limit, reverse_sentence=args.get('reversed', False), root_labels=root_labels)
self.args = args
self.unsaved_modules = []
emb_matrix = pretrain.emb
self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
# replacing NBSP picks up a whole bunch of words for VI
self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
# precompute tensors for the word indices
# the tensors should be put on the GPU if needed by calling to(device)
self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False))
self.vocab_size = emb_matrix.shape[0]
self.embedding_dim = emb_matrix.shape[1]
self.constituents = sorted(list(constituents))
self.hidden_size = self.args['hidden_size']
self.constituency_composition = self.args.get("constituency_composition", ConstituencyComposition.BILSTM)
if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY):
self.reduce_heads = self.args['reduce_heads']
if self.hidden_size % self.reduce_heads != 0:
self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads)
if args['constituent_stack'] == StackHistory.ATTN:
self.reduce_heads = self.args['reduce_heads']
if self.hidden_size % args['constituent_heads'] != 0:
# TODO: technically we should either use the LCM of this and reduce_heads, or just have two separate fields
self.hidden_size = self.hidden_size + args['constituent_heads'] - (hidden_size % args['constituent_heads'])
if self.constituency_composition == ConstituencyComposition.ATTN and self.hidden_size % self.reduce_heads != 0:
raise ValueError("--reduce_heads and --constituent_heads not compatible!")
self.transition_hidden_size = self.args['transition_hidden_size']
if args['transition_stack'] == StackHistory.ATTN:
if self.transition_hidden_size % args['transition_heads'] > 0:
logger.warning("transition_hidden_size %d %% transition_heads %d != 0. reconfiguring", transition_hidden_size, args['transition_heads'])
self.transition_hidden_size = self.transition_hidden_size + args['transition_heads'] - (self.transition_hidden_size % args['transition_heads'])
self.tag_embedding_dim = self.args['tag_embedding_dim']
self.transition_embedding_dim = self.args['transition_embedding_dim']
self.delta_embedding_dim = self.args['delta_embedding_dim']
self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim
if forward_charlm is not None:
self.add_unsaved_module('forward_charlm', forward_charlm)
self.word_input_size += self.forward_charlm.hidden_dim()
if not forward_charlm.is_forward_lm:
raise ValueError("Got a backward charlm as a forward charlm!")
else:
self.forward_charlm = None
if backward_charlm is not None:
self.add_unsaved_module('backward_charlm', backward_charlm)
self.word_input_size += self.backward_charlm.hidden_dim()
if backward_charlm.is_forward_lm:
raise ValueError("Got a forward charlm as a backward charlm!")
else:
self.backward_charlm = None
self.delta_words = sorted(set(words))
self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) }
assert PAD_ID == 0
assert UNK_ID == 1
# initialization is chosen based on the observed values of the norms
# after several long training cycles
# (this is true for other embeddings and embedding-like vectors as well)
# the experiments show this slightly helps were done with
# Adadelta and the correct initialization may be slightly
# different for a different optimizer.
# in fact, it is likely a scheme other than normal_ would
# be better - the optimizer tends to learn the weights
# rather close to 0 before learning in the direction it
# actually wants to go
self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2,
embedding_dim = self.delta_embedding_dim,
padding_idx = 0)
nn.init.normal_(self.delta_embedding.weight, std=0.05)
self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False))
self.rare_words = set(rare_words)
self.tags = sorted(list(tags))
if self.tag_embedding_dim > 0:
self.tag_map = { t: i+2 for i, t in enumerate(self.tags) }
self.tag_embedding = nn.Embedding(num_embeddings = len(tags)+2,
embedding_dim = self.tag_embedding_dim,
padding_idx = 0)
nn.init.normal_(self.tag_embedding.weight, std=0.25)
self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags) + 2), requires_grad=False))
self.num_lstm_layers = self.args['num_lstm_layers']
self.num_tree_lstm_layers = self.args['num_tree_lstm_layers']
self.lstm_layer_dropout = self.args['lstm_layer_dropout']
self.word_dropout = nn.Dropout(self.args['word_dropout'])
self.predict_dropout = nn.Dropout(self.args['predict_dropout'])
self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout'])
# also register a buffer of zeros so that we can always get zeros on the appropriate device
self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))
self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
# possibly add a couple vectors for bookends of the sentence
# We put the word_start and word_end here, AFTER counting the
# charlm dimension, but BEFORE counting the bert dimension,
# as we want word_start and word_end to not have dimensions
# for the bert embedding. The bert model will add its own
# start and end representation.
self.sentence_boundary_vectors = self.args['sentence_boundary_vectors']
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
self.register_parameter('word_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
self.register_parameter('word_end_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
# we set up the bert AFTER building word_start and word_end
# so that we can use the charlm endpoint values rather than
# try to train our own
self.force_bert_saved = force_bert_saved or self.args['bert_finetune'] or self.args['stage1_bert_finetune']
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), self.force_bert_saved)
self.peft_name = peft_name
if bert_model is not None:
if bert_tokenizer is None:
raise ValueError("Cannot have a bert model without a tokenizer")
self.bert_dim = self.bert_model.config.hidden_size
if args['bert_hidden_layers']:
# The average will be offset by 1/N so that the default zeros
# represents an average of the N layers
if args['bert_hidden_layers'] > bert_model.config.num_hidden_layers:
# limit ourselves to the number of layers actually available
# note that we can +1 because of the initial embedding layer
args['bert_hidden_layers'] = bert_model.config.num_hidden_layers + 1
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
self.word_input_size = self.word_input_size + self.bert_dim
self.partitioned_transformer_module = None
self.pattn_d_model = 0
if LSTMModel.uses_pattn(self.args):
# Initializations of parameters for the Partitioned Attention
# round off the size of the model so that it divides in half evenly
self.pattn_d_model = self.args['pattn_d_model'] // 2 * 2
# Initializations for the Partitioned Attention
# experiments suggest having a bias does not help here
self.partitioned_transformer_module = PartitionedTransformerModule(
self.args['pattn_num_layers'],
d_model=self.pattn_d_model,
n_head=self.args['pattn_num_heads'],
d_qkv=self.args['pattn_d_kv'],
d_ff=self.args['pattn_d_ff'],
ff_dropout=self.args['pattn_relu_dropout'],
residual_dropout=self.args['pattn_residual_dropout'],
attention_dropout=self.args['pattn_attention_dropout'],
word_input_size=self.word_input_size,
bias=self.args['pattn_bias'],
morpho_emb_dropout=self.args['pattn_morpho_emb_dropout'],
timing=self.args['pattn_timing'],
encoder_max_len=self.args['pattn_encoder_max_len']
)
self.word_input_size += self.pattn_d_model
self.label_attention_module = None
if LSTMModel.uses_lattn(self.args):
if self.partitioned_transformer_module is None:
logger.error("Not using Labeled Attention, as the Partitioned Attention module is not used")
else:
# TODO: think of a couple ways to use alternate inputs
# for example, could pass in the word inputs with a positional embedding
# that would also allow it to work in the case of no partitioned module
if self.args['lattn_combined_input']:
self.lattn_d_input = self.word_input_size
else:
self.lattn_d_input = self.pattn_d_model
self.label_attention_module = LabelAttentionModule(self.lattn_d_input,
self.args['lattn_d_input_proj'],
self.args['lattn_d_kv'],
self.args['lattn_d_kv'],
self.args['lattn_d_l'],
self.args['lattn_d_proj'],
self.args['lattn_combine_as_self'],
self.args['lattn_resdrop'],
self.args['lattn_q_as_matrix'],
self.args['lattn_residual_dropout'],
self.args['lattn_attention_dropout'],
self.pattn_d_model // 2,
self.args['lattn_d_ff'],
self.args['lattn_relu_dropout'],
self.args['lattn_partitioned'])
self.word_input_size = self.word_input_size + self.args['lattn_d_proj']*self.args['lattn_d_l']
self.rel_attn_forward = None
self.rel_attn_reverse = None
if self.args.get('use_rattn', False):
if not self.args['rattn_cat'] and self.word_input_size % self.args['rattn_heads'] != 0:
for rattn_heads in range(self.args['rattn_heads'] // 2):
if self.word_input_size % (self.args['rattn_heads'] + rattn_heads) == 0:
new_rattn_heads = self.args['rattn_heads'] + rattn_heads
break
if self.word_input_size % (self.args['rattn_heads'] - rattn_heads) == 0:
new_rattn_heads = self.args['rattn_heads'] - rattn_heads
break
else:
raise ValueError("Number of heads %d does not divide evenly into input size %d" % (self.args['rattn_heads'], self.word_input_size))
logger.warning("rattn_heads of %d does not work, but found a similar value of %d which does work", self.args['rattn_heads'], new_rattn_heads)
self.args['rattn_heads'] = new_rattn_heads
if self.args['rattn_forward']:
if self.args['rattn_cat']:
self.rel_attn_forward = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], d_output=self.args['rattn_dim'], fudge_output=True, num_sinks=self.args['rattn_sinks'])
else:
self.rel_attn_forward = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], num_sinks=self.args['rattn_sinks'])
if self.args['rattn_reverse']:
if self.args['rattn_cat']:
self.rel_attn_reverse = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], reverse=True, d_output=self.args['rattn_dim'], fudge_output=True, num_sinks=self.args['rattn_sinks'])
else:
self.rel_attn_reverse = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], reverse=True, num_sinks=self.args['rattn_sinks'])
if self.args['rattn_forward'] and self.args['rattn_cat']:
self.word_input_size += self.rel_attn_forward.d_output
if self.args['rattn_reverse'] and self.args['rattn_cat']:
self.word_input_size += self.rel_attn_reverse.d_output
self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
# after putting the word_delta_tag input through the word_lstm, we get back
# hidden_size * 2 output with the front and back lstms concatenated.
# this transforms it into hidden_size with the values mixed together
self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size * self.num_tree_lstm_layers)
initialize_linear(self.word_to_constituent, self.args['nonlinearity'], self.hidden_size * 2)
self.transitions = sorted(list(transitions))
self.transition_map = { t: i for i, t in enumerate(self.transitions) }
# precompute tensors for the transitions
self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False))
self.transition_embedding = nn.Embedding(num_embeddings = len(transitions),
embedding_dim = self.transition_embedding_dim)
nn.init.normal_(self.transition_embedding.weight, std=0.25)
if args['transition_stack'] == StackHistory.LSTM:
self.transition_stack = LSTMTreeStack(input_size=self.transition_embedding_dim,
hidden_size=self.transition_hidden_size,
num_lstm_layers=self.num_lstm_layers,
dropout=self.lstm_layer_dropout,
uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
input_dropout=self.lstm_input_dropout)
elif args['transition_stack'] == StackHistory.ATTN:
self.transition_stack = TransformerTreeStack(input_size=self.transition_embedding_dim,
output_size=self.transition_hidden_size,
input_dropout=self.lstm_input_dropout,
use_position=True,
num_heads=args['transition_heads'])
else:
raise ValueError("Unhandled transition_stack StackHistory: {}".format(args['transition_stack']))
self.constituent_opens = sorted(list(constituent_opens))
# an embedding for the spot on the constituent LSTM taken up by the Open transitions
# the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding
# TODO: try the two ends have different embeddings?
self.constituent_open_map = { x: i for (i, x) in enumerate(self.constituent_opens) }
self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
embedding_dim = self.hidden_size)
nn.init.normal_(self.constituent_open_embedding.weight, std=0.2)
# input_size is hidden_size - could introduce a new constituent_size instead if we liked
if args['constituent_stack'] == StackHistory.LSTM:
self.constituent_stack = LSTMTreeStack(input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_lstm_layers=self.num_lstm_layers,
dropout=self.lstm_layer_dropout,
uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
input_dropout=self.lstm_input_dropout)
elif args['constituent_stack'] == StackHistory.ATTN:
self.constituent_stack = TransformerTreeStack(input_size=self.hidden_size,
output_size=self.hidden_size,
input_dropout=self.lstm_input_dropout,
use_position=True,
num_heads=args['constituent_heads'])
else:
raise ValueError("Unhandled constituent_stack StackHistory: {}".format(args['transition_stack']))
if args['combined_dummy_embedding']:
self.dummy_embedding = self.constituent_open_embedding
else:
self.dummy_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
embedding_dim = self.hidden_size)
nn.init.normal_(self.dummy_embedding.weight, std=0.2)
self.register_buffer('constituent_open_tensors', torch.tensor(range(len(constituent_opens)), requires_grad=False))
# TODO: refactor
if (self.constituency_composition == ConstituencyComposition.BILSTM or
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
# forward and backward pieces for crunching several
# constituents into one, combined into a bi-lstm
# TODO: make the hidden size here an option?
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
# affine transformation from bi-lstm reduce to a new hidden layer
if self.constituency_composition == ConstituencyComposition.BILSTM:
self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size * 2)
else:
self.reduce_forward = nn.Linear(self.hidden_size, self.hidden_size)
self.reduce_backward = nn.Linear(self.hidden_size, self.hidden_size)
initialize_linear(self.reduce_forward, self.args['nonlinearity'], self.hidden_size)
initialize_linear(self.reduce_backward, self.args['nonlinearity'], self.hidden_size)
elif self.constituency_composition == ConstituencyComposition.MAX:
# transformation to turn several constituents into one new constituent
self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
# transformation to turn several constituents into one new constituent
self.register_parameter('reduce_linear_weight', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, self.hidden_size, requires_grad=True)))
self.register_parameter('reduce_linear_bias', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, requires_grad=True)))
for layer_idx in range(len(constituent_opens)):
nn.init.kaiming_normal_(self.reduce_linear_weight[layer_idx], nonlinearity=self.args['nonlinearity'])
nn.init.uniform_(self.reduce_linear_bias, 0, 1 / (self.hidden_size * 2) ** 0.5)
elif self.constituency_composition == ConstituencyComposition.BIGRAM:
self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
self.reduce_bigram = nn.Linear(self.hidden_size * 2, self.hidden_size)
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)
elif self.constituency_composition == ConstituencyComposition.ATTN:
self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)
elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
if self.args['reduce_position']:
# unsaved module so that if it grows, we don't save
# the larger version unnecessarily
# under any normal circumstances, the growth will
# happen early in training when the model is not
# behaving well, then will not be needed once the
# model learns not to make super degenerate
# constituents
self.add_unsaved_module("reduce_position", ConcatSinusoidalEncoding(self.args['reduce_position'], 50))
else:
self.add_unsaved_module("reduce_position", nn.Identity())
self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False)
self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size)
if self.constituency_composition == ConstituencyComposition.KEY:
self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
else:
self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,
embedding_dim = self.num_tree_lstm_layers * self.hidden_size)
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
else:
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
self.nonlinearity = build_nonlinearity(self.args['nonlinearity'])
# matrix for predicting the next transition using word/constituent/transition queues
# word size + constituency size + transition size
# TODO: .get() is only necessary until all models rebuilt with this param
self.maxout_k = self.args.get('maxout_k', 0)
self.output_layers = self.build_output_layers(self.args['num_output_layers'], len(transitions), self.maxout_k)
@staticmethod
def uses_lattn(args):
return args.get('use_lattn', True) and args.get('lattn_d_proj', 0) > 0 and args.get('lattn_d_l', 0) > 0
@staticmethod
def uses_pattn(args):
return args['pattn_num_heads'] > 0 and args['pattn_num_layers'] > 0
def copy_with_new_structure(self, other):
"""
Copy parameters from the other model to this model
word_lstm can change size if the other model didn't use pattn / lattn and this one does.
In that case, the new values are initialized to 0.
This will rebuild the model in such a way that the outputs will be
exactly the same as the previous model.
"""
if self.constituency_composition != other.constituency_composition and self.constituency_composition != ConstituencyComposition.UNTIED_MAX:
raise ValueError("Models are incompatible: self.constituency_composition == {}, other.constituency_composition == {}".format(self.constituency_composition, other.constituency_composition))
for name, other_parameter in other.named_parameters():
# this allows other.constituency_composition == UNTIED_MAX to fall through
if name.startswith('reduce_linear.') and self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
if name == 'reduce_linear.weight':
my_parameter = self.reduce_linear_weight
elif name == 'reduce_linear.bias':
my_parameter = self.reduce_linear_bias
else:
raise ValueError("Unexpected other parameter name {}".format(name))
for idx in range(len(self.constituent_opens)):
my_parameter[idx].data.copy_(other_parameter.data)
elif name.startswith('word_lstm.weight_ih_l0'):
# bottom layer shape may have changed from adding a new pattn / lattn block
my_parameter = self.get_parameter(name)
# -1 so that it can be converted easier to a different parameter
copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])
#new_values = my_parameter.data.clone().detach()
new_values = torch.zeros_like(my_parameter.data)
new_values[..., :copy_size] = other_parameter.data[..., :copy_size]
my_parameter.data.copy_(new_values)
else:
try:
self.get_parameter(name).data.copy_(other_parameter.data)
except AttributeError as e:
raise AttributeError("Could not process %s" % name) from e
def build_output_layers(self, num_output_layers, final_layer_size, maxout_k):
"""
Build a ModuleList of Linear transformations for the given num_output_layers
The final layer size can be specified.
Initial layer size is the combination of word, constituent, and transition vectors
Middle layer sizes are self.hidden_size
"""
middle_layers = num_output_layers - 1
# word_lstm: hidden_size * num_tree_lstm_layers
# transition_stack: transition_hidden_size
# constituent_stack: hidden_size
predict_input_size = [self.hidden_size + self.hidden_size * self.num_tree_lstm_layers + self.transition_hidden_size] + [self.hidden_size] * middle_layers
predict_output_size = [self.hidden_size] * middle_layers + [final_layer_size]
if not maxout_k:
output_layers = nn.ModuleList([nn.Linear(input_size, output_size)
for input_size, output_size in zip(predict_input_size, predict_output_size)])
for output_layer, input_size in zip(output_layers, predict_input_size):
initialize_linear(output_layer, self.args['nonlinearity'], input_size)
else:
output_layers = nn.ModuleList([MaxoutLinear(input_size, output_size, maxout_k)
for input_size, output_size in zip(predict_input_size, predict_output_size)])
return output_layers
def num_words_known(self, words):
return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words)
@property
def retag_method(self):
# TODO: make the method an enum
return self.args['retag_method']
def uses_xpos(self):
return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos'
def add_unsaved_module(self, name, module):
"""
Adds a module which will not be saved to disk
Best used for large models such as pretrained word embeddings
"""
self.unsaved_modules += [name]
setattr(self, name, module)
if module is not None and name in ('forward_charlm', 'backward_charlm'):
for _, parameter in module.named_parameters():
parameter.requires_grad = False
def is_unsaved_module(self, name):
return name.split('.')[0] in self.unsaved_modules
def get_norms(self):
lines = []
skip = set()
if self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
skip = {'reduce_linear_weight', 'reduce_linear_bias'}
lines.append("reduce_linear:")
for c_idx, c_open in enumerate(self.constituent_opens):
lines.append(" %s weight %.6g bias %.6g" % (c_open, torch.norm(self.reduce_linear_weight[c_idx]).item(), torch.norm(self.reduce_linear_bias[c_idx]).item()))
active_params = [(name, param) for name, param in self.named_parameters() if param.requires_grad and name not in skip]
if len(active_params) == 0:
return lines
print(len(active_params))
max_name_len = max(len(name) for name, param in active_params)
max_norm_len = max(len("%.6g" % torch.norm(param).item()) for name, param in active_params)
format_string = "%-" + str(max_name_len) + "s norm %" + str(max_norm_len) + "s zeros %d / %d"
for name, param in active_params:
zeros = torch.sum(param.abs() < 0.000001).item()
norm = "%.6g" % torch.norm(param).item()
lines.append(format_string % (name, norm, zeros, param.nelement()))
return lines
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMETERS"]
lines.extend(self.get_norms())
logger.info("\n".join(lines))
def log_shapes(self):
lines = ["NORMS FOR MODEL PARAMETERS"]
for name, param in self.named_parameters():
if param.requires_grad:
lines.append("{} {}".format(name, param.shape))
logger.info("\n".join(lines))
def initial_word_queues(self, tagged_word_lists):
"""
Produce initial word queues out of the model's LSTMs for use in the tagged word lists.
Operates in a batched fashion to reduce the runtime for the LSTM operations
"""
device = next(self.parameters()).device
vocab_map = self.vocab_map
def map_word(word):
idx = vocab_map.get(word, None)
if idx is not None:
return idx
return vocab_map.get(word.lower(), UNK_ID)
all_word_inputs = []
all_word_labels = [[word.children[0].label for word in tagged_words]
for tagged_words in tagged_word_lists]
for sentence_idx, tagged_words in enumerate(tagged_word_lists):
word_labels = all_word_labels[sentence_idx]
word_idx = torch.stack([self.vocab_tensors[map_word(word.children[0].label)] for word in tagged_words])
word_input = self.embedding(word_idx)
# this occasionally learns UNK at train time
if self.training:
delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word
for word in word_labels]
else:
delta_labels = word_labels
delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels])
delta_input = self.delta_embedding(delta_idx)
word_inputs = [word_input, delta_input]
if self.tag_embedding_dim > 0:
if self.training:
tag_labels = [None if random.random() < self.args['tag_unknown_frequency'] else word.label for word in tagged_words]
else:
tag_labels = [word.label for word in tagged_words]
tag_idx = torch.stack([self.tag_tensors[self.tag_map.get(tag, UNK_ID)] for tag in tag_labels])
tag_input = self.tag_embedding(tag_idx)
word_inputs.append(tag_input)
all_word_inputs.append(word_inputs)
if self.forward_charlm is not None:
all_forward_chars = self.forward_charlm.build_char_representation(all_word_labels)
for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars):
word_inputs.append(forward_chars)
if self.backward_charlm is not None:
all_backward_chars = self.backward_charlm.build_char_representation(all_word_labels)
for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars):
word_inputs.append(backward_chars)
all_word_inputs = [torch.cat(word_inputs, dim=1) for word_inputs in all_word_inputs]
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
word_start = self.word_start_embedding.unsqueeze(0)
word_end = self.word_end_embedding.unsqueeze(0)
all_word_inputs = [torch.cat([word_start, word_inputs, word_end], dim=0) for word_inputs in all_word_inputs]
if self.bert_model is not None:
# BERT embedding extraction
# result will be len+2 for each sentence
# we will take 1:-1 if we don't care about the endpoints
bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,
keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.args['bert_finetune'] and not self.args['stage1_bert_finetune'],
peft_name=self.peft_name)
if self.bert_layer_mix is not None:
# add the average so that the default behavior is to
# take an average of the N layers, and anything else
# other than that needs to be learned
bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]
# Extract partitioned representation
if self.partitioned_transformer_module is not None:
partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)
all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, partitioned_embeddings)]
# Extract Labeled Representation
if self.label_attention_module is not None:
if self.args['lattn_combined_input']:
labeled_representations = self.label_attention_module(all_word_inputs, tagged_word_lists)
else:
labeled_representations = self.label_attention_module(partitioned_embeddings, tagged_word_lists)
all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, labeled_representations)]
if self.rel_attn_forward is not None or self.rel_attn_reverse is not None:
rattn_inputs = [[x] for x in all_word_inputs]
if self.rel_attn_forward is not None:
if self.args['rattn_use_endpoint_sinks']:
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0), x[0][0]).squeeze(0)] for x in rattn_inputs]
else:
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
if self.rel_attn_reverse is not None:
if self.args['rattn_use_endpoint_sinks']:
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]
else:
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
if self.args['rattn_cat']:
all_word_inputs = [torch.cat(x, axis=1) for x in rattn_inputs]
else:
rattn_inputs = [torch.stack(x, axis=2) for x in rattn_inputs]
all_word_inputs = [torch.sum(x, axis=2) for x in rattn_inputs]
all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]
packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)
word_output, _ = self.word_lstm(packed_word_input)
# would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
# word_output will now be sentence x batch x 2*hidden_size
word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
# now sentence x batch x hidden_size
word_queues = []
for sentence_idx, tagged_words in enumerate(tagged_word_lists):
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
sentence_output = word_output[:len(tagged_words)+2, sentence_idx, :]
else:
sentence_output = word_output[:len(tagged_words), sentence_idx, :]
sentence_output = self.word_to_constituent(sentence_output)
sentence_output = self.nonlinearity(sentence_output)
# TODO: this makes it so constituents downstream are
# build with the outputs of the LSTM, not the word
# embeddings themselves. It is possible we want to
# transform the word_input to hidden_size in some way
# and use that instead
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
word_queue = [WordNode(None, sentence_output[0, :])]
word_queue += [WordNode(tag_node, sentence_output[idx+1, :])
for idx, tag_node in enumerate(tagged_words)]
word_queue.append(WordNode(None, sentence_output[len(tagged_words)+1, :]))
else:
word_queue = [WordNode(None, self.word_zeros)]
word_queue += [WordNode(tag_node, sentence_output[idx, :])
for idx, tag_node in enumerate(tagged_words)]
word_queue.append(WordNode(None, self.word_zeros))
if self.reverse_sentence:
word_queue = list(reversed(word_queue))
word_queues.append(word_queue)
return word_queues
def initial_transitions(self):
"""
Return an initial TreeStack with no transitions
"""
return self.transition_stack.initial_state()
def initial_constituents(self):
"""
Return an initial TreeStack with no constituents
"""
return self.constituent_stack.initial_state(Constituent(None, self.constituent_zeros, self.constituent_zeros))
def get_word(self, word_node):
return word_node.value
def transform_word_to_constituent(self, state):
word_node = state.get_word(state.word_position)
word = word_node.value
if self.constituency_composition == ConstituencyComposition.TREE_LSTM:
return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_zeros.view(self.num_tree_lstm_layers, self.hidden_size))
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
# the UNK tag will be trained thanks to occasionally dropping out tags
tag = word.label
tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size)
tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)]
tree_cx = self.constituent_reduce_embedding(tag_tensor)
tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size)
return Constituent(word, tree_hx, tree_cx * tree_hx)
else:
return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None)
def dummy_constituent(self, dummy):
label = dummy.label
open_index = self.constituent_open_tensors[self.constituent_open_map[label]]
hx = self.dummy_embedding(open_index)
# the cx doesn't matter: the dummy will be discarded when building a new constituent
return Constituent(dummy, hx.unsqueeze(0), None)
def build_constituents(self, labels, children_lists):
"""
Build new constituents with the given label from the list of children
labels is a list of labels for each of the new nodes to construct
children_lists is a list of children that go under each of the new nodes
lists of each are used so that we can stack operations
"""
# at the end of each of these operations, we expect lstm_hx.shape
# is (L, N, hidden_size) for N lists of children
if (self.constituency_composition == ConstituencyComposition.BILSTM or
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists]
label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
max_length = max(len(children) for children in children_lists)
zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)
# weirdly, this is faster than using pack_sequence
unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)]
unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx]
packed_hx = torch.stack(unpacked_hx, axis=1)
packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False)
lstm_output = self.constituent_reduce_lstm(packed_hx)
# take just the output of the final layer
# result of lstm is ouput, (hx, cx)
# so [1][0] gets hx
# [1][0][-1] is the final output
# will be shape len(children_lists) * 2, hidden_size for bidirectional
# where forward outputs are -2 and backwards are -1
if self.constituency_composition == ConstituencyComposition.BILSTM:
lstm_output = lstm_output[1][0]
forward_hx = lstm_output[-2, :, :]
backward_hx = lstm_output[-1, :, :]
hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))
else:
lstm_output, lstm_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_output[0])
lstm_output = [lstm_output[1:length-1, x, :] for x, length in zip(range(len(lstm_lengths)), lstm_lengths)]
lstm_output = torch.stack([torch.max(x, 0).values for x in lstm_output], axis=0)
hx = self.reduce_forward(lstm_output[:, :self.hidden_size]) + self.reduce_backward(lstm_output[:, self.hidden_size:])
lstm_hx = self.nonlinearity(hx).unsqueeze(0)
lstm_cx = None
elif self.constituency_composition == ConstituencyComposition.MAX:
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
packed_hx = torch.stack(unpacked_hx, axis=1)
hx = self.reduce_linear(packed_hx)
lstm_hx = self.nonlinearity(hx)
lstm_cx = None
elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
# shape == len(labels),1,hidden_size after the stack
#packed_hx = torch.stack(unpacked_hx, axis=0)
label_indices = [self.constituent_open_map[label] for label in labels]
# we would like to stack the reduce_linear_weight calculations as follows:
#reduce_weight = self.reduce_linear_weight[label_indices]
#reduce_bias = self.reduce_linear_bias[label_indices]
# this would allow for faster vectorized operations.
# however, this runs out of memory on larger training examples,
# presumably because there are too many stacks in a row and each one
# has its own gradient kept for the entire calculation
# fortunately, this operation is not a huge part of the expense
hx = [torch.matmul(self.reduce_linear_weight[label_idx], hx_layer.squeeze(0)) + self.reduce_linear_bias[label_idx]
for label_idx, hx_layer in zip(label_indices, unpacked_hx)]
hx = torch.stack(hx, axis=0)
hx = hx.unsqueeze(0)
lstm_hx = self.nonlinearity(hx)
lstm_cx = None
elif self.constituency_composition == ConstituencyComposition.BIGRAM:
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
unpacked_hx = []
for nhx in node_hx:
# tanh or otherwise limit the size of the output?
stacked_nhx = self.lstm_input_dropout(torch.cat(nhx, axis=0))
if stacked_nhx.shape[0] > 1:
bigram_hx = torch.cat((stacked_nhx[:-1, :], stacked_nhx[1:, :]), axis=1)
bigram_hx = self.reduce_bigram(bigram_hx) / 2
stacked_nhx = torch.cat((stacked_nhx, bigram_hx), axis=0)
unpacked_hx.append(torch.max(stacked_nhx, 0).values)
packed_hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
hx = self.reduce_linear(packed_hx)
lstm_hx = self.nonlinearity(hx)
lstm_cx = None
elif self.constituency_composition == ConstituencyComposition.ATTN:
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
unpacked_hx = [torch.stack(nhx) for nhx in node_hx]
unpacked_hx = [torch.cat((lhx.unsqueeze(0).unsqueeze(0), nhx), axis=0) for lhx, nhx in zip(label_hx, unpacked_hx)]
unpacked_hx = [self.reduce_attn(nhx, nhx, nhx)[0].squeeze(1) for nhx in unpacked_hx]
unpacked_hx = [self.lstm_input_dropout(torch.max(nhx, 0).values) for nhx in unpacked_hx]
hx = torch.stack(unpacked_hx, axis=0)
lstm_hx = self.nonlinearity(hx).unsqueeze(0)
lstm_cx = None
elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists]
# add a position vector to each node_hx
node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx]
query_hx = [self.reduce_query(nhx) for nhx in node_hx]
# reshape query for MHA
query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx]
if self.constituency_composition == ConstituencyComposition.KEY:
queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx]
else:
label_indices = [self.constituent_open_map[label] for label in labels]
queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)]
# softmax each head
weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries]
value_hx = [self.reduce_value(nhx) for nhx in node_hx]
value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx]
# use the softmaxes to add up the heads
unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)]
unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx]
hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
lstm_hx = self.nonlinearity(hx)
lstm_cx = None
elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels]
label_hx = torch.stack(label_hx).unsqueeze(0)
max_length = max(len(children) for children in children_lists)
# stacking will let us do elementwise multiplication faster, hopefully
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in node_hx]
unpacked_hx = [nhx.max(dim=0) for nhx in unpacked_hx]
packed_hx = torch.stack([nhx.values for nhx in unpacked_hx], axis=1)
#packed_hx = packed_hx.max(dim=0).values
node_cx = [torch.stack([child.value.tree_cx for child in children]) for children in children_lists]
node_cx_indices = [uhx.indices.unsqueeze(0) for uhx in unpacked_hx]
unpacked_cx = [ncx.gather(0, nci).squeeze(0) for ncx, nci in zip(node_cx, node_cx_indices)]
packed_cx = torch.stack(unpacked_cx, axis=1)
_, (lstm_hx, lstm_cx) = self.constituent_reduce_lstm(label_hx, (packed_hx, packed_cx))
else:
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
constituents = []
for idx, (label, children) in enumerate(zip(labels, children_lists)):
children = [child.value.value for child in children]
if isinstance(label, str):
node = Tree(label=label, children=children)
else:
for value in reversed(label):
node = Tree(label=value, children=children)
children = node
constituents.append(Constituent(node, lstm_hx[:, idx, :], lstm_cx[:, idx, :] if lstm_cx is not None else None))
return constituents
def push_constituents(self, constituent_stacks, constituents):
# Another possibility here would be to use output[0, i, :]
# from the constituency lstm for the value of the new node.
# This might theoretically make the new constituent include
# information from neighboring constituents. However, this
# lowers the scores of various models.
# For example, an experiment on ja_alt built this way,
# averaged over 5 trials, had the following loss in accuracy:
# 150 epochs: 0.8971 to 0.8953
# 200 epochs: 0.8985 to 0.8964
current_nodes = [stack.value for stack in constituent_stacks]
constituent_input = torch.stack([x.tree_hx[-1:] for x in constituents], axis=1)
#constituent_input = constituent_input.unsqueeze(0)
# the constituents are already Constituent(tree, tree_hx, tree_cx)
return self.constituent_stack.push_states(constituent_stacks, constituents, constituent_input)
def get_top_constituent(self, constituents):
"""
Extract only the top constituent from a state's constituent
sequence, even though it has multiple addition pieces of
information
"""
# TreeStack value -> LSTMTreeStack value -> Constituent value -> constituent
return constituents.value.value.value
def push_transitions(self, transition_stacks, transitions):
"""
Push all of the given transitions on to the stack as a batch operations.
Significantly faster than doing one transition at a time.
"""
transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions])
transition_input = self.transition_embedding(transition_idx).unsqueeze(0)
return self.transition_stack.push_states(transition_stacks, transitions, transition_input)
def get_top_transition(self, transitions):
"""
Extract only the top transition from a state's transition
sequence, even though it has multiple addition pieces of
information
"""
# TreeStack value -> LSTMTreeStack value -> transition
return transitions.value.value
def forward(self, states):
"""
Return logits for a prediction of what transition to make next
We've basically done all the work analyzing the state as
part of applying the transitions, so this method is very simple
return shape: (num_states, num_transitions)
"""
word_hx = torch.stack([state.get_word(state.word_position).hx for state in states])
transition_hx = torch.stack([self.transition_stack.output(state.transitions) for state in states])
# this .output() is the output of the constituent stack, not the
# constituent itself
# this way, we can, as an option, NOT include the constituents to the left
# when building the current vector for a constituent
# and the vector used for inference will still incorporate the entire LSTM
constituent_hx = torch.stack([self.constituent_stack.output(state.constituents) for state in states])
hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1)
for idx, output_layer in enumerate(self.output_layers):
hx = self.predict_dropout(hx)
# TODO: why self.output_layers - 1?
if not self.maxout_k and idx < len(self.output_layers) - 1:
hx = self.nonlinearity(hx)
hx = output_layer(hx)
return hx
def predict(self, states, is_legal=True):
"""
Generate and return predictions, along with the transitions those predictions represent
If is_legal is set to True, will only return legal transitions.
This means returning None if there are no legal transitions.
Hopefully the constraints prevent that from happening
Returns:
tensor(batch_size, num_transitions) - final output layer
list(Transition) - predicted transitions
tensor(batch_size) - the final output specifically for the chosen transition
"""
predictions = self.forward(states)
pred_max = torch.argmax(predictions, dim=1)
scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
pred_max = pred_max.detach().cpu()
pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))]
if is_legal:
for idx, (state, trans) in enumerate(zip(states, pred_trans)):
if not trans.is_legal(state, self):
_, indices = predictions[idx, :].sort(descending=True)
for index in indices:
if self.transitions[index].is_legal(state, self):
pred_trans[idx] = self.transitions[index]
scores[idx] = predictions[idx, index]
break
else: # yeah, else on a for loop, deal with it
pred_trans[idx] = None
scores[idx] = None
return predictions, pred_trans, scores.squeeze(1)
def weighted_choice(self, states):
"""
Generate and return predictions, and randomly choose a prediction weighted by the scores
TODO: pass in a temperature
"""
predictions = self.forward(states)
pred_trans = []
all_scores = []
for state, prediction in zip(states, predictions):
legal_idx = [idx for idx in range(prediction.shape[0]) if self.transitions[idx].is_legal(state, self)]
if len(legal_idx) == 0:
pred_trans.append(None)
continue
scores = prediction[legal_idx]
scores = torch.softmax(scores, dim=0)
idx = torch.multinomial(scores, 1)
idx = legal_idx[idx]
pred_trans.append(self.transitions[idx])
all_scores.append(prediction[idx])
all_scores = torch.stack(all_scores)
return predictions, pred_trans, all_scores
def predict_gold(self, states):
"""
For each State, return the next item in the gold_sequence
"""
predictions = self.forward(states)
transitions = [y.gold_sequence[y.num_transitions] for y in states]
indices = torch.tensor([self.transition_map[t] for t in transitions], device=predictions.device)
scores = torch.take_along_dim(predictions, indices.unsqueeze(1), dim=1)
return predictions, transitions, scores.squeeze(1)
def get_params(self, skip_modules=True):
"""
Get a dictionary for saving the model
"""
model_state = self.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
for k in skipped:
del model_state[k]
config = copy.deepcopy(self.args)
config['sentence_boundary_vectors'] = config['sentence_boundary_vectors'].name
config['constituency_composition'] = config['constituency_composition'].name
config['transition_stack'] = config['transition_stack'].name
config['constituent_stack'] = config['constituent_stack'].name
config['transition_scheme'] = config['transition_scheme'].name
assert isinstance(self.rare_words, set)
params = {
'model': model_state,
'model_type': "LSTM",
'config': config,
'transitions': [repr(x) for x in self.transitions],
'constituents': self.constituents,
'tags': self.tags,
'words': self.delta_words,
'rare_words': list(self.rare_words),
'root_labels': self.root_labels,
'constituent_opens': self.constituent_opens,
'unary_limit': self.unary_limit(),
}
return params
================================================
FILE: stanza/models/constituency/lstm_tree_stack.py
================================================
"""
Keeps an LSTM in TreeStack form.
The TreeStack nodes keep the hx and cx for the LSTM, along with a
"value" which represents whatever the user needs to store.
The TreeStacks can be ppped to get back to the previous LSTM state.
The module itself implements three methods: initial_state, push_states, output
"""
from collections import namedtuple
import torch
import torch.nn as nn
from stanza.models.constituency.tree_stack import TreeStack
Node = namedtuple("Node", ['value', 'lstm_hx', 'lstm_cx'])
class LSTMTreeStack(nn.Module):
def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout):
"""
Prepare LSTM and parameters
input_size: dimension of the inputs to the LSTM
hidden_size: LSTM internal & output dimension
num_lstm_layers: how many layers of LSTM to use
dropout: value of the LSTM dropout
uses_boundary_vector: if set, learn a start_embedding parameter. otherwise, use zeros
input_dropout: an nn.Module to dropout inputs. TODO: allow a float parameter as well
"""
super().__init__()
self.uses_boundary_vector = uses_boundary_vector
# The start embedding needs to be input_size as we put it through the LSTM
if uses_boundary_vector:
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
else:
self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size))
self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)
self.input_dropout = input_dropout
def initial_state(self, initial_value=None):
"""
Return an initial state, either based on zeros or based on the initial embedding and LSTM
Note that LSTM start operation is already batched, in a sense
The subsequent batch built this way will be used for batch_size trees
Returns a stack with None value, hx & cx either based on the
start_embedding or zeros, and no parent.
"""
if self.uses_boundary_vector:
start = self.start_embedding.unsqueeze(0).unsqueeze(0)
output, (hx, cx) = self.lstm(start)
start = output[0, 0, :]
else:
start = self.input_zeros
hx = self.hidden_zeros
cx = self.hidden_zeros
return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)
def push_states(self, stacks, values, inputs):
"""
Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes.
B = stacks.len() = values.len()
inputs must be of shape 1 x B x input_size
"""
inputs = self.input_dropout(inputs)
hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1)
cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1)
output, (hx, cx) = self.lstm(inputs, (hx, cx))
new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :]))
for i, (stack, transition) in enumerate(zip(stacks, values))]
return new_stacks
def output(self, stack):
"""
Return the last layer of the lstm_hx as the output from a stack
Refactored so that alternate structures have an easy way of getting the output
"""
return stack.value.lstm_hx[-1, 0, :]
================================================
FILE: stanza/models/constituency/parse_transitions.py
================================================
"""
Defines a series of transitions (open a constituent, close a constituent, etc)
"""
from abc import ABC, abstractmethod
import ast
from collections import defaultdict
from enum import Enum
import functools
import logging
from stanza.models.constituency.parse_tree import Tree
logger = logging.getLogger('stanza')
class TransitionScheme(Enum):
def __new__(cls, value, short_name):
obj = object.__new__(cls)
obj._value_ = value
obj.short_name = short_name
return obj
# top down, so the open transition comes before any constituents
# score on vi_vlsp22 with 5 different sizes of bert layers,
# bert tagger, no silver dataset:
# 0.8171
TOP_DOWN = 1, "top"
# unary transitions are modeled as one entire transition
# version that uses one transform per item,
# score on experiment described above:
# 0.8157
# score using one combination step for an entire transition:
# 0.8178
TOP_DOWN_COMPOUND = 2, "topc"
# unary is a separate transition. doesn't help
# score on experiment described above:
# 0.8128
TOP_DOWN_UNARY = 3, "topu"
# open transition comes after the first constituent it cares about
# score on experiment described above:
# 0.8205
# note that this is with an oracle, whereas IN_ORDER_COMPOUND does
# not have a dynamic oracle, so there may be room for improvement
IN_ORDER = 4, "in"
# in order, with unaries after preterminals represented as a single
# transition after the preterminal
# and unaries elsewhere tied to the rest of the constituent
# score: 0.8186
IN_ORDER_COMPOUND = 5, "inc"
# in order, with CompoundUnary on both preterminals and internal nodes
# score: 0.8166
IN_ORDER_UNARY = 6, "inu"
@functools.total_ordering
class Transition(ABC):
"""
model is passed in as a dependency injection
for example, an LSTM model can update hidden & output vectors when transitioning
"""
@abstractmethod
def update_state(self, state, model):
"""
update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent
the return value should be a tuple:
updated word_position
updated constituents
new constituent to put on the queue and None
- note that the constituent shouldn't be on the queue yet
that allows putting it on as a batch operation, which
saves a significant amount of time in an LSTM, for example
OR
data used to make a new constituent and the method used
- for example, CloseConstituent can return the children needed
and itself. this allows a batch operation to build
the constituent
"""
def delta_opens(self):
return 0
def apply(self, state, model):
"""
return a new State transformed via this transition
convenience method to call bulk_apply, which is significantly
faster than single operations for an NN based model
"""
update = model.bulk_apply([state], [self])
return update[0]
@abstractmethod
def is_legal(self, state, model):
"""
assess whether or not this transition is legal in this state
at parse time, the parser might choose a transition which cannot be made
"""
def components(self):
"""
Return a list of transitions which could theoretically make up this transition
For example, an Open transition with multiple labels would
return a list of Opens with those labels
"""
return [self]
@abstractmethod
def short_name(self):
"""
A short name to identify this transition
"""
def short_label(self):
if not hasattr(self, "label"):
return self.short_name()
if isinstance(self.label, str):
label = self.label
elif len(self.label) == 1:
label = self.label[0]
else:
label = self.label
return "{}({})".format(self.short_name(), label)
def __lt__(self, other):
# put the Shift at the front of a list, and otherwise sort alphabetically
if self == other:
return False
if isinstance(self, Shift):
return True
if isinstance(other, Shift):
return False
return str(self) < str(other)
@staticmethod
def from_repr(desc):
"""
This method is to avoid using eval() or otherwise trying to
deserialize strings in a possibly untrusted manner when
loading from a checkpoint
"""
if desc == 'Shift':
return Shift()
if desc == 'CloseConstituent':
return CloseConstituent()
labels = desc.split("(", maxsplit=1)
if labels[0] not in ('CompoundUnary', 'OpenConstituent', 'Finalize'):
raise ValueError("Unknown Transition %s" % desc)
if len(labels) == 1:
raise ValueError("Unexpected Transition repr, %s needs labels" % labels[0])
if labels[1][-1] != ')':
raise ValueError("Expected Transition repr for %s: %s(labels)" % (labels[0], labels[0]))
trans_type = labels[0]
labels = labels[1][:-1]
labels = ast.literal_eval(labels)
if trans_type == 'CompoundUnary':
return CompoundUnary(*labels)
if trans_type == 'OpenConstituent':
return OpenConstituent(*labels)
if trans_type == 'Finalize':
return Finalize(*labels)
raise ValueError("Unexpected Transition %s" % desc)
class Shift(Transition):
def update_state(self, state, model):
"""
This will handle all aspects of a shift transition
- push the top element of the word queue onto constituents
- pop the top element of the word queue
"""
new_constituent = model.transform_word_to_constituent(state)
return state.word_position+1, state.constituents, new_constituent, None
def is_legal(self, state, model):
"""
Disallow shifting when the word queue is empty or there are no opens to eventually eat this word
"""
if state.empty_word_queue():
return False
if model.is_top_down:
# top down transition sequences cannot shift if there are currently no
# Open transitions on the stack. in such a case, the new constituent
# will never be reduced
if state.num_opens == 0:
return False
if state.num_opens == 1:
# there must be at least one transition, since there is an open
assert state.transitions.parent is not None
if state.transitions.parent.parent is None:
# only one transition
trans = model.get_top_transition(state.transitions)
# must be an Open, since there is one open and one transitions
# note that an S, FRAG, etc could happen if we're using unary
# and ROOT-S is possible in the case of compound Open
# in both cases, Shift is legal
# Note that the corresponding problem of shifting after the ROOT-S
# has been closed to just ROOT is handled in CloseConstituent
if len(trans.label) == 1 and trans.top_label in model.root_labels:
# don't shift a word at the very start of a parse
# we want there to be an extra layer below ROOT
return False
else:
# in-order k==1 (the only other option currently)
# can shift ONCE, but note that there is no way to consume
# two items in a row if there is no Open on the stack.
# As long as there is one or more open transitions,
# everything can be eaten
if state.num_opens == 0:
if not state.empty_constituents:
return False
return True
def short_name(self):
return "Shift"
def __repr__(self):
return "Shift"
def __eq__(self, other):
if self is other:
return True
if isinstance(other, Shift):
return True
return False
def __hash__(self):
return hash(37)
class CompoundUnary(Transition):
def __init__(self, *label):
# the FIRST label will be the top of the tree
# so CompoundUnary that results in root will have root as labels[0], for example
self.label = tuple(label)
def update_state(self, state, model):
"""
Apply potentially multiple unary transitions to the same preterminal
It reuses the CloseConstituent machinery
"""
# only the top constituent is meaningful here
constituents = state.constituents
children = [constituents.value]
constituents = constituents.pop()
# unlike with CloseConstituent, our label is not on the stack.
# it is just our label
# ... but we do reuse CloseConstituent's update mechanism
return state.word_position, constituents, (self.label, children), CloseConstituent
def is_legal(self, state, model):
"""
Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT
"""
# can't unary transition nothing
tree = model.get_top_constituent(state.constituents)
if tree is None:
return False
# don't unary transition a dummy, dummy
# and don't stack CompoundUnary transitions
if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)):
return False
# if we are doing IN_ORDER_COMPOUND, then we are only using these
# transitions to model changes from a tag node to a sequence of
# unary nodes. can only occur at preterminals
if model.transition_scheme() is TransitionScheme.IN_ORDER_COMPOUND:
return tree.is_preterminal()
if model.transition_scheme() is not TransitionScheme.TOP_DOWN_UNARY:
return True
is_root = self.label[0] in model.root_labels
if not state.empty_word_queue() or not state.has_one_constituent():
return not is_root
else:
return is_root
def components(self):
return [CompoundUnary(label) for label in self.label]
def short_name(self):
return "Unary"
def __repr__(self):
return "CompoundUnary(%s)" % ",".join(self.label)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, CompoundUnary):
return False
if self.label == other.label:
return True
return False
def __hash__(self):
return hash(self.label)
class Dummy():
"""
Takes a space on the constituent stack to represent where an Open transition occurred
"""
def __init__(self, label):
self.label = label
def is_preterminal(self):
return False
def __format__(self, spec):
if spec is None or spec == '' or spec == 'O':
return "(%s ...)" % self.label
if spec == 'T':
return r"\\Tree [.%s ? ]" % self.label
raise ValueError("Unhandled spec: %s" % spec)
def __str__(self):
return "Dummy({})".format(self.label)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, Dummy):
return False
if self.label == other.label:
return True
return False
def __hash__(self):
return hash(self.label)
def too_many_unary_nodes(tree, unary_limit):
"""
Return True iff there are UNARY_LIMIT unary nodes in a tree in a row
helps prevent infinite open/close patterns
otherwise, the model can get stuck in essentially an infinite loop
"""
if tree is None:
return False
for _ in range(unary_limit + 1):
if len(tree.children) != 1:
return False
tree = tree.children[0]
return True
class OpenConstituent(Transition):
def __init__(self, *label):
self.label = tuple(label)
self.top_label = self.label[0]
def delta_opens(self):
return 1
def update_state(self, state, model):
# open a new constituent which can later be closed
# puts a DUMMY constituent on the stack to mark where the constituents end
return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None
def is_legal(self, state, model):
"""
disallow based on the length of the sentence
"""
if state.num_opens > state.sentence_length + 10:
# fudge a bit so we don't miss root nodes etc in very small trees
# also there's one really deep tree in CTB 9.0
return False
if model.is_top_down:
# If the model is top down, you can't Open if there are
# no words to eventually eat
if state.empty_word_queue():
return False
# Also, you can only Open a ROOT iff it is at the root position
# The assumption in the unary scheme is there will be no
# root open transitions
if not model.has_unary_transitions():
# TODO: maybe cache this value if this is an expensive operation
is_root = self.top_label in model.root_labels
if is_root:
return state.empty_transitions()
else:
return not state.empty_transitions()
else:
# in-order nodes can Open as long as there is at least one thing
# on the constituency stack
# since closing the in-order involves removing one more
# item before the open, and it can close at any time
# (a close immediately after the open represents a unary)
if state.empty_constituents:
return False
if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
# consecutive Opens don't make sense in the context of in-order
return False
if not model.transition_scheme() is TransitionScheme.IN_ORDER:
# eg, IN_ORDER_UNARY or IN_ORDER_COMPOUND
# if compound unary opens are used
# or the unary transitions are via CompoundUnary
# can always open as long as the word queue isn't empty
# if the word queue is empty, only close is allowed
return not state.empty_word_queue()
# one other restriction - we assume all parse trees
# start with (ROOT (first_real_con ...))
# therefore ROOT can only occur via Open after everything
# else has been pushed and processed
# there are no further restrictions
is_root = self.top_label in model.root_labels
if is_root:
# can't make a root node if it will be in the middle of the parse
# can't make a root node if there's still words to eat
# note that the second assumption wouldn't work,
# except we are assuming there will never be multiple
# nodes under one root
return state.num_opens == 0 and state.empty_word_queue()
else:
if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents), model.unary_limit()):
# looks like we've been in a loop of lots of unary transitions
# note that we check `num_opens > 0` because otherwise we might wind up stuck
# in a state where the only legal transition is open, such as if the
# constituent stack is otherwise empty, but the open is illegal because
# it causes too many unaries
# in such a case we can forbid the corresponding close instead...
# if empty_word_queue, that means it is trying to make infinitiely many
# non-ROOT Open transitions instead of just transitioning ROOT
return False
return True
return True
def components(self):
return [OpenConstituent(label) for label in self.label]
def short_name(self):
return "Open"
def __repr__(self):
return "OpenConstituent({})".format(self.label)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, OpenConstituent):
return False
if self.label == other.label:
return True
return False
def __hash__(self):
return hash(self.label)
class Finalize(Transition):
"""
Specifically applies at the end of a parse sequence to add a ROOT
Seemed like the simplest way to remove ROOT from the
in_order_compound transitions while still using the mechanism of
the transitions to build the parse tree
"""
def __init__(self, *label):
self.label = tuple(label)
def update_state(self, state, model):
"""
Apply potentially multiple unary transitions to the same preterminal
Only applies to preterminals
It reuses the CloseConstituent machinery
"""
# only the top constituent is meaningful here
constituents = state.constituents
children = [constituents.value]
constituents = constituents.pop()
# unlike with CloseConstituent, our label is not on the stack.
# it is just our label
label = self.label
# ... but we do reuse CloseConstituent's update
return state.word_position, constituents, (label, children), CloseConstituent
def is_legal(self, state, model):
"""
Legal if & only if there is one tree, no more words, and no ROOT yet
"""
return state.empty_word_queue() and state.has_one_constituent() and not state.finished(model)
def short_name(self):
return "Finalize"
def __repr__(self):
return "Finalize(%s)" % ",".join(self.label)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, Finalize):
return False
return other.label == self.label
def __hash__(self):
return hash((53, self.label))
class CloseConstituent(Transition):
def delta_opens(self):
return -1
def update_state(self, state, model):
# pop constituents until we are done
children = []
constituents = state.constituents
while not isinstance(model.get_top_constituent(constituents), Dummy):
# keep the entire value from the stack - the model may need
# the whole thing to transform the children into a new node
children.append(constituents.value)
constituents = constituents.pop()
# the Dummy has the label on it
label = model.get_top_constituent(constituents).label
# pop past the Dummy as well
constituents = constituents.pop()
if not model.is_top_down:
# the alternative to TOP_DOWN_... is IN_ORDER
# in which case we want to pop one more constituent
children.append(constituents.value)
constituents = constituents.pop()
# the children are in the opposite order of what we expect
children.reverse()
return state.word_position, constituents, (label, children), CloseConstituent
@staticmethod
def build_constituents(model, data):
"""
builds new constituents out of the incoming data
data is a list of tuples: (label, children)
the model will batch the build operation
again, the purpose of this batching is to do multiple deep learning operations at once
"""
labels, children_lists = map(list, zip(*data))
new_constituents = model.build_constituents(labels, children_lists)
return new_constituents
def is_legal(self, state, model):
"""
Disallow if there is no Open on the stack yet
in TOP_DOWN, if the previous transition was the Open (nothing built yet)
in IN_ORDER, previous transition does not matter, except for one small corner case
"""
if state.num_opens <= 0:
return False
if model.is_top_down:
if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
return False
if state.num_opens <= 1 and not state.empty_word_queue():
# don't close the last open until all words have been used
return False
if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND:
# when doing TOP_DOWN_COMPOUND, we assume all transitions
# at the ROOT level have an S, SQ, FRAG, etc underneath
# this is checked when the model is first trained
if state.num_opens == 1 and not state.empty_word_queue():
return False
elif not model.has_unary_transitions():
# in fact, we have to leave the top level constituent
# under the ROOT open if unary transitions are not possible
if state.num_opens == 2 and not state.empty_word_queue():
return False
elif model.transition_scheme() is TransitionScheme.IN_ORDER:
if not isinstance(model.get_top_transition(state.transitions), OpenConstituent):
# we're not stuck in a loop of unaries
return True
if state.num_opens > 1 or state.empty_word_queue():
# in either of these cases, the corresponding Open should be eliminated
# if we're stuck in a loop of unaries
return True
node = model.get_top_constituent(state.constituents.pop())
if too_many_unary_nodes(node, model.unary_limit()):
# at this point, we are in a situation where
# - multiple unaries have happened in a row
# - there is stuff on the word_queue, so a ROOT open isn't legal
# - there's only one constituent on the stack, so the only legal
# option once there are no opens left will be an open
# this means we'll be stuck having to open again if we do close
# this node, so instead we make the Close illegal
return False
else:
# model.transition_scheme() == TransitionScheme.IN_ORDER_COMPOUND or
# model.transition_scheme() == TransitionScheme.IN_ORDER_UNARY:
# in both of these cases, we cannot do open/close
# IN_ORDER_COMPOUND will use compound opens and preterminal unaries
# IN_ORDER_UNARY will use compound unaries
# the only restriction here is that we can't close immediately after an open
# internal unaries are handled by the opens being compound
# preterminal unaries are handled with CompoundUnary
if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
return False
return True
def short_name(self):
return "Close"
def __repr__(self):
return "CloseConstituent"
def __eq__(self, other):
if self is other:
return True
if isinstance(other, CloseConstituent):
return True
return False
def __hash__(self):
return hash(93)
def check_transitions(train_transitions, other_transitions, treebank_name):
"""
Check that all the transitions in the other dataset are known in the train set
Weird nested unaries are warned rather than failed as long as the
components are all known
There is a tree in VLSP, for example, with three (!) nested NP nodes
If this is an unknown compound transition, we won't possibly get it
right when parsing, but at least we don't need to fail
"""
unknown_transitions = set()
for trans in other_transitions:
if trans not in train_transitions:
for component in trans.components():
if component not in train_transitions:
raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name))
unknown_transitions.add(trans)
if len(unknown_transitions) > 0:
logger.warning("Found transitions where the components are all valid transitions, but the complete transition is unknown: %s", sorted(unknown_transitions))
================================================
FILE: stanza/models/constituency/parse_tree.py
================================================
"""
Tree datastructure
"""
from collections import deque, Counter
import copy
from enum import Enum
from io import StringIO
import itertools
import re
import warnings
from stanza.models.common.stanza_object import StanzaObject
# useful more for the "is" functionality than the time savings
CLOSE_PAREN = ')'
SPACE_SEPARATOR = ' '
OPEN_PAREN = '('
EMPTY_CHILDREN = ()
# used to split off the functional tags from various treebanks
# for example, the Icelandic treebank (which we don't currently
# incorporate) uses * to distinguish 'ADJP', 'ADJP*OC' but we treat
# those as the same
CONSTITUENT_SPLIT = re.compile("[-=#*]")
# These words occur in the VLSP dataset.
# The documentation claims there might be *O*, although those don't
# seem to exist in practice
WORDS_TO_PRUNE = ('*E*', '*T*', '*O*')
class TreePrintMethod(Enum):
"""
Describes a few options for printing trees.
This probably doesn't need to be used directly. See __format__
"""
ONE_LINE = 1 # (ROOT (S ... ))
LABELED_PARENS = 2 # (_ROOT (_S ... )_S )_ROOT
PRETTY = 3 # multiple lines
VLSP = 4 # (S ... )
LATEX_TREE = 5 # \Tree [.S [.NP ... ] ]
class Tree(StanzaObject):
"""
A data structure to represent a parse tree
"""
def __init__(self, label=None, children=None):
if children is None:
self.children = EMPTY_CHILDREN
elif isinstance(children, Tree):
self.children = (children,)
else:
self.children = tuple(children)
self.label = label
def is_leaf(self):
return len(self.children) == 0
def is_preterminal(self):
return len(self.children) == 1 and len(self.children[0].children) == 0
def yield_preterminals(self):
"""
Yield the preterminals one at a time in order
"""
if self.is_preterminal():
yield self
return
if self.is_leaf():
raise ValueError("Attempted to iterate preterminals on non-internal node")
iterator = iter(self.children)
node = next(iterator, None)
while node is not None:
if node.is_preterminal():
yield node
else:
iterator = itertools.chain(node.children, iterator)
node = next(iterator, None)
def leaf_labels(self):
"""
Get the labels of the leaves
"""
if self.is_leaf():
return [self.label]
words = [x.children[0].label for x in self.yield_preterminals()]
return words
def __len__(self):
return len(self.leaf_labels())
def all_leaves_are_preterminals(self):
"""
Returns True if all leaves are under preterminals, False otherwise
"""
if self.is_leaf():
return False
if self.is_preterminal():
return True
return all(t.all_leaves_are_preterminals() for t in self.children)
def pretty_print(self, normalize=None):
"""
Print with newlines & indentation on each line
Preterminals and nodes with all preterminal children go on their own line
You can pass in your own normalize() function. If you do,
make sure the function updates the parens to be something
other than () or the brackets will be broken
"""
if normalize is None:
normalize = lambda x: x.replace("(", "-LRB-").replace(")", "-RRB-")
indent = 0
with StringIO() as buf:
stack = deque()
stack.append(self)
while len(stack) > 0:
node = stack.pop()
if node is CLOSE_PAREN:
# if we're trying to pretty print trees, pop all off close parens
# then write a newline
while node is CLOSE_PAREN:
indent -= 1
buf.write(CLOSE_PAREN)
if len(stack) == 0:
node = None
break
node = stack.pop()
buf.write("\n")
if node is None:
break
stack.append(node)
elif node.is_preterminal():
buf.write(" " * indent)
buf.write("%s%s %s%s" % (OPEN_PAREN, normalize(node.label), normalize(node.children[0].label), CLOSE_PAREN))
if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
buf.write("\n")
elif all(x.is_preterminal() for x in node.children):
buf.write(" " * indent)
buf.write("%s%s" % (OPEN_PAREN, normalize(node.label)))
for child in node.children:
buf.write(" %s%s %s%s" % (OPEN_PAREN, normalize(child.label), normalize(child.children[0].label), CLOSE_PAREN))
buf.write(CLOSE_PAREN)
if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
buf.write("\n")
else:
buf.write(" " * indent)
buf.write("%s%s\n" % (OPEN_PAREN, normalize(node.label)))
stack.append(CLOSE_PAREN)
for child in reversed(node.children):
stack.append(child)
indent += 1
buf.seek(0)
return buf.read()
def __format__(self, spec):
"""
Turn the tree into a string representing the tree
Note that this is not a recursive traversal
Otherwise, a tree too deep might blow up the call stack
There is a type specific format:
O -> one line PTB format, which is the default anyway
L -> open and close brackets are labeled, spaces in the tokens are replaced with _
P -> pretty print over multiple lines
V -> surround lines with ... , don't print ROOT, and turn () into L/RBKT
? -> spaces in the tokens are replaced with ? for any value of ? other than OLP
warning: this may be removed in the future
?{OLPV} -> specific format AND a custom space replacement
Vi -> add an ID to the in the V format. Also works with ?Vi
"""
space_replacement = " "
print_format = TreePrintMethod.ONE_LINE
if spec == 'L':
print_format = TreePrintMethod.LABELED_PARENS
space_replacement = "_"
elif spec and spec[-1] == 'L':
print_format = TreePrintMethod.LABELED_PARENS
space_replacement = spec[0]
elif spec == 'O':
print_format = TreePrintMethod.ONE_LINE
elif spec and spec[-1] == 'O':
print_format = TreePrintMethod.ONE_LINE
space_replacement = spec[0]
elif spec == 'P':
print_format = TreePrintMethod.PRETTY
elif spec and spec[-1] == 'P':
print_format = TreePrintMethod.PRETTY
space_replacement = spec[0]
elif spec and spec[0] == 'V':
print_format = TreePrintMethod.VLSP
use_tree_id = spec[-1] == 'i'
elif spec and len(spec) > 1 and spec[1] == 'V':
print_format = TreePrintMethod.VLSP
space_replacement = spec[0]
use_tree_id = spec[-1] == 'i'
elif spec == 'T':
print_format = TreePrintMethod.LATEX_TREE
elif spec and len(spec) > 1 and spec[1] == 'T':
print_format = TreePrintMethod.LATEX_TREE
space_replacement = spec[0]
elif spec:
space_replacement = spec[0]
warnings.warn("Use of a custom replacement without a format specifier is deprecated. Please use {}O instead".format(space_replacement), stacklevel=2)
LRB = "LBKT" if print_format == TreePrintMethod.VLSP else "-LRB-"
RRB = "RBKT" if print_format == TreePrintMethod.VLSP else "-RRB-"
def normalize(text):
return text.replace(" ", space_replacement).replace("(", LRB).replace(")", RRB)
if print_format is TreePrintMethod.PRETTY:
return self.pretty_print(normalize)
with StringIO() as buf:
stack = deque()
if print_format == TreePrintMethod.VLSP:
if use_tree_id:
buf.write("\n".format(self.tree_id))
else:
buf.write("\n")
if len(self.children) == 0:
raise ValueError("Cannot print an empty tree with V format")
elif len(self.children) > 1:
raise ValueError("Cannot print a tree with %d branches with V format" % len(self.children))
stack.append(self.children[0])
elif print_format == TreePrintMethod.LATEX_TREE:
buf.write("\\Tree ")
if len(self.children) == 0:
raise ValueError("Cannot print an empty tree with T format")
elif len(self.children) == 1 and len(self.children[0].children) == 0:
buf.write("[.? ")
buf.write(normalize(self.children[0].label))
buf.write(" ]")
elif self.label == 'ROOT':
stack.append(self.children[0])
else:
stack.append(self)
else:
stack.append(self)
while len(stack) > 0:
node = stack.pop()
if isinstance(node, str):
buf.write(node)
continue
if len(node.children) == 0:
if node.label is not None:
buf.write(normalize(node.label))
continue
if print_format is TreePrintMethod.LATEX_TREE:
if node.is_preterminal():
buf.write(normalize(node.children[0].label))
continue
buf.write("[.%s" % normalize(node.label))
stack.append(" ]")
elif print_format is TreePrintMethod.ONE_LINE or print_format is TreePrintMethod.VLSP:
buf.write(OPEN_PAREN)
if node.label is not None:
buf.write(normalize(node.label))
stack.append(CLOSE_PAREN)
elif print_format is TreePrintMethod.LABELED_PARENS:
buf.write("%s_%s" % (OPEN_PAREN, normalize(node.label)))
stack.append(CLOSE_PAREN + "_" + normalize(node.label))
stack.append(SPACE_SEPARATOR)
for child in reversed(node.children):
stack.append(child)
stack.append(SPACE_SEPARATOR)
if print_format == TreePrintMethod.VLSP:
buf.write("\n ")
buf.seek(0)
return buf.read()
def __repr__(self):
return "{}".format(self)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, Tree):
return False
if self.label != other.label:
return False
if len(self.children) != len(other.children):
return False
if any(c1 != c2 for c1, c2 in zip(self.children, other.children)):
return False
return True
def depth(self):
if not self.children:
return 0
return 1 + max(x.depth() for x in self.children)
def visit_preorder(self, internal=None, preterminal=None, leaf=None):
"""
Visit the tree in a preorder order
Applies the given functions to each node.
internal: if not None, applies this function to each non-leaf, non-preterminal node
preterminal: if not None, applies this functiion to each preterminal
leaf: if not None, applies this function to each leaf
The functions should *not* destructively alter the trees.
There is no attempt to interpret the results of calling these functions.
Rather, you can use visit_preorder to collect stats on trees, etc.
"""
if self.is_leaf():
if leaf:
leaf(self)
elif self.is_preterminal():
if preterminal:
preterminal(self)
else:
if internal:
internal(self)
for child in self.children:
child.visit_preorder(internal, preterminal, leaf)
@staticmethod
def get_unique_constituent_labels(trees):
"""
Walks over all of the trees and gets all of the unique constituent names from the trees
"""
if isinstance(trees, Tree):
trees = [trees]
constituents = Tree.get_constituent_counts(trees)
return sorted(set(constituents.keys()))
@staticmethod
def get_constituent_counts(trees):
"""
Walks over all of the trees and gets the count of the unique constituent names from the trees
"""
if isinstance(trees, Tree):
trees = [trees]
constituents = Counter()
for tree in trees:
tree.visit_preorder(internal = lambda x: constituents.update([x.label]))
return constituents
@staticmethod
def get_unique_tags(trees):
"""
Walks over all of the trees and gets all of the unique tags from the trees
"""
if isinstance(trees, Tree):
trees = [trees]
tags = set()
for tree in trees:
tree.visit_preorder(preterminal = lambda x: tags.add(x.label))
return sorted(tags)
@staticmethod
def get_unique_words(trees):
"""
Walks over all of the trees and gets all of the unique words from the trees
"""
if isinstance(trees, Tree):
trees = [trees]
words = set()
for tree in trees:
tree.visit_preorder(leaf = lambda x: words.add(x.label))
return sorted(words)
@staticmethod
def get_common_words(trees, num_words):
"""
Walks over all of the trees and gets the most frequently occurring words.
"""
if num_words == 0:
return set()
if isinstance(trees, Tree):
trees = [trees]
words = Counter()
for tree in trees:
tree.visit_preorder(leaf = lambda x: words.update([x.label]))
return sorted(x[0] for x in words.most_common()[:num_words])
@staticmethod
def get_rare_words(trees, threshold=0.05):
"""
Walks over all of the trees and gets the least frequently occurring words.
threshold: choose the bottom X percent
"""
if isinstance(trees, Tree):
trees = [trees]
words = Counter()
for tree in trees:
tree.visit_preorder(leaf = lambda x: words.update([x.label]))
threshold = max(int(len(words) * threshold), 1)
return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])
@staticmethod
def get_root_labels(trees):
return sorted(set(x.label for x in trees))
@staticmethod
def get_compound_constituents(trees, separate_root=False):
constituents = set()
stack = deque()
for tree in trees:
if separate_root:
constituents.add((tree.label,))
for child in tree.children:
stack.append(child)
else:
stack.append(tree)
while len(stack) > 0:
node = stack.pop()
if node.is_leaf() or node.is_preterminal():
continue
labels = [node.label]
while len(node.children) == 1 and not node.children[0].is_preterminal():
node = node.children[0]
labels.append(node.label)
constituents.add(tuple(labels))
for child in node.children:
stack.append(child)
return sorted(constituents)
# TODO: test different pattern
def simplify_labels(self, pattern=CONSTITUENT_SPLIT):
"""
Return a copy of the tree with the -=# removed
Leaves the text of the leaves alone.
"""
new_label = self.label
# check len(new_label) just in case it's a tag of - or =
if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'):
new_label = pattern.split(new_label)[0]
new_children = [child.simplify_labels(pattern) for child in self.children]
return Tree(new_label, new_children)
def reverse(self):
"""
Flip a tree backwards
The intent is to train a parser backwards to see if the
forward and backwards parsers can augment each other
"""
if self.is_leaf():
return Tree(self.label)
new_children = [child.reverse() for child in reversed(self.children)]
return Tree(self.label, new_children)
def remap_constituent_labels(self, label_map):
"""
Copies the tree with some labels replaced.
Labels in the map are replaced with the mapped value.
Labels not in the map are unchanged.
"""
if self.is_leaf():
return Tree(self.label)
if self.is_preterminal():
return Tree(self.label, Tree(self.children[0].label))
new_label = label_map.get(self.label, self.label)
return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children])
def remap_words(self, word_map):
"""
Copies the tree with some labels replaced.
Labels in the map are replaced with the mapped value.
Labels not in the map are unchanged.
"""
if self.is_leaf():
new_label = word_map.get(self.label, self.label)
return Tree(new_label)
if self.is_preterminal():
return Tree(self.label, self.children[0].remap_words(word_map))
return Tree(self.label, [child.remap_words(word_map) for child in self.children])
def replace_words(self, words):
"""
Replace all leaf words with the words in the given list (or iterable)
Returns a new tree
"""
word_iterator = iter(words)
def recursive_replace_words(subtree):
if subtree.is_leaf():
word = next(word_iterator, None)
if word is None:
raise ValueError("Not enough words to replace all leaves")
return Tree(word)
return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children])
new_tree = recursive_replace_words(self)
if any(True for _ in word_iterator):
raise ValueError("Too many words for the given tree")
return new_tree
def replace_tags(self, tags):
if self.is_leaf():
raise ValueError("Must call replace_tags with non-leaf")
if isinstance(tags, Tree):
tag_iterator = (x.label for x in tags.yield_preterminals())
else:
tag_iterator = iter(tags)
new_tree = copy.deepcopy(self)
queue = deque()
queue.append(new_tree)
while len(queue) > 0:
next_node = queue.pop()
if next_node.is_preterminal():
try:
label = next(tag_iterator)
except StopIteration:
raise ValueError("Not enough tags in sentence for given tree")
next_node.label = label
elif next_node.is_leaf():
raise ValueError("Got a badly structured tree: {}".format(self))
else:
queue.extend(reversed(next_node.children))
if any(True for _ in tag_iterator):
raise ValueError("Too many tags for the given tree")
return new_tree
def prune_none(self):
"""
Return a copy of the tree, eliminating all nodes which are in one of two categories:
they are a preterminal -NONE-, such as appears in PTB
*E* shows up in a VLSP dataset
they have been pruned to 0 children by the recursive call
"""
if self.is_leaf():
return Tree(self.label)
if self.is_preterminal():
if self.label == '-NONE-' or self.children[0].label in WORDS_TO_PRUNE:
return None
return Tree(self.label, Tree(self.children[0].label))
# must be internal node
new_children = [child.prune_none() for child in self.children]
new_children = [child for child in new_children if child is not None]
if len(new_children) == 0:
return None
return Tree(self.label, new_children)
def count_unary_depth(self):
if self.is_preterminal() or self.is_leaf():
return 0
if len(self.children) == 1:
t = self
score = 0
while not t.is_preterminal() and not t.is_leaf() and len(t.children) == 1:
score = score + 1
t = t.children[0]
child_score = max(tc.count_unary_depth() for tc in t.children)
score = max(score, child_score)
return score
score = max(t.count_unary_depth() for t in self.children)
return score
@staticmethod
def write_treebank(trees, out_file, fmt="{}"):
with open(out_file, "w", encoding="utf-8") as fout:
for tree in trees:
fout.write(fmt.format(tree))
fout.write("\n")
================================================
FILE: stanza/models/constituency/parser_training.py
================================================
from collections import Counter, namedtuple
import copy
import logging
import os
import random
import re
import torch
from torch import nn
#from stanza.models.common import pretrain
from stanza.models.common import utils
from stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache
from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.models.constituency import error_analysis_in_order
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import transition_sequence
from stanza.models.constituency import tree_reader
from stanza.models.constituency.in_order_compound_oracle import InOrderCompoundOracle
from stanza.models.constituency.in_order_oracle import InOrderOracle
from stanza.models.constituency.lstm_model import LSTMModel
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.top_down_oracle import TopDownOracle
from stanza.models.constituency.trainer import Trainer
from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler, verify_transitions, get_open_nodes, check_constituents, check_root_labels, remove_duplicate_trees, remove_singleton_trees
from stanza.server.parser_eval import EvaluateParser, ParseResult
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
tlogger = logging.getLogger('stanza.constituency.trainer')
TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])
class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
nans = self.nans + other.nans
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def evaluate(args, model_file, retag_pipeline):
"""
Loads the given model file and tests the eval_file treebank.
May retag the trees using retag_pipeline
Uses a subprocess to run the Java EvalB code
"""
# we create the Evaluator here because otherwise the transformers
# library constantly complains about forking the process
# note that this won't help in the event of training multiple
# models in the same run, although since that would take hours
# or days, that's not a very common problem
if args['num_generate'] > 0:
kbest = args['num_generate'] + 1
else:
kbest = None
with EvaluateParser(kbest=kbest) as evaluator:
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
'charlm_backward_file': args['charlm_backward_file'],
'device': args['device'],
}
trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
if args['log_shapes']:
trainer.log_shapes()
treebank = tree_reader.read_treebank(args['eval_file'])
tlogger.info("Read %d trees for evaluation", len(treebank))
retagged_treebank = treebank
if retag_pipeline is not None:
retag_method = trainer.model.retag_method
retag_xpos = retag_method == 'xpos'
tlogger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package'])
retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)
tlogger.info("Retagging finished")
if args['log_norms']:
trainer.log_norms()
f1, kbestF1, _ = run_dev_set(trainer.model, retagged_treebank, treebank, args, evaluator, analyze_first_errors=True)
tlogger.info("F1 score on %s: %f", args['eval_file'], f1)
if kbestF1 is not None:
tlogger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1)
def remove_optimizer(args, model_save_file, model_load_file):
"""
A utility method to remove the optimizer from a save file
Will make the save file a lot smaller
"""
# TODO: kind of overkill to load in the pretrain rather than
# change the load/save to work without it, but probably this
# functionality isn't used that often anyway
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
'charlm_backward_file': args['charlm_backward_file'],
'device': args['device'],
}
trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False)
trainer.save(model_save_file)
def add_grad_clipping(trainer, grad_clipping):
"""
Adds a torch.clamp hook on each parameter if grad_clipping is not None
"""
if grad_clipping is not None:
for p in trainer.model.parameters():
if p.requires_grad:
p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping))
def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file):
"""
Builds a Trainer (with model) and the train_sequences and transitions for the given trees.
"""
train_constituents = Tree.get_unique_constituent_labels(train_trees)
tlogger.info("Unique constituents in training set: %s", train_constituents)
if args['check_valid_states']:
check_constituents(train_constituents, dev_trees, "dev", fail=args['strict_check_constituents'])
check_constituents(train_constituents, silver_trees, "silver", fail=args['strict_check_constituents'])
constituent_counts = Tree.get_constituent_counts(train_trees)
tlogger.info("Constituent node counts: %s", constituent_counts)
tags = Tree.get_unique_tags(train_trees)
if None in tags:
raise RuntimeError("Fatal problem: the tagger put None on some of the nodes!")
tlogger.info("Unique tags in training set: %s", tags)
# no need to fail for missing tags between train/dev set
# the model has an unknown tag embedding
for tag in Tree.get_unique_tags(dev_trees):
if tag not in tags:
tlogger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag)
unary_limit = max(max(t.count_unary_depth() for t in train_trees),
max(t.count_unary_depth() for t in dev_trees)) + 1
if silver_trees:
unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))
tlogger.info("Unary limit: %d", unary_limit)
train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed'])
dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], args['reversed'])
silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], args['reversed'])
tlogger.info("Total unique transitions in train set: %d", len(train_transitions))
tlogger.info("Unique transitions in training set:\n %s", "\n ".join(map(str, train_transitions)))
expanded_train_transitions = set(train_transitions + [x for trans in train_transitions for x in trans.components()])
if args['check_valid_states']:
parse_transitions.check_transitions(expanded_train_transitions, dev_transitions, "dev")
# theoretically could just train based on the items in the silver dataset
parse_transitions.check_transitions(expanded_train_transitions, silver_transitions, "silver")
root_labels = Tree.get_root_labels(train_trees)
check_root_labels(root_labels, dev_trees, "dev")
check_root_labels(root_labels, silver_trees, "silver")
tlogger.info("Root labels in treebank: %s", root_labels)
verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels)
verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit, args['reversed'], "dev", root_labels)
# we don't check against the words in the dev set as it is
# expected there will be some UNK words
words = Tree.get_unique_words(train_trees)
rare_words = Tree.get_rare_words(train_trees, args['rare_word_threshold'])
# rare/unknown silver words will just get UNK if they are not already known
if silver_trees and args['use_silver_words']:
tlogger.info("Getting silver words to add to the delta embedding")
silver_words = Tree.get_common_words(tqdm(silver_trees, postfix='Silver words'), len(words))
words = sorted(set(words + silver_words))
# also, it's not actually an error if there is a pattern of
# compound unary or compound open nodes which doesn't exist in the
# train set. it just means we probably won't ever get that right
open_nodes = get_open_nodes(train_trees, args['transition_scheme'])
tlogger.info("Using the following open nodes:\n %s", "\n ".join(map(str, open_nodes)))
# at this point we have:
# pretrain
# train_trees, dev_trees
# lists of transitions, internal nodes, and root states the parser needs to be aware of
trainer = Trainer.build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file)
trainer.log_num_words_known(words)
# grad clipping is not saved with the rest of the model,
# so even in the case of a model we saved,
# we now have to add the grad clipping
add_grad_clipping(trainer, args['grad_clipping'])
return trainer, train_sequences, silver_sequences, train_transitions
def train(args, model_load_file, retag_pipeline):
"""
Build a model, train it using the requested train & dev files
"""
utils.log_training_args(args, tlogger)
# we create the Evaluator here because otherwise the transformers
# library constantly complains about forking the process
# note that this won't help in the event of training multiple
# models in the same run, although since that would take hours
# or days, that's not a very common problem
if args['num_generate'] > 0:
kbest = args['num_generate'] + 1
else:
kbest = None
if args['wandb']:
global wandb
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_constituency" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('dev_score', summary='max')
with EvaluateParser(kbest=kbest) as evaluator:
utils.ensure_dir(args['save_dir'])
train_trees = tree_reader.read_treebank(args['train_file'])
tlogger.info("Read %d trees for the training set", len(train_trees))
if args['train_remove_duplicates']:
train_trees = remove_duplicate_trees(train_trees, "train")
train_trees = remove_singleton_trees(train_trees)
dev_trees = tree_reader.read_treebank(args['eval_file'])
tlogger.info("Read %d trees for the dev set", len(dev_trees))
dev_trees = remove_duplicate_trees(dev_trees, "dev")
silver_trees = []
if args['silver_file']:
silver_trees = tree_reader.read_treebank(args['silver_file'])
tlogger.info("Read %d trees for the silver training set", len(silver_trees))
if args['silver_remove_duplicates']:
silver_trees = remove_duplicate_trees(silver_trees, "silver")
if retag_pipeline is not None:
tlogger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])
dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])
silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos'])
tlogger.info("Retagging finished")
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file)
if args['log_shapes']:
trainer.log_shapes()
trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator)
if args['wandb']:
wandb.finish()
return trainer
def compose_train_data(trees, sequences):
preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label))
for preterminal in tree.yield_preterminals()]
for tree in trees]
data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)]
return data
def next_epoch_data(leftover_training_data, train_data, epoch_size):
"""
Return the next epoch_size trees from the training data, starting
with leftover data from the previous epoch if there is any
The training loop generally operates on a fixed number of trees,
rather than going through all the trees in the training set
exactly once, and keeping the leftover training data via this
function ensures that each tree in the training set is touched
once before beginning to iterate again.
"""
if not train_data:
return [], []
epoch_data = leftover_training_data
while len(epoch_data) < epoch_size:
random.shuffle(train_data)
epoch_data.extend(train_data)
leftover_training_data = epoch_data[epoch_size:]
epoch_data = epoch_data[:epoch_size]
return leftover_training_data, epoch_data
def update_bert_learning_rate(args, optimizer, epochs_trained):
"""
Update the learning rate for the bert finetuning, if applicable
"""
# would be nice to have a parameter group specific scheduler
# however, there is an issue with the optimizer we had the most success with, madgrad
# when the learning rate is 0 for a group, it still learns by some
# small amount because of the eps parameter
# in fact, that is enough to make the learning for the bert in the
# second half broken
for base_param_group in optimizer.param_groups:
if base_param_group['param_group_name'] == 'base':
break
else:
raise AssertionError("There should always be a base parameter group")
for param_group in optimizer.param_groups:
if param_group['param_group_name'] == 'bert':
# Occasionally a model goes haywire and forgets how to use the transformer
# So far we have only seen this happen with Electra on the non-NML version of PTB
# We tried fixing that with an increasing transformer learning rate, but that
# didn't fully resolve the problem
# Switching to starting the finetuning after a few epochs seems to help a lot, though
old_lr = param_group['lr']
if args['bert_finetune_begin_epoch'] is not None and epochs_trained < args['bert_finetune_begin_epoch']:
param_group['lr'] = 0.0
elif args['bert_finetune_end_epoch'] is not None and epochs_trained >= args['bert_finetune_end_epoch']:
param_group['lr'] = 0.0
elif args['multistage'] and epochs_trained < args['epochs'] // 2:
param_group['lr'] = base_param_group['lr'] * args['stage1_bert_learning_rate']
else:
param_group['lr'] = base_param_group['lr'] * args['bert_learning_rate']
if param_group['lr'] != old_lr:
tlogger.info("Setting %s finetuning rate from %f to %f", param_group['param_group_name'], old_lr, param_group['lr'])
def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator):
"""
Given an initialized model, a processed dataset, and a secondary dev dataset, train the model
The training is iterated in the following loop:
extract a batch of trees of the same length from the training set
convert those trees into initial parsing states
repeat until trees are done:
batch predict the model's interpretation of the current states
add the errors to the list of things to backprop
advance the parsing state for each of the trees
"""
# Somewhat unusual, but possibly related to the extreme variability in length of trees
# Various experiments generally show about 0.5 F1 loss on various
# datasets when using 'mean' instead of 'sum' for reduction
# (Remember to adjust the weight decay when rerunning that experiment)
if args['loss'] == 'cross':
tlogger.info("Building CrossEntropyLoss(sum)")
process_outputs = lambda x: x
model_loss_function = nn.CrossEntropyLoss(reduction='sum')
elif args['loss'] == 'focal':
try:
from focal_loss.focal_loss import FocalLoss
except ImportError:
raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature")
tlogger.info("Building FocalLoss, gamma=%f", args['loss_focal_gamma'])
process_outputs = lambda x: torch.softmax(x, dim=1)
model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma'])
elif args['loss'] == 'large_margin':
tlogger.info("Building LargeMarginInSoftmaxLoss(sum)")
process_outputs = lambda x: x
model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum')
else:
raise ValueError("Unexpected loss term: %s" % args['loss'])
device = trainer.device
model_loss_function.to(device)
transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
for (y, x) in enumerate(trainer.transitions)}
trainer.train()
train_data = compose_train_data(train_trees, train_sequences)
silver_data = compose_train_data(silver_trees, silver_sequences)
if not args['epoch_size']:
args['epoch_size'] = len(train_data)
if silver_data and not args['silver_epoch_size']:
args['silver_epoch_size'] = args['epoch_size']
if args['multistage']:
multistage_splits = {}
# if we're halfway, only do pattn. save lattn for next time
multistage_splits[args['epochs'] // 2] = (args['pattn_num_layers'], False)
if LSTMModel.uses_lattn(args):
multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True)
# TODO: refactor the oracle choice into the transition scheme?
oracle = None
if args['transition_scheme'] is TransitionScheme.IN_ORDER:
oracle = InOrderOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
elif args['transition_scheme'] is TransitionScheme.IN_ORDER_COMPOUND:
oracle = InOrderCompoundOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
elif args['transition_scheme'] is TransitionScheme.TOP_DOWN:
oracle = TopDownOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
leftover_training_data = []
leftover_silver_data = []
if trainer.best_epoch > 0:
tlogger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1)
# if we're training a new model, save the initial state so it can be inspected
if args['save_each_start'] == 0 and trainer.epochs_trained == 0:
trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=True)
# trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1
for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1):
trainer.train()
tlogger.info("Starting epoch %d", trainer.epochs_trained)
update_bert_learning_rate(args, trainer.optimizer, trainer.epochs_trained)
if args['log_norms']:
trainer.log_norms()
leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size'])
leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['silver_epoch_size'])
epoch_data = epoch_data + epoch_silver_data
epoch_data.sort(key=lambda x: len(x[1]))
epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args)
# print statistics
# by now we've forgotten about the original tags on the trees,
# but it doesn't matter for hill climbing
f1, _, _ = run_dev_set(trainer.model, dev_trees, dev_trees, args, evaluator)
if f1 > trainer.best_f1 or (trainer.best_epoch == 0 and trainer.best_f1 == 0.0):
# best_epoch == 0 to force a save of an initial model
# useful for tests which expect something, even when a
# very simple model didn't learn anything
tlogger.info("New best dev score: %.5f > %.5f", f1, trainer.best_f1)
trainer.best_f1 = f1
trainer.best_epoch = trainer.epochs_trained
trainer.save(args['save_name'], save_optimizer=False)
if epoch_stats.nans > 0:
tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
# TODO: refactor the logging?
total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())
correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct])
tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str)
total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())
incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect])
tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str)
if len(epoch_stats.repairs_used) > 0:
tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common()))
if epoch_stats.fake_transitions_used > 0:
tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used)
stats_log_lines = [
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %d" % total_correct,
"Transitions incorrect: %d" % total_incorrect,
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
]
tlogger.info("\n ".join(stats_log_lines))
old_lr = trainer.optimizer.param_groups[0]['lr']
trainer.scheduler.step(f1)
new_lr = trainer.optimizer.param_groups[0]['lr']
if old_lr != new_lr:
tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr)
if args['wandb']:
wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)
if args['wandb_norm_regex']:
watch_regex = re.compile(args['wandb_norm_regex'])
for n, p in trainer.model.named_parameters():
if watch_regex.search(n):
wandb.log({n: torch.linalg.norm(p)})
if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']:
if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)):
tlogger.info("Setting dropout to 0.0 at epoch %d", trainer.epochs_trained)
trainer.model.word_dropout.p = 0
trainer.model.predict_dropout.p = 0
trainer.model.lstm_input_dropout.p = 0
# recreate the optimizer and alter the model as needed if we hit a new multistage split
if args['multistage'] and trainer.epochs_trained in multistage_splits:
# we may be loading a save model from an earlier epoch if the scores stopped increasing
epochs_trained = trainer.epochs_trained
batches_trained = trainer.batches_trained
stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained]
# when loading the model, let the saved model determine whether it has pattn or lattn
temp_args = copy.deepcopy(trainer.model.args)
temp_args.pop('pattn_num_layers', None)
temp_args.pop('lattn_d_proj', None)
# overwriting the old trainer & model will hopefully free memory
# load a new bert, even in PEFT mode, mostly so that the bert model
# doesn't collect a whole bunch of PEFTs
# for one thing, two PEFTs would mean 2x the optimizer parameters,
# messing up saving and loading the optimizer without jumping
# through more hoops
# loading the trainer w/o the foundation_cache should create
# the necessary bert_model and bert_tokenizer, and then we
# can reuse those values when building out new LSTMModel
trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False)
model = trainer.model
tlogger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained)
tlogger.info("Previous best model was at epoch %d", trainer.epochs_trained)
temp_args = dict(args)
tlogger.info("Switching to a model with %d pattn layers and %slattn", stage_pattn_layers, "" if stage_uses_lattn else "NO ")
temp_args['pattn_num_layers'] = stage_pattn_layers
if not stage_uses_lattn:
temp_args['lattn_d_proj'] = 0
pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])
forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
new_model = LSTMModel(pt,
forward_charlm,
backward_charlm,
model.bert_model,
model.bert_tokenizer,
model.force_bert_saved,
model.peft_name,
model.transitions,
model.constituents,
model.tags,
model.delta_words,
model.rare_words,
model.root_labels,
model.constituent_opens,
model.unary_limit(),
temp_args)
new_model.to(device)
new_model.copy_with_new_structure(model)
optimizer = build_optimizer(temp_args, new_model, False)
scheduler = build_scheduler(temp_args, optimizer)
trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch)
add_grad_clipping(trainer, args['grad_clipping'])
# checkpoint needs to be saved AFTER rebuilding the optimizer
# so that assumptions about the optimizer in the checkpoint
# can be made based on the end of the epoch
if args['checkpoint'] and args['checkpoint_save_name']:
trainer.save(args['checkpoint_save_name'], save_optimizer=True)
# same with the "each filename", actually, in case those are
# brought back for more training or even just for testing
if args['save_each_start'] is not None and args['save_each_start'] <= trainer.epochs_trained and trainer.epochs_trained % args['save_each_frequency'] == 0:
trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=args['save_each_optimizer'])
return trainer
def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args):
interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
random.shuffle(interval_starts)
optimizer = trainer.optimizer
epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args)
trainer.batches_trained += 1
# Early in the training, some trees will be degenerate in a
# way that results in layers going up the tree amplifying the
# weights until they overflow. Generally that problem
# resolves itself in a few iterations, so for now we just
# ignore those batches, but report how often it happens
if batch_stats.nans == 0:
optimizer.step()
optimizer.zero_grad()
epoch_stats = epoch_stats + batch_stats
return epoch_stats
def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):
"""
Train the model for one batch
The model itself will be updated, and a bunch of stats are returned
It is unclear if this refactoring is useful in any way. Might not be
... although the indentation does get pretty ridiculous if this is
merged into train_model_one_epoch and then iterate_training
"""
# now we add the state to the trees in the batch
# the state is built as a bulk operation
current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch],
[x.tree for x in training_batch],
[x.gold_sequence for x in training_batch])
transitions_correct = Counter()
transitions_incorrect = Counter()
repairs_used = Counter()
fake_transitions_used = 0
all_errors = []
all_answers = []
# we iterate through the batch in the following sequence:
# predict the logits and the applied transition for each tree in the batch
# collect errors
# - we always train to the desired one-hot vector
# this was a noticeable improvement over training just the
# incorrect transitions
# determine whether the training can continue using the "student" transition
# or if we need to use teacher forcing
# update all states using either the gold or predicted transition
# any trees which are now finished are removed from the training cycle
while len(current_batch) > 0:
outputs, pred_transitions, _ = model.predict(current_batch, is_legal=False)
gold_transitions = [x.gold_sequence[x.num_transitions] for x in current_batch]
trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions]
all_errors.append(outputs)
all_answers.extend(trans_tensor)
new_batch = []
update_transitions = []
for pred_transition, gold_transition, state in zip(pred_transitions, gold_transitions, current_batch):
# forget teacher forcing vs scheduled sampling
# we're going with idiot forcing
if pred_transition == gold_transition:
transitions_correct[gold_transition.short_name()] += 1
if state.num_transitions + 1 < len(state.gold_sequence):
if oracle is not None and epoch >= args['oracle_initial_epoch'] and random.random() < args['oracle_forced_errors']:
# TODO: could randomly choose from the legal transitions
# perhaps the second best scored transition
fake_transition = random.choice(model.transitions)
if fake_transition.is_legal(state, model):
_, new_sequence = oracle.fix_error(fake_transition, model, state)
if new_sequence is not None:
new_batch.append(state._replace(gold_sequence=new_sequence))
update_transitions.append(fake_transition)
fake_transitions_used = fake_transitions_used + 1
continue
new_batch.append(state)
update_transitions.append(gold_transition)
continue
transitions_incorrect[gold_transition.short_name(), pred_transition.short_name()] += 1
# if we are on the final operation, there are two choices:
# - the parsing mode is IN_ORDER, and the final transition
# is the close to end the sequence, which has no alternatives
# - the parsing mode is something else, in which case
# we have no oracle anyway
if state.num_transitions + 1 >= len(state.gold_sequence):
continue
if oracle is None or epoch < args['oracle_initial_epoch'] or not pred_transition.is_legal(state, model):
new_batch.append(state)
update_transitions.append(gold_transition)
continue
repair_type, new_sequence = oracle.fix_error(pred_transition, model, state)
# we can only reach here on an error
assert not repair_type.is_correct
repairs_used[repair_type] += 1
if new_sequence is not None and random.random() < args['oracle_frequency']:
new_batch.append(state._replace(gold_sequence=new_sequence))
update_transitions.append(pred_transition)
else:
new_batch.append(state)
update_transitions.append(gold_transition)
if len(current_batch) > 0:
# bulk update states - significantly faster
current_batch = model.bulk_apply(new_batch, update_transitions, fail=True)
errors = torch.cat(all_errors)
answers = torch.cat(all_answers)
errors = process_outputs(errors)
tree_loss = model_loss_function(errors, answers)
tree_loss.backward()
if args['watch_regex']:
matched = False
tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx)
watch_regex = re.compile(args['watch_regex'])
for n, p in trainer.model.named_parameters():
if watch_regex.search(n):
matched = True
if p.requires_grad and p.grad is not None:
tlogger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad))
elif p.requires_grad:
tlogger.info(" %s norm: %f grad required, but is None!", n, torch.linalg.norm(p))
else:
tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p))
if not matched:
tlogger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
batch_loss = 0.0
nans = 1
else:
batch_loss = tree_loss.item()
nans = 0
return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None, analyze_first_errors=False):
"""
This reparses a treebank and executes the CoreNLP Java EvalB code.
It only works if CoreNLP 4.3.0 or higher is in the classpath.
"""
tlogger.info("Processing %d trees from %s", len(retagged_trees), args['eval_file'])
model.eval()
num_generate = args.get('num_generate', 0)
keep_scores = num_generate > 0
sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True)
tree_iterator = iter(tqdm(sorted_trees))
treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores)
treebank = unsort(treebank, original_indices)
full_results = treebank
if num_generate > 0:
tlogger.info("Generating %d random analyses", args['num_generate'])
generated_treebanks = [treebank]
for i in tqdm(range(num_generate)):
tree_iterator = iter(tqdm(retagged_trees, leave=False, postfix="tb%03d" % i))
generated_treebanks.append(model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.weighted_choice, keep_scores=keep_scores))
#best_treebank = [ParseResult(parses[0].gold, [max([p.predictions[0] for p in parses], key=itemgetter(1))], None, None)
# for parses in zip(*generated_treebanks)]
#generated_treebanks = [best_treebank] + generated_treebanks
# TODO: if the model is dropping trees, this will not work
full_results = [ParseResult(parses[0].gold, [p.predictions[0] for p in parses], None, None)
for parses in zip(*generated_treebanks)]
if len(full_results) < len(retagged_trees):
tlogger.warning("Only evaluating %d trees instead of %d", len(full_results), len(retagged_trees))
else:
full_results = [x._replace(gold=gold) for x, gold in zip(full_results, original_trees)]
if args.get('mode', None) == 'predict' and args['predict_file']:
utils.ensure_dir(args['predict_dir'], verbose=False)
pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg")
orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg")
if os.path.exists(pred_file):
tlogger.warning("Cowardly refusing to overwrite {}".format(pred_file))
elif os.path.exists(orig_file):
tlogger.warning("Cowardly refusing to overwrite {}".format(orig_file))
else:
with open(pred_file, 'w') as fout:
for tree in full_results:
output_tree = tree.predictions[0].tree
if args['predict_output_gold_tags']:
output_tree = output_tree.replace_tags(tree.gold)
fout.write(args['predict_format'].format(output_tree))
fout.write("\n")
for i in range(num_generate):
pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".%03d.pred.mrg" % i)
with open(pred_file, 'w') as fout:
for tree in generated_treebanks[-(i+1)]:
output_tree = tree.predictions[0].tree
if args['predict_output_gold_tags']:
output_tree = output_tree.replace_tags(tree.gold)
fout.write(args['predict_format'].format(output_tree))
fout.write("\n")
with open(orig_file, 'w') as fout:
for tree in full_results:
fout.write(args['predict_format'].format(tree.gold))
fout.write("\n")
if len(full_results) == 0:
return 0.0, 0.0
if evaluator is None:
if num_generate > 0:
kbest = max(len(fr.predictions) for fr in full_results)
else:
kbest = None
with EvaluateParser(kbest=kbest) as evaluator:
response = evaluator.process(full_results)
else:
response = evaluator.process(full_results)
if analyze_first_errors and args['transition_scheme'] is TransitionScheme.IN_ORDER:
errors = Counter()
for result in full_results:
first_error = error_analysis_in_order.analyze_tree(result.gold, result.predictions[0].tree)
errors[first_error] += 1
log_lines = ["%30s: %d" % (key.name, errors[key]) for key in error_analysis_in_order.FirstError]
tlogger.info("First error frequency:\n %s", "\n ".join(log_lines))
kbestF1 = response.kbestF1 if response.HasField("kbestF1") else None
return response.f1, kbestF1, response.treeF1
================================================
FILE: stanza/models/constituency/partitioned_transformer.py
================================================
"""
Transformer with partitioned content and position features.
See section 3 of https://arxiv.org/pdf/1805.01052.pdf
"""
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
@staticmethod
def forward(ctx, input, p=0.5, train=False, inplace=False):
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, but got {}".format(p)
)
ctx.p = p
ctx.train = train
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
if ctx.p > 0 and ctx.train:
ctx.noise = torch.empty(
(input.size(0), input.size(-1)),
dtype=input.dtype,
layout=input.layout,
device=input.device,
)
if ctx.p == 1:
ctx.noise.fill_(0)
else:
ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
ctx.noise = ctx.noise[:, None, :]
output.mul_(ctx.noise)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.p > 0 and ctx.train:
return grad_output.mul(ctx.noise), None, None, None
else:
return grad_output, None, None, None
class FeatureDropout(nn.Dropout):
"""
Feature-level dropout: takes an input of size len x num_features and drops
each feature with probabibility p. A feature is dropped across the full
portion of the input that corresponds to a single batch element.
"""
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace)
x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace)
return x_c, x_p
else:
return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace)
# TODO: this module apparently is not treated the same the built-in
# nonlinearity modules, as multiple uses of the same relu on different
# tensors winds up mixing the gradients See if there is a way to
# resolve that other than creating a new nonlinearity for each layer
class PartitionedReLU(nn.ReLU):
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
return super().forward(x_c), super().forward(x_p)
class PartitionedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias)
self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias)
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
out_c = self.linear_c(x_c)
out_p = self.linear_p(x_p)
return out_c, out_p
class PartitionedMultiHeadAttention(nn.Module):
def __init__(
self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02
):
super().__init__()
self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
bound = math.sqrt(3.0) * initializer_range
for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]:
nn.init.uniform_(param, -bound, bound)
self.scaling_factor = 1 / d_qkv ** 0.5
self.dropout = nn.Dropout(attention_dropout)
def forward(self, x, mask=None):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c)
qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p)
q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)]
q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)]
q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor
k = torch.cat([k_c, k_p], dim=-1)
v = torch.cat([v_c, v_p], dim=-1)
dots = torch.einsum("bhqa,bhka->bhqk", q, k)
if mask is not None:
dots.data.masked_fill_(~mask[:, None, None, :], -float("inf"))
probs = F.softmax(dots, dim=-1)
probs = self.dropout(probs)
o = torch.einsum("bhqk,bhka->bhqa", probs, v)
o_c, o_p = torch.chunk(o, 2, dim=-1)
out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c)
out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p)
return out_c, out_p
class PartitionedTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model,
n_head,
d_qkv,
d_ff,
ff_dropout,
residual_dropout,
attention_dropout,
activation=PartitionedReLU(),
):
super().__init__()
self.self_attn = PartitionedMultiHeadAttention(
d_model, n_head, d_qkv, attention_dropout=attention_dropout
)
self.linear1 = PartitionedLinear(d_model, d_ff)
self.ff_dropout = FeatureDropout(ff_dropout)
self.linear2 = PartitionedLinear(d_ff, d_model)
self.norm_attn = nn.LayerNorm(d_model)
self.norm_ff = nn.LayerNorm(d_model)
self.residual_dropout_attn = FeatureDropout(residual_dropout)
self.residual_dropout_ff = FeatureDropout(residual_dropout)
self.activation = activation
def forward(self, x, mask=None):
residual = self.self_attn(x, mask=mask)
residual = torch.cat(residual, dim=-1)
residual = self.residual_dropout_attn(residual)
x = self.norm_attn(x + residual)
residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x))))
residual = torch.cat(residual, dim=-1)
residual = self.residual_dropout_ff(residual)
x = self.norm_ff(x + residual)
return x
class PartitionedTransformerEncoder(nn.Module):
def __init__(self,
n_layers,
d_model,
n_head,
d_qkv,
d_ff,
ff_dropout,
residual_dropout,
attention_dropout,
activation=PartitionedReLU,
):
super().__init__()
self.layers = nn.ModuleList([PartitionedTransformerEncoderLayer(d_model=d_model,
n_head=n_head,
d_qkv=d_qkv,
d_ff=d_ff,
ff_dropout=ff_dropout,
residual_dropout=residual_dropout,
attention_dropout=attention_dropout,
activation=activation())
for i in range(n_layers)])
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask=mask)
return x
class ConcatPositionalEncoding(nn.Module):
"""
Learns a position embedding
"""
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model))
nn.init.normal_(self.timing_table)
def forward(self, x):
timing = self.timing_table[:x.shape[1], :]
timing = timing.expand(x.shape[0], -1, -1)
out = torch.cat([x, timing], dim=-1)
return out
#
class PartitionedTransformerModule(nn.Module):
def __init__(self,
n_layers,
d_model,
n_head,
d_qkv,
d_ff,
ff_dropout,
residual_dropout,
attention_dropout,
word_input_size,
bias,
morpho_emb_dropout,
timing,
encoder_max_len,
activation=PartitionedReLU()
):
super().__init__()
self.project_pretrained = nn.Linear(
word_input_size, d_model // 2, bias=bias
)
self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout)
if timing == 'sin':
self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
elif timing == 'learned':
self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
else:
raise ValueError("Unhandled timing type: %s" % timing)
self.transformer_input_norm = nn.LayerNorm(d_model)
self.pattn_encoder = PartitionedTransformerEncoder(
n_layers,
d_model=d_model,
n_head=n_head,
d_qkv=d_qkv,
d_ff=d_ff,
ff_dropout=ff_dropout,
residual_dropout=residual_dropout,
attention_dropout=attention_dropout,
)
#
def forward(self, attention_mask, bert_embeddings):
# Prepares attention mask for feeding into the self-attention
device = bert_embeddings[0].device
if attention_mask:
valid_token_mask = attention_mask
else:
valids = []
for sent in bert_embeddings:
valids.append(torch.ones(len(sent), device=device))
padded_data = torch.nn.utils.rnn.pad_sequence(
valids,
batch_first=True,
padding_value=-100
)
valid_token_mask = padded_data != -100
valid_token_mask = valid_token_mask.to(device=device)
padded_embeddings = torch.nn.utils.rnn.pad_sequence(
bert_embeddings,
batch_first=True,
padding_value=0
)
# Project the pretrained embedding onto the desired dimension
extra_content_annotations = self.project_pretrained(padded_embeddings)
# Add positional information through the table
encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations))
encoder_in = self.transformer_input_norm(encoder_in)
# Put the partitioned input through the partitioned attention
annotations = self.pattn_encoder(encoder_in, valid_token_mask)
return annotations
================================================
FILE: stanza/models/constituency/positional_encoding.py
================================================
"""
Based on
https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
"""
import math
import torch
from torch import nn
class SinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position
"""
def __init__(self, model_dim, max_len):
super().__init__()
self.register_buffer('pe', self.build_position(model_dim, max_len))
@staticmethod
def build_position(model_dim, max_len, device=None):
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
pe = torch.zeros(max_len, model_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
if device is not None:
pe = pe.to(device=device)
return pe
def forward(self, x):
if max(x) >= self.pe.shape[0]:
# try to drop the reference first before creating a new encoding
# the goal being to save memory if we are close to the memory limit
device = self.pe.device
shape = self.pe.shape[1]
self.register_buffer('pe', None)
# TODO: this may result in very poor performance
# in the event of a model that increases size one at a time
self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
return self.pe[x]
def max_len(self):
return self.pe.shape[0]
class AddSinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position. Adds the position to the given matrix
Default behavior is batch_first
"""
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.encoding = SinusoidalEncoding(d_model, max_len)
def forward(self, x, scale=1.0):
"""
Adds the positional encoding to the input tensor
The tensor is expected to be of the shape B, N, D
Properly masking the output tensor is up to the caller
"""
if len(x.shape) == 3:
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
timing = timing.expand(x.shape[0], -1, -1)
elif len(x.shape) == 2:
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
return x + timing * scale
class ConcatSinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position. Concats the position and returns a larger object
Default behavior is batch_first
"""
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.encoding = SinusoidalEncoding(d_model, max_len)
def forward(self, x):
if len(x.shape) == 3:
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
timing = timing.expand(x.shape[0], -1, -1)
else:
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
out = torch.cat((x, timing), dim=-1)
return out
================================================
FILE: stanza/models/constituency/retagging.py
================================================
"""
Refactor a few functions specifically for retagging trees
Retagging is important because the gold tags will not be available at runtime
Note that the method which does the actual retagging is in utils.py
so as to avoid unnecessary circular imports
(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)
"""
import copy
import logging
from stanza import Pipeline
from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.resources.common import download_resources_json, load_resources_json, get_language_resources
tlogger = logging.getLogger('stanza.constituency.trainer')
# xpos tagger doesn't produce PP tag on the turin treebank,
# so instead we use upos to avoid unknown tag errors
RETAG_METHOD = {
"da": "upos", # the DDT has no xpos tags anyway
"de": "upos", # DE GSD is also missing a few punctuation tags
"es": "upos", # AnCora has half-finished xpos tags
"id": "upos", # GSD is missing a few punctuation tags - fixed in 2.12, though
"it": "upos",
"pt": "upos", # default PT model has no xpos either
"vi": "xpos", # the new version of UD can be merged with xpos from VLSP22
}
def add_retag_args(parser):
"""
Arguments specifically for retagging treebanks
"""
parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
parser.add_argument('--retag_method', default=None, choices=['xpos', 'upos'], help='Which tags to use when retagging. Default depends on the language')
parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ; in which case the majority vote wins')
parser.add_argument('--retag_pretrain_path', default=None, help='Use this for a pretrain path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom pretrain')
parser.add_argument('--retag_charlm_forward_file', default=None, help='Use this for a forward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
parser.add_argument('--retag_charlm_backward_file', default=None, help='Use this for a backward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
def postprocess_args(args):
"""
After parsing args, unify some settings
"""
# use a language specific default for retag_method if we know the language
# otherwise, use xpos
if args['retag_method'] is None and 'lang' in args and args['lang'] in RETAG_METHOD:
args['retag_method'] = RETAG_METHOD[args['lang']]
if args['retag_method'] is None:
args['retag_method'] = 'xpos'
if args['retag_method'] == 'xpos':
args['retag_xpos'] = True
elif args['retag_method'] == 'upos':
args['retag_xpos'] = False
else:
raise ValueError("Unknown retag method {}".format(xpos))
def build_retag_pipeline(args):
"""
Builds retag pipelines based on the arguments
May alter the arguments if the pipeline is incompatible, such as
taggers with no xpos
Will return a list of one or more retag pipelines.
Multiple tagger models can be specified by having them
semi-colon separated in retag_model_path.
"""
# some argument sets might not use 'mode'
if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':
download_resources_json()
resources = load_resources_json()
if '_' in args['retag_package']:
lang, package = args['retag_package'].split('_', 1)
lang_resources = get_language_resources(resources, lang)
if lang_resources is None and 'lang' in args:
lang_resources = get_language_resources(resources, args['lang'])
if lang_resources is not None and 'pos' in lang_resources and args['retag_package'] in lang_resources['pos']:
lang = args['lang']
package = args['retag_package']
else:
if 'lang' not in args:
raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package'])
lang = args.get('lang', None)
package = args['retag_package']
foundation_cache = FoundationCache()
retag_args = {"lang": lang,
"processors": "tokenize, pos",
"tokenize_pretokenized": True,
"package": {"pos": package}}
if args['retag_pretrain_path'] is not None:
retag_args['pos_pretrain_path'] = args['retag_pretrain_path']
if args['retag_charlm_forward_file'] is not None:
retag_args['pos_forward_charlm_path'] = args['retag_charlm_forward_file']
if args['retag_charlm_backward_file'] is not None:
retag_args['pos_backward_charlm_path'] = args['retag_charlm_backward_file']
def build(retag_args, path):
retag_args = copy.deepcopy(retag_args)
# we just downloaded the resources a moment ago
# no need to repeatedly download
retag_args['download_method'] = 'reuse_resources'
if path is not None:
retag_args['allow_unknown_language'] = True
retag_args['pos_model_path'] = path
tlogger.debug('Creating retag pipeline using %s', path)
else:
tlogger.debug('Creating retag pipeline for %s package', package)
retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args)
if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
tlogger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
args['retag_xpos'] = False
args['retag_method'] = 'upos'
return retag_pipeline
if args['retag_model_path'] is None:
return [build(retag_args, None)]
paths = args['retag_model_path'].split(";")
# can be length 1 if only one tagger to work with
return [build(retag_args, path) for path in paths]
return None
================================================
FILE: stanza/models/constituency/score_converted_dependencies.py
================================================
"""
Script which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter
Currently this does not have the constituency parser as an option,
although that is easy to add.
Only English is supported, as only English is available in the CoreNLP converter
"""
import argparse
import os
import tempfile
import stanza
from stanza.models.constituency import retagging
from stanza.models.depparse import scorer
from stanza.utils.conll import CoNLL
def score_converted_dependencies(args):
if args['lang'] != 'en':
raise ValueError("Converting and scoring dependencies is currently only supported for English")
constituency_package = args['constituency_package']
pipeline_args = {'lang': args['lang'],
'tokenize_pretokenized': True,
'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package},
'processors': 'tokenize, pos, constituency, depparse'}
pipeline = stanza.Pipeline(**pipeline_args)
input_doc = CoNLL.conll2doc(args['eval_file'])
output_doc = pipeline(input_doc)
print("Processed %d sentences" % len(output_doc.sentences))
# reload - the pipeline clobbered the gold values
input_doc = CoNLL.conll2doc(args['eval_file'])
scorer.score_named_dependencies(output_doc, input_doc)
with tempfile.TemporaryDirectory() as tempdir:
output_path = os.path.join(tempdir, "converted.conll")
CoNLL.write_doc2conll(output_doc, output_path)
_, _, score = scorer.score(output_path, args['eval_file'])
print("Parser score:")
print("{} {:.2f}".format(constituency_package, score*100))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--lang', default='en', type=str, help='Language')
parser.add_argument('--eval_file', default="extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu", help='Input file for data loader.')
parser.add_argument('--constituency_package', default="ptb3-revised_electra-large", help='Which constituency parser to use for converting')
retagging.add_retag_args(parser)
args = parser.parse_args()
args = vars(args)
retagging.postprocess_args(args)
score_converted_dependencies(args)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/constituency/state.py
================================================
from collections import namedtuple
class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',
'sentence_length', 'num_opens', 'word_position', 'score', 'broken'])):
"""
Represents a partially completed transition parse
Includes stack/buffers for unused words, already executed transitions, and partially build constituents
At training time, also keeps track of the gold data we are reparsing
num_opens is useful for tracking
1) if the parser is in a stuck state where it is making infinite opens
2) if a close transition is impossible because there are no previous opens
sentence_length tracks how long the sentence is so we abort if we go infinite
non-stack information such as sentence_length and num_opens
will be copied from the original_state if possible, with the
exact arguments overriding the values in the original_state
gold_tree: the original tree, if made from a gold tree. might be None
gold_sequence: the original transition sequence, if available
Note that at runtime, gold values will not be available
word_position tracks where in the word queue we are. cheaper than
manipulating the list itself. this can be handled differently
from transitions and constituents as it is processed once
at the start of parsing
The word_queue should have both a start and an end word.
Those can be None in the case of the endpoints if they are unused.
"""
def empty_word_queue(self):
# the first element of each stack is a sentinel with no value
# and no parent
return self.word_position == self.sentence_length
def empty_transitions(self):
# the first element of each stack is a sentinel with no value
# and no parent
return self.transitions.parent is None
def has_one_constituent(self):
# a length of 1 represents no constituents
return self.constituents.length == 2
@property
def empty_constituents(self):
return self.constituents.parent is None
def num_constituents(self):
return self.constituents.length - 1
@property
def num_transitions(self):
# -1 for the sentinel value
return self.transitions.length - 1
def get_word(self, pos):
# +1 to handle the initial sentinel value
# (which you can actually get with pos=-1)
return self.word_queue[pos+1]
def finished(self, model):
return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.root_labels
def get_tree(self, model):
return model.get_top_constituent(self.constituents)
def all_transitions(self, model):
# TODO: rewrite this to be nicer / faster? or just refactor?
all_transitions = []
transitions = self.transitions
while transitions.parent is not None:
all_transitions.append(model.get_top_transition(transitions))
transitions = transitions.parent
return list(reversed(all_transitions))
def all_constituents(self, model):
# TODO: rewrite this to be nicer / faster?
all_constituents = []
constituents = self.constituents
while constituents.parent is not None:
all_constituents.append(model.get_top_constituent(constituents))
constituents = constituents.parent
return list(reversed(all_constituents))
def all_words(self, model):
return [model.get_word(x) for x in self.word_queue]
def to_string(self, model):
return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens)
def __str__(self):
return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents))
class MultiState(namedtuple('MultiState', ['states', 'gold_tree', 'gold_sequence', 'score'])):
def finished(self, ensemble):
return self.states[0].finished(ensemble.models[0])
def get_tree(self, ensemble):
return self.states[0].get_tree(ensemble.models[0])
@property
def empty_constituents(self):
return self.states[0].empty_constituents
def num_constituents(self):
return len(self.states[0].constituents) - 1
@property
def num_transitions(self):
# -1 for the sentinel value
return len(self.states[0].transitions) - 1
@property
def num_opens(self):
return self.states[0].num_opens
@property
def sentence_length(self):
return self.states[0].sentence_length
def empty_word_queue(self):
return self.states[0].empty_word_queue()
def empty_transitions(self):
return self.states[0].empty_transitions()
@property
def constituents(self):
# warning! if there is information in the constituents such as
# the embedding of the constituent, this will only contain the
# first such embedding
# the other models' constituent states won't be returned
return self.states[0].constituents
@property
def transitions(self):
# warning! if there is information in the transitions such as
# the embedding of the transition, this will only contain the
# first such embedding
# the other models' transition states won't be returned
return self.states[0].transitions
================================================
FILE: stanza/models/constituency/text_processing.py
================================================
import os
import logging
from stanza.models.common import utils
from stanza.models.constituency.utils import retag_tags
from stanza.models.constituency.trainer import Trainer
from stanza.models.constituency.tree_reader import read_trees
from stanza.utils.get_tqdm import get_tqdm
logger = logging.getLogger('stanza')
tqdm = get_tqdm()
def read_tokenized_file(tokenized_file):
"""
Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
"""
with open(tokenized_file, encoding='utf-8') as fin:
lines = fin.readlines()
lines = [x.strip() for x in lines]
lines = [x for x in lines if x]
docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
ids = [None] * len(docs)
return docs, ids
def read_xml_tree_file(tree_file):
"""
Read sentences from a file of the format unique to VLSP test sets
in particular, it should be multiple blocks of
(tree ...)
"""
with open(tree_file, encoding='utf-8') as fin:
lines = fin.readlines()
lines = [x.strip() for x in lines]
lines = [x for x in lines if x]
docs = []
ids = []
tree_id = None
tree_text = []
for line in lines:
if line.startswith(" 1:
tree_id = tree_id[1]
if tree_id.endswith(">"):
tree_id = tree_id[:-1]
tree_id = int(tree_id)
else:
tree_id = None
elif line.startswith(" = len(gold_sequence):
raise AssertionError("Found a sequence of OpenConstituent at the end of a TOP_DOWN sequence!")
if not isinstance(gold_sequence[shift_index], Shift):
raise AssertionError("Expected to find a Shift after a sequence of OpenConstituent. There should not be a %s" % gold_sequence[shift_index])
#print("Input sequence: %s\nIndex %d\nGold %s Pred %s" % (gold_sequence, gold_index, gold_transition, pred_transition))
updated_sequence = gold_sequence
while shift_index > gold_index:
close_index = advance_past_constituents(updated_sequence, shift_index)
if close_index is None:
raise AssertionError("Did not find a corresponding Close for this Open")
# cut out the corresponding open and close
updated_sequence = updated_sequence[:shift_index-1] + updated_sequence[shift_index:close_index] + updated_sequence[close_index+1:]
shift_index -= 1
#print(" %s" % updated_sequence)
#print("Final updated sequence: %s" % updated_sequence)
return updated_sequence
def fix_nested_open_constituent(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
We were supposed to predict Open(X), then Open(Y), but predicted Open(Y) instead
We treat this as a single recall error.
We could even go crazy and turn it into a Unary,
such as Open(Y), Open(X), Open(Y)...
presumably that would be very confusing to the parser
not to mention ambiguous as to where to close the new constituent
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, OpenConstituent):
return None
assert len(gold_sequence) > gold_index + 1
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
return None
# This replacement works if we skipped exactly one level
if gold_sequence[gold_index+1].label != pred_transition.label:
return None
close_index = advance_past_constituents(gold_sequence, gold_index+1)
assert close_index is not None
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
return updated_sequence
def fix_shift_open_immediate_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
We were supposed to Shift, but instead we Opened
The biggest problem with this type of error is that the Close of
the Open is ambiguous. We could put it immediately before the
next Close, immediately after the Shift, or anywhere in between.
One unambiguous case would be if the proper sequence was Shift - Close.
Then it is unambiguous that the only possible repair is Open - Shift - Close - Close.
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, Shift):
return None
assert len(gold_sequence) > gold_index + 1
if not isinstance(gold_sequence[gold_index+1], CloseConstituent):
# this is the ambiguous case
return None
return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
def fix_shift_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
We were supposed to Shift, but instead we Opened
The biggest problem with this type of error is that the Close of
the Open is ambiguous. We could put it immediately before the
next Close, immediately after the Shift, or anywhere in between.
In this fix, we are testing what happens if we treat this Open as a Unary transition.
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, Shift):
return None
assert len(gold_sequence) > gold_index + 1
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
# this is the unambiguous case, which should already be handled
return None
return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
def fix_shift_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
We were supposed to Shift, but instead we Opened
The biggest problem with this type of error is that the Close of
the Open is ambiguous. We could put it immediately before the
next Close, immediately after the Shift, or anywhere in between.
In this fix, we put the corresponding Close for this Open at the end of the enclosing bracket.
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, Shift):
return None
assert len(gold_sequence) > gold_index + 1
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
# this is the unambiguous case, which should already be handled
return None
outer_close_index = advance_past_constituents(gold_sequence, gold_index)
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:outer_close_index] + [CloseConstituent()] + gold_sequence[outer_close_index:]
def fix_shift_open_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, Shift):
return None
assert len(gold_sequence) > gold_index + 1
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
# this is the unambiguous case, which should already be handled
return None
# at this point: have Opened a constituent which we don't want
# need to figure out where to Close it
# could close it after the shift or after any given block
candidates = []
current_index = gold_index
while not isinstance(gold_sequence[current_index], CloseConstituent):
if isinstance(gold_sequence[current_index], Shift):
end_index = current_index
else:
end_index = find_constituent_end(gold_sequence, current_index)
candidates.append((gold_sequence[:gold_index], [pred_transition], gold_sequence[gold_index:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
current_index = end_index + 1
scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3)
if best_idx == len(candidates) - 1:
best_idx = -1
repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.name,
value="%d.%d" % (RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.value, best_idx),
is_correct=False)
return repair_type, best_candidate
def fix_close_shift_ambiguous_immediate(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Instead of a Close, we predicted a Shift. This time, we immediately close no matter what comes after the next Shift.
An alternate strategy would be to Close at the closing of the outer constituent.
"""
if not isinstance(pred_transition, Shift):
return None
if not isinstance(gold_transition, CloseConstituent):
return None
num_closes = 0
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
num_closes += 1
if not isinstance(gold_sequence[gold_index + num_closes], Shift):
# TODO: we should be able to handle this case too (an Open)
# however, it will be rare once the parser gets going and it
# would cause a lot of errors, anyway
return None
if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
# this one should just have been satisfied in the non-ambiguous version
return None
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+1:]
return updated_sequence
def fix_close_shift_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
Instead of a Close, we predicted a Shift. This time, we close at the end of the outer bracket no matter what comes after the next Shift.
An alternate strategy would be to Close as soon as possible after the Shift.
"""
if not isinstance(pred_transition, Shift):
return None
if not isinstance(gold_transition, CloseConstituent):
return None
num_closes = 0
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
num_closes += 1
if not isinstance(gold_sequence[gold_index + num_closes], Shift):
# TODO: we should be able to handle this case too (an Open)
# however, it will be rare once the parser gets going and it
# would cause a lot of errors, anyway
return None
if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
# this one should just have been satisfied in the non-ambiguous version
return None
# outer_close_index is now where the constituent which the broken constituent(s) reside inside gets closed
outer_close_index = advance_past_constituents(gold_sequence, gold_index + num_closes)
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+num_closes:outer_close_index] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[outer_close_index:]
return updated_sequence
def fix_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, count_opens=False):
"""
We were supposed to Close, but instead did a Shift
In most cases, this will be ambiguous. There is now a constituent
which has been missed, no matter what we do, and we are on the
hook for eventually closing this constituent, creating a precision
error as well. The ambiguity arises because there will be
multiple places where the Close could occur if there are more
constituents created between now and when the outer constituent is
Closed.
The non-ambiguous case is if the proper sequence was
Close - Shift - Close
similar cases are also non-ambiguous, such as
Close - Close - Shift - Close
for that matter, so is the following, although the Opens will be lost
Close - Open - Shift - Close - Close
count_opens is an option to make it easy to count with or without
Open as different oracle fixes
"""
if not isinstance(pred_transition, Shift):
return None
if not isinstance(gold_transition, CloseConstituent):
return None
num_closes = 0
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
num_closes += 1
# We may allow unary transitions here
# the opens will be lost in the repaired sequence
num_opens = 0
if count_opens:
while isinstance(gold_sequence[gold_index + num_closes + num_opens], OpenConstituent):
num_opens += 1
if not isinstance(gold_sequence[gold_index + num_closes + num_opens], Shift):
if count_opens:
raise AssertionError("Should have found a Shift after a sequence of Opens or a Close with no Open. Started counting at %d in sequence %s" % (gold_index, gold_sequence))
return None
if not isinstance(gold_sequence[gold_index + num_closes + num_opens + 1], CloseConstituent):
return None
for idx in range(num_opens):
if not isinstance(gold_sequence[gold_index + num_closes + num_opens + idx + 1], CloseConstituent):
return None
# Now we know it is Close x num_closes, Shift, Close
# Since we have erroneously predicted a Shift now, the best we can
# do is to follow that, then add num_closes Closes
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+num_opens*2+1:]
return updated_sequence
def fix_close_shift_with_opens(*args, **kwargs):
return fix_close_shift(*args, **kwargs, count_opens=True)
def fix_close_next_correct_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
We were supposed to Close, but instead predicted Shift when the next transition is Shift
This differs from the previous Close-Shift in that this case does
not have an unambiguous place to put the Close. Instead, we let
the model predict where to put the Close
Note that this can also work for Close-Open with the next Open correct
Not covered (yet?) is multiple Close in a row
"""
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, (Shift, OpenConstituent)):
return None
if gold_sequence[gold_index+1] != pred_transition:
return None
candidates = []
current_index = gold_index + 1
while not isinstance(gold_sequence[current_index], CloseConstituent):
if isinstance(gold_sequence[current_index], Shift):
end_index = current_index
else:
end_index = find_constituent_end(gold_sequence, current_index)
candidates.append((gold_sequence[:gold_index], gold_sequence[gold_index+1:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
current_index = end_index + 1
scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3)
if best_idx == len(candidates) - 1:
best_idx = -1
repair_type = RepairEnum(name=RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.name,
value="%d.%d" % (RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.value, best_idx),
is_correct=False)
return repair_type, best_candidate
def fix_close_open_correct_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
"""
We were supposed to Close, but instead did an Open
In general this is ambiguous (like close/shift), as we need to know when to close the incorrect constituent
A case that is not ambiguous is when exactly one constituent was
supposed to come after the Close and it matches the Open we just
created. In that case, we treat that constituent as if it were
part of the non-Closed constituent. For example,
"ate (NP spaghetti) (PP with a fork)" ->
"ate (NP spaghetti (PP with a fork))"
(delicious)
There is also an option to not check for the Close after the first
constituent, in which case any number of constituents could have
been predicted. This represents a solution of the ambiguous form
of the Close/Open transition where the Close could occur in
multiple places later in the sequence.
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, CloseConstituent):
return None
if gold_sequence[gold_index+1] != pred_transition:
return None
close_index = find_constituent_end(gold_sequence, gold_index+1)
if check_close and not isinstance(gold_sequence[close_index+1], CloseConstituent):
return None
# at this point, we know we can put the Close at the end of the
# Open which was accidentally added
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index+1] + [gold_transition] + gold_sequence[close_index+1:]
return updated_sequence
def fix_close_open_correct_open_ambiguous_immediate(*args, **kwargs):
return fix_close_open_correct_open(*args, **kwargs, check_close=False)
def fix_close_open_correct_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
"""
We were supposed to Close, but instead did an Open in an ambiguous context. Here we resolve it later in the tree
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, CloseConstituent):
return None
if gold_sequence[gold_index+1] != pred_transition:
return None
# this will be the index of the Close for the surrounding constituent
close_index = advance_past_constituents(gold_sequence, gold_index+1)
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + [gold_transition] + gold_sequence[close_index:]
return updated_sequence
def fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it as a Unary
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if pred_transition == gold_transition:
return None
if gold_sequence[gold_index+1] == pred_transition:
# This case is covered by the nested open repair
return None
close_index = find_constituent_end(gold_sequence, gold_index)
assert close_index is not None
assert isinstance(gold_sequence[close_index], CloseConstituent)
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
return updated_sequence
def fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
If there is an Open/Open error which is not covered by the
unambiguous single recall error, we try fixing it by putting the
close at the end of the outer constituent
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if pred_transition == gold_transition:
return None
if gold_sequence[gold_index+1] == pred_transition:
# This case is covered by the nested open repair
return None
close_index = advance_past_constituents(gold_sequence, gold_index)
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
return updated_sequence
def fix_open_open_ambiguous_random(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
"""
If there is an Open/Open error which is not covered by the
unambiguous single recall error, we try fixing it by putting the
close at the end of the outer constituent
"""
if not isinstance(pred_transition, OpenConstituent):
return None
if not isinstance(gold_transition, OpenConstituent):
return None
if pred_transition == gold_transition:
return None
if gold_sequence[gold_index+1] == pred_transition:
# This case is covered by the nested open repair
return None
if random.random() < 0.5:
return fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
else:
return fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, Shift):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_SHIFT_OPEN, None
def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, Shift):
return None
return RepairType.OTHER_CLOSE_SHIFT, None
def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, CloseConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_CLOSE_OPEN, None
def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
if not isinstance(gold_transition, OpenConstituent):
return None
if not isinstance(pred_transition, OpenConstituent):
return None
return RepairType.OTHER_OPEN_OPEN, None
class RepairType(Enum):
"""
Keep track of which repair is used, if any, on an incorrect transition
A test of the top-down oracle with no charlm or transformer
(eg, word vectors only) on EN PTB3 goes as follows.
3x training rounds, best training parameters as of Jan. 2024
unambiguous transitions only:
oracle scheme dev test
no oracle 0.9230 0.9194
+shift/close 0.9224 0.9180
+open/close 0.9225 0.9193
+open/shift (one) 0.9245 0.9207
+open/shift (mult) 0.9243 0.9211
+open/open nested 0.9258 0.9213
+shift/open 0.9266 0.9229
+close/shift (only) 0.9270 0.9230
+close/shift w/ opens 0.9262 0.9221
+close/open one con 0.9273 0.9230
Potential solutions for various ambiguous transitions:
close/open
can close immediately after the corresponding constituent or after any number of constituents
close/shift
can close immediately
can close anywhere up to the next close
any number of missed Opens are treated as recall errors
open/open
could treat as unary
could close at any number of positions after the next structures, up to the outer open's closing
shift/open ambiguity resolutions:
treat as unary
treat as wrapper around the next full constituent to build
treat as wrapper around everything to build until the next constituent
testing one at a time in addition to the full set of unambiguous corrections:
+close/open immediate 0.9259 0.9225
+close/open later 0.9258 0.9257
+close/shift immediate 0.9261 0.9219
+close/shift later 0.9270 0.9230
+open/open later 0.9269 0.9239
+open/open unary 0.9275 0.9246
+shift/open later 0.9263 0.9253
+shift/open unary 0.9264 0.9243
so there is some evidence that open/open or shift/open would be beneficial
Training by randomly choosing between the open/open, 50/50
+open/open random 0.9257 0.9235
so that didn't work great compared to the individual transitions
Testing deterministic resolutions of the ambiguous transitions
vs predicting the appropriate transition to use:
SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR,CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR,CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR
SHIFT_OPEN_AMBIGUOUS_PREDICTED,CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED
EN ambiguous (no charlm or transformer) 0.9268 0.9231
EN predicted 0.9270 0.9257
EN none of the above 0.9268 0.9229
ZH ambiguous 0.9137 0.9127
ZH predicted 0.9148 0.9141
ZH none of the above 0.9141 0.9143
DE ambiguous 0.9579 0.9408
DE predicted 0.9575 0.9406
DE none of the above 0.9581 0.9411
ID ambiguous 0.8889 0.8794
ID predicted 0.8911 0.8801
ID none of the above 0.8913 0.8822
IT ambiguous 0.8404 0.8380
IT predicted 0.8397 0.8398
IT none of the above 0.8400 0.8409
VI ambiguous 0.8290 0.7676
VI predicted 0.8287 0.7682
VI none of the above 0.8292 0.7691
"""
def __new__(cls, fn, correct=False, debug=False):
"""
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
"""
value = len(cls.__members__)
obj = object.__new__(cls)
obj._value_ = value + 1
obj.fn = fn
obj.correct = correct
obj.debug = debug
return obj
@property
def is_correct(self):
return self.correct
# The parser chose to close a bracket instead of shift something
# into the bracket
# This causes both a precision and a recall error as there is now
# an incorrect bracket and a missing correct bracket
# Any bracket creation here would cause more wrong brackets, though
SHIFT_CLOSE_ERROR = (fix_shift_close,)
OPEN_CLOSE_ERROR = (fix_open_close,)
# open followed by shift was instead predicted to be shift
ONE_OPEN_SHIFT_ERROR = (fix_one_open_shift,)
# open followed by shift was instead predicted to be shift
MULTIPLE_OPEN_SHIFT_ERROR = (fix_multiple_open_shift,)
# should have done Open(X), Open(Y)
# instead just did Open(Y)
NESTED_OPEN_OPEN_ERROR = (fix_nested_open_constituent,)
SHIFT_OPEN_ERROR = (fix_shift_open_immediate_close,)
CLOSE_SHIFT_ERROR = (fix_close_shift,)
CLOSE_SHIFT_WITH_OPENS_ERROR = (fix_close_shift_with_opens,)
CLOSE_OPEN_ONE_CON_ERROR = (fix_close_open_correct_open,)
CORRECT = (None, True)
UNKNOWN = None
CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_open_correct_open_ambiguous_immediate,)
CLOSE_OPEN_AMBIGUOUS_LATER_ERROR = (fix_close_open_correct_open_ambiguous_later,)
CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_shift_ambiguous_immediate,)
CLOSE_SHIFT_AMBIGUOUS_LATER_ERROR = (fix_close_shift_ambiguous_later,)
# can potentially fix either close/shift or close/open
# as long as the gold transition after the close
# was the same as the transition we just predicted
CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED = (fix_close_next_correct_predicted,)
OPEN_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_open_open_ambiguous_unary,)
OPEN_OPEN_AMBIGUOUS_LATER_ERROR = (fix_open_open_ambiguous_later,)
OPEN_OPEN_AMBIGUOUS_RANDOM_ERROR = (fix_open_open_ambiguous_random,)
SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_shift_open_ambiguous_unary,)
SHIFT_OPEN_AMBIGUOUS_LATER_ERROR = (fix_shift_open_ambiguous_later,)
SHIFT_OPEN_AMBIGUOUS_PREDICTED = (fix_shift_open_ambiguous_predicted,)
OTHER_SHIFT_OPEN = (report_shift_open, False, True)
OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
OTHER_CLOSE_OPEN = (report_close_open, False, True)
OTHER_OPEN_OPEN = (report_open_open, False, True)
class TopDownOracle(DynamicOracle):
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
================================================
FILE: stanza/models/constituency/trainer.py
================================================
"""
This file includes a variety of methods needed to train new
constituency parsers. It also includes a method to load an
already-trained parser.
See the `train` method for the code block which starts from
raw treebank and returns a new parser.
`evaluate` reads a treebank and gives a score for those trees.
"""
import copy
import logging
import os
import torch
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain, NoTransformerFoundationCache
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper, pop_peft_args
from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
from stanza.models.constituency.lstm_model import LSTMModel, SentenceBoundary, StackHistory, ConstituencyComposition
from stanza.models.constituency.parse_transitions import Transition, TransitionScheme
from stanza.models.constituency.utils import build_optimizer, build_scheduler
# TODO: could put find_wordvec_pretrain, choose_charlm, etc in a more central place if it becomes widely used
from stanza.utils.training.common import find_wordvec_pretrain, choose_charlm, find_charlm_file
from stanza.resources.default_packages import default_charlms, default_pretrains
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.constituency.trainer')
class Trainer(BaseTrainer):
"""
Stores a constituency model and its optimizer
Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)
"""
def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
super().__init__(model, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
def save(self, filename, save_optimizer=True):
"""
Save the model (and by default the optimizer) to the given path
"""
super().save(filename, save_optimizer)
def get_peft_params(self):
# Hide import so that peft dependency is optional
if self.model.args.get('use_peft', False):
from peft import get_peft_model_state_dict
return get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
return None
@property
def model_type(self):
return ModelType.LSTM
@staticmethod
def find_and_load_pretrain(saved_args, foundation_cache):
if 'wordvec_pretrain_file' not in saved_args:
return None
if os.path.exists(saved_args['wordvec_pretrain_file']):
return load_pretrain(saved_args['wordvec_pretrain_file'], foundation_cache)
logger.info("Unable to find pretrain in %s Will try to load from the default resources instead", saved_args['wordvec_pretrain_file'])
language = saved_args['lang']
wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains)
return load_pretrain(wordvec_pretrain, foundation_cache)
@staticmethod
def find_and_load_charlm(charlm_file, direction, saved_args, foundation_cache):
try:
return load_charlm(charlm_file, foundation_cache)
except FileNotFoundError as e:
logger.info("Unable to load charlm from %s Will try to load from the default resources instead", charlm_file)
language = saved_args['lang']
dataset = saved_args['shorthand'].split("_")[1]
charlm = choose_charlm(language, dataset, "default", default_charlms, {})
charlm_file = find_charlm_file(direction, language, charlm)
return load_charlm(charlm_file, foundation_cache)
def log_num_words_known(self, words):
tlogger.info("Number of words in the training set found in the embedding: %d out of %d", self.model.num_words_known(words), len(words))
@staticmethod
def load_optimizer(model, checkpoint, first_optimizer, filename):
optimizer = build_optimizer(model.args, model, first_optimizer)
if checkpoint.get('optimizer_state_dict', None) is not None:
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except ValueError as e:
raise ValueError("Failed to load optimizer from %s" % filename) from e
else:
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
return optimizer
@staticmethod
def load_scheduler(model, optimizer, checkpoint, first_optimizer):
scheduler = build_scheduler(model.args, optimizer, first_optimizer=first_optimizer)
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return scheduler
@staticmethod
def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
"""
Build a new model just from the saved params and some extra args
Refactoring allows other processors to include a constituency parser as a module
"""
saved_args = dict(params['config'])
if isinstance(saved_args['sentence_boundary_vectors'], str):
saved_args['sentence_boundary_vectors'] = SentenceBoundary[saved_args['sentence_boundary_vectors']]
if isinstance(saved_args['constituency_composition'], str):
saved_args['constituency_composition'] = ConstituencyComposition[saved_args['constituency_composition']]
if isinstance(saved_args['transition_stack'], str):
saved_args['transition_stack'] = StackHistory[saved_args['transition_stack']]
if isinstance(saved_args['constituent_stack'], str):
saved_args['constituent_stack'] = StackHistory[saved_args['constituent_stack']]
if isinstance(saved_args['transition_scheme'], str):
saved_args['transition_scheme'] = TransitionScheme[saved_args['transition_scheme']]
# some parameters which change the structure of a model have
# to be ignored, or the model will not function when it is
# reloaded from disk
if args is None: args = {}
update_args = copy.deepcopy(args)
pop_peft_args(update_args)
update_args.pop("bert_hidden_layers", None)
update_args.pop("bert_model", None)
update_args.pop("constituency_composition", None)
update_args.pop("constituent_stack", None)
update_args.pop("num_tree_lstm_layers", None)
update_args.pop("transition_scheme", None)
update_args.pop("transition_stack", None)
update_args.pop("maxout_k", None)
# if the pretrain or charlms are not specified, don't override the values in the model
# (if any), since the model won't even work without loading the same charlm
if 'wordvec_pretrain_file' in update_args and update_args['wordvec_pretrain_file'] is None:
update_args.pop('wordvec_pretrain_file')
if 'charlm_forward_file' in update_args and update_args['charlm_forward_file'] is None:
update_args.pop('charlm_forward_file')
if 'charlm_backward_file' in update_args and update_args['charlm_backward_file'] is None:
update_args.pop('charlm_backward_file')
# we don't pop bert_finetune, with the theory being that if
# the saved model has bert_finetune==True we can load the bert
# weights but then not further finetune if bert_finetune==False
saved_args.update(update_args)
# TODO: not needed if we rebuild the models
if saved_args.get("bert_finetune", None) is None:
saved_args["bert_finetune"] = False
if saved_args.get("stage1_bert_finetune", None) is None:
saved_args["stage1_bert_finetune"] = False
model_type = params['model_type']
if model_type == 'LSTM':
pt = Trainer.find_and_load_pretrain(saved_args, foundation_cache)
if saved_args.get('use_peft', False):
# if loading a peft model, we first load the base transformer
# then we load the weights using the saved weights in the file
if peft_name is None:
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(saved_args.get('bert_model', None), "constituency", foundation_cache)
else:
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
bert_model = load_peft_wrapper(bert_model, peft_params, saved_args, logger, peft_name)
bert_saved = True
elif saved_args['bert_finetune'] or saved_args['stage1_bert_finetune'] or any(x.startswith("bert_model.") for x in params['model'].keys()):
# if bert_finetune is True, don't use the cached model!
# otherwise, other uses of the cached model will be ruined
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None))
bert_saved = True
else:
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
bert_saved = False
forward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_forward_file"], "forward", saved_args, foundation_cache)
backward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_backward_file"], "backward", saved_args, foundation_cache)
# TODO: the isinstance will be unnecessary after 1.10.0
transitions = params['transitions']
if all(isinstance(x, str) for x in transitions):
transitions = [Transition.from_repr(x) for x in transitions]
model = LSTMModel(pretrain=pt,
forward_charlm=forward_charlm,
backward_charlm=backward_charlm,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
force_bert_saved=bert_saved,
peft_name=peft_name,
transitions=transitions,
constituents=params['constituents'],
tags=params['tags'],
words=params['words'],
rare_words=set(params['rare_words']),
root_labels=params['root_labels'],
constituent_opens=params['constituent_opens'],
unary_limit=params['unary_limit'],
args=saved_args)
else:
raise ValueError("Unknown model type {}".format(model_type))
model.load_state_dict(params['model'], strict=False)
# model will stay on CPU if device==None
# can be moved elsewhere later, of course
model = model.to(args.get('device', None))
return model
@staticmethod
def build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file):
# TODO: turn finetune, relearn_structure, multistage into an enum?
# finetune just means continue learning, so checkpoint is sufficient
# relearn_structure is essentially a one stage multistage
# multistage with a checkpoint will have the proper optimizer for that epoch
# and no special learning mode means we are training a new model and should continue
if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']):
tlogger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name'])
trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache)
return trainer
# in the 'finetune' case, this will preload the models into foundation_cache,
# so the effort is not wasted
pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])
forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
if args['finetune']:
tlogger.info("Loading model to finetune: %s", model_load_file)
trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=NoTransformerFoundationCache(foundation_cache))
# a new finetuning will start with a new epochs_trained count
trainer.epochs_trained = 0
return trainer
if args['relearn_structure']:
tlogger.info("Loading model to continue training with new structure from %s", model_load_file)
temp_args = dict(args)
# remove the pattn & lattn layers unless the saved model had them
temp_args.pop('pattn_num_layers', None)
temp_args.pop('lattn_d_proj', None)
trainer = Trainer.load(model_load_file, temp_args, load_optimizer=False, foundation_cache=NoTransformerFoundationCache(foundation_cache))
# using the model's current values works for if the new
# dataset is the same or smaller
# TODO: handle a larger dataset as well
model = LSTMModel(pt,
forward_charlm,
backward_charlm,
trainer.model.bert_model,
trainer.model.bert_tokenizer,
trainer.model.force_bert_saved,
trainer.model.peft_name,
trainer.model.transitions,
trainer.model.constituents,
trainer.model.tags,
trainer.model.delta_words,
trainer.model.rare_words,
trainer.model.root_labels,
trainer.model.constituent_opens,
trainer.model.unary_limit(),
args)
model = model.to(args['device'])
model.copy_with_new_structure(trainer.model)
optimizer = build_optimizer(args, model, False)
scheduler = build_scheduler(args, optimizer)
trainer = Trainer(model, optimizer, scheduler)
return trainer
if args['multistage']:
# run adadelta over the model for half the time with no pattn or lattn
# training then switches to a different optimizer for the rest
# this works surprisingly well
tlogger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2)
temp_args = dict(args)
# remove the attention layers for the temporary model
temp_args['pattn_num_layers'] = 0
temp_args['lattn_d_proj'] = 0
args = temp_args
peft_name = None
if args['use_peft']:
peft_name = "constituency"
bert_model, bert_tokenizer = load_bert(args['bert_model'])
bert_model = build_peft_wrapper(bert_model, temp_args, tlogger, adapter_name=peft_name)
elif args['bert_finetune'] or args['stage1_bert_finetune']:
bert_model, bert_tokenizer = load_bert(args['bert_model'])
else:
bert_model, bert_tokenizer = load_bert(args['bert_model'], foundation_cache)
model = LSTMModel(pt,
forward_charlm,
backward_charlm,
bert_model,
bert_tokenizer,
False,
peft_name,
train_transitions,
train_constituents,
tags,
words,
rare_words,
root_labels,
open_nodes,
unary_limit,
args)
model = model.to(args['device'])
optimizer = build_optimizer(args, model, build_simple_adadelta=args['multistage'])
scheduler = build_scheduler(args, optimizer, first_optimizer=args['multistage'])
trainer = Trainer(model, optimizer, scheduler, first_optimizer=args['multistage'])
return trainer
================================================
FILE: stanza/models/constituency/transformer_tree_stack.py
================================================
"""
Based on
Transition-based Parsing with Stack-Transformers
Ramon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem,
Austin Blodget, and Radu Florian
https://aclanthology.org/2020.findings-emnlp.89.pdf
"""
from collections import namedtuple
import torch
import torch.nn as nn
from stanza.models.constituency.positional_encoding import SinusoidalEncoding
from stanza.models.constituency.tree_stack import TreeStack
Node = namedtuple("Node", ['value', 'key_stack', 'value_stack', 'output'])
class TransformerTreeStack(nn.Module):
def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1):
"""
Builds the internal matrices and start parameter
TODO: currently only one attention head, implement MHA
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.inv_sqrt_output_size = 1 / output_size ** 0.5
self.num_heads = num_heads
self.w_query = nn.Linear(input_size, output_size)
self.w_key = nn.Linear(input_size, output_size)
self.w_value = nn.Linear(input_size, output_size)
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
if isinstance(input_dropout, nn.Module):
self.input_dropout = input_dropout
else:
self.input_dropout = nn.Dropout(input_dropout)
if length_limit is not None and length_limit < 1:
raise ValueError("length_limit < 1 makes no sense")
self.length_limit = length_limit
self.use_position = use_position
if use_position:
self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512)
def attention(self, key, query, value, mask=None):
"""
Calculate attention for the given key, query value
Where B is the number of items stacked together, N is the length:
The key should be BxNxD
The query is BxD
The value is BxNxD
If mask is specified, it should be BxN of True/False values,
where True means that location is masked out
Reshapes and reorders are used to handle num_heads
Return will be softmax(query x key^T) * value
of size BxD
"""
B = key.shape[0]
N = key.shape[1]
D = key.shape[2]
H = self.num_heads
# query is now BxDx1
query = query.unsqueeze(2)
# BxHxD/Hx1
query = query.reshape((B, H, -1, 1))
# BxNxHxD/H
key = key.reshape((B, N, H, -1))
# BxHxNxD/H
key = key.transpose(1, 2)
# BxNxHxD/H
value = value.reshape((B, N, H, -1))
# BxHxNxD/H
value = value.transpose(1, 2)
# BxHxNxD/H x BxHxD/Hx1
# result shape: BxHxN
attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size
if mask is not None:
# mask goes from BxN -> Bx1xN
mask = mask.unsqueeze(1)
mask = mask.expand(-1, H, -1)
attn.masked_fill_(mask, float('-inf'))
# attn shape will now be BxHx1xN
attn = torch.softmax(attn, dim=2).unsqueeze(2)
# BxHx1xN x BxHxNxD/H -> BxHxD/H
output = torch.matmul(attn, value).squeeze(2)
output = output.reshape(B, -1)
return output
def initial_state(self, initial_value=None):
"""
Return an initial state based on a single layer of attention
Running attention might be overkill, but it is the simplest
way to put the Linears and start_embedding in the computation graph
"""
start = self.start_embedding
if self.use_position:
position = self.position_encoding([0]).squeeze(0)
start = start + position
# N=1
# shape: 1xD
key = self.w_key(start).unsqueeze(0)
# shape: D
query = self.w_query(start)
# shape: 1xD
value = self.w_value(start).unsqueeze(0)
# unsqueeze to make it look like we are part of a batch of size 1
output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0)
return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1)
def push_states(self, stacks, values, inputs):
"""
Push new inputs to the stacks and rerun attention on them
Where B is the number of items stacked together, I is input_size
stacks: B TreeStacks such as produced by initial_state and/or push_states
values: the new items to push on the stacks such as tree nodes or anything
inputs: BxI for the new input items
Runs attention starting from the existing keys & values
"""
device = self.w_key.weight.device
batch_len = len(stacks) # B
positions = [x.value.key_stack.shape[0] for x in stacks]
max_len = max(positions) # N
if self.use_position:
position_encodings = self.position_encoding(positions)
inputs = inputs + position_encodings
inputs = self.input_dropout(inputs)
if len(inputs.shape) == 3:
if inputs.shape[0] == 1:
inputs = inputs.squeeze(0)
else:
raise ValueError("Expected the inputs to be of shape 1xBxI, got {}".format(inputs.shape))
new_keys = self.w_key(inputs)
key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
key_stack[:, -1, :] = new_keys
for stack_idx, stack in enumerate(stacks):
key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack
new_values = self.w_value(inputs)
value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
value_stack[:, -1, :] = new_values
for stack_idx, stack in enumerate(stacks):
value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack
query = self.w_query(inputs)
mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool)
for stack_idx, stack in enumerate(stacks):
if len(stack) < max_len:
masked = max_len - positions[stack_idx]
mask[stack_idx, :masked] = True
batched_output = self.attention(key_stack, query, value_stack, mask)
new_stacks = []
for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)):
# max_len-len(stack) so that we ignore the padding at the start of shorter stacks
new_key_stack = new_key[max_len-positions[stack_idx]:, :]
new_value_stack = new_value[max_len-positions[stack_idx]:, :]
if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1:
new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0)
new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0)
new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output)))
return new_stacks
def output(self, stack):
"""
Return the last layer of the lstm_hx as the output from a stack
Refactored so that alternate structures have an easy way of getting the output
"""
return stack.value.output
================================================
FILE: stanza/models/constituency/transition_sequence.py
================================================
"""
Build a transition sequence from parse trees.
Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER
"""
import logging
from stanza.models.common import utils
from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize
from stanza.models.constituency.tree_reader import read_trees
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
logger = logging.getLogger('stanza.constituency.trainer')
def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
"""
For tree (X A B C D), yield Open(X) A B C D Close
The details are in how to treat unary transitions
Three possibilities handled by this method:
TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y)
TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close
TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close
"""
if tree.is_preterminal():
yield Shift()
return
if tree.is_leaf():
return
if transition_scheme is TransitionScheme.TOP_DOWN_UNARY:
if len(tree.children) == 1:
labels = []
while not tree.is_preterminal() and len(tree.children) == 1:
labels.append(tree.label)
tree = tree.children[0]
for transition in yield_top_down_sequence(tree, transition_scheme):
yield transition
yield CompoundUnary(*labels)
return
if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
labels = [tree.label]
while len(tree.children) == 1 and not tree.children[0].is_preterminal():
tree = tree.children[0]
labels.append(tree.label)
yield OpenConstituent(*labels)
else:
yield OpenConstituent(tree.label)
for child in tree.children:
for transition in yield_top_down_sequence(child, transition_scheme):
yield transition
yield CloseConstituent()
def yield_in_order_sequence(tree):
"""
For tree (X A B C D), yield A Open(X) B C D Close
"""
if tree.is_preterminal():
yield Shift()
return
if tree.is_leaf():
return
for transition in yield_in_order_sequence(tree.children[0]):
yield transition
yield OpenConstituent(tree.label)
for child in tree.children[1:]:
for transition in yield_in_order_sequence(child):
yield transition
yield CloseConstituent()
def yield_in_order_compound_sequence(tree, transition_scheme):
def helper(tree):
if tree.is_leaf():
return
labels = []
while len(tree.children) == 1 and not tree.is_preterminal():
labels.append(tree.label)
tree = tree.children[0]
if tree.is_preterminal():
yield Shift()
if len(labels) > 0:
yield CompoundUnary(*labels)
return
for transition in helper(tree.children[0]):
yield transition
if transition_scheme is TransitionScheme.IN_ORDER_UNARY:
yield OpenConstituent(tree.label)
else:
labels.append(tree.label)
yield OpenConstituent(*labels)
for child in tree.children[1:]:
for transition in helper(child):
yield transition
yield CloseConstituent()
if transition_scheme is TransitionScheme.IN_ORDER_UNARY and len(labels) > 0:
yield CompoundUnary(*labels)
if len(tree.children) == 0:
raise ValueError("Cannot build {} on an empty tree".format(transition_scheme))
if len(tree.children) != 1:
raise ValueError("Cannot build {} with a tree that has two top level nodes: {}".format(transition_scheme, tree))
for t in helper(tree.children[0]):
yield t
yield Finalize(tree.label)
def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
"""
Turn a single tree into a list of transitions based on the TransitionScheme
"""
if transition_scheme is TransitionScheme.IN_ORDER:
return list(yield_in_order_sequence(tree))
elif (transition_scheme is TransitionScheme.IN_ORDER_COMPOUND or
transition_scheme is TransitionScheme.IN_ORDER_UNARY):
return list(yield_in_order_compound_sequence(tree, transition_scheme))
else:
return list(yield_top_down_sequence(tree, transition_scheme))
def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, reverse=False):
"""
Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme
"""
if reverse:
return [build_sequence(tree.reverse(), transition_scheme) for tree in trees]
else:
return [build_sequence(tree, transition_scheme) for tree in trees]
def all_transitions(transition_lists):
"""
Given a list of transition lists, combine them all into a list of unique transitions.
"""
transitions = set()
for trans_list in transition_lists:
transitions.update(trans_list)
return sorted(transitions)
def convert_trees_to_sequences(trees, treebank_name, transition_scheme, reverse=False):
"""
Wrap both build_treebank and all_transitions, possibly with a tqdm
Converts trees to a list of sequences, then returns the list of known transitions
"""
if len(trees) == 0:
return [], []
logger.info("Building %s transition sequences", treebank_name)
if logger.getEffectiveLevel() <= logging.INFO:
trees = tqdm(trees)
sequences = build_treebank(trees, transition_scheme, reverse)
transitions = all_transitions(sequences)
return sequences, transitions
def main():
"""
Convert a sample tree and print its transitions
"""
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
#text = "(WP Who)"
tree = read_trees(text)[0]
print(tree)
transitions = build_sequence(tree)
print(transitions)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/constituency/tree_embedding.py
================================================
"""
A module to use a Constituency Parser to make an embedding for a tree
The embedding can be produced just from the words and the top of the
tree, or it can be done with a form of attention over the nodes
Can be done over an existing parse tree or unparsed text
"""
import torch
import torch.nn as nn
from stanza.models.constituency.trainer import Trainer
class TreeEmbedding(nn.Module):
def __init__(self, constituency_parser, args):
super(TreeEmbedding, self).__init__()
self.config = {
"all_words": args["all_words"],
"backprop": args["backprop"],
#"batch_norm": args["batch_norm"],
"node_attn": args["node_attn"],
"top_layer": args["top_layer"],
}
self.constituency_parser = constituency_parser
# word_lstm: hidden_size * num_tree_lstm_layers * 2 (start & end)
# transition_stack: transition_hidden_size
# constituent_stack: hidden_size
self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size
if self.config["all_words"]:
self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
else:
self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2
if self.config["node_attn"]:
self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size)
self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
# TODO: cat transition and constituent hx as well?
self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
else:
self.output_size = self.hidden_size
# TODO: maybe have batch_norm, maybe use Identity
#if self.config["batch_norm"]:
# self.input_norm = nn.BatchNorm1d(self.output_size)
def embed_trees(self, inputs):
if self.config["backprop"]:
states = self.constituency_parser.analyze_trees(inputs)
else:
with torch.no_grad():
states = self.constituency_parser.analyze_trees(inputs)
constituent_lists = [x.constituents for x in states]
states = [x.state for x in states]
word_begin_hx = torch.stack([state.word_queue[0].hx for state in states])
word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])
transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states])
# go down one layer to get the embedding off the top of the S, not the ROOT
# (in terms of the typical treebank)
# the idea being that the ROOT has no additional information
# and may even have 0s for the embedding in certain circumstances,
# such as after learning UNTIED_MAX long enough
if self.config["top_layer"]:
constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states])
else:
constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0)
if self.config["all_words"]:
# need B matrices of N x hidden_size
key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0)
for state, thx, chx in zip(states, transition_hx, constituent_hx)]
else:
key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1)
if not self.config["node_attn"]:
return key
key = [self.key(x) for x in key]
node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists]
queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
# TODO: could pad to make faster here
attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)]
attn = [torch.softmax(x, dim=0) for x in attn]
previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)]
return previous_layer
def forward(self, inputs):
return embed_trees(self, inputs)
def get_norms(self):
lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()]
for name, param in self.named_parameters():
if param.requires_grad and not name.startswith('constituency_parser.'):
lines.append("%s %.6g" % (name, torch.norm(param).item()))
return lines
def get_params(self, skip_modules=True):
model_state = self.state_dict()
# skip all of the constituency parameters here -
# we will add them by calling the model's get_params()
skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")]
for k in skipped:
del model_state[k]
parser = self.constituency_parser.get_params(skip_modules)
params = {
'model': model_state,
'constituency': parser,
'config': self.config,
}
return params
@staticmethod
def from_parser_file(args, foundation_cache=None):
constituency_parser = Trainer.load(args['model'], args, foundation_cache)
return TreeEmbedding(constituency_parser.model, args)
@staticmethod
def model_from_params(params, args, foundation_cache=None):
# TODO: integrate with peft
constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache)
model = TreeEmbedding(constituency_parser, params['config'])
model.load_state_dict(params['model'], strict=False)
return model
================================================
FILE: stanza/models/constituency/tree_reader.py
================================================
"""
Reads ParseTree objects from a file, string, or similar input
Works by first splitting the input into (, ), and all other tokens,
then recursively processing those tokens into trees.
"""
from collections import deque
import logging
import os
import re
from stanza.models.constituency.parse_tree import Tree
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
OPEN_PAREN = "("
CLOSE_PAREN = ")"
logger = logging.getLogger('stanza.constituency')
# A few specific exception types to clarify parsing errors
# They store the line number where the error occurred
class UnclosedTreeError(ValueError):
"""
A tree looked like (Foo
"""
def __init__(self, line_num):
super().__init__("Found an unfinished tree (missing close brackets). Tree started on line %d" % line_num)
self.line_num = line_num
class ExtraCloseTreeError(ValueError):
"""
A tree looked like (Foo))
"""
def __init__(self, line_num):
super().__init__("Found a broken tree (extra close brackets). Tree started on line %d" % line_num)
self.line_num = line_num
class UnlabeledTreeError(ValueError):
"""
A tree had no label, such as ((Foo) (Bar))
This does not actually happen at the root, btw, as ROOT is silently added
"""
def __init__(self, line_num):
super().__init__("Found a tree with no label on a node! Line number %d" % line_num)
self.line_num = line_num
class MixedTreeError(ValueError):
"""
Leaf and constituent children are mixed in the same node
"""
def __init__(self, line_num, child_label, children):
super().__init__("Found a tree with both text children and bracketed children! Line number {} Child label {} Children {}".format(line_num, child_label, children))
self.line_num = line_num
self.child_label = child_label
self.children = children
def normalize(text):
return text.replace("-LRB-", "(").replace("-RRB-", ")")
def read_single_tree(token_iterator, broken_ok):
"""
Build a tree from the tokens in the token_iterator
"""
# we were called here at a open paren, so start the stack of
# children with one empty list already on it
children_stack = deque()
children_stack.append([])
text_stack = deque()
text_stack.append([])
token = next(token_iterator, None)
token_iterator.set_mark()
while token is not None:
if token == OPEN_PAREN:
children_stack.append([])
text_stack.append([])
elif token == CLOSE_PAREN:
text = text_stack.pop()
children = children_stack.pop()
if text:
pieces = " ".join(text).split()
if len(pieces) == 1:
child = Tree(pieces[0], children)
else:
# the assumption here is that a language such as VI may
# have spaces in the words, but it still represents
# just one child
label = pieces[0]
child_label = " ".join(pieces[1:])
if children:
if broken_ok:
child = Tree(label, children + [Tree(normalize(child_label))])
else:
raise MixedTreeError(token_iterator.line_num, child_label, children)
else:
child = Tree(label, Tree(normalize(child_label)))
if not children_stack:
return child
else:
if not children_stack:
return Tree("ROOT", children)
elif broken_ok:
child = Tree(None, children)
else:
raise UnlabeledTreeError(token_iterator.line_num)
children_stack[-1].append(child)
else:
text_stack[-1].append(token)
token = next(token_iterator, None)
raise UnclosedTreeError(token_iterator.get_mark())
LINE_SPLIT_RE = re.compile(r"([()])")
class TokenIterator:
"""
A specific iterator for reading trees from a tree file
The idea is that this will keep track of which line
we are processing, so that an error can be logged
from the correct line
"""
def __init__(self):
self.token_iterator = iter([])
self.line_num = -1
self.mark = None
def set_mark(self):
"""
The mark is used for determining where the start of a tree occurs for an error
"""
self.mark = self.line_num
def get_mark(self):
if self.mark is None:
raise ValueError("No mark set!")
return self.mark
def __iter__(self):
return self
def __next__(self):
n = next(self.token_iterator, None)
while n is None:
self.line_num = self.line_num + 1
line = next(self.line_iterator)
if line is None:
raise StopIteration
line = line.strip()
if not line:
continue
pieces = LINE_SPLIT_RE.split(line)
pieces = [x.strip() for x in pieces]
pieces = [x for x in pieces if x]
self.token_iterator = iter(pieces)
n = next(self.token_iterator, None)
return n
class TextTokenIterator(TokenIterator):
def __init__(self, text, use_tqdm=True):
super().__init__()
self.lines = text.split("\n")
self.num_lines = len(self.lines)
if self.num_lines > 1000 and use_tqdm:
self.line_iterator = iter(tqdm(self.lines))
else:
self.line_iterator = iter(self.lines)
class FileTokenIterator(TokenIterator):
def __init__(self, filename):
super().__init__()
self.filename = filename
def __enter__(self):
# TODO: use the file_size instead of counting the lines
# file_size = Path(self.filename).stat().st_size
with open(self.filename) as fin:
num_lines = sum(1 for _ in fin)
self.file_obj = open(self.filename)
if num_lines > 1000:
self.line_iterator = iter(tqdm(self.file_obj, total=num_lines))
else:
self.line_iterator = iter(self.file_obj)
return self
def __exit__(self, exc_type, exc_value, exc_tb):
if self.file_obj:
self.file_obj.close()
def read_token_iterator(token_iterator, broken_ok, tree_callback):
trees = []
token = next(token_iterator, None)
while token:
if token == OPEN_PAREN:
next_tree = read_single_tree(token_iterator, broken_ok=broken_ok)
if next_tree is None:
raise ValueError("Tree reader somehow created a None tree! Line number %d" % token_iterator.line_num)
if tree_callback is not None:
transformed = tree_callback(next_tree)
if transformed is not None:
trees.append(transformed)
else:
trees.append(next_tree)
token = next(token_iterator, None)
elif token == CLOSE_PAREN:
raise ExtraCloseTreeError(token_iterator.line_num)
else:
raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num)
return trees
def read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True):
"""
Reads multiple trees from the text
TODO: some of the error cases we hit can be recovered from
"""
token_iterator = TextTokenIterator(text, use_tqdm)
return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
def read_tree_file(filename, broken_ok=False, tree_callback=None):
"""
Read all of the trees in the given file
"""
with FileTokenIterator(filename) as token_iterator:
trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
return trees
def read_directory(dirname, broken_ok=False, tree_callback=None):
"""
Read all of the trees in all of the files in a directory
"""
trees = []
for filename in sorted(os.listdir(dirname)):
full_name = os.path.join(dirname, filename)
trees.extend(read_tree_file(full_name, broken_ok, tree_callback))
return trees
def read_treebank(filename, tree_callback=None):
"""
Read a treebank and alter the trees to be a simpler format for learning to parse
"""
logger.info("Reading trees from %s", filename)
trees = read_tree_file(filename, tree_callback=tree_callback)
trees = [t.prune_none().simplify_labels() for t in trees]
illegal_trees = [t for t in trees if len(t.children) > 1]
if len(illegal_trees) > 0:
raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {:P}".format(len(illegal_trees), illegal_trees[0]))
return trees
def main():
"""
Reads a sample tree
"""
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = read_trees(text)
print(trees)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/constituency/tree_stack.py
================================================
"""
A utilitiy class for keeping track of intermediate parse states
"""
from collections import namedtuple
class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
"""
A stack which can branch in several directions, as long as you
keep track of the branching heads
An example usage is when K constituents are removed at once
to create a new constituent, and then the LSTM which tracks the
values of the constituents is updated starting from the Kth
output of the LSTM with the new value.
We don't simply keep track of a single stack object using a deque
because versions of the parser which use a beam will want to be
able to branch in different directions from the same base stack
Another possible usage is if an oracle is used for training
in a manner where some fraction of steps are non-gold steps,
but we also want to take a gold step from the same state.
Eg, parser gets to state X, wants to make incorrect transition T
instead of gold transition G, and so we continue training both
X+G and X+T. If we only represent the state X with standard
python stacks, it would not be possible to track both of these
states at the same time without copying the entire thing.
Value can be as transition, a word, or a partially built constituent
Implemented as a namedtuple to make it a bit more efficient
"""
def pop(self):
return self.parent
def push(self, value):
# returns a new stack node which points to this
return TreeStack(value, self, self.length+1)
def __iter__(self):
stack = self
while stack.parent is not None:
yield stack.value
stack = stack.parent
yield stack.value
def __reversed__(self):
items = list(iter(self))
for item in reversed(items):
yield item
def __str__(self):
return "TreeStack(%s)" % ", ".join([str(x) for x in self])
def __len__(self):
return self.length
================================================
FILE: stanza/models/constituency/utils.py
================================================
"""
Collects a few of the conparser utility methods which don't belong elsewhere
"""
from collections import Counter
import logging
import warnings
import torch.nn as nn
from torch import optim
from stanza.models.common.doc import TEXT, Document
from stanza.models.common.utils import get_optimizer
from stanza.models.constituency.base_model import SimpleModel
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 }
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
DEFAULT_LEARNING_RHO = 0.9
DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 }
tlogger = logging.getLogger('stanza.constituency.trainer')
# madgrad experiment for weight decay
# with learning_rate set to 0.0000007 and momentum 0.9
# on en_wsj, with a baseline model trained on adadela for 200,
# then madgrad used to further improve that model
# 0.00000002.out: 0.9590347746438835
# 0.00000005.out: 0.9591378819960182
# 0.0000001.out: 0.9595450596319405
# 0.0000002.out: 0.9594603134479271
# 0.0000005.out: 0.9591317672706594
# 0.000001.out: 0.9592548741021389
# 0.000002.out: 0.9598395477013945
# 0.000003.out: 0.9594974271553495
# 0.000004.out: 0.9596665982603754
# 0.000005.out: 0.9591620720706487
DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 }
def retag_tags(doc, pipelines, xpos):
"""
Returns a list of list of tags for the items in doc
doc can be anything which feeds into the pipeline(s)
pipelines are a list of 1 or more retag pipelines
if multiple pipelines are given, majority vote wins
"""
tag_lists = []
for pipeline in pipelines:
doc = pipeline(doc)
tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])
# tag_lists: for N pipeline, S sentences
# we now have N lists of S sentences each
# for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s
# for tag in zip(*sentence): N predicted tags.
# most common one in the Counter will be chosen
tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]
for sentence in zip(*tag_lists)]
return tag_lists
def retag_trees(trees, pipelines, xpos=True):
"""
Retag all of the trees using the given processor
Returns a list of new trees
"""
if len(trees) == 0:
return trees
new_trees = []
chunk_size = 1000
with tqdm(total=len(trees)) as pbar:
for chunk_start in range(0, len(trees), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(trees))
chunk = trees[chunk_start:chunk_end]
sentences = []
try:
for idx, tree in enumerate(chunk):
tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
sentences.append(tokens)
except ValueError as e:
raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
doc = Document(sentences)
tag_lists = retag_tags(doc, pipelines, xpos)
for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
try:
if any(tag is None for tag in tags):
raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags))
new_tree = tree.replace_tags(tags)
new_trees.append(new_tree)
pbar.update(1)
except ValueError as e:
raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
if len(new_trees) != len(trees):
raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees)))
return new_trees
def build_optimizer(args, model, build_simple_adadelta=False):
"""
Build an optimizer based on the arguments given
If we are "multistage" training and epochs_trained < epochs // 2,
we build an AdaDelta optimizer instead of whatever was requested
The build_simple_adadelta parameter controls this
"""
bert_learning_rate = 0.0
bert_weight_decay = args['bert_weight_decay']
if build_simple_adadelta:
optim_type = 'adadelta'
bert_finetune = args.get('stage1_bert_finetune', False)
if bert_finetune:
bert_learning_rate = args['stage1_bert_learning_rate']
learning_beta2 = 0.999 # doesn't matter for AdaDelta
learning_eps = DEFAULT_LEARNING_EPS['adadelta']
learning_rate = args['stage1_learning_rate']
learning_rho = DEFAULT_LEARNING_RHO
momentum = None # also doesn't matter for AdaDelta
weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']
else:
optim_type = args['optim'].lower()
bert_finetune = args.get('bert_finetune', False)
if bert_finetune:
bert_learning_rate = args['bert_learning_rate']
learning_beta2 = args['learning_beta2']
learning_eps = args['learning_eps']
learning_rate = args['learning_rate']
learning_rho = args['learning_rho']
momentum = args['learning_momentum']
weight_decay = args['learning_weight_decay']
# TODO: allow rho as an arg for AdaDelta
return get_optimizer(name=optim_type,
model=model,
lr=learning_rate,
betas=(0.9, learning_beta2),
eps=learning_eps,
momentum=momentum,
weight_decay=weight_decay,
bert_learning_rate=bert_learning_rate,
bert_weight_decay=weight_decay*bert_weight_decay,
is_peft=args.get('use_peft', False),
bert_finetune_layers=args['bert_finetune_layers'],
opt_logger=tlogger)
def build_scheduler(args, optimizer, first_optimizer=False):
"""
Build the scheduler for the conparser based on its args
Used to use a warmup for learning rate, but that wasn't working very well
Now, we just use a ReduceLROnPlateau, which does quite well
"""
#if args.get('learning_rate_warmup', 0) <= 0:
# # TODO: is there an easier way to make an empty scheduler?
# lr_lambda = lambda x: 1.0
#else:
# warmup_end = args['learning_rate_warmup']
# def lr_lambda(x):
# if x >= warmup_end:
# return 1.0
# return x / warmup_end
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
if first_optimizer:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr'])
else:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr'])
return scheduler
def initialize_linear(linear, nonlinearity, bias):
"""
Initializes the bias to a positive value, hopefully preventing dead neurons
"""
if nonlinearity in ('relu', 'leaky_relu'):
nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)
nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)
def add_predict_output_args(parser):
"""
Args specifically for the output location of data
"""
parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB')
def postprocess_predict_output_args(args):
if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")):
args['predict_format'] = "{:" + args['predict_format'] + "}"
def get_open_nodes(trees, transition_scheme):
"""
Return a list of all open nodes in the given dataset.
Depending on the parameters, may be single or compound open transitions.
"""
if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
return Tree.get_compound_constituents(trees)
elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND:
return Tree.get_compound_constituents(trees, separate_root=True)
else:
return [(x,) for x in Tree.get_unique_constituent_labels(trees)]
def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels):
"""
Given a list of trees and their transition sequences, verify that the sequences rebuild the trees
"""
model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels)
tlogger.info("Verifying the transition sequences for %d trees", len(trees))
data = zip(trees, sequences)
if tlogger.getEffectiveLevel() <= logging.INFO:
data = tqdm(zip(trees, sequences), total=len(trees))
for tree_idx, (tree, sequence) in enumerate(data):
# TODO: make the SimpleModel have a parse operation?
state = model.initial_state_from_gold_trees([tree])[0]
for idx, trans in enumerate(sequence):
if not trans.is_legal(state, model):
raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence))
state = trans.apply(state, model)
result = model.get_top_constituent(state.constituents)
if reverse:
result = result.reverse()
if tree != result:
raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result))
def check_constituents(train_constituents, trees, treebank_name, fail=True):
"""
Check that all the constituents in the other dataset are known in the train set
"""
constituents = Tree.get_unique_constituent_labels(trees)
for con in constituents:
if con not in train_constituents:
first_error = None
num_errors = 0
for tree_idx, tree in enumerate(trees):
constituents = Tree.get_unique_constituent_labels(tree)
if con in constituents:
num_errors += 1
if first_error is None:
first_error = tree_idx
error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occurred in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])
if fail:
raise RuntimeError(error)
else:
warnings.warn(error)
def check_root_labels(root_labels, other_trees, treebank_name):
"""
Check that all the root states in the other dataset are known in the train set
"""
for root_state in Tree.get_root_labels(other_trees):
if root_state not in root_labels:
raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name))
def remove_duplicate_trees(trees, treebank_name):
"""
Filter duplicates from the given dataset
"""
new_trees = []
known_trees = set()
for tree in trees:
tree_str = "{}".format(tree)
if tree_str in known_trees:
continue
known_trees.add(tree_str)
new_trees.append(tree)
if len(new_trees) < len(trees):
tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name)
return new_trees
def remove_singleton_trees(trees):
"""
remove trees which are just a root and a single word
TODO: remove these trees in the conversion instead of here
"""
new_trees = [x for x in trees if
len(x.children) > 1 or
(len(x.children) == 1 and len(x.children[0].children) > 1) or
(len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)]
if len(trees) - len(new_trees) > 0:
tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees)))
return new_trees
================================================
FILE: stanza/models/constituency_parser.py
================================================
"""A command line interface to a shift reduce constituency parser.
This follows the work of
Recurrent neural network grammars by Dyer et al
In-Order Transition-based Constituent Parsing by Liu & Zhang
The general outline is:
Train a model by taking a list of trees, converting them to
transition sequences, and learning a model which can predict the
next transition given a current state
Then, at inference time, repeatedly predict the next transition until parsing is complete
The "transitions" are variations on shift/reduce as per an
intro-to-compilers class. The idea is that you can treat all of the
words in a sentence as a buffer of tokens, then either "shift" them to
represent a new constituent, or "reduce" one or more constituents to
form a new constituent.
In order to make the runtime a more competitive speed, effort is taken
to batch the transitions and apply multiple transitions at once. At
train time, batches are groups together by length, and at inference
time, new trees are added to the batch as previous trees on the batch
finish their inference.
There are a few minor differences in the model:
- The word input is a bi-lstm, not a uni-lstm.
This gave a small increase in accuracy.
- The combination of several constituents into one constituent is done
via a single bi-lstm rather than two separate lstms. This increases
speed without a noticeable effect on accuracy.
- In fact, an even better (in terms of final model accuracy) method
is to combine the constituents with torch.max, believe it or not
See lstm_model.py for more details
- Initializing the embeddings with smaller values than pytorch default
For example, on a ja_alt dataset, scores went from 0.8980 to 0.8985
at 200 iterations averaged over 5 trials
- Training with AdaDelta first, then AdamW or madgrad later improves
results quite a bit. See --multistage
A couple experiments which have been tried with little noticeable impact:
- Combining constituents using the method in the paper (only a trained
vector at the start instead of both ends) did not affect results
and is a little slower
- Using multiple layers of LSTM hidden state for the input to the final
classification layers didn't help
- Initializing Linear layers with He initialization and a positive bias
(to avoid dead connections) had no noticeable effect on accuracy
0.8396 on it_turin with the original initialization
0.8401 and 0.8427 on two runs with updated initialization
(so maybe a small improvement...)
- Initializing LSTM layers with different gates was slightly worse:
forget gates of 1.0
forget gates of 1.0, input gates of -1.0
- Replacing the LSTMs that make up the Transition and Constituent
LSTMs with Dynamic Skip LSTMs made no difference, but was slower
- Highway LSTMs also made no difference
- Putting labels on the shift transitions (the word or the tag shifted)
or putting labels on the close transitions didn't help
- Building larger constituents from the output of the constituent LSTM
instead of the children constituents hurts scores
For example, an experiment on ja_alt went from 0.8985 to 0.8964
when built that way
- The initial transition scheme implemented was TOP_DOWN. We tried
a compound unary option, since this worked so well in the CoreNLP
constituency parser. Unfortunately, this is far less effective
than IN_ORDER. Both specialized unary matrices and reusing the
n-ary constituency combination fell short. On the ja_alt dataset:
IN_ORDER, max combination method: 0.8985
TOP_DOWN_UNARY, specialized matrices: 0.8501
TOP_DOWN_UNARY, max combination method: 0.8508
- Adding multiple layers of MLP to combine inputs for words made
no difference in the scores
Tried both before the LSTM and after
A simple single layer tensor multiply after the LSTM works well.
Replacing that with a two layer MLP on the English PTB
with roberta-base causes a notable drop in scores
First experiment didn't use the fancy Linear weight init,
but adding that barely made a difference
260 training iterations on en_wsj dev, roberta-base
model as of bb983fd5e912f6706ad484bf819486971742c3d1
two layer MLP: 0.9409
two layer MLP, init weights: 0.9413
single layer: 0.9467
- There is code to rebuild models with a new structure in lstm_model.py
As part of this, we tried to randomly reinitialize the transitions
if the transition embedding had gone to 0, which often happens
This didn't help at all
- We tried something akin to attention with just the query vector
over the bert embeddings as a way to mix them, but that did not
improve scores.
Example, with a self.bert_layer_mix of size bert_dim x 1:
mixed_bert_embeddings = []
for feature in bert_embeddings:
weighted_feature = self.bert_layer_mix(feature.transpose(1, 2))
weighted_feature = torch.softmax(weighted_feature, dim=1)
weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2)
mixed_bert_embeddings.append(weighted_feature)
bert_embeddings = mixed_bert_embeddings
It seems just finetuning the transformer is already enough
(in general, no need to mix layers at all when finetuning bert embeddings)
The code breakdown is as follows:
this file: main interface for training or evaluating models
constituency/trainer.py: contains the training & evaluation code
constituency/ensemble.py: evaluation code specifically for letting multiple models
vote on the correct next transition. a modest improvement.
constituency/evaluate_treebanks.py: specifically to evaluate multiple parsed treebanks
against a gold. in particular, reports whether the theoretical best from those
parsed treebanks is an improvement (eg, the k-best score as reported by CoreNLP)
constituency/parse_tree.py: a data structure for representing a parse tree and utility methods
constituency/tree_reader.py: a module which can read trees from a string or input file
constituency/tree_stack.py: a linked list which can branch in
different directions, which will be useful when implementing beam
search or a dynamic oracle
constituency/lstm_tree_stack.py: an LSTM over the elements of a TreeStack
constituency/transformer_tree_stack.py: attempts to run attention over the nodes
of a tree_stack. not as effective as the lstm_tree_stack in the initial experiments.
perhaps it could be refined to work better, though
constituency/parse_transitions.py: transitions and a State data structure to store them
constituency/transition_sequence.py: turns ParseTree objects into
the transition sequences needed to make them
constituency/base_model.py: operates on the transitions to turn them in to constituents,
eventually forming one final parse tree composed of all of the constituents
constituency/lstm_model.py: adds LSTM features to the constituents to predict what the
correct transition to make is, allowing for predictions on previously unseen text
constituency/retagging.py: a couple utility methods specifically for retagging
constituency/utils.py: a couple utility methods
constituency/dyanmic_oracle.py: a dynamic oracle which currently
only operates for the inorder transition sequence.
uses deterministic rules to redo the correct action sequence when
the parser makes an error.
constituency/partitioned_transformer.py: implementation of a transformer for self-attention.
presumably this should help, but we have yet to find a model structure where
this makes the scores go up.
constituency/label_attention.py: an even fancier form of transformer based on labeled attention:
https://arxiv.org/abs/1911.03875
constituency/positional_encoding.py: so far, just the sinusoidal is here.
a trained encoding is in partitioned_transformer.py.
this should probably be refactored to common, especially if used elsewhere.
stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline
stanza/utils/datasets/constituency: various scripts and tools for processing constituency datasets
Some alternate optimizer methods:
adabelief: https://github.com/juntang-zhuang/Adabelief-Optimizer
madgrad: https://github.com/facebookresearch/madgrad
"""
import argparse
import logging
import os
import random
import re
import torch
import stanza
from stanza.models.common import constant
from stanza.models.common import utils
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
from stanza.models.common.utils import NONLINEARITY
from stanza.models.constituency import parser_training
from stanza.models.constituency import retagging
from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.text_processing import load_model_parse_text
from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, add_predict_output_args, postprocess_predict_output_args
from stanza.resources.common import DEFAULT_MODEL_DIR
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.constituency.trainer')
def build_argparse():
"""
Adds the arguments for building the con parser
For the most part, defaults are set to cross-validated values, at least for WSJ
"""
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')
parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
# BERT helps a lot and actually doesn't slow things down too much
# for VI, for example, use vinai/phobert-base
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')
# BERT finetuning (or any transformer finetuning)
# also helps quite a lot.
# Experimentally, finetuning all of the layers is the most effective
# On the id_icon dataset with the indolem transformer
# In this experiment, we trained for 150 iterations with AdaDelta,
# with the learning rate 0.01,
# then trained for another 150 with madgrad and no finetuning
# 1 layer 0.880753 (152)
# 2 layers 0.880453 (174)
# 3 layers 0.881774 (163)
# 4 layers 0.886915 (194)
# 5 layers 0.892064 (299)
# 6 layers 0.891825 (224)
# 7 layers 0.894373 (173)
# 8 layers 0.894505 (233)
# 9 layers 0.896676 (269)
# 10 layers 0.897525 (269)
# 11 layers 0.897348 (211)
# 12 layers 0.898729 (270)
# everything 0.898855 (252)
# so the trend is clear that more finetuning is better
#
# We found that finetuning works very well on the AdaDelta portion
# of a multistage training, but less well on a madgrad second
# stage. The issue was that we literally could not set the
# learning rate low enough because madgrad used epsilon in the LR:
# https://github.com/facebookresearch/madgrad/issues/16
#
# Possible values of the AdaDelta learning rate on the id_icon dataset
# In this experiment, we finetuned the entire transformer 150
# iterations on AdaDelta, then trained with madgrad for another
# 150 with no finetuning
# 0.0005: 0.89122 (155)
# 0.001: 0.889807 (241)
# 0.002: 0.894874 (202)
# 0.005: 0.896327 (270)
# 0.006: 0.898989 (246)
# 0.007: 0.896712 (167)
# 0.008: 0.900136 (237)
# 0.009: 0.898597 (169)
# 0.01: 0.898665 (251)
# 0.012: 0.89661 (274)
# 0.014: 0.899149 (283)
# 0.016: 0.896314 (230)
# 0.018: 0.897753 (257)
# 0.02: 0.893665 (256)
# 0.05: 0.849274 (159)
# 0.1: 0.850633 (183)
# 0.2: 0.847332 (176)
#
# The peak is somewhere around 0.008 to 0.014, with the further
# observation that at the 150 iteration mark, 0.09 was winning:
# 0.007: 0.894589 (33)
# 0.008: 0.894777 (53)
# 0.009: 0.896466 (56)
# 0.01: 0.895557 (71)
# 0.012: 0.893479 (45)
# 0.014: 0.89468 (116)
# 0.016: 0.893053 (128)
# 0.018: 0.893086 (48)
#
# Another option is to train for a few iterations with no
# finetuning, then begin finetuning. However, that was not
# beneficial at all.
# Start iteration on id_icon, same setup as above:
# 1: 0.898855 (252)
# 5: 0.897885 (217)
# 10: 0.895367 (215)
# 25: 0.896781 (193)
# 50: 0.895216 (193)
# Using adamw instead of madgrad:
# 1: 0.900594 (226)
# 5: 0.898153 (267)
# 10: 0.898756 (271)
# 25: 0.896867 (256)
# 50: 0.895025 (220)
#
#
# With the observation that very low learning rate is currently
# not working for madgrad, we tried to parameter sweep LR for
# AdamW, and got the following, using a first stage LR of 0.009:
# 0.0: 0.899706 (290)
# 0.00005: 0.899631 (176)
# 0.0001: 0.899851 (233)
# 0.0002: 0.898601 (207)
# 0.0003: 0.899258 (252)
# 0.0004: 0.90033 (187)
# 0.0005: 0.899091 (183)
# 0.001: 0.899791 (268)
# 0.002: 0.899453 (196)
# 0.003: 0.897029 (173)
# 0.004: 0.899566 (290)
# 0.005: 0.899285 (289)
# 0.01: 0.898938 (233)
# 0.02: 0.898983 (248)
# 0.03: 0.898571 (247)
# 0.04: 0.898466 (180)
# 0.05: 0.897448 (214)
# It should be noted that in the 0.0001 range, the epoch to epoch
# change of the Bert weights was almost negligible. Weights would
# change in the 5th or 6th decimal place, if at all.
#
# The conclusion of all these experiments is that, if we are using
# bert_finetuning, the best approach is probably a stage1 learning
# rate of 0.009 or so and a second stage optimizer of adamw with
# no LR or a very low LR. This behavior is what happens with the
# --stage1_bert_finetune flag
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')
parser.add_argument('--bert_finetune_begin_epoch', default=None, type=int, help='Which epoch to start finetuning the transformer')
parser.add_argument('--bert_finetune_end_epoch', default=None, type=int, help='Which epoch to stop finetuning the transformer')
parser.add_argument('--bert_learning_rate', default=0.009, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--stage1_bert_learning_rate', default=None, type=float, help="Scale the learning rate for transformer finetuning by this much only during an AdaDelta warmup")
parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')
parser.add_argument('--stage1_bert_finetune', default=None, action='store_true', help="Finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
parser.add_argument('--no_stage1_bert_finetune', dest='stage1_bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
add_peft_args(parser)
parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature")
# Smaller values also seem to work
# For example, after 700 iterations:
# 32: 0.9174
# 50: 0.9183
# 72: 0.9176
# 100: 0.9185
# not a huge difference regardless
# (these numbers were without retagging)
parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding")
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--no_train_remove_duplicates', default=True, action='store_false', dest="train_remove_duplicates", help="Do/don't remove duplicates from the training file. Could be useful for intentionally reweighting some trees")
parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.')
parser.add_argument('--silver_remove_duplicates', default=False, action='store_true', help="Do/don't remove duplicates from the silver training file. Could be useful for intentionally reweighting some trees")
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
# TODO: possibly refactor --tokenized_file / --tokenized_dir from here & ensemble
parser.add_argument('--xml_tree_file', type=str, default=None, help='Input file of VLSP formatted trees for parsing with parse_text.')
parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
parser.add_argument('--tokenized_dir', type=str, default=None, help='Input directory of tokenized text for parsing with parse_text.')
parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')
add_predict_output_args(parser)
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition")
parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack")
parser.add_argument('--transition_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
parser.add_argument('--transition_heads', default=4, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
parser.add_argument('--constituent_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
parser.add_argument('--constituent_heads', default=8, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
# larger was more effective, up to a point
# substantially smaller, such as 128,
# is fine if bert & charlm are not available
parser.add_argument('--hidden_size', type=int, default=512, help="Size of the output layers for constituency stack and word queue")
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--epoch_size', type=int, default=5000, help="Runs this many trees in an 'epoch' instead of going through the training dataset exactly once. Set to 0 to do the whole training set")
parser.add_argument('--silver_epoch_size', type=int, default=None, help="Runs this many trees in a silver 'epoch'. If not set, will match --epoch_size")
# AdaDelta warmup for the conparser. Motivation: AdaDelta results in
# higher scores overall, but learns 0s for the weights of the pattn and
# lattn layers. AdamW learns weights for pattn, and the models are more
# accurate than models trained without pattn using AdamW, but the models
# are lower scores overall than the AdaDelta models.
#
# This improves that by first running AdaDelta, then switching.
#
# Now, if --multistage is set, run AdaDelta for half the epochs with no
# pattn or lattn. Then start the specified optimizer for the rest of
# the time with the full model. If pattn and lattn are both present,
# the model is 1/2 no attn, 1/4 pattn, 1/4 pattn and lattn
#
# Improvement on the WSJ dev set can be seen from 94.8 to 95.3
# when 4 layers of pattn are trained this way.
# More experiments to follow.
parser.add_argument('--multistage', default=True, action='store_true', help='1/2 epochs with adadelta no pattn or lattn, 1/4 with chosen optim and no lattn, 1/4 full model')
parser.add_argument('--no_multistage', dest='multistage', action='store_false', help="don't do the multistage learning")
# 1 seems to be the most effective, but we should cross-validate
parser.add_argument('--oracle_initial_epoch', type=int, default=1, help="Epoch where we start using the dynamic oracle to let the parser keep going with wrong decisions")
parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition")
parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover")
parser.add_argument('--oracle_level', type=int, default=None, help='Restrict oracle transitions to this level or lower. 0 means off. None means use all oracle transitions.')
parser.add_argument('--additional_oracle_levels', type=str, default=None, help='Add some additional experimental oracle transitions. Basically for A/B testing transitions we expect to be bad.')
parser.add_argument('--deactivated_oracle_levels', type=str, default=None, help='Temporarily turn off a default oracle level. Basically for A/B testing transitions we expect to be bad.')
# 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ
# earlier version of the model (less accurate overall) had the following results with adadelta:
# 30: 0.9085
# 50: 0.9070
# 75: 0.9010
# 150: 0.8985
# as another data point, running a newer version with better constituency lstm behavior had:
# 30: 0.9111
# 50: 0.9094
# checking smaller batch sizes to see how this works, at 135 epochs, the values are
# 10: 0.8919
# 20: 0.9072
# 30: 0.9121
# obviously these experiments aren't the complete story, but it
# looks like 30 trees per batch is the best value for WSJ
# note that these numbers are for adadelta and might not apply
# to other optimizers
# eval batch should generally be faster the bigger the batch,
# up to a point, as it allows for more batching of the LSTM
# operations and the prediction step
parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step')
parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_constituency.pt", help="File name to save the model")
parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
parser.add_argument('--save_each_start', type=int, default=None, help="When to start saving each model")
parser.add_argument('--save_each_frequency', type=int, default=1, help="How frequently to save each model")
parser.add_argument('--no_save_each_optimizer', dest='save_each_optimizer', default=True, action='store_false', help="Don't save the optimizer when saving 'each' model")
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--no_seed', action='store_const', const=None, dest='seed', help='Remove the random seed, resulting in a randomly chosen random seed')
parser.add_argument('--no_check_valid_states', default=True, action='store_false', dest='check_valid_states', help="Don't check the constituents or transitions in the dev set when starting a new parser. Warning: the parser will never guess unknown constituents")
parser.add_argument('--no_strict_check_constituents', default=True, action='store_false', dest='strict_check_constituents', help="Don't check the constituents between the train & dev set. May result in untrainable transitions")
utils.add_device_args(parser)
# Numbers are on a VLSP dataset, before adding attn or other improvements
# baseline is an 80.6 model that occurs when trained using adadelta, lr 1.0
#
# adabelief 0.1: fails horribly
# 0.02: converges very low scores
# 0.01: very slow learning
# 0.002: almost decent
# 0.001: close, but about 1 f1 low on IT
# 0.0005: 79.71
# 0.0002: 80.11
# 0.0001: 79.85
# 0.00005: 80.40
# 0.00002: 80.02
# 0.00001: 78.95
#
# madgrad 0.005: fails horribly
# 0.001: low scores
# 0.0005: still somewhat low
# 0.0002: close, but about 1 f1 low on IT
# 0.0001: 80.04
# 0.00005: 79.91
# 0.00002: 80.15
# 0.00001: 80.44
# 0.000005: 80.34
# 0.000002: 80.39
#
# adamw experiment on a TR dataset (not necessarily the best test case)
# note that at that time, the expected best for adadelta was 0.816
#
# 0.00005 - 0.7925
# 0.0001 - 0.7889
# 0.0002 - 0.8110
# 0.00025 - 0.8108
# 0.0003 - 0.8050
# 0.0005 - 0.8076
# 0.001 - 0.8069
# Numbers on the VLSP Dataset, with --multistage and default learning rates and adabelief optimizer
# Gelu: 82.32
# Mish: 81.95
# ELU: 81.73
# Hardshrink: 0.3
# Hardsigmoid: 79.03
# Hardtanh: 81.44
# Hardswish: 81.67
# Logsigmoid: 80.91
# Prelu: 80.95 (terminated early)
# Relu6: 81.91
# RReLU: 77.00
# Selu: 81.17
# Celu: 81.43
# Silu: 81.90
# Softplus: 80.94
# Softshrink: 0.3
# Softsign: 81.63
# Softshrink: 13.74
#
# Tests with no_charlm, --multitstage
# Gelu
# 0.00002 0.819746
# 0.00005 0.818
# 0.0001 0.818566
# 0.0002 0.819111
# 0.001 0.815609
#
# Mish
# 0.00002 0.816898
# 0.00005 0.821085
# 0.0001 0.817821
# 0.0002 0.818806
# 0.001 0.816494
#
# Relu
# 0.00002 0.818402
# 0.00005 0.819019
# 0.0001 0.821625
# 0.0002 0.820633
# 0.001 0.814315
#
# Relu6
# 0.00002 0.819719
# 0.00005 0.819871
# 0.0001 0.819018
# 0.0002 0.819506
# 0.001 0.819018
parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES))
parser.add_argument('--learning_eps', default=None, type=float, help='eps value to use in the optimizer. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_EPS))
parser.add_argument('--learning_momentum', default=None, type=float, help='Momentum. None uses a default for the given optimizer: {}'.format(DEFAULT_MOMENTUM))
# weight decay values other than adadelta have not been thoroughly tested.
# When using adadelta, weight_decay of 0.01 to 0.001 had the best results.
# 0.1 was very clearly too high. 0.0001 might have been okay.
# Running a series of 5x experiments on a VI dataset:
# 0.030: 0.8167018
# 0.025: 0.81659
# 0.020: 0.81722
# 0.015: 0.81721
# 0.010: 0.81474348
# 0.005: 0.81503
parser.add_argument('--learning_weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')
parser.add_argument('--learning_rho', default=DEFAULT_LEARNING_RHO, type=float, help='Rho parameter in Adadelta')
# A few experiments on beta2 didn't show much benefit from changing it
# On an experiment with training WSJ with default parameters
# AdaDelta for 200 iterations, then training AdamW for 200 more,
# 0.999, 0.997, 0.995 all wound up with 0.9588
# values lower than 0.995 all had a slight dropoff
parser.add_argument('--learning_beta2', default=0.999, type=float, help='Beta2 argument for AdamW')
parser.add_argument('--optim', default=None, help='Optimizer type: SGD, AdamW, Adadelta, AdaBelief, Madgrad')
parser.add_argument('--stage1_learning_rate', default=None, type=float, help='Learning rate to use in the first stage of --multistage. None means use default: {}'.format(DEFAULT_LEARNING_RATES['adadelta']))
parser.add_argument('--learning_rate_warmup', default=0, type=int, help="Number of epochs to ramp up learning rate from 0 to full. Set to 0 to always use the chosen learning rate. Currently not functional, as it didn't do anything")
parser.add_argument('--learning_rate_factor', default=0.6, type=float, help='Plateau learning rate decreate when plateaued')
parser.add_argument('--learning_rate_patience', default=5, type=int, help='Plateau learning rate patience')
parser.add_argument('--learning_rate_cooldown', default=10, type=int, help='Plateau learning rate cooldown')
parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum')
parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)')
parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')
# Large Margin is from Large Margin In Softmax Cross-Entropy Loss
# it did not help on an Italian VIT test
# scores went from 0.8252 to 0.8248
parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`')
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')
# turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout
# this mechanism doesn't actually turn off lstm_layer_dropout (yet)
# but that is set to a default of 0 anyway
# this is reusing the idea presented in
# https://arxiv.org/pdf/2303.01500v2
# "Dropout Reduces Underfitting"
# Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell
# Unfortunately, this does not consistently help results
# Averaged of 5 models w/ transformer, dev / test
# id_icon - improves a little
# baseline 0.8823 0.8904
# early_dropout 40 0.8835 0.8919
# ja_alt - worsens a little
# baseline 0.9308 0.9355
# early_dropout 40 0.9287 0.9345
# vi_vlsp23 - worsens a little
# baseline 0.8262 0.8290
# early_dropout 40 0.8255 0.8286
# We keep this as an available option for further experiments, if needed
parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout')
# When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
# 0.0: 0.9085
# 0.2: 0.9165
# 0.4: 0.9162
# 0.5: 0.9123
# Letting 0.2 and 0.4 run for longer, along with 0.3 as another
# trial, continued to give extremely similar results over time.
# No attempt has been made to test the different dropouts separately...
parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding')
parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer')
# lstm_dropout has not been fully tested yet
# one experiment after 200 iterations (after retagging, so scores are lower than some other experiments):
# 0.0: 0.9093
# 0.1: 0.9094
# 0.2: 0.9094
# 0.3: 0.9076
# 0.4: 0.9077
parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers')
# one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations
# 0.0 0.9091
# 0.1 0.9095
# 0.2 0.9118
# 0.3 0.9123
# 0.4 0.9080
parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM')
parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')
# combining dummy and open node embeddings might be a slight improvement
# for example, after 550 iterations, one experiment had
# True: 0.9154
# False: 0.9150
# another (with a different structure) had 850 iterations
# True: 0.9155
# False: 0.9149
parser.add_argument('--combined_dummy_embedding', default=True, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents")
parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents")
# relu gave at least 1 F1 improvement over tanh in various experiments
# relu & gelu seem roughly the same, but relu is clearly faster.
# relu, 496 iterations: 0.9176
# gelu, 467 iterations: 0.9181
# after the same clock time on the same hardware. the two had been
# trading places in terms of accuracy over those ~500 iterations.
# leaky_relu was not an improvement - a full run on WSJ led to 0.9181 f1 instead of 0.919
# See constituency/utils.py for more extensive comments on nonlinearity options
parser.add_argument('--nonlinearity', default='relu', choices=NONLINEARITY.keys(), help='Nonlinearity to use in the model. relu is a noticeable improvement over tanh')
# In one experiment on an Italian dataset, VIT, we got the following:
# 0.8254 with relu as the nonlinearity (10 trials)
# 0.8265 with maxout, k = 2 (15)
# 0.8253 with maxout, k = 3 (5)
# The speed in terms of trees/second might be slightly slower with maxout.
# 51.4 it/s on a Titan Xp with maxout 2 and 51.9 it/s with relu
# It might also be worth running some experiments with bigger
# output layers to see if that makes up for the difference in score.
parser.add_argument('--maxout_k', default=None, type=int, help="Use maxout layers instead of a nonlinearity for the output layers")
parser.add_argument('--use_silver_words', default=True, dest='use_silver_words', action='store_true', help="Train/don't train word vectors for words only in the silver dataset")
parser.add_argument('--no_use_silver_words', default=True, dest='use_silver_words', action='store_false', help="Train/don't train word vectors for words only in the silver dataset")
parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training')
parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset')
parser.add_argument('--tag_unknown_frequency', default=0.001, type=float, help='How often to replace a tag with UNK when training')
parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs')
parser.add_argument('--num_tree_lstm_layers', default=None, type=int, help='How many layers to use in the TREE_LSTMs, if used. This also increases the width of the word outputs to match the tree lstm inputs. Default 2 if TREE_LSTM or TREE_LSTM_CX, 1 otherwise')
parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level')
parser.add_argument('--sentence_boundary_vectors', default=SentenceBoundary.EVERYTHING, type=lambda x: SentenceBoundary[x.upper()],
help='Vectors to learn at the start & end of sentences. {}'.format(", ".join(x.name for x in SentenceBoundary)))
parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()],
help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition)))
parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)')
parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)")
parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
parser.add_argument('--load_package', type=str, default=None, help='Download an existing stanza package & use this for tests, finetuning, etc')
retagging.add_retag_args(parser)
# Partitioned Attention
parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality')
parser.add_argument('--pattn_morpho_emb_dropout', default=0.2, type=float, help='Dropout rate for morphological features obtained from pretrained model')
parser.add_argument('--pattn_encoder_max_len', default=512, type=int, help='Max length that can be put into the transformer attention layer')
parser.add_argument('--pattn_num_heads', default=8, type=int, help='Partitioned attention model number of attention heads')
parser.add_argument('--pattn_d_kv', default=64, type=int, help='Size of the query and key vector')
parser.add_argument('--pattn_d_ff', default=2048, type=int, help='Size of the intermediate vectors in the feed-forward sublayer')
parser.add_argument('--pattn_relu_dropout', default=0.1, type=float, help='ReLU dropout probability in feed-forward sublayer')
parser.add_argument('--pattn_residual_dropout', default=0.2, type=float, help='Residual dropout probability for all residual connections')
parser.add_argument('--pattn_attention_dropout', default=0.2, type=float, help='Attention dropout probability')
parser.add_argument('--pattn_num_layers', default=0, type=int, help='Number of layers for the Partitioned Attention. Currently turned off')
parser.add_argument('--pattn_bias', default=False, action='store_true', help='Whether or not to learn an additive bias')
# Results seem relatively similar with learned position embeddings or sin/cos position embeddings
parser.add_argument('--pattn_timing', default='sin', choices=['learned', 'sin'], help='Use a learned embedding or a sin embedding')
# Label Attention
parser.add_argument('--lattn_d_input_proj', default=None, type=int, help='If set, project the non-positional inputs down to this size before proceeding.')
parser.add_argument('--lattn_d_kv', default=64, type=int, help='Dimension of the key/query vector')
parser.add_argument('--lattn_d_proj', default=64, type=int, help='Dimension of the output vector from each label attention head')
parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout')
parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')
parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')
parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')
parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')
# currently unused - always assume 1/2 of pattn
#parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')
parser.add_argument('--lattn_d_l', default=32, type=int, help='Number of labels')
parser.add_argument('--lattn_attention_dropout', default=0.2, type=float, help='Dropout for attention layer')
parser.add_argument('--lattn_d_ff', default=2048, type=int, help='Dimension of the Feed-forward layer')
parser.add_argument('--lattn_relu_dropout', default=0.2, type=float, help='Relu dropout for the label attention')
parser.add_argument('--lattn_residual_dropout', default=0.2, type=float, help='Residual dropout for the label attention')
parser.add_argument('--lattn_combined_input', default=True, action='store_true', help='Combine all inputs for the lattn, not just the pattn')
parser.add_argument('--use_lattn', default=False, action='store_true', help='Use the lattn layers - currently turned off')
parser.add_argument('--no_use_lattn', dest='use_lattn', action='store_false', help='Use the lattn layers - currently turned off')
parser.add_argument('--no_lattn_combined_input', dest='lattn_combined_input', action='store_false', help="Don't combine all inputs for the lattn, not just the pattn")
parser.add_argument('--use_rattn', default=False, action='store_true', help='Use a local attention layer')
parser.add_argument('--rattn_window', default=16, type=int, help='Number of tokens to use for context in the local attention')
# Ran an experiment on id_icon with in_order, peft, 200 epochs training
# Equivalent experiment with no rattn had an average of 0.8922 dev
# window 16, cat, dim 200, sinks 0
# head dev score
# 1 0.8915
# 2 0.8933
# 3 0.8918
# 4 0.8934
# 5 0.8924
# 6 0.8936
# 8 0.8920
# 10 0.8909
# 12 0.8939
# 14 0.8949
# 16 0.8952
# 18 0.8915
# 20 0.8925
# 25 0.8913
# 30 0.8913
# 40 0.8943
# 50 0.8931
# 75 0.8940
# The average here is 0.8928, which is a tiny bit higher...
parser.add_argument('--rattn_heads', default=16, type=int, help='Number of heads to use for context in the local attention')
parser.add_argument('--no_rattn_forward', default=True, action='store_false', dest='rattn_forward', help="Use or don't use the forward relative attention")
parser.add_argument('--no_rattn_reverse', default=True, action='store_false', dest='rattn_reverse', help="Use or don't use the reverse relative attention")
parser.add_argument('--no_rattn_cat', action='store_false', dest='rattn_cat', help='Stack the rattn layers instead of adding them')
parser.add_argument('--rattn_cat', default=True, action='store_true', help='Stack the rattn layers instead of adding them')
parser.add_argument('--rattn_dim', default=200, type=int, help='Dimension of the rattn output when cat')
parser.add_argument('--rattn_sinks', default=0, type=int, help='Number of attention sink tokens to learn')
parser.add_argument('--rattn_use_endpoint_sinks', default=False, action='store_true', help='Use the endpoints of the sentences as sinks')
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')
parser.add_argument('--watch_regex', default=None, help='regex to describe which weights and biases to output, if any')
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
parser.add_argument('--wandb_norm_regex', default=None, help='Log on wandb any tensor whose norm matches this matrix. Might get cluttered?')
return parser
def build_model_filename(args):
embedding = utils.embedding_name(args)
maybe_finetune = "finetuned" if args['bert_finetune'] or args['stage1_bert_finetune'] else ""
transformer_finetune_begin = "%d" % args['bert_finetune_begin_epoch'] if args['bert_finetune_begin_epoch'] is not None else ""
rattn = ""
if args['use_rattn']:
if args['rattn_forward']: rattn = rattn + "F"
if args['rattn_reverse']: rattn = rattn + "R"
if rattn:
if args['rattn_cat']:
rattn += "c"
rattn += "h%02d" % args['rattn_heads']
rattn += "w%02d" % args['rattn_window']
if args['rattn_sinks'] > 0:
rattn += "s%d" % args['rattn_sinks']
model_save_file = args['save_name'].format(shorthand=args['shorthand'],
oracle_level=args['oracle_level'],
embedding=embedding,
finetune=maybe_finetune,
transformer_finetune_begin=transformer_finetune_begin,
transition_scheme=args['transition_scheme'].name.lower().replace("_", ""),
tscheme=args['transition_scheme'].short_name,
trans_layers=args['bert_hidden_layers'],
rattn=rattn,
seed=args['seed'])
model_save_file = re.sub("_+", "_", model_save_file)
logger.info("Expanded save_name: %s", model_save_file)
model_dir = os.path.split(model_save_file)[0]
if model_dir != args['save_dir']:
model_save_file = os.path.join(args['save_dir'], model_save_file)
return model_save_file
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
resolve_peft_args(args, logger, check_bert_finetune=False)
if not args.lang and args.shorthand and len(args.shorthand.split("_", maxsplit=1)) == 2:
args.lang = args.shorthand.split("_")[0]
if args.stage1_bert_learning_rate is None:
args.stage1_bert_learning_rate = args.bert_learning_rate
if args.optim is None and args.mode == 'train':
if not args.multistage:
# this seemed to work the best when not doing multistage
args.optim = "adadelta"
if args.use_peft and not args.bert_finetune:
logger.info("--use_peft set. setting --bert_finetune as well")
args.bert_finetune = True
elif args.bert_finetune or args.stage1_bert_finetune:
logger.info("Multistage training is set, optimizer is not chosen, and bert finetuning is active. Will use AdamW as the second stage optimizer.")
args.optim = "adamw"
else:
# if MADGRAD exists, use it
# otherwise, adamw
try:
import madgrad
args.optim = "madgrad"
logger.info("Multistage training is set, optimizer is not chosen, and MADGRAD is available. Will use MADGRAD as the second stage optimizer.")
except ModuleNotFoundError as e:
logger.warning("Multistage training is set. Best models are with MADGRAD, but it is not installed. Will use AdamW for the second stage optimizer. Consider installing MADGRAD")
args.optim = "adamw"
if args.mode == 'train':
if args.learning_rate is None:
args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None)
if args.learning_eps is None:
args.learning_eps = DEFAULT_LEARNING_EPS.get(args.optim.lower(), None)
if args.learning_momentum is None:
args.learning_momentum = DEFAULT_MOMENTUM.get(args.optim.lower(), None)
if args.learning_weight_decay is None:
args.learning_weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim.lower(), None)
if args.stage1_learning_rate is None:
args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"]
if args.stage1_bert_finetune is None:
args.stage1_bert_finetune = args.bert_finetune
if args.learning_rate_min_lr is None:
args.learning_rate_min_lr = args.learning_rate * 0.02
if args.stage1_learning_rate_min_lr is None:
args.stage1_learning_rate_min_lr = args.stage1_learning_rate * 0.02
if args.reduce_position is None:
args.reduce_position = args.hidden_size // 4
if args.num_tree_lstm_layers is None:
if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
args.num_tree_lstm_layers = 2
else:
args.num_tree_lstm_layers = 1
if args.wandb_name or args.wandb_norm_regex:
args.wandb = True
args = vars(args)
retagging.postprocess_args(args)
postprocess_predict_output_args(args)
if args['seed'] is None:
args['seed'] = random.randint(0, 1000000000)
logger.info("Using random seed %d", args['seed'])
model_save_file = build_model_filename(args)
args['save_name'] = model_save_file
if args['save_each_name']:
model_save_each_file = os.path.join(args['save_dir'], args['save_each_name'])
model_save_each_file = utils.build_save_each_filename(model_save_each_file)
args['save_each_name'] = model_save_each_file
else:
# in the event that there is a start epoch setting,
# this will make a reasonable default for the path
pieces = os.path.splitext(args['save_name'])
model_save_each_file = pieces[0] + "_%04d" + pieces[1]
args['save_each_name'] = model_save_each_file
if args['checkpoint']:
args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name'])
return args
def main(args=None):
"""
Main function for building con parser
Processes args, calls the appropriate function for the chosen --mode
"""
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running constituency parser in %s mode", args['mode'])
logger.debug("Using device: %s", args['device'])
model_load_file = args['save_name']
if args['load_name']:
if os.path.exists(args['load_name']):
model_load_file = args['load_name']
else:
model_load_file = os.path.join(args['save_dir'], args['load_name'])
elif args['load_package']:
if args['lang'] is None:
lang_pieces = args['load_package'].split("_", maxsplit=1)
try:
lang = constant.lang_to_langcode(lang_pieces[0])
except ValueError as e:
raise ValueError("--lang not specified, and the start of the --load_package name, %s, is not a known language. Please check the values of those parameters" % args['load_package']) from e
args['lang'] = lang
args['load_package'] = lang_pieces[1]
stanza.download(args['lang'], processors="constituency", package={"constituency": args['load_package']})
model_load_file = os.path.join(DEFAULT_MODEL_DIR, args['lang'], 'constituency', args['load_package'] + ".pt")
if not os.path.exists(model_load_file):
raise FileNotFoundError("Expected the downloaded model file for language %s package %s to be in %s, but there is nothing there. Perhaps the package name doesn't exist?" % (args['lang'], args['load_package'], model_load_file))
else:
logger.info("Model for language %s package %s is in %s", args['lang'], args['load_package'], model_load_file)
# TODO: when loading a saved model, we should default to whatever
# is in the model file for --retag_method, not the default for the language
if args['mode'] == 'train':
if tlogger.level == logging.NOTSET:
tlogger.setLevel(logging.DEBUG)
tlogger.debug("Set trainer logging level to DEBUG")
retag_pipeline = retagging.build_retag_pipeline(args)
if args['mode'] == 'train':
parser_training.train(args, model_load_file, retag_pipeline)
elif args['mode'] == 'predict':
parser_training.evaluate(args, model_load_file, retag_pipeline)
elif args['mode'] == 'parse_text':
load_model_parse_text(args, model_load_file, retag_pipeline)
elif args['mode'] == 'remove_optimizer':
parser_training.remove_optimizer(args, args['save_name'], model_load_file)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/coref/__init__.py
================================================
================================================
FILE: stanza/models/coref/anaphoricity_scorer.py
================================================
""" Describes AnaphicityScorer, a torch module that for a matrix of
mentions produces their anaphoricity scores.
"""
import torch
from stanza.models.coref import utils
from stanza.models.coref.config import Config
class AnaphoricityScorer(torch.nn.Module):
""" Calculates anaphoricity scores by passing the inputs into a FFNN """
def __init__(self,
in_features: int,
config: Config):
super().__init__()
hidden_size = config.hidden_size
if not config.n_hidden_layers:
hidden_size = in_features
layers = []
for i in range(config.n_hidden_layers):
layers.extend([torch.nn.Linear(hidden_size if i else in_features,
hidden_size),
torch.nn.LeakyReLU(),
torch.nn.Dropout(config.dropout_rate)])
self.hidden = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(hidden_size, out_features=1)
# are we going to predict singletons
self.predict_singletons = config.singletons
if self.predict_singletons:
# map to whether or not this is a start of a coref given all the
# antecedents; not used when config.singletons = False because
# we only need to know this for predicting singletons
self.start_map = torch.nn.Linear(config.rough_k, out_features=1, bias=False)
def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_rough_scores_batch: torch.Tensor,
) -> torch.Tensor:
""" Builds a pairwise matrix, scores the pairs and returns the scores.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb]
mentions_batch (torch.Tensor): [batch_size, mention_emb]
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
top_indices_batch (torch.Tensor): [batch_size, n_ants]
top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
Returns:
torch.Tensor [batch_size, n_ants + 1]
anaphoricity scores for the pairs + a dummy column
"""
# [batch_size, n_ants, pair_emb]
pair_matrix = self._get_pair_matrix(mentions_batch, pw_batch, top_mentions)
# [batch_size, n_ants] vs [batch_size, 1]
# first is coref scores, the second is whether its the start of a coref
if self.predict_singletons:
scores, start = self._ffnn(pair_matrix)
scores = utils.add_dummy(scores+top_rough_scores_batch, eps=True)
return torch.cat([start, scores], dim=1)
else:
scores = self._ffnn(pair_matrix)
return utils.add_dummy(scores+top_rough_scores_batch, eps=True)
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculates anaphoricity scores.
Args:
x: tensor of shape [batch_size, n_ants, n_features]
Returns:
tensor of shape [batch_size, n_ants]
"""
x = self.out(self.hidden(x))
x = x.squeeze(2)
if not self.predict_singletons:
return x
# because sometimes we only have the first 49 anaphoricities
start = x @ self.start_map.weight[:,:x.shape[1]].T
return x, start
@staticmethod
def _get_pair_matrix(mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_mentions: torch.Tensor) -> torch.Tensor:
"""
Builds the matrix used as input for AnaphoricityScorer.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb],
all the valid mentions of the document,
can be on a different device
mentions_batch (torch.Tensor): [batch_size, mention_emb],
the mentions of the current batch,
is expected to be on the current device
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
pairwise features of the current batch,
is expected to be on the current device
top_indices_batch (torch.Tensor): [batch_size, n_ants],
indices of antecedents of each mention
Returns:
torch.Tensor: [batch_size, n_ants, pair_emb]
"""
emb_size = mentions_batch.shape[1]
n_ants = pw_batch.shape[1]
a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
b_mentions = top_mentions
similarity = a_mentions * b_mentions
out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
return out
================================================
FILE: stanza/models/coref/bert.py
================================================
"""Functions related to BERT or similar models"""
import logging
from typing import List, Tuple
import numpy as np # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc
logger = logging.getLogger('stanza')
def get_subwords_batches(doc: Doc,
config: Config,
tok: AutoTokenizer
) -> np.ndarray:
"""
Turns a list of subwords to a list of lists of subword indices
of max length == batch_size (or shorter, as batch boundaries
should match sentence boundaries). Each batch is enclosed in cls and sep
special tokens.
Returns:
batches of bert tokens [n_batches, batch_size]
"""
batch_size = config.bert_window_size - 2 # to save space for CLS and SEP
subwords: List[str] = doc["subwords"]
subwords_batches = []
start, end = 0, 0
while end < len(subwords):
# to prevent the case where a batch_size step forward
# doesn't capture more than 1 sentence, we will just cut
# that sequence
prev_end = end
end = min(end + batch_size, len(subwords))
# Move back till we hit a sentence end
if end < len(subwords):
sent_id = doc["sent_id"][doc["word_id"][end]]
while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id:
end -= 1
# this occurs IFF there was no sentence end found throughout
# the forward scan; this means that our sentence was waay too
# long (i.e. longer than the max length of the transformer.
#
# if so, we give up and just chop the sentence off at the max length
# that was given
if end == prev_end:
end = min(end + batch_size, len(subwords))
length = end - start
if tok.cls_token == None or tok.sep_token == None:
batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token]
else:
batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token]
# Padding to desired length
batch += [tok.pad_token] * (batch_size - length)
subwords_batches.append([tok.convert_tokens_to_ids(token)
for token in batch])
start += length
return np.array(subwords_batches)
================================================
FILE: stanza/models/coref/cluster_checker.py
================================================
""" Describes ClusterChecker, a class used to retrieve LEA scores.
See aclweb.org/anthology/P16-1060.pdf. """
from typing import Hashable, List, Tuple
from stanza.models.coref.const import EPSILON
import numpy as np
import math
import logging
logger = logging.getLogger('stanza')
class ClusterChecker:
""" Collects information on gold and predicted clusters across documents.
Can be used to retrieve weighted LEA-score for them.
"""
def __init__(self):
self._lea_precision = 0.0
self._lea_recall = 0.0
self._lea_precision_weighting = 0.0
self._lea_recall_weighting = 0.0
self._num_preds = 0.0
# muc
self._muc_precision = 0.0
self._muc_recall = 0.0
# b3
self._b3_precision = 0.0
self._b3_recall = 0.0
# ceafe
self._ceafe_precision = 0.0
self._ceafe_recall = 0.0
@staticmethod
def _f1(p,r):
return (p * r) / (p+r + EPSILON) * 2
def add_predictions(self,
gold_clusters: List[List[Hashable]],
pred_clusters: List[List[Hashable]]):
"""
Calculates LEA for the document's clusters and stores them to later
output weighted LEA across documents.
Returns:
LEA score for the document as a tuple of (f1, precision, recall)
"""
# if len(gold_clusters) == 0:
# breakpoint()
self._num_preds += 1
recall, r_weight = ClusterChecker._lea(gold_clusters, pred_clusters)
precision, p_weight = ClusterChecker._lea(pred_clusters, gold_clusters)
self._muc_recall += ClusterChecker._muc(gold_clusters, pred_clusters)
self._muc_precision += ClusterChecker._muc(pred_clusters, gold_clusters)
self._b3_recall += ClusterChecker._b3(gold_clusters, pred_clusters)
self._b3_precision += ClusterChecker._b3(pred_clusters, gold_clusters)
ceafe_precision, ceafe_recall = ClusterChecker._ceafe(pred_clusters, gold_clusters)
if math.isnan(ceafe_precision) and len(gold_clusters) > 0:
# because our model predicted no clusters
ceafe_precision = 0.0
self._ceafe_precision += ceafe_precision
self._ceafe_recall += ceafe_recall
self._lea_recall += recall
self._lea_recall_weighting += r_weight
self._lea_precision += precision
self._lea_precision_weighting += p_weight
doc_precision = precision / (p_weight + EPSILON)
doc_recall = recall / (r_weight + EPSILON)
doc_f1 = (doc_precision * doc_recall) \
/ (doc_precision + doc_recall + EPSILON) * 2
return doc_f1, doc_precision, doc_recall
@property
def bakeoff(self):
""" Get the F1 macroaverage score used by the bakeoff """
return sum(self.mbc)/3
@property
def mbc(self):
""" Get the F1 average score of (muc, b3, ceafe) over docs """
avg_precisions = [self._muc_precision, self._b3_precision, self._ceafe_precision]
avg_precisions = [i/(self._num_preds + EPSILON) for i in avg_precisions]
avg_recalls = [self._muc_recall, self._b3_recall, self._ceafe_recall]
avg_recalls = [i/(self._num_preds + EPSILON) for i in avg_recalls]
avg_f1s = [self._f1(p,r) for p,r in zip(avg_precisions, avg_recalls)]
return avg_f1s
@property
def total_lea(self):
""" Returns weighted LEA for all the documents as
(f1, precision, recall) """
precision = self._lea_precision / (self._lea_precision_weighting + EPSILON)
recall = self._lea_recall / (self._lea_recall_weighting + EPSILON)
f1 = self._f1(precision, recall)
return f1, precision, recall
@staticmethod
def _lea(key: List[List[Hashable]],
response: List[List[Hashable]]) -> Tuple[float, float]:
""" See aclweb.org/anthology/P16-1060.pdf. """
response_clusters = [set(cluster) for cluster in response]
response_map = {mention: cluster
for cluster in response_clusters
for mention in cluster}
importances = []
resolutions = []
for entity in key:
size = len(entity)
if size == 1: # entities of size 1 are not annotated
continue
importances.append(size)
correct_links = 0
for i in range(size):
for j in range(i + 1, size):
correct_links += int(entity[i]
in response_map.get(entity[j], {}))
resolutions.append(correct_links / (size * (size - 1) / 2))
res = sum(imp * res for imp, res in zip(importances, resolutions))
weight = sum(importances)
return res, weight
@staticmethod
def _muc(key: List[List[Hashable]],
response: List[List[Hashable]]) -> float:
""" See aclweb.org/anthology/P16-1060.pdf. """
response_clusters = [set(cluster) for cluster in response]
response_map = {mention: cluster
for cluster in response_clusters
for mention in cluster}
top = 0 # sum over k of |k_i| - response_partitions(|k_i|)
bottom = 0 # sum over k of |k_i| - 1
for entity in key:
S = len(entity)
# we need to figure the number of DIFFERENT clusters
# the response assigns to members of the entity; ideally
# this number is 1 (i.e. they are all assigned the same
# coref).
response_clusters = [response_map.get(i, None) for i in entity]
# and dedplicate
deduped = []
for i in response_clusters:
if i == None:
deduped.append(i)
elif i not in deduped:
deduped.append(i)
# the "partitions" will then be size of the deduped list
p_k = len(deduped)
top += (S - p_k)
bottom += (S - 1)
try:
return top/bottom
except ZeroDivisionError:
logger.warning("muc got a zero division error because the model predicted no spans!")
return 0 # +inf technically
@staticmethod
def _b3(key: List[List[Hashable]],
response: List[List[Hashable]]) -> float:
""" See aclweb.org/anthology/P16-1060.pdf. """
response_clusters = [set(cluster) for cluster in response]
top = 0 # sum over key and response of (|k intersect response|^2/|k|)
bottom = 0 # sum over k of |k_i|
for entity in key:
bottom += len(entity)
entity = set(entity)
for res_entity in response_clusters:
top += (len(entity.intersection(res_entity))**2)/len(entity)
try:
return top/bottom
except ZeroDivisionError:
logger.warning("b3 got a zero division error because the model predicted no spans!")
return 0 # +inf technically
@staticmethod
def _phi4(c1, c2):
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
@staticmethod
def _ceafe(clusters: List[List[Hashable]], gold_clusters: List[List[Hashable]]):
""" see https://github.com/ufal/corefud-scorer/blob/main/coval/eval/evaluator.py """
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
raise ImportError("To perform CEAF scoring, please install scipy via `pip install scipy` for the Kuhn-Munkres linear assignment scheme.")
clusters = [c for c in clusters]
scores = np.zeros((len(gold_clusters), len(clusters)))
for i in range(len(gold_clusters)):
for j in range(len(clusters)):
scores[i, j] = ClusterChecker._phi4(gold_clusters[i], clusters[j])
row_ind, col_ind = linear_sum_assignment(-scores)
similarity = scores[row_ind, col_ind].sum()
# precision, recall
try:
prec = similarity/len(clusters)
except ZeroDivisionError:
logger.warning("ceafe got a zero division error because the model predicted no spans!")
prec = 0
recc = similarity/len(gold_clusters)
return prec, recc
================================================
FILE: stanza/models/coref/config.py
================================================
""" Describes Config, a simple namespace for config values.
For description of all config values, refer to config.toml.
"""
from dataclasses import dataclass
from typing import Dict, List
@dataclass
class Config: # pylint: disable=too-many-instance-attributes, too-few-public-methods
""" Contains values needed to set up the coreference model. """
section: str
# TODO: can either eliminate data_dir or use it for the train/dev/test data
data_dir: str
save_dir: str
save_name: str
train_data: str
dev_data: str
test_data: str
device: str
bert_model: str
bert_window_size: int
embedding_size: int
sp_embedding_size: int
a_scoring_batch_size: int
hidden_size: int
n_hidden_layers: int
max_span_len: int
rough_k: int
lora: bool
lora_alpha: int
lora_rank: int
lora_dropout: float
full_pairwise: bool
lora_target_modules: List[str]
lora_modules_to_save: List[str]
clusters_starts_are_singletons: bool
bert_finetune: bool
dropout_rate: float
learning_rate: float
bert_learning_rate: float
# we find that setting this to a small but non-zero number
# makes the model less likely to forget how to do anything
bert_finetune_begin_epoch: float
train_epochs: int
# if plateaued for this many epochs, stop training
plateau_epochs: int
bce_loss_weight: float
tokenizer_kwargs: Dict[str, dict]
conll_log_dir: str
save_each_checkpoint: bool
log_norms: bool
singletons: bool
max_train_len: int
use_zeros: bool
lang_lr_attenuation: str
lang_lr_weights: str
================================================
FILE: stanza/models/coref/conll.py
================================================
""" Contains functions to produce conll-formatted output files with
predicted spans and their clustering """
from collections import defaultdict
from contextlib import contextmanager
import os
from typing import List, TextIO
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc, Span
# pylint: disable=too-many-locals
def write_conll(doc: Doc,
clusters: List[List[Span]],
heads: List[int],
f_obj: TextIO):
""" Writes span/cluster information to f_obj, which is assumed to be a file
object open for writing """
placeholder = list("\t_" * 7)
# the nth token needs to be a number
placeholder[9] = "0"
placeholder = "".join(placeholder)
doc_id = doc["document_id"].replace("-", "_").replace("/", "_").replace(".","_")
words = doc["cased_words"]
part_id = doc["part_id"]
sents = doc["sent_id"]
max_word_len = max(len(w) for w in words)
starts = defaultdict(lambda: [])
ends = defaultdict(lambda: [])
single_word = defaultdict(lambda: [])
for cluster_id, cluster in enumerate(clusters):
if len(heads[cluster_id]) != len(cluster):
# TODO debug this fact and why it occurs
# print(f"cluster {cluster_id} doesn't have the same number of elements for word and span levels, skipping...")
continue
for cluster_part, (start, end) in enumerate(cluster):
if end - start == 1:
single_word[start].append((cluster_part, cluster_id))
else:
starts[start].append((cluster_part, cluster_id))
ends[end - 1].append((cluster_part, cluster_id))
f_obj.write(f"# newdoc id = {doc_id}\n# global.Entity = eid-head\n")
word_number = 0
sent_id = 0
for word_id, word in enumerate(words):
cluster_info_lst = []
for part, cluster_marker in starts[word_id]:
start, end = clusters[cluster_marker][part]
cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)}")
for part, cluster_marker in single_word[word_id]:
start, end = clusters[cluster_marker][part]
cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)})")
for part, cluster_marker in ends[word_id]:
cluster_info_lst.append(f"e{cluster_marker})")
# we need our clusters to be ordered such that the one that is closest the first change
# is listed last in the chains
def compare_sort(x):
split = x.split("-")
if len(split) > 1:
return int(split[-1].replace(")", "").strip())
else:
# we want everything that's a closer to be first
return float("inf")
cluster_info_lst = sorted(cluster_info_lst, key=compare_sort, reverse=True)
cluster_info = "".join(cluster_info_lst) if cluster_info_lst else "_"
if word_id == 0 or sents[word_id] != sents[word_id - 1]:
f_obj.write(f"# sent_id = {doc_id}-{sent_id}\n")
word_number = 0
sent_id += 1
if cluster_info != "_":
cluster_info = f"Entity={cluster_info}"
f_obj.write(f"{word_id}\t{word}{placeholder}\t{cluster_info}\n")
word_number += 1
f_obj.write("\n")
@contextmanager
def open_(config: Config, epochs: int, data_split: str):
""" Opens conll log files for writing in a safe way. """
base_filename = f"{config.section}_{data_split}_e{epochs}"
conll_dir = config.conll_log_dir
kwargs = {"mode": "w", "encoding": "utf8"}
os.makedirs(conll_dir, exist_ok=True)
with open(os.path.join( # type: ignore
conll_dir, f"{base_filename}.gold.conll"), **kwargs) as gold_f:
with open(os.path.join( # type: ignore
conll_dir, f"{base_filename}.pred.conll"), **kwargs) as pred_f:
yield (gold_f, pred_f)
================================================
FILE: stanza/models/coref/const.py
================================================
""" Contains type aliases for coref module """
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import torch
EPSILON = 1e-7
LARGE_VALUE = 1000 # used instead of inf due to bug #16762 in pytorch
Doc = Dict[str, Any]
Span = Tuple[int, int]
@dataclass
class CorefResult:
coref_scores: torch.Tensor = None # [n_words, k + 1]
coref_y: torch.Tensor = None # [n_words, k + 1]
rough_y: torch.Tensor = None # [n_words, n_words]
word_clusters: List[List[int]] = None
span_clusters: List[List[Span]] = None
rough_scores: torch.Tensor = None # [n_words, n_words]
span_scores: torch.Tensor = None # [n_heads, n_words, 2]
span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2
zero_scores: torch.Tensor = None
================================================
FILE: stanza/models/coref/coref_chain.py
================================================
"""
Coref chain suitable for attaching to a Document after coref processing
"""
# by not using namedtuple, we can use this object as output from the json module
# in the doc class as long as we wrap the encoder to print these out in dict() form
# CorefMention = namedtuple('CorefMention', ['sentence', 'start_word', 'end_word'])
class CorefMention:
def __init__(self, sentence, start_word, end_word):
self.sentence = sentence
self.start_word = start_word
self.end_word = end_word
class CorefChain:
def __init__(self, index, mentions, representative_text, representative_index):
self.index = index
self.mentions = mentions
self.representative_text = representative_text
self.representative_index = representative_index
class CorefAttachment:
def __init__(self, chain, is_start, is_end, is_representative):
self.chain = chain
self.is_start = is_start
self.is_end = is_end
self.is_representative = is_representative
def to_json(self):
j = {
"index": self.chain.index,
"representative_text": self.chain.representative_text
}
if self.is_start:
j['is_start'] = True
if self.is_end:
j['is_end'] = True
if self.is_representative:
j['is_representative'] = True
return j
================================================
FILE: stanza/models/coref/coref_config.toml
================================================
# =============================================================================
# Before you start changing anything here, read the comments.
# All of them can be found below in the "DEFAULT" section
[DEFAULT]
# The directory that contains extracted files of everything you've downloaded.
data_dir = "data/coref"
# where to put checkpoints and final models
save_dir = "saved_models/coref"
save_name = "bert-large-cased"
# Train, dev and test jsonlines
# train_data = "data/coref/en_gum-ud.train.nosgl.json"
# dev_data = "data/coref/en_gum-ud.dev.nosgl.json"
# test_data = "data/coref/en_gum-ud.test.nosgl.json"
train_data = "data/coref/corefud_concat_v1_0_langid.train.json"
dev_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
test_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
#train_data = "data/coref/english_train_head.jsonlines"
#dev_data = "data/coref/english_development_head.jsonlines"
#test_data = "data/coref/english_test_head.jsonlines"
# do not use the full pairwise encoding scheme
full_pairwise = false
# The device where everything is to be placed. "cuda:N"/"cpu" are supported.
device = "cuda:0"
save_each_checkpoint = false
log_norms = false
# Bert settings ======================
# Base bert model architecture and tokenizer
bert_model = "bert-large-cased"
# Controls max length of sequences passed through bert to obtain its
# contextual embeddings
# Must be less than or equal to 512
bert_window_size = 512
# General model settings =============
# Controls the dimensionality of feature embeddings
embedding_size = 20
# Controls the dimensionality of distance embeddings used by SpanPredictor
sp_embedding_size = 64
# Controls the number of spans for which anaphoricity can be scores in one
# batch. Only affects final scoring; mention extraction and rough scoring
# are less memory intensive, so they are always done in just one batch.
a_scoring_batch_size = 128
# AnaphoricityScorer FFNN parameters
hidden_size = 1024
n_hidden_layers = 1
# Do you want to support singletons?
singletons = true
# Mention extraction settings ========
# Mention extractor will check spans up to max_span_len words
# The default value is chosen to be big enough to hold any dev data span
max_span_len = 64
# Pruning settings ===================
# Controls how many pairs should be preserved per mention
# after applying rough scoring.
rough_k = 50
# Lora settings ===================
# LoRA settings
lora = false
lora_alpha = 128
lora_dropout = 0.1
lora_rank = 64
lora_target_modules = []
lora_modules_to_save = []
# Training settings ==================
# Controls whether the first dummy node predicts cluster starts or singletons
clusters_starts_are_singletons = true
# Controls whether to fine-tune bert_model
bert_finetune = true
# Controls the dropout rate throughout all models
dropout_rate = 0.3
# Bert learning rate (only used if bert_finetune is set)
bert_learning_rate = 1e-6
bert_finetune_begin_epoch = 0.5
# Task learning rate
learning_rate = 3e-4
# For how many epochs the training is done
train_epochs = 32
# plateau for this many epochs = early terminate
plateau_epochs = 0
# Controls the weight of binary cross entropy loss added to nlml loss
bce_loss_weight = 0.5
# The directory that will contain conll prediction files
conll_log_dir = "data/conll_logs"
# Skip any documents longer than this length
max_train_len = 5000
# if this is set to false, the model will set its zero_predictor to, well, 0
use_zeros = true
# two different methods for specifying how to weaken the LR for certain languages
# however, in their current forms, on an HE experiment, neither worked
# better than just mixing the two datasets together unweighted
# Starting from the HE IAHLT dataset, and possibly mixing in the ger/rom ud coref,
# averaging over 5 different seeds, we got the following results:
# HE only: 0.497
# Attenuated: 0.508
# Scaled: 0.517
# Mixed: 0.517
# the attenuation scheme for that experiment was 1/epoch
# These were the settings
# --lang_lr_weights es=0.2,en=0.2,de=0.2,ca=0.2,fr=0.2,no=0.2
# --lang_lr_attenuation es,en,de,ca,fr,no
lang_lr_attenuation = ""
lang_lr_weights = ""
# =============================================================================
# Extra keyword arguments to be passed to bert tokenizers of specified models
[DEFAULT.tokenizer_kwargs]
[DEFAULT.tokenizer_kwargs.roberta-large]
"add_prefix_space" = true
[DEFAULT.tokenizer_kwargs.xlm-roberta-large]
"add_prefix_space" = true
[DEFAULT.tokenizer_kwargs.spanbert-large-cased]
"do_lower_case" = false
[DEFAULT.tokenizer_kwargs.bert-large-cased]
"do_lower_case" = false
# =============================================================================
# The sections listed here do not need to make use of all config variables
# If a variable is omitted, its default value will be used instead
[roberta]
bert_model = "roberta-large"
[roberta_lora]
bert_model = "roberta-large"
bert_learning_rate = 0.00005
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[scandibert_lora]
bert_model = "vesteinn/ScandiBERT"
bert_learning_rate = 0.0002
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[xlm_roberta]
bert_model = "FacebookAI/xlm-roberta-large"
bert_learning_rate = 0.00001
bert_finetune = true
[xlm_roberta_lora]
bert_model = "FacebookAI/xlm-roberta-large"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[deeppavlov_slavic_bert_lora]
bert_model = "DeepPavlov/bert-base-bg-cs-pl-ru-cased"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[deberta_lora]
bert_model = "microsoft/deberta-v3-large"
bert_learning_rate = 0.00001
lora = true
lora_target_modules = [ "query_proj", "value_proj", "output.dense" ]
lora_modules_to_save = [ ]
[electra]
bert_model = "google/electra-large-discriminator"
bert_learning_rate = 0.00002
[electra_lora]
bert_model = "google/electra-large-discriminator"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ ]
[hungarian_electra_lora]
# TODO: experiment with tokenizer options for this to see if that's
# why the results are so low using this transformer
bert_model = "NYTK/electra-small-discriminator-hungarian"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ ]
[muril_large_cased_lora]
bert_model = "google/muril-large-cased"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[muril_base_cased_lora]
bert_model = "google/muril-base-cased"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[indic_bert_lora]
bert_model = "ai4bharat/indic-bert"
bert_learning_rate = 0.0005
lora = true
# indic-bert is an albert with repeating layers of different names
lora_target_modules = [ "query", "value", "dense", "ffn", "full_layer" ]
lora_modules_to_save = [ "pooler" ]
[alephbertgimmel_lora]
bert_model = "imvladikon/alephbertgimmel-base-512"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[alephbert_lora]
bert_model = "onlplab/alephbert-base"
bert_learning_rate = 0.000025
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[hero_lora]
# LR sweep on Hebrew IAHLT coref dev scores
# (although there may be tokenization problems)
# 0.000005 0.44202
# 0.00001 0.45271
# 0.000015 0.45771
# 0.00002 0.45877
# 0.000025 0.46076
# 0.00003 0.45957
# 0.000035 0.46187
# 0.00004 0.46066
# 0.000045 0.46132
# 0.00005 0.46238
# 0.000055 0.46084
# 0.00006 0.46047
# 0.000075 0.45772
# 0.0001 0.44910
bert_model = "HeNLP/HeRo"
bert_learning_rate = 0.00005
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[bert_multilingual_cased_lora]
# LR sweep on a Hindi dataset
# 0.00001: 0.53238
# 0.00002: 0.54012
# 0.000025: 0.54206
# 0.00003: 0.54050
# 0.00004: 0.55081
# 0.00005: 0.55135
# 0.000075: 0.54482
# 0.0001: 0.53888
bert_model = "google-bert/bert-base-multilingual-cased"
bert_learning_rate = 0.00005
lora = true
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_modules_to_save = [ "pooler" ]
[t5_lora]
bert_model = "google-t5/t5-large"
bert_learning_rate = 0.000025
bert_window_size = 1024
lora = true
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
lora_modules_to_save = [ ]
[mt5_lora]
bert_model = "google/mt5-base"
bert_learning_rate = 0.000025
lora_alpha = 64
lora_rank = 32
lora = true
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
lora_modules_to_save = [ ]
[deepnarrow_t5_xl_lora]
bert_model = "google/t5-efficient-xl"
bert_learning_rate = 0.00025
lora = true
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
lora_modules_to_save = [ ]
[roberta_no_finetune]
bert_model = "roberta-large"
bert_finetune = false
[roberta_no_bce]
bert_model = "roberta-large"
bce_loss_weight = 0.0
[spanbert]
bert_model = "SpanBERT/spanbert-large-cased"
[spanbert_no_bce]
bert_model = "SpanBERT/spanbert-large-cased"
bce_loss_weight = 0.0
[bert]
bert_model = "bert-large-cased"
[longformer]
bert_model = "allenai/longformer-large-4096"
bert_window_size = 2048
[debug]
bert_window_size = 384
bert_finetune = false
device = "cpu:0"
[debug_gpu]
bert_window_size = 384
bert_finetune = false
================================================
FILE: stanza/models/coref/dataset.py
================================================
import json
import logging
from torch.utils.data import Dataset
from stanza.models.coref.tokenizer_customization import TOKENIZER_FILTERS, TOKENIZER_MAPS
logger = logging.getLogger('stanza')
class CorefDataset(Dataset):
def __init__(self, path, config, tokenizer):
self.config = config
self.tokenizer = tokenizer
# by default, this doesn't filter anything (see lambda _ True);
# however, there are some subword symbols which are standalone
# tokens which we don't want on models like Albert; hence we
# pass along a filter if needed.
self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
lambda _: True)
self.__token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
try:
with open(path, encoding="utf-8") as fin:
data_f = json.load(fin)
except json.decoder.JSONDecodeError:
# read the old jsonlines format if necessary
with open(path, encoding="utf-8") as fin:
text = "[" + ",\n".join(fin) + "]"
data_f = json.loads(text)
logger.info("Processing %d docs from %s...", len(data_f), path)
self.__raw = data_f
self.__avg_span = sum(len(doc["head2span"]) for doc in self.__raw) / len(self.__raw)
self.__out = []
for doc in self.__raw:
doc["span_clusters"] = [[tuple(mention) for mention in cluster]
for cluster in doc["span_clusters"]]
word2subword = []
subwords = []
word_id = []
for i, word in enumerate(doc["cased_words"]):
tokenized = self.tokenizer.tokenize(word)
if len(tokenized) == 0:
word = "_"
doc["cased_words"][i] = word
tokenized = self.tokenizer.tokenize(word)
assert len(tokenized) > 0
tokenized_word = self.__token_map.get(word, tokenized)
tokenized_word = list(filter(self.__filter_func, tokenized_word))
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
subwords.extend(tokenized_word)
word_id.extend([i] * len(tokenized_word))
doc["word2subword"] = word2subword
doc["subwords"] = subwords
doc["word_id"] = word_id
self.__out.append(doc)
logger.info("Loaded %d docs from %s.", len(data_f), path)
@property
def avg_span(self):
return self.__avg_span
def __getitem__(self, x):
return self.__out[x]
def __len__(self):
return len(self.__out)
================================================
FILE: stanza/models/coref/loss.py
================================================
""" Describes the loss function used to train the model, which is a weighted
sum of NLML and BCE losses. """
import torch
class CorefLoss(torch.nn.Module):
""" See the rationale for using NLML in Lee et al. 2017
https://www.aclweb.org/anthology/D17-1018/
The added weighted summand of BCE helps the model learn even after
converging on the NLML task. """
def __init__(self, bce_weight: float):
assert 0 <= bce_weight <= 1
super().__init__()
self._bce_module = torch.nn.BCEWithLogitsLoss()
self._bce_weight = bce_weight
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
input_: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
""" Returns a weighted sum of two losses as a torch.Tensor """
return (self._nlml(input_, target)
+ self._bce(input_, target) * self._bce_weight)
def _bce(self,
input_: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
""" For numerical stability, clamps the input before passing it to BCE.
"""
return self._bce_module(torch.clamp(input_, min=-50, max=50), target)
@staticmethod
def _nlml(input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
gold = torch.logsumexp(input_ + torch.log(target), dim=1)
input_ = torch.logsumexp(input_, dim=1)
return (input_ - gold).mean()
================================================
FILE: stanza/models/coref/model.py
================================================
""" see __init__.py """
from datetime import datetime
import dataclasses
import json
import logging
import os
import random
import re
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np # type: ignore
try:
import tomllib
except ImportError:
import tomli as tomllib
import torch
import transformers # type: ignore
from pickle import UnpicklingError
import warnings
from stanza.utils.get_tqdm import get_tqdm # type: ignore
tqdm = get_tqdm()
from stanza.models.coref import bert, conll, utils
from stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer
from stanza.models.coref.cluster_checker import ClusterChecker
from stanza.models.coref.config import Config
from stanza.models.coref.const import CorefResult, Doc
from stanza.models.coref.loss import CorefLoss
from stanza.models.coref.pairwise_encoder import PairwiseEncoder
from stanza.models.coref.rough_scorer import RoughScorer
from stanza.models.coref.span_predictor import SpanPredictor
from stanza.models.coref.utils import GraphNode
from stanza.models.coref.utils import sigmoid_focal_loss
from stanza.models.coref.word_encoder import WordEncoder
from stanza.models.coref.dataset import CorefDataset
from stanza.models.coref.tokenizer_customization import *
from stanza.models.common.bert_embedding import load_tokenizer
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
import torch.nn as nn
logger = logging.getLogger('stanza')
class CorefModel: # pylint: disable=too-many-instance-attributes
"""Combines all coref modules together to find coreferent spans.
Attributes:
config (coref.config.Config): the model's configuration,
see config.toml for the details
epochs_trained (int): number of epochs the model has been trained for
trainable (Dict[str, torch.nn.Module]): trainable submodules with their
names used as keys
training (bool): used to toggle train/eval modes
Submodules (in the order of their usage in the pipeline):
tokenizer (transformers.AutoTokenizer)
bert (transformers.AutoModel)
we (WordEncoder)
rough_scorer (RoughScorer)
pw (PairwiseEncoder)
a_scorer (AnaphoricityScorer)
sp (SpanPredictor)
"""
def __init__(self,
epochs_trained: int = 0,
build_optimizers: bool = True,
config: Optional[dict] = None,
foundation_cache=None):
"""
A newly created model is set to evaluation mode.
Args:
config_path (str): the path to the toml file with the configuration
section (str): the selected section of the config file
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""
if config is None:
raise ValueError("Cannot create a model without a config")
self.config = config
self.epochs_trained = epochs_trained
self._docs: Dict[str, List[Doc]] = {}
self._build_model(foundation_cache)
self.optimizers = {}
self.schedulers = {}
if build_optimizers:
self._build_optimizers()
self._set_training(False)
# final coreference resolution score
self._coref_criterion = CorefLoss(self.config.bce_loss_weight)
# score simply for the top-k choices out of the rough scorer
self._rough_criterion = CorefLoss(0)
# exact span matches
self._span_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
@property
def training(self) -> bool:
""" Represents whether the model is in the training mode """
return self._training
@training.setter
def training(self, new_value: bool):
if self._training is new_value:
return
self._set_training(new_value)
# ========================================================== Public methods
@torch.no_grad()
def evaluate(self,
data_split: str = "dev",
word_level_conll: bool = False,
eval_lang: Optional[str] = None
) -> Tuple[float, Tuple[float, float, float]]:
""" Evaluates the modes on the data split provided.
Args:
data_split (str): one of 'dev'/'test'/'train'
word_level_conll (bool): if True, outputs conll files on word-level
eval_lang (str): which language to evaluate
Returns:
mean loss
span-level LEA: f1, precision, recal
"""
self.training = False
w_checker = ClusterChecker()
s_checker = ClusterChecker()
try:
data_split_data = f"{data_split}_data"
data_path = self.config.__dict__[data_split_data]
docs = self._get_docs(data_path)
except FileNotFoundError as e:
raise FileNotFoundError("Unable to find data split %s at file %s" % (data_split_data, data_path)) from e
running_loss = 0.0
s_correct = 0
s_total = 0
z_correct = 0
z_total = 0
with conll.open_(self.config, self.epochs_trained, data_split) \
as (gold_f, pred_f):
pbar = tqdm(docs, unit="docs", ncols=0)
for doc in pbar:
if eval_lang and doc.get("lang", "") != eval_lang:
# skip that document, only used for ablation where we only
# want to test evaluation on one language
continue
res = self.run(doc, True)
# measure zero prediction accuracy
zero_preds = (res.zero_scores > 0).view(-1).to(device=res.zero_scores.device)
is_zero = doc.get("is_zero")
if is_zero is None:
zero_targets = torch.zeros_like(zero_preds, device=zero_preds.device)
else:
zero_targets = torch.tensor(is_zero, device=zero_preds.device)
z_correct += (zero_preds == zero_targets).sum().item()
z_total += zero_targets.numel()
if (res.coref_y.argmax(dim=1) == 1).all():
logger.warning(f"EVAL: skipping document with no corefs...")
continue
running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item()
if res.word_clusters is None or res.span_clusters is None:
logger.warning(f"EVAL: skipping document with no clusters...")
continue
if res.span_y:
pred_starts = res.span_scores[:, :, 0].argmax(dim=1)
pred_ends = res.span_scores[:, :, 1].argmax(dim=1)
s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item()
s_total += len(pred_starts)
if word_level_conll:
raise NotImplementedError("We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.")
else:
conll.write_conll(doc, doc["span_clusters"], doc["word_clusters"], gold_f)
conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f)
w_checker.add_predictions(doc["word_clusters"], res.word_clusters)
w_lea = w_checker.total_lea
s_checker.add_predictions(doc["span_clusters"], res.span_clusters)
s_lea = s_checker.total_lea
del res
pbar.set_description(
f"{data_split}:"
f" | WL: "
f" loss: {running_loss / (pbar.n + 1):<.5f},"
f" f1: {w_lea[0]:.5f},"
f" p: {w_lea[1]:.5f},"
f" r: {w_lea[2]:<.5f}"
f" | SL: "
f" sa: {s_correct / s_total:<.5f},"
f" f1: {s_lea[0]:.5f},"
f" p: {s_lea[1]:.5f},"
f" r: {s_lea[2]:<.5f}"
f" | ZA: {z_correct / z_total:<.5f}"
)
logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}")
logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}")
return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff)
def load_weights(self,
path: Optional[str] = None,
ignore: Optional[Set[str]] = None,
map_location: Optional[str] = None,
noexception: bool = False) -> None:
"""
Loads pretrained weights of modules saved in a file located at path.
If path is None, the last saved model with current configuration
in save_dir is loaded.
Assumes files are named like {configuration}_(e{epoch}_{time})*.pt.
"""
if path is None:
# pattern = rf"{self.config.save_name}_\(e(\d+)_[^()]*\).*\.pt"
# tries to load the last checkpoint in the same dir
pattern = rf"{self.config.save_name}.*?\.checkpoint\.pt"
files = []
os.makedirs(self.config.save_dir, exist_ok=True)
for f in os.listdir(self.config.save_dir):
match_obj = re.match(pattern, f)
if match_obj:
files.append(f)
if not files:
if noexception:
logger.debug("No weights have been loaded", flush=True)
return
raise OSError(f"No weights found in {self.config.save_dir}!")
path = sorted(files)[-1]
path = os.path.join(self.config.save_dir, path)
if map_location is None:
map_location = self.config.device
logger.debug(f"Loading from {path}...")
try:
state_dicts = torch.load(path, map_location=map_location, weights_only=True)
except UnpicklingError:
state_dicts = torch.load(path, map_location=map_location, weights_only=False)
warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
self.epochs_trained = state_dicts.pop("epochs_trained", 0)
# just ignore a config in the model, since we should already have one
# TODO: some config elements may be fixed parameters of the model,
# such as the dimensions of the head,
# so we would want to use the ones from the config even if the
# user created a weird shaped model
config = state_dicts.pop("config", {})
self.load_state_dicts(state_dicts, ignore)
def load_state_dicts(self,
state_dicts: dict,
ignore: Optional[Set[str]] = None):
"""
Process the dictionaries from the save file
Loads the weights into the tensors of this model
May also have optimizer and/or schedule state
"""
for key, state_dict in state_dicts.items():
logger.debug("Loading state: %s", key)
if not ignore or key not in ignore:
if key.endswith("_optimizer"):
self.optimizers[key].load_state_dict(state_dict)
elif key.endswith("_scheduler"):
self.schedulers[key].load_state_dict(state_dict)
elif key == "bert_lora":
assert self.config.lora, "Unable to load state dict of LoRA model into model initialized without LoRA!"
self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name)
else:
self.trainable[key].load_state_dict(state_dict, strict=False)
logger.debug(f"Loaded {key}")
if self.config.log_norms:
self.log_norms()
def build_doc(self, doc: dict) -> dict:
filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
lambda _: True)
token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
word2subword = []
subwords = []
word_id = []
for i, word in enumerate(doc["cased_words"]):
tokenized_word = (token_map[word]
if word in token_map
else self.tokenizer.tokenize(word))
tokenized_word = list(filter(filter_func, tokenized_word))
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
subwords.extend(tokenized_word)
word_id.extend([i] * len(tokenized_word))
doc["word2subword"] = word2subword
doc["subwords"] = subwords
doc["word_id"] = word_id
doc["head2span"] = []
if "speaker" not in doc:
doc["speaker"] = ["_" for _ in doc["cased_words"]]
doc["word_clusters"] = []
doc["span_clusters"] = []
return doc
@staticmethod
def load_model(path: str,
map_location: str = "cpu",
ignore: Optional[Set[str]] = None,
config_update: Optional[dict] = None,
foundation_cache = None):
if not path:
raise FileNotFoundError("coref model got an invalid path |%s|" % path)
if not os.path.exists(path):
raise FileNotFoundError("coref model file %s not found" % path)
try:
state_dicts = torch.load(path, map_location=map_location, weights_only=True)
except UnpicklingError:
state_dicts = torch.load(path, map_location=map_location, weights_only=False)
warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
epochs_trained = state_dicts.pop("epochs_trained", 0)
config = state_dicts.pop('config', None)
if config is None:
raise ValueError("Cannot load this format model without config in the dicts")
if 'max_train_len' not in config:
# TODO: this is to keep old models working.
# Can get rid of it if those models are rebuilt
config['max_train_len'] = 5000
if isinstance(config, dict):
config = Config(**config)
if config_update:
for key, value in config_update.items():
setattr(config, key, value)
model = CorefModel(config=config, build_optimizers=False,
epochs_trained=epochs_trained, foundation_cache=foundation_cache)
model.load_state_dicts(state_dicts, ignore)
return model
def run(self, # pylint: disable=too-many-locals
doc: Doc,
use_gold_spans_for_zeros = False
) -> CorefResult:
"""
This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow.
Args:
doc (Doc): a dictionary with the document data.
Returns:
CorefResult (see const.py)
"""
# Encode words with bert
# words [n_words, span_emb]
# cluster_ids [n_words]
words, cluster_ids = self.we(doc, self._bertify(doc))
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants]
# top_indices [n_words, n_ants]
top_rough_scores, top_indices, rough_scores = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features]
pw = self.pw(top_indices, doc)
batch_size = self.config.a_scoring_batch_size
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(
top_mentions=words[top_indices_batch], mentions_batch=words_batch,
pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch
)
a_scores_lst.append(a_scores_batch)
res = CorefResult()
# coref_scores [n_spans, n_ants]
res.coref_scores = torch.cat(a_scores_lst, dim=0)
res.coref_y = self._get_ground_truth(
cluster_ids, top_indices, (top_rough_scores > float("-inf")),
self.config.clusters_starts_are_singletons,
self.config.singletons
)
res.word_clusters = self._clusterize(
doc, res.coref_scores, top_indices,
self.config.singletons
)
res.span_scores, res.span_y = self.sp.get_training_data(doc, words)
if not self.training:
res.span_clusters = self.sp.predict(doc, words, res.word_clusters)
if not self.training and not use_gold_spans_for_zeros:
zero_words = words[[word_id
for cluster in res.word_clusters
for word_id in cluster]]
else:
zero_words = words[[i[0] for i in sorted(doc["head2span"])]]
res.zero_scores = self.zeros_predictor(zero_words)
return res
def save_weights(self, save_path=None, save_optimizers=True):
""" Saves trainable models as state dicts. """
to_save: List[Tuple[str, Any]] = \
[(key, value) for key, value in self.trainable.items()
if (self.config.bert_finetune and not self.config.lora) or key != "bert"]
if save_optimizers:
to_save.extend(self.optimizers.items())
to_save.extend(self.schedulers.items())
time = datetime.strftime(datetime.now(), "%Y.%m.%d_%H.%M")
if save_path is None:
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}"
f"_e{self.epochs_trained}_{time}.pt")
savedict = {name: module.state_dict() for name, module in to_save}
if self.config.lora:
# so that this dependency remains optional
from peft import get_peft_model_state_dict
savedict["bert_lora"] = get_peft_model_state_dict(self.bert, adapter_name="coref")
savedict["epochs_trained"] = self.epochs_trained # type: ignore
# save as a dictionary because the weights_only=True load option
# doesn't allow for arbitrary @dataclass configs
savedict["config"] = dataclasses.asdict(self.config)
save_dir = os.path.split(save_path)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
torch.save(savedict, save_path)
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMTERS"]
for t_name, trainable in self.trainable.items():
for name, param in trainable.named_parameters():
if param.requires_grad:
lines.append(" %s: %s %.6g (%d)" % (t_name, name, torch.norm(param).item(), param.numel()))
logger.info("\n".join(lines))
def train(self, log=False):
"""
Trains all the trainable blocks in the model using the config provided.
log: whether or not to log using wandb
skip_lang: str if we want to skip training this language (used for ablation)
"""
if log:
import wandb
wandb.watch((self.bert, self.pw,
self.a_scorer, self.we,
self.rough_scorer, self.sp))
docs = self._get_docs(self.config.train_data)
docs_ids = list(range(len(docs)))
avg_spans = docs.avg_span
# for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros
training_has_zeros = any('is_zero' in doc for doc in docs)
if not training_has_zeros:
logger.info("No zeros found in the dataset. The zeros predictor will set to 0")
if self.epochs_trained == 0:
# new model, set it to always predict not-zero
self.disable_zeros_predictor()
attenuated_languages = set()
if self.config.lang_lr_attenuation:
attenuated_languages = self.config.lang_lr_attenuation.split(",")
logger.info("Attenuating LR for the following languages: %s", attenuated_languages)
lr_scaled_languages = dict()
if self.config.lang_lr_weights:
scaled_languages = self.config.lang_lr_weights.split(",")
for piece in scaled_languages:
pieces = piece.split("=")
lr_scaled_languages[pieces[0]] = float(pieces[1])
logger.info("Scaling LR for the following languages: %s", lr_scaled_languages)
best_f1 = None
best_epoch = self.epochs_trained
for epoch in range(self.epochs_trained, self.config.train_epochs):
self.training = True
if self.config.log_norms:
self.log_norms()
running_c_loss = 0.0
running_s_loss = 0.0
running_z_loss = 0.0
random.shuffle(docs_ids)
pbar = tqdm(docs_ids, unit="docs", ncols=0)
for doc_indx, doc_id in enumerate(pbar):
doc = docs[doc_id]
# skip very long documents during training time
if len(doc["subwords"]) > self.config.max_train_len:
continue
for optim in self.optimizers.values():
optim.zero_grad()
res = self.run(doc)
if res.zero_scores.size(0) == 0 or not training_has_zeros:
z_loss = 0.0 # since there are no corefs
else:
is_zero = doc.get("is_zero")
if is_zero is None:
is_zero = torch.zeros_like(res.zero_scores.squeeze(-1), device=res.zero_scores.device, dtype=torch.float)
else:
is_zero = torch.tensor(is_zero).to(res.zero_scores.device).float()
z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), is_zero, reduction="mean")
c_loss = self._coref_criterion(res.coref_scores, res.coref_y)
if res.span_y:
s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0])
+ self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2
else:
s_loss = torch.zeros_like(c_loss)
lr_scale = lr_scaled_languages.get(doc.get("lang"), 1.0)
if doc.get("lang") in attenuated_languages:
lr_scale = lr_scale / max(epoch, 1.0)
c_loss = c_loss * lr_scale
s_loss = s_loss * lr_scale
z_loss = z_loss * lr_scale
(c_loss + s_loss + z_loss).backward()
running_c_loss += c_loss.item()
running_s_loss += s_loss.item()
if res.zero_scores.size(0) != 0 and training_has_zeros:
running_z_loss += z_loss.item()
# log every 100 docs
if log and doc_indx % 100 == 0:
logged = {
'train_c_loss': c_loss.item(),
'train_s_loss': s_loss.item(),
}
if res.zero_scores.size(0) != 0 and training_has_zeros:
logged['train_z_loss'] = z_loss.item()
wandb.log(logged)
del c_loss, s_loss, z_loss, res
for optim in self.optimizers.values():
optim.step()
for scheduler in self.schedulers.values():
scheduler.step()
pbar.set_description(
f"Epoch {epoch + 1}:"
f" {doc['document_id']:26}"
f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}"
f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}"
f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}"
)
self.epochs_trained += 1
scores = self.evaluate()
prev_best_f1 = best_f1
if log:
wandb.log({'dev_score': scores[1]})
wandb.log({'dev_bakeoff': scores[-1]})
if best_f1 is None or scores[1] > best_f1:
best_epoch = epoch
if best_f1 is None:
logger.info("Saving new best model: F1 %.4f", scores[1])
else:
logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1)
best_f1 = scores[1]
if self.config.save_name.endswith(".pt"):
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}")
else:
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}.pt")
self.save_weights(save_path, save_optimizers=False)
if self.config.save_each_checkpoint:
self.save_weights()
else:
if self.config.save_name.endswith(".pt"):
checkpoint_path = os.path.join(self.config.save_dir,
f"{self.config.save_name[:-3]}.checkpoint.pt")
else:
checkpoint_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}.checkpoint.pt")
self.save_weights(checkpoint_path)
if prev_best_f1 is not None and prev_best_f1 != best_f1:
logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f\nPrevious best F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1)
else:
logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1)
if self.config.plateau_epochs > 0 and best_epoch + self.config.plateau_epochs < epoch:
logger.info("Have plateaued for too long (%d epochs). Will terminate training", self.config.plateau_epochs)
break
# ========================================================= Private methods
def _bertify(self, doc: Doc) -> torch.Tensor:
all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer)
# we index the batches n at a time to prevent oom
result = []
for i in range(0, all_batches.shape[0], 1024):
subwords_batches = all_batches[i:i+1024]
special_tokens = np.array([self.tokenizer.cls_token_id,
self.tokenizer.sep_token_id,
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id])
subword_mask = ~(np.isin(subwords_batches, special_tokens))
subwords_batches_tensor = torch.tensor(subwords_batches,
device=self.config.device,
dtype=torch.long)
subword_mask_tensor = torch.tensor(subword_mask,
device=self.config.device)
# Obtain bert output for selected batches only
attention_mask = (subwords_batches != self.tokenizer.pad_token_id)
if "t5" in self.config.bert_model:
out = self.bert.encoder(
input_ids=subwords_batches_tensor,
attention_mask=torch.tensor(
attention_mask, device=self.config.device))
else:
out = self.bert(
subwords_batches_tensor,
attention_mask=torch.tensor(
attention_mask, device=self.config.device))
out = out['last_hidden_state']
# [n_subwords, bert_emb]
result.append(out[subword_mask_tensor])
# stack returns and return
return torch.cat(result)
def _build_model(self, foundation_cache):
if hasattr(self.config, 'lora') and self.config.lora:
self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, "coref", foundation_cache)
# vars() converts a dataclass to a dict, used for being able to index things like args["lora_*"]
self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name)
self.peft_name = peft_name
else:
if self.config.bert_finetune:
logger.debug("Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.")
foundation_cache = NoTransformerFoundationCache(foundation_cache)
self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache)
base_bert_name = self.config.bert_model.split("/")[-1]
tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {})
if tokenizer_kwargs:
logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}")
# we just downloaded the tokenizer, so for simplicity, we don't make another request to HF
self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True)
if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora):
self.bert = self.bert.train()
self.bert = self.bert.to(self.config.device)
self.pw = PairwiseEncoder(self.config).to(self.config.device)
bert_emb = self.bert.config.hidden_size
pair_emb = bert_emb * 3 + self.pw.shape
# pylint: disable=line-too-long
self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device)
self.we = WordEncoder(bert_emb, self.config).to(self.config.device)
self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)
self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)
self.zeros_predictor = nn.Sequential(
nn.Linear(bert_emb, bert_emb),
nn.ReLU(),
nn.Linear(bert_emb, 1)
).to(self.config.device)
if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros:
self.disable_zeros_predictor()
self.trainable: Dict[str, torch.nn.Module] = {
"bert": self.bert, "we": self.we,
"rough_scorer": self.rough_scorer,
"pw": self.pw, "a_scorer": self.a_scorer,
"sp": self.sp, "zeros_predictor": self.zeros_predictor
}
def disable_zeros_predictor(self):
nn.init.zeros_(self.zeros_predictor[-1].weight)
nn.init.zeros_(self.zeros_predictor[-1].bias)
def _build_optimizers(self):
n_docs = len(self._get_docs(self.config.train_data))
self.optimizers: Dict[str, torch.optim.Optimizer] = {}
self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {}
if not getattr(self.config, 'lora', False):
for param in self.bert.parameters():
param.requires_grad = self.config.bert_finetune
if self.config.bert_finetune:
logger.debug("Making bert optimizer with LR of %f", self.config.bert_learning_rate)
self.optimizers["bert_optimizer"] = torch.optim.Adam(
self.bert.parameters(), lr=self.config.bert_learning_rate
)
start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch)
if start_finetuning > 0:
logger.info("Will begin finetuning transformer at iteration %d", start_finetuning)
zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers["bert_optimizer"], factor=0, total_iters=start_finetuning)
warmup_scheduler = transformers.get_linear_schedule_with_warmup(
self.optimizers["bert_optimizer"],
start_finetuning, n_docs * self.config.train_epochs - start_finetuning)
self.schedulers["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
self.optimizers["bert_optimizer"],
schedulers=[zero_scheduler, warmup_scheduler],
milestones=[start_finetuning])
# Must ensure the same ordering of parameters between launches
modules = sorted((key, value) for key, value in self.trainable.items()
if key != "bert")
params = []
for _, module in modules:
for param in module.parameters():
param.requires_grad = True
params.append(param)
self.optimizers["general_optimizer"] = torch.optim.Adam(
params, lr=self.config.learning_rate)
self.schedulers["general_scheduler"] = \
transformers.get_linear_schedule_with_warmup(
self.optimizers["general_optimizer"],
0, n_docs * self.config.train_epochs
)
def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor,
singletons: bool = True):
if singletons:
antecedents = scores[:,1:].argmax(dim=1) - 1
# set the dummy values to -1, so that they are not coref to themselves
is_start = (scores[:, :2].argmax(dim=1) == 0)
else:
antecedents = scores.argmax(dim=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
nodes = [GraphNode(i) for i in range(len(doc["cased_words"]))]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
visited = {}
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
for i in cluster:
visited[i] = True
clusters.append(sorted(cluster))
if singletons:
# go through the is_start nodes; if no clusters contain that node
# i.e. visited[i] == False, we add it as a singleton
for indx, i in enumerate(is_start):
if i and not visited.get(indx, False):
clusters.append([indx])
return sorted(clusters)
def _get_docs(self, path: str) -> List[Doc]:
if path not in self._docs:
self._docs[path] = CorefDataset(path, self.config, self.tokenizer)
return self._docs[path]
@staticmethod
def _get_ground_truth(cluster_ids: torch.Tensor,
top_indices: torch.Tensor,
valid_pair_map: torch.Tensor,
cluster_starts: bool,
singletons:bool = True) -> torch.Tensor:
"""
Args:
cluster_ids: tensor of shape [n_words], containing cluster indices
for each word. Non-gold words have cluster id of zero.
top_indices: tensor of shape [n_words, n_ants],
indices of antecedents of each word
valid_pair_map: boolean tensor of shape [n_words, n_ants],
whether for pair at [i, j] (i-th word and j-th word)
j < i is True
Returns:
tensor of shape [n_words, n_ants + 1] (dummy added),
containing 1 at position [i, j] if i-th and j-th words corefer.
"""
y = cluster_ids[top_indices] * valid_pair_map # [n_words, n_ants]
y[y == 0] = -1 # -1 for non-gold words
y = utils.add_dummy(y) # [n_words, n_cands + 1]
if singletons:
if not cluster_starts:
unique, counts = cluster_ids.unique(return_counts=True)
singleton_clusters = unique[(counts == 1) & (unique != 0)]
first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters]
if len(first_corefs) > 0:
first_coref = torch.stack(first_corefs)
else:
first_coref = torch.tensor([]).to(cluster_ids.device).long()
else:
# I apologize for this abuse of everything that's good about PyTorch.
# in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value
# from cluster_ids. We need this information because we use it to mark the
# special "is-start-of-ref" marker used to detect singletons.
first_coref = (cluster_ids ==
cluster_ids.unique().sort().values[1:].unsqueeze(1)
).float().topk(k=1, dim=1).indices.squeeze()
y = (y == cluster_ids.unsqueeze(1)) # True if coreferent
# For all rows with no gold antecedents setting dummy to True
y[y.sum(dim=1) == 0, 0] = True
if singletons:
# add another dummy for first coref
y = utils.add_dummy(y) # [n_words, n_cands + 2]
# for all rows that's a first coref, setting its dummy to True and unset the
# non-coref dummy to false
y[first_coref, 0] = True
y[first_coref, 1] = False
return y.to(torch.float)
@staticmethod
def _load_config(config_path: str,
section: str) -> Config:
with open(config_path, "rb") as fin:
config = tomllib.load(fin)
default_section = config["DEFAULT"]
current_section = config[section]
unknown_keys = (set(current_section.keys())
- set(default_section.keys()))
if unknown_keys:
raise ValueError(f"Unexpected config keys: {unknown_keys}")
return Config(section, **{**default_section, **current_section})
def _set_training(self, value: bool):
self._training = value
for module in self.trainable.values():
module.train(self._training)
================================================
FILE: stanza/models/coref/pairwise_encoder.py
================================================
""" Describes PairwiseEncodes, that transforms pairwise features, such as
distance between the mentions, same/different speaker into feature embeddings
"""
from typing import List
import torch
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc
class PairwiseEncoder(torch.nn.Module):
""" A Pytorch module to obtain feature embeddings for pairwise features
Usage:
encoder = PairwiseEncoder(config)
pairwise_features = encoder(pair_indices, doc)
"""
def __init__(self, config: Config):
super().__init__()
emb_size = config.embedding_size
self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw",
"pt", "tc", "wb"])}
self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size)
# each position corresponds to a bucket:
# [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8),
# (8, 16), (16, 32), (32, 64), (64, float("inf"))]
self.distance_emb = torch.nn.Embedding(9, emb_size)
# two possibilities: same vs different speaker
self.speaker_emb = torch.nn.Embedding(2, emb_size)
self.dropout = torch.nn.Dropout(config.dropout_rate)
self.__full_pw = config.full_pairwise
if self.__full_pw:
self.shape = emb_size * 2 # distance, speaker
else:
self.shape = emb_size # distance only
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.genre_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor,
doc: Doc) -> torch.Tensor:
word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device)
# bucketing the distance (see __init__())
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = self.distance_emb(distance)
if not self.__full_pw:
return self.dropout(distance)
# calculate speaker embeddings
speaker_map = torch.tensor(self._speaker_map(doc), device=self.device)
same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1))
same_speaker = self.speaker_emb(same_speaker.to(torch.long))
return self.dropout(torch.cat((same_speaker, distance), dim=2))
@staticmethod
def _speaker_map(doc: Doc) -> List[int]:
"""
Returns a tensor where i-th element is the speaker id of i-th word.
"""
# if speaker is not found in the doc, simply return "speaker#1" for all the speakers
# and embed them using the same ID
# speaker string -> speaker id
str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1"
for _ in range(len(doc["cased_words"]))])))}
# word id -> speaker id
return [str2int[s] for s in doc.get("speaker", ["speaker#1"
for _ in range(len(doc["cased_words"]))])]
================================================
FILE: stanza/models/coref/predict.py
================================================
import argparse
import json
import torch
from tqdm import tqdm
from stanza.models.coref.model import CorefModel
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("experiment")
argparser.add_argument("input_file")
argparser.add_argument("output_file")
argparser.add_argument("--config-file", default="config.toml")
argparser.add_argument("--batch-size", type=int,
help="Adjust to override the config value if you're"
" experiencing out-of-memory issues")
argparser.add_argument("--weights",
help="Path to file with weights to load."
" If not supplied, in the latest"
" weights of the experiment will be loaded;"
" if there aren't any, an error is raised.")
args = argparser.parse_args()
model = CorefModel.load_model(path=args.weights,
map_location="cpu",
ignore={"bert_optimizer", "general_optimizer",
"bert_scheduler", "general_scheduler"})
if args.batch_size:
model.config.a_scoring_batch_size = args.batch_size
model.training = False
try:
with open(args.input_file, encoding="utf-8") as fin:
input_data = json.load(fin)
except json.decoder.JSONDecodeError:
# read the old jsonlines format if necessary
with open(args.input_file, encoding="utf-8") as fin:
text = "[" + ",\n".join(fin) + "]"
input_data = json.loads(text)
docs = [model.build_doc(doc) for doc in input_data]
with torch.no_grad():
for doc in tqdm(docs, unit="docs"):
result = model.run(doc)
doc["span_clusters"] = result.span_clusters
doc["word_clusters"] = result.word_clusters
for key in ("word2subword", "subwords", "word_id", "head2span"):
del doc[key]
with open(args.output_file, mode="w") as fout:
for doc in docs:
json.dump(doc, fout)
================================================
FILE: stanza/models/coref/rough_scorer.py
================================================
""" Describes RoughScorer, a simple bilinear module to calculate rough
anaphoricity scores.
"""
from typing import Tuple
import torch
from stanza.models.coref.config import Config
class RoughScorer(torch.nn.Module):
"""
Is needed to give a roughly estimate of the anaphoricity of two candidates,
only top scoring candidates are considered on later steps to reduce
computational complexity.
"""
def __init__(self, features: int, config: Config):
super().__init__()
self.dropout = torch.nn.Dropout(config.dropout_rate)
self.bilinear = torch.nn.Linear(features, features)
self.k = config.rough_k
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
mentions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns rough anaphoricity scores for candidates, which consist of
the bilinear output of the current model summed with mention scores.
"""
# [n_mentions, n_mentions]
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
return self._prune(rough_scores)
def _prune(self,
rough_scores: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Selects top-k rough antecedent scores for each mention.
Args:
rough_scores: tensor of shape [n_mentions, n_mentions], containing
rough antecedent scores of each mention-antecedent pair.
Returns:
FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices
"""
top_scores, indices = torch.topk(rough_scores,
k=min(self.k, len(rough_scores)),
dim=1, sorted=False)
return top_scores, indices, rough_scores
================================================
FILE: stanza/models/coref/span_predictor.py
================================================
""" Describes SpanPredictor which aims to predict spans by taking as input
head word and context embeddings.
"""
from typing import List, Optional, Tuple
from stanza.models.coref.const import Doc, Span
import torch
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, distance_emb_size: int):
super().__init__()
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(input_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64),
)
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.ffnn.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
doc: Doc,
words: torch.Tensor,
heads_ids: torch.Tensor) -> torch.Tensor:
"""
Calculates span start/end scores of words for each span head in
heads_ids
Args:
doc (Doc): the document data
words (torch.Tensor): contextual embeddings for each word in the
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
emb_ids = relative_positions + 63 # make all valid distances positive
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far"
# Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(doc["sent_id"], device=words.device)
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat((
words[heads_ids[rows]],
words[cols],
self.emb(emb_ids[rows, cols]),
), dim=1)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
def get_training_data(self,
doc: Doc,
words: torch.Tensor
) -> Tuple[Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]]]:
""" Returns span starts/ends for gold mentions in the document. """
head2span = sorted(doc["head2span"])
if not head2span:
return None, None
heads, starts, ends = zip(*head2span)
heads = torch.tensor(heads, device=self.device)
starts = torch.tensor(starts, device=self.device)
ends = torch.tensor(ends, device=self.device) - 1
return self(doc, words, heads), (starts, ends)
def predict(self,
doc: Doc,
words: torch.Tensor,
clusters: List[List[int]]) -> List[List[Span]]:
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
heads_ids = torch.tensor(
sorted(i for cluster in clusters for i in cluster),
device=self.device
)
scores = self(doc, words, heads_ids)
starts = scores[:, :, 0].argmax(dim=1).tolist()
ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()
head2span = {
head: (start, end)
for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster]
for cluster in clusters]
================================================
FILE: stanza/models/coref/tokenizer_customization.py
================================================
""" This file defines functions used to modify the default behaviour
of transformers.AutoTokenizer. These changes are necessary, because some
tokenizers are meant to be used with raw text, while the OntoNotes documents
have already been split into words.
All the functions are used in coref_model.CorefModel._get_docs. """
# Filters out unwanted tokens produced by the tokenizer
TOKENIZER_FILTERS = {
"albert-xxlarge-v2": (lambda token: token != "▁"), # U+2581, not just "_"
"albert-large-v2": (lambda token: token != "▁"),
}
# Maps some words to tokens directly, without a tokenizer
TOKENIZER_MAPS = {
"roberta-large": {".": ["."], ",": [","], "!": ["!"], "?": ["?"],
":":[":"], ";":[";"], "'s": ["'s"]}
}
================================================
FILE: stanza/models/coref/utils.py
================================================
""" Contains functions not directly linked to coreference resolution """
from typing import List, Set
import torch
import torch.nn.functional as F
from stanza.models.coref.const import EPSILON
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
return torch.cat((dummy, tensor), dim=1)
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float): Weighting factor in range [0, 1] to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
Returns:
Loss tensor with the reduction option applied.
"""
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
if not (0 <= alpha <= 1) and alpha != -1:
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
================================================
FILE: stanza/models/coref/word_encoder.py
================================================
""" Describes WordEncoder. Extracts mention vectors from bert-encoded text.
"""
from typing import Tuple
import torch
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc
class WordEncoder(torch.nn.Module): # pylint: disable=too-many-instance-attributes
""" Receives bert contextual embeddings of a text, extracts all the
possible mentions in that text. """
def __init__(self, features: int, config: Config):
"""
Args:
features (int): the number of featues in the input embeddings
config (Config): the configuration of the current session
"""
super().__init__()
self.attn = torch.nn.Linear(in_features=features, out_features=1)
self.dropout = torch.nn.Dropout(config.dropout_rate)
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.attn.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
doc: Doc,
x: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""
Extracts word representations from text.
Args:
doc: the document data
x: a tensor containing bert output, shape (n_subtokens, bert_dim)
Returns:
words: a Tensor of shape [n_words, mention_emb];
mention representations
cluster_ids: tensor of shape [n_words], containing cluster indices
for each word. Non-coreferent words have cluster id of zero.
"""
word_boundaries = torch.tensor(doc["word2subword"], device=self.device)
starts = word_boundaries[:, 0]
ends = word_boundaries[:, 1]
# [n_mentions, features]
words = self._attn_scores(x, starts, ends).mm(x)
words = self.dropout(words)
return (words, self._cluster_ids(doc))
def _attn_scores(self,
bert_out: torch.Tensor,
word_starts: torch.Tensor,
word_ends: torch.Tensor) -> torch.Tensor:
""" Calculates attention scores for each of the mentions.
Args:
bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings
for each of the subwords in the document
word_starts (torch.Tensor): [n_words], start indices of words
word_ends (torch.Tensor): [n_words], end indices of words
Returns:
torch.Tensor: [description]
"""
n_subtokens = len(bert_out)
n_words = len(word_starts)
# [n_mentions, n_subtokens]
# with 0 at positions belonging to the words and -inf elsewhere
attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens))
attn_mask = ((attn_mask >= word_starts.unsqueeze(1))
* (attn_mask < word_ends.unsqueeze(1)))
# if first row all False, set col 0 to True
# otherwise, set the row to be the previous row?
word_lengths = torch.sum(attn_mask, dim=1)
if torch.any(word_lengths == 0):
raise ValueError("Found a blank word in training data! This will break everything, starting with the attention masks, as some rows of the scoring table will be set to entirely -inf and then softmax to NaN.")
attn_mask = torch.log(attn_mask.to(torch.float))
attn_scores = self.attn(bert_out).T # [1, n_subtokens]
attn_scores = attn_scores.expand((n_words, n_subtokens))
attn_scores = attn_mask + attn_scores
del attn_mask
return torch.softmax(attn_scores, dim=1) # [n_words, n_subtokens]
def _cluster_ids(self, doc: Doc) -> torch.Tensor:
"""
Args:
doc: document information
Returns:
torch.Tensor of shape [n_word], containing cluster indices for
each word. Non-coreferent words have cluster id of zero.
"""
word2cluster = {word_i: i
for i, cluster in enumerate(doc["word_clusters"], start=1)
for word_i in cluster}
return torch.tensor(
[word2cluster.get(word_i, 0)
for word_i in range(len(doc["cased_words"]))],
device=self.device
)
================================================
FILE: stanza/models/depparse/__init__.py
================================================
================================================
FILE: stanza/models/depparse/data.py
================================================
import random
import logging
import torch
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab
from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
from stanza.models.common.doc import *
logger = logging.getLogger('stanza')
def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):
"""
Given a list of lists, where the first element of each sublist
represents the sentence, group the sentences into batches.
During training mode (not eval_mode) the sentences are sorted by
length with a bit of random shuffling. During eval mode, the
sentences are sorted by length if sort_during_eval is true.
Refactored from the data structure in case other models could use
it and for ease of testing.
Returns (batches, original_order), where original_order is None
when in train mode or when unsorted and represents the original
location of each sentence in the sort
"""
res = []
if not eval_mode:
# sort sentences (roughly) by length for better memory utilization
data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)
data_orig_idx = None
elif sort_during_eval:
(data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
else:
data_orig_idx = None
current = []
currentlen = 0
for x in data:
if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:
if currentlen > 0:
res.append(current)
current = []
currentlen = 0
res.append([x])
else:
if len(x[0]) + currentlen > batch_size and currentlen > 0:
res.append(current)
current = []
currentlen = 0
current.append(x)
currentlen += len(x[0])
if currentlen > 0:
res.append(current)
return res, data_orig_idx
class DataLoader:
def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):
self.batch_size = batch_size
self.min_length_to_batch_separately=min_length_to_batch_separately
self.args = args
self.eval = evaluation
self.shuffled = not self.eval
self.sort_during_eval = sort_during_eval
self.doc = doc
data = self.load_doc(doc)
# handle vocab
if vocab is None:
self.vocab = self.init_vocab(data)
else:
self.vocab = vocab
# filter out the long sentences if bert is used
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
# handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
self.pretrain_vocab = None
if pretrain is not None and args['pretrain']:
self.pretrain_vocab = pretrain.vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
# shuffle for training
if self.shuffled:
random.shuffle(data)
self.num_examples = len(data)
# chunk into batches
self.data = self.chunk_batches(data)
logger.debug("{} batches created.".format(len(self.data)))
def init_vocab(self, data):
assert self.eval == False # for eval vocab must exist
cutoff = self.args['word_cutoff'] if self.args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF
charvocab = CharVocab(data, self.args['shorthand'])
wordvocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, lower=True)
uposvocab = WordVocab(data, self.args['shorthand'], idx=1)
xposvocab = xpos_vocab_factory(data, self.args['shorthand'])
featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)
lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, idx=4, lower=True)
deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)
vocab = MultiVocab({'char': charvocab,
'word': wordvocab,
'upos': uposvocab,
'xpos': xposvocab,
'feats': featsvocab,
'lemma': lemmavocab,
'deprel': deprelvocab})
return vocab
def preprocess(self, data, vocab, pretrain_vocab, args):
processed = []
xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]
feats_replacement = [[ROOT_ID] * len(vocab['feats'])]
for sent in data:
processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]
processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]
processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]
processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]
processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]
if pretrain_vocab is not None:
# always use lowercase lookup in pretrained vocab
processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]
else:
processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]
processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]
processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]
processed_sent += [vocab['deprel'].map([w[6] for w in sent])]
processed_sent.append([w[0] for w in sent])
processed.append(processed_sent)
return processed
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 10
# sort sentences by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
# sort words by lens for easy char-RNN operations
batch_words = [w for sent in batch[1] for w in sent]
word_lens = [len(x) for x in batch_words]
batch_words, word_orig_idx = sort_all([batch_words], word_lens)
batch_words = batch_words[0]
word_lens = [len(x) for x in batch_words]
# convert to tensors
words = batch[0]
words = get_long_tensor(words, batch_size)
words_mask = torch.eq(words, PAD_ID)
wordchars = get_long_tensor(batch_words, len(word_lens))
wordchars_mask = torch.eq(wordchars, PAD_ID)
upos = get_long_tensor(batch[2], batch_size)
xpos = get_long_tensor(batch[3], batch_size)
ufeats = get_long_tensor(batch[4], batch_size)
pretrained = get_long_tensor(batch[5], batch_size)
sentlens = [len(x) for x in batch[0]]
lemma = get_long_tensor(batch[6], batch_size)
head = get_long_tensor(batch[7], batch_size)
deprel = get_long_tensor(batch[8], batch_size)
text = batch[9]
return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text
def load_doc(self, doc):
data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)
data = self.resolve_none(data)
data = simplify_punct(data)
return data
def resolve_none(self, data):
# replace None to '_'
for sent_idx in range(len(data)):
for tok_idx in range(len(data[sent_idx])):
for feat_idx in range(len(data[sent_idx][tok_idx])):
if data[sent_idx][tok_idx][feat_idx] is None:
data[sent_idx][tok_idx][feat_idx] = '_'
return data
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def set_batch_size(self, batch_size):
self.batch_size = batch_size
def reshuffle(self):
data = [y for x in self.data for y in x]
self.data = self.chunk_batches(data)
random.shuffle(self.data)
def chunk_batches(self, data):
batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,
eval_mode=self.eval, sort_during_eval=self.sort_during_eval,
min_length_to_batch_separately=self.min_length_to_batch_separately)
# data_orig_idx might be None at train time, since we don't anticipate unsorting
self.data_orig_idx = data_orig_idx
return batches
def to_int(string, ignore_error=False):
try:
res = int(string)
except ValueError as err:
if ignore_error:
return 0
else:
raise err
return res
================================================
FILE: stanza/models/depparse/model.py
================================================
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.common.biaffine import DeepBiaffineScorer
from stanza.models.common.foundation_cache import load_charlm
from stanza.models.common.hlstm import HighwayLSTM
from stanza.models.common.dropout import WordDropout
from stanza.models.common.utils import attach_bert_model
from stanza.models.common.vocab import CompositeVocab
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
from stanza.models.common import utils
logger = logging.getLogger('stanza')
class Parser(nn.Module):
def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
super().__init__()
self.vocab = vocab
self.args = args
self.unsaved_modules = []
# input layers
input_size = 0
if self.args['word_emb_dim'] > 0:
# frequent word embeddings
self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
self.lemma_emb = nn.Embedding(len(vocab['lemma']), self.args['word_emb_dim'], padding_idx=0)
input_size += self.args['word_emb_dim'] * 2
if self.args['tag_emb_dim'] > 0:
if self.args.get('use_upos', True):
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
if self.args.get('use_xpos', True):
if not isinstance(vocab['xpos'], CompositeVocab):
self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0)
else:
self.xpos_emb = nn.ModuleList()
for l in vocab['xpos'].lens():
self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
input_size += self.args['tag_emb_dim']
if self.args.get('use_ufeats', True):
self.ufeats_emb = nn.ModuleList()
for l in vocab['feats'].lens():
self.ufeats_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
input_size += self.args['tag_emb_dim']
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args.get('charlm', None):
if self.args['charlm_forward_file'] is None or not os.path.exists(self.args['charlm_forward_file']):
raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(self.args['charlm_forward_file']))
if self.args['charlm_backward_file'] is None or not os.path.exists(self.args['charlm_backward_file']):
raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(self.args['charlm_backward_file']))
logger.debug("Depparse model loading charmodels: %s and %s", self.args['charlm_forward_file'], self.args['charlm_backward_file'])
self.add_unsaved_module('charmodel_forward', load_charlm(self.args['charlm_forward_file'], foundation_cache=foundation_cache))
self.add_unsaved_module('charmodel_backward', load_charlm(self.args['charlm_backward_file'], foundation_cache=foundation_cache))
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
else:
self.charmodel = CharacterModel(self.args, vocab)
self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
input_size += self.args['transformed_dim']
self.peft_name = peft_name
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
if self.args.get('bert_model', None):
# TODO: refactor bert_hidden_layers between the different models
if self.args.get('bert_hidden_layers', False):
# The average will be offset by 1/N so that the default zeros
# represents an average of the N layers
self.bert_layer_mix = nn.Linear(self.args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
input_size += self.bert_model.config.hidden_size
if self.args['pretrain']:
# pretrained embeddings, by default this won't be saved into model file
self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
input_size += self.args['transformed_dim']
# recurrent layers
self.parserlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
self.parserlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
self.parserlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
# dropout
self.drop = nn.Dropout(self.args['dropout'])
self.worddrop = WordDropout(self.args['word_dropout'])
# classifiers
# args.get to preserve old models, including models other people might have created
if self.args.get('use_arc_embedding'):
logger.debug("Using arc embedding enhancement")
self.arc_embedding = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], self.args['deep_biaff_output_dim'], pairwise=True, dropout=self.args['dropout'])
self.unlabeled_linear = nn.Sequential(self.drop,
nn.Linear(self.args['deep_biaff_output_dim'], 1))
self.deprel_linear = nn.Sequential(self.drop,
nn.Linear(self.args['deep_biaff_output_dim'], 2 * self.args['deep_biaff_output_dim']),
nn.ReLU(),
self.drop,
nn.Linear(self.args['deep_biaff_output_dim'] * 2, len(vocab['deprel'])))
else:
logger.debug("Not using arc embedding enhancement")
self.unlabeled = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])
self.deprel = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], len(vocab['deprel']), pairwise=True, dropout=self.args['dropout'])
if self.args['linearization']:
self.linearization = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])
if self.args['distance']:
self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])
# criterion
self.crit = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') # ignore padding
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def log_norms(self):
utils.log_norms(self)
def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
def pack(x):
return pack_padded_sequence(x, sentlens, batch_first=True)
inputs = []
if self.args['pretrain']:
pretrained_emb = self.pretrained_emb(pretrained)
pretrained_emb = self.trans_pretrained(pretrained_emb)
pretrained_emb = pack(pretrained_emb)
inputs += [pretrained_emb]
#def pad(x):
# return pad_packed_sequence(PackedSequence(x, pretrained_emb.batch_sizes), batch_first=True)[0]
if self.args['word_emb_dim'] > 0:
word_emb = self.word_emb(word)
word_emb = pack(word_emb)
lemma_emb = self.lemma_emb(lemma)
lemma_emb = pack(lemma_emb)
inputs += [word_emb, lemma_emb]
if self.args['tag_emb_dim'] > 0:
if self.args.get('use_upos', True):
pos_emb = self.upos_emb(upos)
else:
pos_emb = 0
if self.args.get('use_xpos', True):
if isinstance(self.vocab['xpos'], CompositeVocab):
for i in range(len(self.vocab['xpos'])):
pos_emb += self.xpos_emb[i](xpos[:, :, i])
else:
pos_emb += self.xpos_emb(xpos)
if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
pos_emb = pack(pos_emb)
inputs += [pos_emb]
if self.args.get('use_ufeats', True):
feats_emb = 0
for i in range(len(self.vocab['feats'])):
feats_emb += self.ufeats_emb[i](ufeats[:, :, i])
feats_emb = pack(feats_emb)
inputs += [pos_emb]
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args.get('charlm', None):
# \n is to add a somewhat neutral "word" for the ROOT
charlm_text = [["\n"] + x for x in text]
all_forward_chars = self.charmodel_forward.build_char_representation(charlm_text)
all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
all_backward_chars = self.charmodel_backward.build_char_representation(charlm_text)
all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
inputs += [all_forward_chars, all_backward_chars]
else:
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
inputs += [char_reps]
if self.bert_model is not None:
device = next(self.parameters()).device
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.args.get('bert_finetune', False) or not self.training,
peft_name=self.peft_name)
if self.bert_layer_mix is not None:
# use a linear layer to weighted average the embedding dynamically
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
# we are using the first endpoint from the transformer as the "word" for ROOT
processed_bert = [x[:-1, :] for x in processed_bert]
processed_bert = pad_sequence(processed_bert, batch_first=True)
inputs += [pack(processed_bert)]
lstm_inputs = torch.cat([x.data for x in inputs], 1)
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
lstm_inputs = self.drop(lstm_inputs)
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
lstm_outputs, _ = self.parserlstm(lstm_inputs, sentlens, hx=(self.parserlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.parserlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
lstm_outputs, _ = pad_packed_sequence(lstm_outputs, batch_first=True)
if self.args.get('use_arc_embedding'):
arc_scores = self.arc_embedding(self.drop(lstm_outputs), self.drop(lstm_outputs))
unlabeled_scores = self.unlabeled_linear(arc_scores).squeeze(3)
deprel_scores = self.deprel_linear(arc_scores).squeeze(3)
else:
unlabeled_scores = self.unlabeled(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
deprel_scores = self.deprel(self.drop(lstm_outputs), self.drop(lstm_outputs))
#goldmask = head.new_zeros(*head.size(), head.size(-1)+1, dtype=torch.uint8)
#goldmask.scatter_(2, head.unsqueeze(2), 1)
if self.args['linearization'] or self.args['distance']:
head_offset = torch.arange(word.size(1), device=head.device).view(1, 1, -1).expand(word.size(0), -1, -1) - torch.arange(word.size(1), device=head.device).view(1, -1, 1).expand(word.size(0), -1, -1)
if self.args['linearization']:
lin_scores = self.linearization(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
unlabeled_scores += F.logsigmoid(lin_scores * torch.sign(head_offset).float()).detach()
if self.args['distance']:
dist_scores = self.distance(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
dist_pred = 1 + F.softplus(dist_scores)
dist_target = torch.abs(head_offset)
dist_kld = -torch.log((dist_target.float() - dist_pred)**2/2 + 1)
unlabeled_scores += dist_kld.detach()
diag = torch.eye(head.size(-1)+1, dtype=torch.bool, device=head.device).unsqueeze(0)
unlabeled_scores.masked_fill_(diag, -float('inf'))
preds = []
if self.training:
unlabeled_scores = unlabeled_scores[:, 1:, :] # exclude attachment for the root symbol
unlabeled_scores = unlabeled_scores.masked_fill(word_mask.unsqueeze(1), -float('inf'))
unlabeled_target = head.masked_fill(word_mask[:, 1:], -1)
loss = self.crit(unlabeled_scores.contiguous().view(-1, unlabeled_scores.size(2)), unlabeled_target.view(-1))
deprel_scores = deprel_scores[:, 1:] # exclude attachment for the root symbol
#deprel_scores = deprel_scores.masked_select(goldmask.unsqueeze(3)).view(-1, len(self.vocab['deprel']))
deprel_scores = torch.gather(deprel_scores, 2, head.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, len(self.vocab['deprel']))).view(-1, len(self.vocab['deprel']))
deprel_target = deprel.masked_fill(word_mask[:, 1:], -1)
loss += self.crit(deprel_scores.contiguous(), deprel_target.view(-1))
if self.args['linearization']:
#lin_scores = lin_scores[:, 1:].masked_select(goldmask)
lin_scores = torch.gather(lin_scores[:, 1:], 2, head.unsqueeze(2)).view(-1)
lin_scores = torch.cat([-lin_scores.unsqueeze(1)/2, lin_scores.unsqueeze(1)/2], 1)
#lin_target = (head_offset[:, 1:] > 0).long().masked_select(goldmask)
lin_target = torch.gather((head_offset[:, 1:] > 0).long(), 2, head.unsqueeze(2))
loss += self.crit(lin_scores.contiguous(), lin_target.view(-1))
if self.args['distance']:
#dist_kld = dist_kld[:, 1:].masked_select(goldmask)
# dist_kld[:, 1:] so that the root isn't included in the distance calculation
dist_kld = torch.gather(dist_kld[:, 1:], 2, head.unsqueeze(2))
loss -= dist_kld.sum()
loss /= wordchars.size(0) # number of words
else:
loss = 0
preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy())
preds.append(deprel_scores.max(3)[1].detach().cpu().numpy())
return loss, preds
================================================
FILE: stanza/models/depparse/scorer.py
================================================
"""
Utils and wrappers for scoring parsers.
"""
from collections import Counter
import logging
from stanza.models.common.utils import ud_scores
logger = logging.getLogger('stanza')
def score_named_dependencies(pred_doc, gold_doc, output_latex=False):
if len(pred_doc.sentences) != len(gold_doc.sentences):
logger.warning("Not evaluating individual dependency F1 on accound of document length mismatch")
return
for sent_idx, (x, y) in enumerate(zip(pred_doc.sentences, gold_doc.sentences)):
if len(x.words) != len(y.words):
logger.warning("Not evaluating individual dependency F1 on accound of sentence length mismatch")
return
tp = Counter()
fp = Counter()
fn = Counter()
for pred_sentence, gold_sentence in zip(pred_doc.sentences, gold_doc.sentences):
for pred_word, gold_word in zip(pred_sentence.words, gold_sentence.words):
if pred_word.head == gold_word.head and pred_word.deprel == gold_word.deprel:
tp[gold_word.deprel] = tp[gold_word.deprel] + 1
else:
fn[gold_word.deprel] = fn[gold_word.deprel] + 1
fp[pred_word.deprel] = fp[pred_word.deprel] + 1
labels = sorted(set(tp.keys()).union(fp.keys()).union(fn.keys()))
max_len = max(len(x) for x in labels)
log_lines = []
#log_line_fmt = "%" + str(max_len) + "s: p %.4f r %.4f f1 %.4f (%d actual)"
if output_latex:
log_lines.append(r"\begin{tabular}{lrr}")
log_lines.append(r"Reln & F1 & Total \\")
log_line_fmt = "{label} & {f1:0.4f} & {actual} \\\\"
else:
log_line_fmt = "{label:>" + str(max_len) + "s}: p {precision:0.4f} r {recall:0.4f} f1 {f1:0.4f} ({actual} actual)"
for label in labels:
if tp[label] == 0:
precision = 0
recall = 0
f1 = 0
else:
precision = tp[label] / (tp[label] + fp[label])
recall = tp[label] / (tp[label] + fn[label])
f1 = 2 * (precision * recall) / (precision + recall)
actual = tp[label] + fn[label]
template = {
'label': label,
'precision': precision,
'recall': recall,
'f1': f1,
'actual': actual
}
log_lines.append(log_line_fmt.format(**template))
if output_latex:
log_lines.append(r"\end{tabular}")
logger.info("F1 scores for each dependency:\n Note that unlabeled attachment errors hurt the labeled attachment scores\n%s" % "\n".join(log_lines))
def score(system_conllu_file, gold_conllu_file, verbose=True):
""" Wrapper for UD parser scorer. """
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
el = evaluation['LAS']
p = el.precision
r = el.recall
f = el.f1
if verbose:
scores = [evaluation[k].f1 * 100 for k in ['LAS', 'MLAS', 'BLEX']]
logger.info("LAS\tMLAS\tBLEX")
logger.info("{:.2f}\t{:.2f}\t{:.2f}".format(*scores))
return p, r, f
================================================
FILE: stanza/models/depparse/trainer.py
================================================
"""
A trainer class to handle training and testing of models.
"""
import copy
import sys
import logging
import torch
from torch import nn
try:
import transformers
except ImportError:
pass
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common import utils, loss
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
from stanza.models.depparse.model import Parser
from stanza.models.pos.vocab import MultiVocab
logger = logging.getLogger('stanza')
def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [b.to(device) if b is not None else None for b in batch[:11]]
orig_idx = batch[11]
word_orig_idx = batch[12]
sentlens = batch[13]
wordlens = batch[14]
text = batch[15]
return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
device=None, foundation_cache=None, ignore_model_config=False, reset_history=False):
self.global_step = 0
self.last_best_step = 0
self.dev_score_history = []
orig_args = copy.deepcopy(args)
# whether the training is in primary or secondary stage
# during FT (loading weights), etc., the training is considered to be in "secondary stage"
# during this time, we (optionally) use a different set of optimizers than that during "primary stage".
#
# Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary
if model_file is not None:
# load everything from file
self.load(model_file, pretrain, args, foundation_cache, device)
if reset_history:
self.global_step = 0
self.last_best_step = 0
self.dev_score_history = []
else:
# build model from scratch
self.args = args
self.vocab = vocab
bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
peft_name = None
if self.args['use_peft']:
# fine tune the bert if we're using peft
self.args['bert_finetune'] = True
peft_name = "depparse"
bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
self.model = self.model.to(device)
self.__init_optim()
self.fallback = self.vocab['deprel'].unit2id('dep') if 'dep' in self.vocab['deprel'] else None
if ignore_model_config:
self.args = orig_args
if self.args.get('wandb'):
import wandb
# track gradients!
wandb.watch(self.model, log_freq=4, log="all", log_graph=True)
def __init_optim(self):
# TODO: can get rid of args.get when models are rebuilt
if (self.args.get("second_stage", False) and self.args.get('second_optim')):
self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model,
self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,
bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0),
is_peft=self.args.get('use_peft', False),
bert_finetune_layers=self.args.get('bert_finetune_layers', None))
else:
self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model,
self.args['lr'], betas=(0.9, self.args['beta2']),
eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0),
weight_decay=self.args.get('weight_decay', None),
bert_weight_decay=self.args.get('bert_weight_decay', 0.0),
is_peft=self.args.get('use_peft', False),
bert_finetune_layers=self.args.get('bert_finetune_layers', None))
self.scheduler = {}
if self.args.get("second_stage", False) and self.args.get('second_optim'):
if self.args.get('second_warmup_steps', None):
for name, optimizer in self.optimizer.items():
name = name + "_scheduler"
warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps'])
self.scheduler[name] = warmup_scheduler
else:
if "bert_optimizer" in self.optimizer:
zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer["bert_optimizer"], factor=0, total_iters=self.args['bert_start_finetuning'])
warmup_scheduler = transformers.get_constant_schedule_with_warmup(
self.optimizer["bert_optimizer"],
self.args['bert_warmup_steps'])
self.scheduler["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
self.optimizer["bert_optimizer"],
schedulers=[zero_scheduler, warmup_scheduler],
milestones=[self.args['bert_start_finetuning']])
def update(self, batch, eval=False):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs
if eval:
self.model.eval()
else:
self.model.train()
for opt in self.optimizer.values():
opt.zero_grad()
loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
loss_val = loss.data.item()
if eval:
return loss_val
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
for opt in self.optimizer.values():
opt.step()
for scheduler in self.scheduler.values():
scheduler.step()
return loss_val
def predict(self, batch, unsort=True):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs
self.model.eval()
batch_size = word.size(0)
_, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
# TODO: would be cleaner for the model to not have the capability to produce predictions < VOCAB_PREFIX_SIZE
if self.fallback is not None:
preds[1][preds[1] < VOCAB_PREFIX_SIZE] = self.fallback
head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root
deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)]
pred_tokens = [[[head_seqs[i][j], deprel_seqs[i][j]] for j in range(sentlens[i]-1)] for i in range(batch_size)]
if unsort:
pred_tokens = utils.unsort(pred_tokens, orig_idx)
return pred_tokens
def save(self, filename, skip_modules=True, save_optimizer=False):
model_state = self.model.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
'config': self.args,
'global_step': self.global_step,
'last_best_step': self.last_best_step,
'dev_score_history': self.dev_score_history,
}
if self.args.get('use_peft', False):
# Hide import so that peft dependency is optional
from peft import get_peft_model_state_dict
params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
if save_optimizer and self.optimizer is not None:
params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()}
params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except BaseException as e:
logger.warning("Saving failed... continuing anyway. Error was: %s" % e)
def load(self, filename, pretrain, args=None, foundation_cache=None, device=None):
"""
Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
"""
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None: self.args.update(args)
# preserve old models which were created before transformers were added
if 'bert_model' not in self.args:
self.args['bert_model'] = None
lora_weights = checkpoint.get('bert_lora')
if lora_weights:
logger.debug("Found peft weights for depparse; loading a peft adapter")
self.args["use_peft"] = True
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
# load model
emb_matrix = None
if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
emb_matrix = pretrain.emb
# TODO: refactor this common block of code with NER
force_bert_saved = False
peft_name = None
if self.args.get('use_peft', False):
force_bert_saved = True
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "depparse", foundation_cache)
bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
logger.debug("Loaded peft with name %s", peft_name)
else:
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
force_bert_saved = True
bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
self.model.load_state_dict(checkpoint['model'], strict=False)
if device is not None:
self.model = self.model.to(device)
self.__init_optim()
optim_state_dict = checkpoint.get("optimizer_state_dict")
if optim_state_dict:
for k, state in optim_state_dict.items():
self.optimizer[k].load_state_dict(state)
scheduler_state_dict = checkpoint.get("scheduler_state_dict")
if scheduler_state_dict:
for k, state in scheduler_state_dict.items():
self.scheduler[k].load_state_dict(state)
self.global_step = checkpoint.get("global_step", 0)
self.last_best_step = checkpoint.get("last_best_step", 0)
self.dev_score_history = checkpoint.get("dev_score_history", list())
================================================
FILE: stanza/models/identity_lemmatizer.py
================================================
"""
An identity lemmatizer that mimics the behavior of a normal lemmatizer but directly uses word as lemma.
"""
import os
import argparse
import logging
import random
from stanza.models.lemma.data import DataLoader
from stanza.models.lemma import scorer
from stanza.models.common import utils
from stanza.models.common.doc import *
from stanza.utils.conll import CoNLL
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--shorthand', type=str, help='Shorthand')
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--seed', type=int, default=1234)
args = parser.parse_args(args=args)
return args
def main(args=None):
args = parse_args(args=args)
random.seed(args.seed)
args = vars(args)
logger.info("[Launching identity lemmatizer...]")
if args['mode'] == 'train':
logger.info("[No training is required; will only generate evaluation output...]")
document = CoNLL.conll2doc(input_file=args['eval_file'])
batch = DataLoader(document, args['batch_size'], args, evaluation=True, conll_only=True)
system_pred_file = args['output_file']
gold_file = args['gold_file']
# use identity mapping for prediction
preds = batch.doc.get([TEXT])
# write to file and score
batch.doc.set([LEMMA], preds)
if system_pred_file is not None:
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if gold_file is not None:
system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, gold_file)
logger.info("Lemma score:")
logger.info("{} {:.2f}".format(args['shorthand'], score*100))
return None, batch.doc
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/lang_identifier.py
================================================
"""
Entry point for training and evaluating a Bi-LSTM language identifier
"""
import argparse
import json
import logging
import os
import random
import torch
from datetime import datetime
from stanza.models.common import utils
from stanza.models.langid.data import DataLoader
from stanza.models.langid.trainer import Trainer
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
logger = logging.getLogger('stanza')
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--batch_mode", help="custom settings when running in batch mode", action="store_true")
parser.add_argument("--batch_size", help="batch size for training", type=int, default=64)
parser.add_argument("--eval_length", help="length of strings to eval on", type=int, default=None)
parser.add_argument("--eval_set", help="eval on dev or test", default="test")
parser.add_argument("--data_dir", help="directory with train/dev/test data", default=None)
parser.add_argument("--load_name", help="path to load model from", default=None)
parser.add_argument("--mode", help="train or eval", default="train")
parser.add_argument("--num_epochs", help="number of epochs for training", type=int, default=50)
parser.add_argument("--randomize", help="take random substrings of samples", action="store_true")
parser.add_argument("--randomize_lengths_range", help="range of lengths to use when random sampling text",
type=randomize_lengths_range, default="5,20")
parser.add_argument("--merge_labels_for_eval",
help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")",
action="store_true")
parser.add_argument("--save_best_epochs", help="save model for every epoch with new best score", action="store_true")
parser.add_argument("--save_name", help="where to save model", default=None)
utils.add_device_args(parser)
args = parser.parse_args(args=args)
return args
def randomize_lengths_range(range_list):
"""
Range of lengths for random samples
"""
range_boundaries = [int(x) for x in range_list.split(",")]
assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})"
return range_boundaries
def main(args=None):
args = parse_args(args=args)
torch.manual_seed(0)
if args.mode == "train":
train_model(args)
else:
eval_model(args)
def build_indexes(args):
tag_to_idx = {}
char_to_idx = {}
train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
for train_file in train_files:
with open(train_file) as curr_file:
lines = curr_file.read().strip().split("\n")
examples = [json.loads(line) for line in lines if line.strip()]
for example in examples:
label = example["label"]
if label not in tag_to_idx:
tag_to_idx[label] = len(tag_to_idx)
sequence = example["text"]
for char in list(sequence):
if char not in char_to_idx:
char_to_idx[char] = len(char_to_idx)
char_to_idx["UNK"] = len(char_to_idx)
char_to_idx[""] = len(char_to_idx)
return tag_to_idx, char_to_idx
def train_model(args):
# set up indexes
tag_to_idx, char_to_idx = build_indexes(args)
# load training data
train_data = DataLoader(args.device)
train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
# load dev data
dev_data = DataLoader(args.device)
dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x]
dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False,
max_length=args.eval_length)
# set up trainer
trainer_config = {
"model_path": args.save_name,
"char_to_idx": char_to_idx,
"tag_to_idx": tag_to_idx,
"batch_size": args.batch_size,
"lang_weights": train_data.lang_weights
}
if args.load_name:
trainer_config["load_name"] = args.load_name
logger.info(f"{datetime.now()}\tLoading model from: {args.load_name}")
trainer = Trainer(trainer_config, load_model=args.load_name is not None, device=args.device)
# run training
best_accuracy = 0.0
for epoch in range(1, args.num_epochs+1):
logger.info(f"{datetime.now()}\tEpoch {epoch}")
logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}")
batches = train_data.batches
if not args.batch_mode:
batches = tqdm(batches)
for train_batch in batches:
inputs = (train_batch["sentences"], train_batch["targets"])
trainer.update(inputs)
logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.")
curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
eval_trainer(trainer, dev_data, batch_mode=args.batch_mode)
logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}")
if curr_dev_accuracy > best_accuracy:
logger.info(f"{datetime.now()}\tNew best score. Saving model.")
model_label = f"epoch{epoch}" if args.save_best_epochs else None
trainer.save(label=model_label)
with open(score_log_path(args.save_name), "w") as score_log_file:
for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions,
curr_recalls, curr_f1s]:
score_log_file.write(json.dumps(score_log) + "\n")
best_accuracy = curr_dev_accuracy
# reload training data
logger.info(f"{datetime.now()}\tResampling training data.")
train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
def score_log_path(file_path):
"""
Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json
"""
model_suffix = os.path.splitext(file_path)
if model_suffix[1]:
score_log_path = f"{file_path[:-len(model_suffix[1])]}.json"
else:
score_log_path = f"{file_path}.json"
return score_log_path
def eval_model(args):
# set up trainer
trainer_config = {
"model_path": None,
"load_name": args.load_name,
"batch_size": args.batch_size
}
trainer = Trainer(trainer_config, load_model=True, device=args.device)
# load test data
test_data = DataLoader(args.device)
test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x]
test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx,
randomize=False, max_length=args.eval_length)
curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval)
logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}")
eval_save_path = args.save_name if args.save_name else score_log_path(args.load_name)
if not os.path.exists(eval_save_path) or args.save_name:
with open(eval_save_path, "w") as score_log_file:
for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions,
curr_recalls, curr_f1s]:
score_log_file.write(json.dumps(score_log) + "\n")
def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True):
"""
Produce dev accuracy and confusion matrix for a trainer
"""
# set up confusion matrix
tag_to_idx = dev_data.tag_to_idx
idx_to_tag = dev_data.idx_to_tag
confusion_matrix = {}
for row_label in tag_to_idx:
confusion_matrix[row_label] = {}
for col_label in tag_to_idx:
confusion_matrix[row_label][col_label] = 0
# process dev batches
batches = dev_data.batches
if not batch_mode:
batches = tqdm(batches)
for dev_batch in batches:
inputs = (dev_batch["sentences"], dev_batch["targets"])
predictions = trainer.predict(inputs)
for target_idx, prediction in zip(dev_batch["targets"], predictions):
prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0]
confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1
# calculate dev accuracy
total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix])
total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix])
dev_accuracy = float(total_correct) / float(total_examples)
# calculate precision, recall, F1
precision_scores = {"type": "precision"}
recall_scores = {"type": "recall"}
f1_scores = {"type": "f1"}
for prediction_label in tag_to_idx:
total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx])
if total != 0.0:
precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total)
else:
precision_scores[prediction_label] = 0.0
for target_label in tag_to_idx:
total = sum([confusion_matrix[target_label][k] for k in tag_to_idx])
if total != 0:
recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total)
else:
recall_scores[target_label] = 0.0
for label in tag_to_idx:
if precision_scores[label] == 0.0 and recall_scores[label] == 0.0:
f1_scores[label] = 0.0
else:
f1_scores[label] = \
2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label])
return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/langid/__init__.py
================================================
================================================
FILE: stanza/models/langid/create_ud_data.py
================================================
"""
Script for producing training/dev/test data from UD data or sentences
Example output data format (one example per line):
{"text": "Hello world.", "label": "en"}
This is an attempt to recreate data pre-processing in https://github.com/AU-DIS/LSTM_langid
Specifically borrows methods from https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
Data format is same as LSTM_langid as well.
"""
import argparse
import json
import logging
import os
import re
import sys
from pathlib import Path
from random import randint, random, shuffle
from string import digits
from tqdm import tqdm
from stanza.models.common.constant import treebank_to_langid
logger = logging.getLogger('stanza')
DEFAULT_LANGUAGES = "af,ar,be,bg,bxr,ca,cop,cs,cu,da,de,el,en,es,et,eu,fa,fi,fr,fro,ga,gd,gl,got,grc,he,hi,hr,hsb,hu,hy,id,it,ja,kk,kmr,ko,la,lt,lv,lzh,mr,mt,nl,nn,no,olo,orv,pl,pt,ro,ru,sk,sl,sme,sr,sv,swl,ta,te,tr,ug,uk,ur,vi,wo,zh-hans,zh-hant".split(",")
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--data-format", help="input data format", choices=["ud", "one-per-line"], default="ud")
parser.add_argument("--eval-length", help="length of eval strings", type=int, default=10)
parser.add_argument("--languages", help="list of languages to use, or \"all\"", default=DEFAULT_LANGUAGES)
parser.add_argument("--min-window", help="minimal training example length", type=int, default=10)
parser.add_argument("--max-window", help="maximum training example length", type=int, default=50)
parser.add_argument("--ud-path", help="path to ud data")
parser.add_argument("--save-path", help="path to save data", default=".")
parser.add_argument("--splits", help="size of train/dev/test splits in percentages", type=splits_from_list,
default="0.8,0.1,0.1")
args = parser.parse_args(args=args)
return args
def splits_from_list(value_list):
return [float(x) for x in value_list.split(",")]
def main(args=None):
args = parse_args(args=args)
if isinstance(args.languages, str):
args.languages = args.languages.split(",")
data_paths = [f"{args.save_path}/{data_split}.jsonl" for data_split in ["train", "dev", "test"]]
lang_to_files = collect_files(args.ud_path, args.languages, data_format=args.data_format)
logger.info(f"Building UD data for languages: {','.join(args.languages)}")
for lang_id in tqdm(lang_to_files):
lang_examples = generate_examples(lang_id, lang_to_files[lang_id], splits=args.splits,
min_window=args.min_window, max_window=args.max_window,
eval_length=args.eval_length, data_format=args.data_format)
for (data_set, save_path) in zip(lang_examples, data_paths):
with open(save_path, "a") as json_file:
for json_entry in data_set:
json.dump(json_entry, json_file, ensure_ascii=False)
json_file.write("\n")
def collect_files(ud_path, languages, data_format="ud"):
"""
Given path to UD, collect files
If data_format = "ud", expects files to be of form *.conllu
If data_format = "one-per-line", expects files to be of form "*.sentences.txt"
In all cases, the UD path should be a directory with subdirectories for each language
"""
data_format_to_search_path = {"ud": "*/*.conllu", "one-per-line": "*/*sentences.txt"}
ud_files = Path(ud_path).glob(data_format_to_search_path[data_format])
lang_to_files = {}
for ud_file in ud_files:
if data_format == "ud":
lang_id = treebank_to_langid(ud_file.parent.name)
else:
lang_id = ud_file.name.split("_")[0]
if lang_id not in languages and "all" not in languages:
continue
if not lang_id in lang_to_files:
lang_to_files[lang_id] = []
lang_to_files[lang_id].append(ud_file)
return lang_to_files
def generate_examples(lang_id, list_of_files, splits=(0.8,0.1,0.1), min_window=10, max_window=50,
eval_length=10, data_format="ud"):
"""
Generate train/dev/test examples for a given language
"""
examples = []
for ud_file in list_of_files:
sentences = sentences_from_file(ud_file, data_format=data_format)
for sentence in sentences:
sentence = clean_sentence(sentence)
if validate_sentence(sentence, min_window):
examples += sentence_to_windows(sentence, min_window=min_window, max_window=max_window)
shuffle(examples)
train_idx = int(splits[0] * len(examples))
train_set = [example_json(lang_id, example) for example in examples[:train_idx]]
dev_idx = int(splits[1] * len(examples)) + train_idx
dev_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[train_idx:dev_idx]]
test_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[dev_idx:]]
return train_set, dev_set, test_set
def sentences_from_file(ud_file_path, data_format="ud"):
"""
Retrieve all sentences from a UD file
"""
if data_format == "ud":
with open(ud_file_path) as ud_file:
ud_file_contents = ud_file.read().strip()
assert "# text = " in ud_file_contents, \
f"{ud_file_path} does not have expected format, \"# text =\" does not appear"
sentences = [x[9:] for x in ud_file_contents.split("\n") if x.startswith("# text = ")]
elif data_format == "one-per-line":
with open(ud_file_path) as ud_file:
sentences = [x for x in ud_file.read().strip().split("\n") if x]
return sentences
def sentence_to_windows(sentence, min_window, max_window):
"""
Create window size chunks from a sentence, always starting with a word
"""
windows = []
words = sentence.split(" ")
curr_window = ""
for idx, word in enumerate(words):
curr_window += (" " + word)
curr_window = curr_window.lstrip()
next_word_len = len(words[idx+1]) + 1 if idx+1 < len(words) else 0
if len(curr_window) + next_word_len > max_window:
curr_window = clean_sentence(curr_window)
if validate_sentence(curr_window, min_window):
windows.append(curr_window.strip())
curr_window = ""
if len(curr_window) >= min_window:
windows.append(curr_window)
return windows
def validate_sentence(current_window, min_window):
"""
Sentence validation from: LSTM-LID
GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
"""
if len(current_window) < min_window:
return False
return True
def find(s, ch):
"""
Helper for clean_sentence from LSTM-LID
GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
"""
return [i for i, ltr in enumerate(s) if ltr == ch]
def clean_sentence(line):
"""
Sentence cleaning from LSTM-LID
GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
"""
# We remove some special characters and fix small errors in the data, to improve the quality of the data
line = line.replace("\n", '') #{"text": "- Mor.\n", "label": "da"}
line = line.replace("- ", '') #{"text": "- Mor.", "label": "da"}
line = line.replace("_", '') #{"text": "- Mor.", "label": "da"}
line = line.replace("\\", '')
line = line.replace("\"", '')
line = line.replace(" ", " ")
remove_digits = str.maketrans('', '', digits)
line = line.translate(remove_digits)
words = line.split()
new_words = []
# Below fixes large I instead of l. Does not catch everything, but should also not really make any mistakes either
for word in words:
clean_word = word
s = clean_word
if clean_word[1:].__contains__("I"):
indices = find(clean_word, "I")
for indx in indices:
if clean_word[indx-1].islower():
if len(clean_word) > indx + 1:
if clean_word[indx+1].islower():
s = s[:indx] + "l" + s[indx + 1:]
else:
s = s[:indx] + "l" + s[indx + 1:]
new_words.append(s)
new_line = " ".join(new_words)
return new_line
def example_json(lang_id, text, eval_length=None):
if eval_length is not None:
text = text[:eval_length]
return {"text": text.strip(), "label": lang_id}
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/langid/data.py
================================================
import json
import random
import torch
class DataLoader:
"""
Class for loading language id data and providing batches
Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid
Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
Data format is same as LSTM_langid
"""
def __init__(self, device=None):
self.batches = None
self.batches_iter = None
self.tag_to_idx = None
self.idx_to_tag = None
self.lang_weights = None
self.device = device
def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
max_length=None):
"""
Load sequence data and labels, calculate weights for weighted cross entropy loss.
Data is stored in a file, 1 example per line
Example: {"text": "Hello world.", "label": "en"}
"""
# set up examples from data files
examples = []
for data_file in data_files:
examples += [x for x in open(data_file).read().split("\n") if x.strip()]
random.shuffle(examples)
examples = [json.loads(x) for x in examples]
# add additional labels in this data set to tag index
tag_index = dict(tag_index)
new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
for new_label in new_labels:
tag_index[new_label] = len(tag_index)
self.tag_to_idx = tag_index
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
# set up lang counts used for weights for cross entropy loss
lang_counts = [0 for _ in tag_index]
# optionally limit text to max length
if max_length is not None:
examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]
# randomize data
if randomize:
split_examples = []
for example in examples:
sequence = example["text"]
label = example["label"]
sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1],
lower_lim=randomize_range[0])
split_examples += [{"text": seq, "label": label} for seq in sequences]
examples = split_examples
random.shuffle(examples)
# break into equal length batches
batch_lengths = {}
for example in examples:
sequence = example["text"]
label = example["label"]
if len(sequence) not in batch_lengths:
batch_lengths[len(sequence)] = []
sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
lang_counts[tag_index[label]] += 1
for length in batch_lengths:
random.shuffle(batch_lengths[length])
# create final set of batches
batches = []
for length in batch_lengths:
for sublist in [batch_lengths[length][i:i + batch_size] for i in
range(0, len(batch_lengths[length]), batch_size)]:
batches.append(sublist)
self.batches = [self.build_batch_tensors(batch) for batch in batches]
# set up lang weights
most_frequent = max(lang_counts)
# set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)
# shuffle batches to mix up lengths
random.shuffle(self.batches)
self.batches_iter = iter(self.batches)
@staticmethod
def randomize_data(sentences, upper_lim=20, lower_lim=5):
"""
Takes the original data and creates random length examples with length between upper limit and lower limit
From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
"""
new_data = []
for sentence in sentences:
remaining = sentence
while lower_lim < len(remaining):
lim = random.randint(lower_lim, upper_lim)
m = min(len(remaining), lim)
new_sentence = remaining[:m]
new_data.append(new_sentence)
split = remaining[m:].split(" ", 1)
if len(split) <= 1:
break
remaining = split[1]
random.shuffle(new_data)
return new_data
def build_batch_tensors(self, batch):
"""
Helper to turn batches into tensors
"""
batch_tensors = dict()
batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)
return batch_tensors
def next(self):
return next(self.batches_iter)
================================================
FILE: stanza/models/langid/model.py
================================================
import os
import torch
import torch.nn as nn
class LangIDBiLSTM(nn.Module):
"""
Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models
for language identification in short strings." (Toftrup et al 2021)
Arxiv: https://arxiv.org/abs/2102.06282
GitHub: https://github.com/AU-DIS/LSTM_langid
This class is similar to https://github.com/AU-DIS/LSTM_langid/blob/main/src/LSTMLID.py
"""
def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None,
dropout=0.0, lang_subset=None):
super(LangIDBiLSTM, self).__init__()
self.num_layers = num_layers
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.char_to_idx = char_to_idx
self.vocab_size = len(char_to_idx)
self.tag_to_idx = tag_to_idx
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
self.lang_subset = lang_subset
self.padding_idx = char_to_idx[""]
self.tagset_size = len(tag_to_idx)
self.batch_size = batch_size
self.loss_train = nn.CrossEntropyLoss(weight=weights)
self.dropout_prob = dropout
# embeddings for chars
self.char_embeds = nn.Embedding(
num_embeddings=self.vocab_size,
embedding_dim=self.embedding_dim,
padding_idx=self.padding_idx
)
# the bidirectional LSTM
self.lstm = nn.LSTM(
self.embedding_dim,
self.hidden_dim,
num_layers=self.num_layers,
bidirectional=True,
batch_first=True
)
# convert output to tag space
self.hidden_to_tag = nn.Linear(
self.hidden_dim * 2,
self.tagset_size
)
# dropout layer
self.dropout = nn.Dropout(p=self.dropout_prob)
def build_lang_mask(self, device):
"""
Build language mask if a lang subset is specified (e.g. ["en", "fr"])
The mask will be added to the results to set the prediction scores of illegal languages to -inf
"""
if self.lang_subset:
lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]
self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
else:
self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)
def loss(self, Y_hat, Y):
return self.loss_train(Y_hat, Y)
def forward(self, x):
# embed input
x = self.char_embeds(x)
# run through LSTM
x, _ = self.lstm(x)
# run through linear layer
x = self.hidden_to_tag(x)
# sum character outputs for each sequence
x = torch.sum(x, dim=1)
return x
def prediction_scores(self, x):
prediction_probs = self(x)
if self.lang_subset:
prediction_batch_size = prediction_probs.size()[0]
batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
prediction_probs = prediction_probs + batch_mask
return torch.argmax(prediction_probs, dim=1)
def save(self, path):
""" Save a model at path """
checkpoint = {
"char_to_idx": self.char_to_idx,
"tag_to_idx": self.tag_to_idx,
"num_layers": self.num_layers,
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"model_state_dict": self.state_dict()
}
torch.save(checkpoint, path)
@classmethod
def load(cls, path, device=None, batch_size=64, lang_subset=None):
""" Load a serialized model located at path """
if path is None:
raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name")
if not os.path.exists(path):
raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path)
checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
weights = checkpoint["model_state_dict"]["loss_train.weight"]
model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,
lang_subset=lang_subset)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.build_lang_mask(device)
return model
================================================
FILE: stanza/models/langid/trainer.py
================================================
import torch
import torch.optim as optim
from stanza.models.langid.model import LangIDBiLSTM
class Trainer:
DEFAULT_BATCH_SIZE = 64
DEFAULT_LAYERS = 2
DEFAULT_EMBEDDING_DIM = 150
DEFAULT_HIDDEN_DIM = 150
def __init__(self, config, load_model=False, device=None):
self.model_path = config["model_path"]
self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE)
if load_model:
self.load(config["load_name"], device)
else:
self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS,
Trainer.DEFAULT_EMBEDDING_DIM,
Trainer.DEFAULT_HIDDEN_DIM,
batch_size=self.batch_size,
weights=config["lang_weights"]).to(device)
self.optimizer = optim.AdamW(self.model.parameters())
def update(self, inputs):
self.model.train()
sentences, targets = inputs
self.optimizer.zero_grad()
y_hat = self.model.forward(sentences)
loss = self.model.loss(y_hat, targets)
loss.backward()
self.optimizer.step()
def predict(self, inputs):
self.model.eval()
sentences, targets = inputs
return torch.argmax(self.model(sentences), dim=1)
def save(self, label=None):
# save a copy of model with label
if label:
self.model.save(f"{self.model_path[:-3]}-{label}.pt")
self.model.save(self.model_path)
def load(self, model_path=None, device=None):
if not model_path:
model_path = self.model_path
self.model = LangIDBiLSTM.load(model_path, device, self.batch_size)
================================================
FILE: stanza/models/lemma/__init__.py
================================================
================================================
FILE: stanza/models/lemma/attach_lemma_classifier.py
================================================
import argparse
from stanza.models.lemma.trainer import Trainer
from stanza.models.lemma_classifier.base_model import LemmaClassifier
def attach_classifier(input_filename, output_filename, classifiers):
trainer = Trainer(model_file=input_filename)
for classifier in classifiers:
classifier = LemmaClassifier.load(classifier)
trainer.contextual_lemmatizers.append(classifier)
trainer.save(output_filename)
def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')
parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')
parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')
args = parser.parse_args(args)
attach_classifier(args.input, args.output, args.classifier)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/lemma/data.py
================================================
import random
import numpy as np
import os
from collections import Counter
import logging
import torch
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.vocab import DeltaVocab
from stanza.models.lemma.vocab import Vocab, MultiVocab
from stanza.models.lemma import edit
from stanza.models.common.doc import *
logger = logging.getLogger('stanza')
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, conll_only=False, skip=None, expand_unk_vocab=False):
self.batch_size = batch_size
self.args = args
self.eval = evaluation
self.shuffled = not self.eval
self.doc = doc
data = self.raw_data()
if conll_only: # only load conll file
return
if skip is not None:
assert len(data) == len(skip)
data = [x for x, y in zip(data, skip) if not y]
# handle vocab
if vocab is not None:
if expand_unk_vocab:
pos_vocab = vocab['pos']
char_vocab = DeltaVocab(data, vocab['char'])
self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})
else:
self.vocab = vocab
else:
self.vocab = dict()
char_vocab, pos_vocab = self.init_vocab(data)
self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
data = self.preprocess(data, self.vocab['char'], self.vocab['pos'], args)
# shuffle for training
if self.shuffled:
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.num_examples = len(data)
# chunk into batches
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
self.data = data
logger.debug("{} batches created.".format(len(data)))
def init_vocab(self, data):
assert self.eval is False, "Vocab file must exist for evaluation"
char_data = "".join(d[0] + d[2] for d in data)
char_vocab = Vocab(char_data, self.args['lang'])
pos_data = [d[1] for d in data]
pos_vocab = Vocab(pos_data, self.args['lang'])
return char_vocab, pos_vocab
def preprocess(self, data, char_vocab, pos_vocab, args):
processed = []
for d in data:
edit_type = edit.EDIT_TO_ID[edit.get_edit_type(d[0], d[2])]
src = list(d[0])
src = [constant.SOS] + src + [constant.EOS]
src = char_vocab.map(src)
pos = d[1]
pos = pos_vocab.unit2id(pos)
tgt = list(d[2])
tgt_in = char_vocab.map([constant.SOS] + tgt)
tgt_out = char_vocab.map(tgt + [constant.EOS])
processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]]
return processed
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 6
# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
# convert to tensors
src = batch[0]
src = get_long_tensor(src, batch_size)
src_mask = torch.eq(src, constant.PAD_ID)
tgt_in = get_long_tensor(batch[1], batch_size)
tgt_out = get_long_tensor(batch[2], batch_size)
pos = torch.LongTensor(batch[3])
edits = torch.LongTensor(batch[4])
text = batch[5]
assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match."
return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def raw_data(self):
return self.load_doc(self.doc, self.args.get('caseless', False), self.args.get('skip_blank_lemmas', False), self.eval)
@staticmethod
def load_doc(doc, caseless, skip_blank_lemmas, evaluation):
if evaluation:
data = doc.get([TEXT, UPOS, LEMMA])
else:
data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True)
data = DataLoader.remove_goeswith(data)
data = DataLoader.extract_correct_forms(data)
data = DataLoader.resolve_none(data)
if not evaluation and skip_blank_lemmas:
data = DataLoader.skip_blank_lemmas(data)
if caseless:
data = DataLoader.lowercase_data(data)
return data
@staticmethod
def extract_correct_forms(data):
"""
Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm
In addition, if the incorrect form of the word is not present in the training data,
we keep the incorrect form for the lemmatizer to learn from.
This way, it can occasionally get things right in misspelled input text.
We do check for and eliminate words where the incorrect form is already known as the
lemma for a different word. For example, in the English datasets, there is a "busy"
which was meant to be "buys", and we don't want the model to learn to lemmatize "busy" to "buy"
"""
new_data = []
incorrect_forms = []
for word in data:
misc = word[-1]
if not misc:
new_data.append(word[:3])
continue
misc = misc.split("|")
for piece in misc:
if piece.startswith("CorrectForm="):
cf = piece.split("=", maxsplit=1)[1]
# treat the CorrectForm as the desired word
new_data.append((cf, word[1], word[2]))
# and save the broken one for later in case it wasn't used anywhere else
incorrect_forms.append((cf, word))
break
else:
# if no CorrectForm, just keep the word as normal
new_data.append(word[:3])
known_words = {x[0] for x in new_data}
for correct_form, word in incorrect_forms:
if word[0] not in known_words:
new_data.append(word[:3])
return new_data
@staticmethod
def remove_goeswith(data):
"""
This method specifically removes words that goeswith something else, along with the something else
The purpose is to eliminate text such as
1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _
2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _
3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _
"""
filtered_data = []
remove_indices = set()
for sentence in data:
remove_indices.clear()
for word_idx, word in enumerate(sentence):
if word[4] == 'goeswith':
remove_indices.add(word_idx)
remove_indices.add(word[3]-1)
filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices])
return filtered_data
@staticmethod
def lowercase_data(data):
for token in data:
token[0] = token[0].lower()
return data
@staticmethod
def skip_blank_lemmas(data):
data = [x for x in data if x[2] != '_']
return data
@staticmethod
def resolve_none(data):
# replace None to '_'
for tok_idx in range(len(data)):
for feat_idx in range(len(data[tok_idx])):
if data[tok_idx][feat_idx] is None:
data[tok_idx][feat_idx] = '_'
return data
================================================
FILE: stanza/models/lemma/edit.py
================================================
"""
Utilities for calculating edits between word and lemma forms.
"""
EDIT_TO_ID = {'none': 0, 'identity': 1, 'lower': 2}
def get_edit_type(word, lemma):
""" Calculate edit types. """
if lemma == word:
return 'identity'
elif lemma == word.lower():
return 'lower'
return 'none'
def edit_word(word, pred, edit_id):
"""
Edit a word, given edit and seq2seq predictions.
"""
if edit_id == 1:
return word
elif edit_id == 2:
return word.lower()
elif edit_id == 0:
return pred
else:
raise Exception("Unrecognized edit ID: {}".format(edit_id))
================================================
FILE: stanza/models/lemma/scorer.py
================================================
"""
Utils and wrappers for scoring lemmatizers.
"""
import logging
from stanza.models.common.utils import ud_scores
logger = logging.getLogger('stanza')
def score(system_conllu_file, gold_conllu_file):
""" Wrapper for lemma scorer. """
logger.debug("Evaluating system file %s vs gold file %s", system_conllu_file, gold_conllu_file)
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
el = evaluation["Lemmas"]
p, r, f = el.precision, el.recall, el.f1
return p, r, f
================================================
FILE: stanza/models/lemma/trainer.py
================================================
"""
A trainer class to handle training and testing of models.
"""
import os
import sys
import numpy as np
from collections import Counter
import logging
import torch
from torch import nn
import torch.nn.init as init
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import TEXT, UPOS
from stanza.models.common.foundation_cache import load_charlm
from stanza.models.common.seq2seq_model import Seq2SeqModel
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
from stanza.models.common import utils, loss
from stanza.models.lemma import edit
from stanza.models.lemma.vocab import MultiVocab
from stanza.models.lemma_classifier.base_model import LemmaClassifier
logger = logging.getLogger('stanza')
def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [b.to(device) if b is not None else None for b in batch[:6]]
orig_idx = batch[6]
text = batch[7]
return inputs, orig_idx, text
class Trainer(object):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None, lemma_classifier_args=None):
if model_file is not None:
# load everything from file
self.load(model_file, args, foundation_cache, lemma_classifier_args)
else:
# build model from scratch
self.args = args
if args['dict_only']:
self.model = None
else:
self.model = self.build_seq2seq(args, emb_matrix, foundation_cache)
self.vocab = vocab
# dict-based components
self.word_dict = dict()
self.composite_dict = dict()
self.contextual_lemmatizers = []
self.caseless = self.args.get('caseless', False)
if not self.args['dict_only']:
self.model = self.model.to(device)
if self.args.get('edit', False):
self.crit = loss.MixLoss(self.vocab['char'].size, self.args['alpha']).to(device)
logger.debug("Running seq2seq lemmatizer with edit classifier...")
else:
self.crit = loss.SequenceLoss(self.vocab['char'].size).to(device)
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])
def build_seq2seq(self, args, emb_matrix, foundation_cache):
charmodel = None
charlms = []
if args is not None and args.get('charlm_forward_file', None):
charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)
charlms.append(charmodel_forward)
if args is not None and args.get('charlm_backward_file', None):
charmodel_backward = load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache)
charlms.append(charmodel_backward)
if len(charlms) > 0:
charlms = nn.ModuleList(charlms)
charmodel = CharacterLanguageModelWordAdapter(charlms)
model = Seq2SeqModel(args, emb_matrix=emb_matrix, contextual_embedding=charmodel)
return model
def update(self, batch, eval=False):
device = next(self.model.parameters()).device
inputs, orig_idx, text = unpack_batch(batch, device)
src, src_mask, tgt_in, tgt_out, pos, edits = inputs
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos, raw=text)
if self.args.get('edit', False):
assert edit_logits is not None
loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1), \
edit_logits, edits)
else:
loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1))
loss_val = loss.data.item()
if eval:
return loss_val
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
return loss_val
def predict(self, batch, beam_size=1, vocab=None):
if vocab is None:
vocab = self.vocab
device = next(self.model.parameters()).device
inputs, orig_idx, text = unpack_batch(batch, device)
src, src_mask, tgt, tgt_mask, pos, edits = inputs
self.model.eval()
batch_size = src.size(0)
preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size, raw=text)
pred_seqs = [vocab['char'].unmap(ids) for ids in preds] # unmap to tokens
pred_seqs = utils.prune_decoded_seqs(pred_seqs)
pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
pred_tokens = utils.unsort(pred_tokens, orig_idx)
if self.args.get('edit', False):
assert edit_logits is not None
edits = np.argmax(edit_logits.data.cpu().numpy(), axis=1).reshape([batch_size]).tolist()
edits = utils.unsort(edits, orig_idx)
else:
edits = None
return pred_tokens, edits
def postprocess(self, words, preds, edits=None):
""" Postprocess, mainly for handing edits. """
assert len(words) == len(preds), "Lemma predictions must have same length as words."
edited = []
if self.args.get('edit', False):
assert edits is not None and len(words) == len(edits)
for w, p, e in zip(words, preds, edits):
lem = edit.edit_word(w, p, e)
edited += [lem]
else:
edited = preds # do not edit
# final sanity check
assert len(edited) == len(words)
final = []
for lem, w in zip(edited, words):
if len(lem) == 0 or constant.UNK in lem:
final += [w] # invalid prediction, fall back to word
else:
final += [lem]
return final
def has_contextual_lemmatizers(self):
return self.contextual_lemmatizers is not None and len(self.contextual_lemmatizers) > 0
def predict_contextual(self, sentence_words, sentence_tags, preds):
if len(self.contextual_lemmatizers) == 0:
return preds
# reversed so that the first lemmatizer has priority
for contextual in reversed(self.contextual_lemmatizers):
pred_idx = []
pred_sent_words = []
pred_sent_tags = []
pred_sent_ids = []
for sent_id, (words, tags) in enumerate(zip(sentence_words, sentence_tags)):
indices = contextual.target_indices(words, tags)
for idx in indices:
pred_idx.append(idx)
pred_sent_words.append(words)
pred_sent_tags.append(tags)
pred_sent_ids.append(sent_id)
if len(pred_idx) == 0:
continue
contextual_predictions = contextual.predict(pred_idx, pred_sent_words, pred_sent_tags)
for sent_id, word_id, pred in zip(pred_sent_ids, pred_idx, contextual_predictions):
preds[sent_id][word_id] = pred
return preds
def update_contextual_preds(self, doc, preds):
"""
Update a flat list of preds with the output of the contextual lemmatizers
- First, it unflattens the preds based on the lengths of the sentences
- Then it uses the contextual lemmatizers
- Finally, it reflattens the preds into the format expected by the caller
"""
if len(self.contextual_lemmatizers) == 0:
return preds
sentence_words = doc.get([TEXT], as_sentences=True)
sentence_tags = doc.get([UPOS], as_sentences=True)
sentence_preds = []
start_index = 0
for sent in sentence_words:
end_index = start_index + len(sent)
sentence_preds.append(preds[start_index:end_index])
start_index += len(sent)
preds = self.predict_contextual(sentence_words, sentence_tags, sentence_preds)
preds = [lemma for sentence in preds for lemma in sentence]
return preds
def update_lr(self, new_lr):
utils.change_lr(self.optimizer, new_lr)
def train_dict(self, triples, update_word_dict=True):
"""
Train a dict lemmatizer given training (word, pos, lemma) triples.
Can update only the composite_dict (word/pos) in situations where
the data might be limited from the tags, such as when adding more
words at pipeline time
"""
# accumulate counter
ctr = Counter()
ctr.update([(p[0], p[1], p[2]) for p in triples])
# find the most frequent mappings
for p, _ in ctr.most_common():
w, pos, l = p
if (w,pos) not in self.composite_dict:
self.composite_dict[(w,pos)] = l
if update_word_dict and w not in self.word_dict:
self.word_dict[w] = l
return
def predict_dict(self, pairs):
""" Predict a list of lemmas using the dict model given (word, pos) pairs. """
lemmas = []
for p in pairs:
w, pos = p
if self.caseless:
w = w.lower()
if (w,pos) in self.composite_dict:
lemmas += [self.composite_dict[(w,pos)]]
elif w in self.word_dict:
lemmas += [self.word_dict[w]]
else:
lemmas += [w]
return lemmas
def skip_seq2seq(self, pairs):
""" Determine if we can skip the seq2seq module when ensembling with the frequency lexicon. """
skip = []
for p in pairs:
w, pos = p
if self.caseless:
w = w.lower()
if (w,pos) in self.composite_dict:
skip.append(True)
elif w in self.word_dict:
skip.append(True)
else:
skip.append(False)
return skip
def ensemble(self, pairs, other_preds):
""" Ensemble the dict with statistical model predictions. """
lemmas = []
assert len(pairs) == len(other_preds)
for p, pred in zip(pairs, other_preds):
w, pos = p
if self.caseless:
w = w.lower()
if (w,pos) in self.composite_dict:
lemma = self.composite_dict[(w,pos)]
elif w in self.word_dict:
lemma = self.word_dict[w]
else:
lemma = pred
if lemma is None:
lemma = w
lemmas.append(lemma)
return lemmas
def save(self, filename, skip_modules=True):
model_state = None
if self.model is not None:
model_state = self.model.state_dict()
# skip saving modules like the pretrained charlm
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'dicts': (self.word_dict, self.composite_dict),
'vocab': self.vocab.state_dict(),
'config': self.args,
'contextual': [],
}
for contextual in self.contextual_lemmatizers:
params['contextual'].append(contextual.get_save_dict())
save_dir = os.path.split(filename)[0]
if save_dir:
os.makedirs(os.path.split(filename)[0], exist_ok=True)
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
def load(self, filename, args, foundation_cache, lemma_classifier_args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None:
self.args['charlm_forward_file'] = args.get('charlm_forward_file', self.args['charlm_forward_file'])
self.args['charlm_backward_file'] = args.get('charlm_backward_file', self.args['charlm_backward_file'])
self.word_dict, self.composite_dict = checkpoint['dicts']
if not self.args['dict_only']:
self.model = self.build_seq2seq(self.args, None, foundation_cache)
# could remove strict=False after rebuilding all models,
# or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
self.model.load_state_dict(checkpoint['model'], strict=False)
else:
self.model = None
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
self.contextual_lemmatizers = []
for contextual in checkpoint.get('contextual', []):
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args))
================================================
FILE: stanza/models/lemma/vocab.py
================================================
from collections import Counter
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab
from stanza.models.common.seq2seq_constant import VOCAB_PREFIX
class Vocab(BaseVocab):
def build_vocab(self):
counter = Counter(self.data)
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
class MultiVocab(BaseMultiVocab):
@classmethod
def load_state_dict(cls, state_dict):
new = cls()
for k,v in state_dict.items():
new[k] = Vocab.load_state_dict(v)
return new
================================================
FILE: stanza/models/lemma_classifier/__init__.py
================================================
================================================
FILE: stanza/models/lemma_classifier/base_model.py
================================================
"""
Base class for the LemmaClassifier types.
Versions include LSTM and Transformer varieties
"""
import logging
from abc import ABC, abstractmethod
import os
import torch
import torch.nn as nn
from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.lemma_classifier.constants import ModelType
from typing import List
logger = logging.getLogger('stanza.lemmaclassifier')
class LemmaClassifier(ABC, nn.Module):
def __init__(self, label_decoder, target_words, target_upos, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label_decoder = label_decoder
self.label_encoder = {y: x for x, y in label_decoder.items()}
self.target_words = target_words
self.target_upos = target_upos
self.unsaved_modules = []
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def is_unsaved_module(self, name):
return name.split('.')[0] in self.unsaved_modules
def save(self, save_name):
"""
Save the model to the given path, possibly with some args
"""
save_dir = os.path.split(save_name)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_dict = self.get_save_dict()
torch.save(save_dict, save_name)
return save_dict
@abstractmethod
def model_type(self):
"""
return a ModelType
"""
def target_indices(self, words, tags):
return [idx for idx, (word, tag) in enumerate(zip(words, tags)) if word.lower() in self.target_words and tag in self.target_upos]
def predict(self, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[str]]=[]) -> torch.Tensor:
upos_tags = self.convert_tags(upos_tags)
with torch.no_grad():
logits = self.forward(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
predicted_class = [self.label_encoder[x.item()] for x in predicted_class]
return predicted_class
@staticmethod
def from_checkpoint(checkpoint, args=None):
model_type = ModelType[checkpoint['model_type']]
if model_type is ModelType.LSTM:
# TODO: if anyone can suggest a way to avoid this circular import
# (or better yet, avoid the load method knowing about subclasses)
# please do so
# maybe the subclassing is not necessary and we just put
# save & load in the trainer
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
saved_args = checkpoint['args']
# other model args are part of the model and cannot be changed for evaluation or pipeline
# the file paths might be relevant, though
keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']
for arg in keep_args:
if args is not None and args.get(arg, None) is not None:
saved_args[arg] = args[arg]
# TODO: refactor loading the pretrain (also done in the trainer)
pt = load_pretrain(saved_args['wordvec_pretrain_file'])
use_charlm = saved_args['use_charlm']
charlm_forward_file = saved_args.get('charlm_forward_file', None)
charlm_backward_file = saved_args.get('charlm_backward_file', None)
model = LemmaClassifierLSTM(model_args=saved_args,
output_dim=len(checkpoint['label_decoder']),
pt_embedding=pt,
label_decoder=checkpoint['label_decoder'],
upos_to_id=checkpoint['upos_to_id'],
known_words=checkpoint['known_words'],
target_words=set(checkpoint['target_words']),
target_upos=set(checkpoint['target_upos']),
use_charlm=use_charlm,
charlm_forward_file=charlm_forward_file,
charlm_backward_file=charlm_backward_file)
elif model_type is ModelType.TRANSFORMER:
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
output_dim = len(checkpoint['label_decoder'])
saved_args = checkpoint['args']
bert_model = saved_args['bert_model']
model = LemmaClassifierWithTransformer(model_args=saved_args,
output_dim=output_dim,
transformer_name=bert_model,
label_decoder=checkpoint['label_decoder'],
target_words=set(checkpoint['target_words']),
target_upos=set(checkpoint['target_upos']))
else:
raise ValueError("Unknown model type %s" % model_type)
# strict=False to accommodate missing parameters from the transformer or charlm
model.load_state_dict(checkpoint['params'], strict=False)
return model
@staticmethod
def load(filename, args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.exception("Cannot load model from %s", filename)
raise
logger.debug("Loading LemmaClassifier model from %s", filename)
return LemmaClassifier.from_checkpoint(checkpoint)
================================================
FILE: stanza/models/lemma_classifier/base_trainer.py
================================================
from abc import ABC, abstractmethod
import logging
import os
from typing import List, Tuple, Any, Mapping
import torch
import torch.nn as nn
import torch.optim as optim
from stanza.models.common.utils import default_device
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
logger = logging.getLogger('stanza.lemmaclassifier')
class BaseLemmaClassifierTrainer(ABC):
def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
"""
If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the
frequency of the class in the set. E.g. classes with lower frequency will have higher weight.
"""
weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class
total_samples = sum(counts.values())
for class_idx in counts:
weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes)
weights = torch.tensor(weights)
logger.info(f"Using weights {weights} for weighted loss.")
self.criterion = nn.BCEWithLogitsLoss(weight=weights)
@abstractmethod
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
"""
Build a model using pieces of the dataset to determine some of the model shape
"""
def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:
"""
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
Args:
num_epochs (int): Number of training epochs
save_name (str): Path to file where trained model should be saved.
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
"""
# Put model on GPU (if possible)
device = default_device()
if not train_file:
raise ValueError("Cannot train model - no train_file supplied!")
dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
label_decoder = dataset.label_decoder
upos_to_id = dataset.upos_to_id
self.output_dim = len(label_decoder)
logger.info(f"Loaded dataset successfully from {train_file}")
logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
logger.info(f"Target words: {dataset.target_words}")
self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos))
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.model.to(device)
logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")
if os.path.exists(save_name) and not args.get('force', False):
raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")
if self.weighted_loss:
self.configure_weighted_loss(label_decoder, dataset.counts)
# Put the criterion on GPU too
logger.debug(f"Criterion on {next(self.model.parameters()).device}")
self.criterion = self.criterion.to(next(self.model.parameters()).device)
best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for sentences, positions, upos_tags, labels in tqdm(dataset):
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"
self.optimizer.zero_grad()
outputs = self.model(positions, sentences, upos_tags)
# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
# TODO: three classes?
targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
# should be shape size (batch_size, 2)
else: # CELoss accepts target as just raw label
targets = labels.to(device)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
if eval_file:
# Evaluate model on dev set to see if it should be saved.
_, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
logger.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_f1 = f1
self.model.save(save_name)
logger.info(f"New best model: weighted f1 score of {f1}.")
else:
self.model.save(save_name)
================================================
FILE: stanza/models/lemma_classifier/baseline_model.py
================================================
"""
Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token.
The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token.
"""
import stanza
import os
from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences
from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file
class BaselineModel:
def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos):
self.token_to_lemmatize = token_to_lemmatize
self.prediction_lemma = prediction_lemma
self.prediction_upos = prediction_upos
def predict(self, token):
if token == self.token_to_lemmatize:
return self.prediction_lemma
def evaluate(self, conll_path):
"""
Evaluates the baseline model against the test set defined in conll_path.
Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores
for that class.
Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags
"""
doc = load_doc_from_conll_file(conll_path)
gold_tag_sequences, pred_tag_sequences = [], []
for sentence in doc.sentences:
gold_tags, pred_tags = [], []
for word in sentence.words:
if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize:
pred = self.prediction_lemma
gold = word.lemma
gold_tags.append(gold)
pred_tags.append(pred)
gold_tag_sequences.append(gold_tags)
pred_tag_sequences.append(pred_tags)
multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences)
return multiclass_result, confusion_mtx
if __name__ == "__main__":
bl_model = BaselineModel("'s", "be", ["AUX"])
coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu")
bl_model.evaluate(coNLL_path)
================================================
FILE: stanza/models/lemma_classifier/constants.py
================================================
from enum import Enum
UNKNOWN_TOKEN = "unk" # token name for unknown tokens
UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens
# TODO: ModelType could just be LSTM and TRANSFORMER
# and then the transformer baseline would have the transformer as another argument
class ModelType(Enum):
LSTM = 1
TRANSFORMER = 2
BERT = 3
ROBERTA = 4
DEFAULT_BATCH_SIZE = 16
================================================
FILE: stanza/models/lemma_classifier/evaluate_many.py
================================================
"""
Utils to evaluate many models of the same type at once
"""
import argparse
import os
import logging
from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main
logger = logging.getLogger('stanza.lemmaclassifier')
def evaluate_n_models(path_to_models_dir, args):
total_results = {
"be": 0.0,
"have": 0.0,
"accuracy": 0.0,
"weighted_f1": 0.0
}
paths = os.listdir(path_to_models_dir)
num_models = len(paths)
for model_path in paths:
full_path = os.path.join(path_to_models_dir, model_path)
args.save_name = full_path
mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args)
for lemma in mcc_results:
lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100
total_results[lemma] += lemma_f1
total_results["accuracy"] += acc
total_results["weighted_f1"] += weighted_f1
total_results["be"] /= num_models
total_results["have"] /= num_models
total_results["accuracy"] /= num_models
total_results["weighted_f1"] /= num_models
logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).")
return total_results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
parser.add_argument("--eval_file", type=str, help="path to evaluation file")
# Args specific to several model eval
parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval")
args = parser.parse_args()
evaluate_n_models(args.base_path, args)
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/evaluate_models.py
================================================
import os
import sys
parentdir = os.path.dirname(__file__)
parentdir = os.path.dirname(parentdir)
parentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
import logging
import argparse
import os
from typing import Any, List, Tuple, Mapping
from collections import defaultdict
from numpy import random
import torch
import torch.nn as nn
import stanza
from stanza.models.common.utils import default_device
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
from stanza.utils.confusion import format_confusion
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
logger = logging.getLogger('stanza.lemmaclassifier')
def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float:
"""
Computes the weighted F1 score across an evaluation set.
The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more
examples in the evaluation more impactful to the weighted f1.
"""
num_total_examples = 0
weighted_f1 = 0
for class_id in mcc_results:
class_f1 = mcc_results.get(class_id).get("f1")
num_class_examples = sum(confusion.get(class_id).values())
weighted_f1 += class_f1 * num_class_examples
num_total_examples += num_class_examples
return weighted_f1 / num_total_examples
def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):
"""
Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.
Precision = true positives / true positives + false positives
Recall = true positives / true positives + false negatives
F1 = 2 * (Precision * Recall) / (Precision + Recall)
Returns:
1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores.
e.g. multiclass_results[0]["precision"] would give class 0's precision.
2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count.
e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times.
"""
assert len(gold_tag_sequences) == len(pred_tag_sequences), \
f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}"
confusion = defaultdict(lambda: defaultdict(int))
reverse_label_decoder = {y: x for x, y in label_decoder.items()}
for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):
confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1
multi_class_result = defaultdict(lambda: defaultdict(float))
# compute precision, recall and f1 for each class and store inside of `multi_class_result`
for gold_tag in confusion.keys():
try:
prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()])
except ZeroDivisionError:
prec = 0.0
try:
recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values())
except ZeroDivisionError:
recall = 0.0
try:
f1 = 2 * (prec * recall) / (prec + recall)
except ZeroDivisionError:
f1 = 0.0
multi_class_result[gold_tag] = {
"precision": prec,
"recall": recall,
"f1": f1
}
if verbose:
for lemma in multi_class_result:
logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}")
weighted_f1 = get_weighted_f1(multi_class_result, confusion)
return multi_class_result, confusion, weighted_f1
def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor:
"""
A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.
Args:
model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token.
position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch.
sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences.
Returns:
(int): The index of the predicted class in `model`'s output.
"""
with torch.no_grad():
logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
return predicted_class
def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:
"""
Helper function for model evaluation
Args:
model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`.
model_path (str): Path to the saved model weights that will be loaded into `model`.
eval_path (str): Path to the saved evaluation dataset.
verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True.
is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode.
Returns:
1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is
another map with key of "f1", "precision", or "recall" with corresponding values.
2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the
map with the key as the predicted tag and corresponding count of that (gold, pred) pair.
3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set.
"""
# load model
device = default_device()
model.to(device)
if not is_training:
model.eval() # set to eval mode
# load in eval data
dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)
logger.info(f"Evaluating on evaluation file {eval_path}")
correct, total = 0, 0
gold_tags, pred_tags = dataset.labels, []
# run eval on each example from dataset
for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"):
pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, )
correct_preds = pred == labels.to(device)
correct += torch.sum(correct_preds)
total += len(correct_preds)
pred_tags += pred.tolist()
logger.info("Finished evaluating on dataset. Computing scores...")
accuracy = correct / total
mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)
# add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
if verbose:
logger.info(f"Accuracy: {accuracy} ({correct}/{total})")
logger.info(f"Label decoder: {dataset.label_decoder}")
return mc_results, confusion, accuracy, weighted_f1
def main(args=None, predefined_args=None):
# TODO: can unify this script with train_lstm_model.py?
# TODO: can save the model type in the model .pt, then
# automatically figure out what type of model we are using by
# looking in the file
parser = argparse.ArgumentParser()
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
parser.add_argument("--eval_file", type=str, help="path to evaluation file")
args = parser.parse_args(args) if not predefined_args else predefined_args
logger.info("Running training script with the following args:")
args = vars(args)
for arg in args:
logger.info(f"{arg}: {args[arg]}")
logger.info("------------------------------------------------------------")
logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}")
model = LemmaClassifier.load(args['save_name'], args)
mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file'])
logger.info(f"MCC Results: {dict(mcc_results)}")
logger.info("______________________________________________")
logger.info(f"Confusion:\n%s", format_confusion(confusion))
logger.info("______________________________________________")
logger.info(f"Accuracy: {acc}")
logger.info("______________________________________________")
logger.info(f"Weighted f1: {weighted_f1}")
return mcc_results, confusion, acc, weighted_f1
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/lstm_model.py
================================================
import torch
import torch.nn as nn
import os
import logging
import math
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
from typing import List, Tuple
from stanza.models.common.vocab import UNK_ID
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType
logger = logging.getLogger('stanza.lemmaclassifier')
class LemmaClassifierLSTM(LemmaClassifier):
"""
Model architecture:
Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding.
From the LSTM output, we get the embedding of the specific token that we classify on. That embedding
is fed into an MLP for classification.
"""
def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
use_charlm=False, charlm_forward_file=None, charlm_backward_file=None):
"""
Args:
vocab_size (int): Size of the vocab being used (if custom vocab)
output_dim (int): Size of output vector from MLP layer
upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs
pt_embedding (Pretrain): pretrained embeddings
known_words (list(str)): Words which are in the training data
target_words (set(str)): a set of the words which might need lemmatization
use_charlm (bool): Whether or not to use the charlm embeddings
charlm_forward_file (str): The path to the forward pass model for the character language model
charlm_backward_file (str): The path to the forward pass model for the character language model.
Kwargs:
upos_emb_dim (int): The size of the UPOS tag embeddings
num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM.
Raises:
FileNotFoundError: if the forward or backward charlm file cannot be found.
"""
super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words, target_upos)
self.model_args = model_args
self.hidden_dim = model_args['hidden_dim']
self.input_size = 0
self.num_heads = self.model_args['num_heads']
emb_matrix = pt_embedding.emb
self.add_unsaved_module("embeddings", nn.Embedding.from_pretrained(emb_matrix, freeze=True))
self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) }
self.vocab_size = emb_matrix.shape[0]
self.embedding_dim = emb_matrix.shape[1]
self.known_words = known_words
self.known_word_map = {word: idx for idx, word in enumerate(known_words)}
self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1,
embedding_dim=self.embedding_dim,
padding_idx=0)
nn.init.normal_(self.delta_embedding.weight, std=0.01)
self.input_size += self.embedding_dim
# Optionally, include charlm embeddings
self.use_charlm = use_charlm
if self.use_charlm:
if charlm_forward_file is None or not os.path.exists(charlm_forward_file):
raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}')
if charlm_backward_file is None or not os.path.exists(charlm_backward_file):
raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}')
self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False))
self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False))
self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
self.upos_emb_dim = self.model_args["upos_emb_dim"]
self.upos_to_id = upos_to_id
if self.upos_emb_dim > 0 and self.upos_to_id is not None:
# TODO: should leave space for unknown POS?
self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id),
embedding_dim=self.upos_emb_dim,
padding_idx=0)
self.input_size += self.upos_emb_dim
device = next(self.parameters()).device
# Determine if attn or LSTM should be used
if self.num_heads > 0:
self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads)
self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device)
logger.debug(f"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.")
else:
self.lstm = nn.LSTM(self.input_size,
self.hidden_dim,
batch_first=True,
bidirectional=True)
logger.debug(f"Using LSTM mechanism.")
mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size
self.mlp = nn.Sequential(
nn.Linear(mlp_input_size, 64),
nn.ReLU(),
nn.Linear(64, output_dim)
)
def get_save_dict(self):
save_dict = {
"params": self.state_dict(),
"label_decoder": self.label_decoder,
"model_type": self.model_type().name,
"args": self.model_args,
"upos_to_id": self.upos_to_id,
"known_words": self.known_words,
"target_words": list(self.target_words),
"target_upos": list(self.target_upos),
}
skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
for k in skipped:
del save_dict["params"][k]
return save_dict
def convert_tags(self, upos_tags: List[List[str]]):
if self.upos_to_id is not None:
return [[self.upos_to_id[x] for x in sentence] for sentence in upos_tags]
return None
def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
"""
Computes the forward pass of the neural net
Args:
pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
sentences (List[List[str]]): A list of the token-split sentences of the input data.
upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence.
Returns:
torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
"""
device = next(self.parameters()).device
batch_size = len(sentences)
token_ids = []
delta_token_ids = []
for words in sentences:
sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words]
sentence_token_ids = torch.tensor(sentence_token_ids, device=device)
token_ids.append(sentence_token_ids)
sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words]
sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device)
delta_token_ids.append(sentence_delta_token_ids)
token_ids = pad_sequence(token_ids, batch_first=True)
delta_token_ids = pad_sequence(delta_token_ids, batch_first=True)
embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids)
if self.upos_emb_dim > 0:
upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags] # convert internal lists to tensors
upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device)
pos_emb = self.upos_emb(upos_tags)
embedded = torch.cat((embedded, pos_emb), 2).to(device)
if self.use_charlm:
char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]]
char_reps_backward = self.charmodel_backward.build_char_representation(sentences)
char_reps_forward = pad_sequence(char_reps_forward, batch_first=True)
char_reps_backward = pad_sequence(char_reps_backward, batch_first=True)
embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2)
if self.num_heads > 0:
def positional_encoding(seq_len, d_model, device):
encoding = torch.zeros(seq_len, d_model, device=device)
position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)
encoding[:, 0::2] = torch.sin(position * div_term)
encoding[:, 1::2] = torch.cos(position * div_term)
# Add a new dimension to fit the batch size
encoding = encoding.unsqueeze(0)
return encoding
seq_len, d_model = embedded.shape[1], embedded.shape[2]
pos_enc = positional_encoding(seq_len, d_model, device=device)
embedded += pos_enc.expand_as(embedded)
padded_sequences = pad_sequence(embedded, batch_first=True)
lengths = torch.tensor([len(seq) for seq in embedded])
if self.num_heads > 0:
target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1)
attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1)
attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length)
attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device)
attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask)
# Extract the hidden state at the index of the token to classify
token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices]
else:
packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True)
lstm_out, (hidden, _) = self.lstm(packed_sequences)
# Extract the hidden state at the index of the token to classify
unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True)
token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices]
# MLP forward pass
output = self.mlp(token_reps)
return output
def model_type(self):
return ModelType.LSTM
================================================
FILE: stanza/models/lemma_classifier/prepare_dataset.py
================================================
import argparse
import json
import os
import re
import stanza
from stanza.models.lemma_classifier import utils
from typing import List, Tuple, Any
"""
The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.
Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.
"""
def load_doc_from_conll_file(path: str):
""""
loads in a Stanza document object from a path to a CoNLL file containing annotated sentences.
"""
return stanza.utils.conll.CoNLL.conll2doc(path)
class DataProcessor():
def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str):
self.target_word = target_word
self.target_word_regex = re.compile(target_word)
self.target_upos = target_upos
self.allowed_lemmas = re.compile(allowed_lemmas)
def keep_sentence(self, sentence):
for word in sentence.words:
if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos:
return True
return False
def find_all_occurrences(self, sentence) -> List[int]:
"""
Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.
"""
occurrences = []
for idx, token in enumerate(sentence.words):
if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos:
occurrences.append(idx)
return occurrences
@staticmethod
def write_output_file(save_name, target_upos, sentences):
with open(save_name, "w+", encoding="utf-8") as output_f:
output_f.write("{\n")
output_f.write(' "upos": %s,\n' % json.dumps(target_upos))
output_f.write(' "sentences": [')
wrote_sentence = False
for sentence in sentences:
if not wrote_sentence:
output_f.write("\n ")
wrote_sentence = True
else:
output_f.write(",\n ")
output_f.write(json.dumps(sentence))
output_f.write("\n ]\n}\n")
def process_document(self, doc, save_name: str) -> None:
"""
Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name`
Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file.
Args:
doc (Stanza.doc): Document object that represents the file to be analyzed
save_name (str): Path to the file for storing output
"""
sentences = []
for sentence in doc.sentences:
# for each sentence, we need to determine if it should be added to the output file.
# if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma
if self.keep_sentence(sentence):
tokens = [token.text for token in sentence.words]
indexes = self.find_all_occurrences(sentence)
for idx in indexes:
if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma):
# for each example found, we write the tokens,
# their respective upos tags, the target token index,
# and the target lemma
upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))]
num_tokens = len(upos_tags)
sentences.append({
"words": tokens,
"upos_tags": upos_tags,
"index": idx,
"lemma": sentence.words[idx].lemma
})
if save_name:
self.write_output_file(save_name, self.target_upos, sentences)
return sentences
def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate")
parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.")
parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token")
parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file")
parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed")
args = parser.parse_args(args)
conll_path = args.conll_path
target_upos = args.target_upos
output_path = args.output_path
allowed_lemmas = args.allowed_lemmas
args = vars(args)
for arg in args:
print(f"{arg}: {args[arg]}")
doc = load_doc_from_conll_file(conll_path)
processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas)
return processor.process_document(doc, output_path)
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/train_lstm_model.py
================================================
"""
The code in this file works to train a lemma classifier for 's
"""
import argparse
import logging
import os
import torch
import torch.nn as nn
from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
logger = logging.getLogger('stanza.lemmaclassifier')
class LemmaClassifierTrainer(BaseLemmaClassifierTrainer):
"""
Class to assist with training a LemmaClassifierLSTM
"""
def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None):
"""
Initializes the LemmaClassifierTrainer class.
Args:
model_args (dict): Various model shape parameters
embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt
use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
charlm_forward_file (str): Path to the forward pass embeddings for the charlm
charlm_backward_file (str): Path to the backward pass embeddings for the charlm
upos_emb_dim (int): The dimension size of UPOS tag embeddings
num_heads (int): The number of attention heads to use.
lr (float): Learning rate, defaults to 0.001.
loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')
Raises:
FileNotFoundError: If the forward charlm file is not present
FileNotFoundError: If the backward charlm file is not present
"""
super().__init__()
self.model_args = model_args
# Load word embeddings
pt = load_pretrain(embedding_file)
self.pt_embedding = pt
# Load CharLM embeddings
if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file):
raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}")
if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file):
raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}")
# TODO: just pass around the args instead
self.use_charlm = use_charlm
self.charlm_forward_file = charlm_forward_file
self.charlm_backward_file = charlm_backward_file
self.lr = lr
# Find loss function
if loss_func == "ce":
self.criterion = nn.CrossEntropyLoss()
self.weighted_loss = False
logger.debug("Using CE loss")
elif loss_func == "weighted_bce":
self.criterion = nn.BCEWithLogitsLoss()
self.weighted_loss = True # used to add weights during train time.
logger.debug("Using Weighted BCE loss")
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
return parser
def main(args=None, predefined_args=None):
parser = build_argparse()
args = parser.parse_args(args) if predefined_args is None else predefined_args
wordvec_pretrain_file = args.wordvec_pretrain_file
use_charlm = args.use_charlm
charlm_forward_file = args.charlm_forward_file
charlm_backward_file = args.charlm_backward_file
upos_emb_dim = args.upos_emb_dim
use_attention = args.attn
num_heads = args.num_heads
save_name = args.save_name
lr = args.lr
num_epochs = args.num_epochs
train_file = args.train_file
weighted_loss = args.weighted_loss
eval_file = args.eval_file
args = vars(args)
if os.path.exists(save_name) and not args.get('force', False):
raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
if not os.path.exists(train_file):
raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
logger.info("Running training script with the following args:")
for arg in args:
logger.info(f"{arg}: {args[arg]}")
logger.info("------------------------------------------------------------")
trainer = LemmaClassifierTrainer(model_args=args,
embedding_file=wordvec_pretrain_file,
use_charlm=use_charlm,
charlm_forward_file=charlm_forward_file,
charlm_backward_file=charlm_backward_file,
lr=lr,
loss_func="weighted_bce" if weighted_loss else "ce",
)
trainer.train(
num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file
)
return trainer
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/train_many.py
================================================
"""
Utils for training and evaluating multiple models simultaneously
"""
import argparse
import os
from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main
from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
change_params_map = {
"lstm_layer": [16, 32, 64, 128, 256, 512],
"upos_emb_dim": [5, 10, 20, 30],
"training_size": [150, 300, 450, 600, 'full'],
} # TODO: Add attention
def train_n_models(num_models: int, base_path: str, args):
if args.change_param == "lstm_layer":
for num_layers in change_params_map.get("lstm_layer", None):
for i in range(num_models):
new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt")
args.save_name = new_save_name
args.hidden_dim = num_layers
train_lstm_main(predefined_args=args)
if args.change_param == "upos_emb_dim":
for upos_dim in change_params_map("upos_emb_dim", None):
for i in range(num_models):
new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt")
args.save_name = new_save_name
args.upos_emb_dim = upos_dim
train_lstm_main(predefined_args=args)
if args.change_param == "training_size":
for size in change_params_map.get("training_size", None):
for i in range(num_models):
new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt")
new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt")
args.save_name = new_save_name
args.train_file = new_train_file
train_lstm_main(predefined_args=args)
if args.change_param == "base":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt")
args.save_name = new_save_name
args.weighted_loss = False
train_lstm_main(predefined_args=args)
if not args.weighted_loss:
args.weighted_loss = True
new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)
if args.change_param == "base_charlm":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)
if args.change_param == "base_charlm_upos":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)
if args.change_param == "base_upos":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)
if args.change_param == "attn_model":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)
def train_n_tfmrs(num_models: int, base_path: str, args):
if args.multi_train_type == "tfmr":
for i in range(num_models):
if args.change_param == "bert":
new_save_name = os.path.join(base_path, f"bert_{i}.pt")
args.save_name = new_save_name
args.loss_fn = "ce"
train_tfmr_main(predefined_args=args)
new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt")
args.save_name = new_save_name
args.loss_fn = "weighted_bce"
train_tfmr_main(predefined_args=args)
elif args.change_param == "roberta":
new_save_name = os.path.join(base_path, f"roberta_{i}.pt")
args.save_name = new_save_name
args.loss_fn = "ce"
train_tfmr_main(predefined_args=args)
new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt")
args.save_name = new_save_name
args.loss_fn = "weighted_bce"
train_tfmr_main(predefined_args=args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
# Tfmr-specific args
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
# Multi-model train args
parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer")
parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build")
parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.")
parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training")
args = parser.parse_args()
if args.multi_train_type == "lstm":
train_n_models(num_models=args.multi_train_count,
base_path=args.base_path,
args=args)
elif args.multi_train_type == "tfmr":
train_n_tfmrs(num_models=args.multi_train_count,
base_path=args.base_path,
args=args)
else:
raise ValueError(f"Improper input {args.multi_train_type}")
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/train_transformer_model.py
================================================
"""
This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
"""
import argparse
import os
import sys
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
from stanza.models.common.utils import default_device
logger = logging.getLogger('stanza.lemmaclassifier')
class TransformerBaselineTrainer(BaseLemmaClassifierTrainer):
"""
Class to assist with training a baseline transformer model to classify on token lemmas.
To find the model spec, refer to `model.py` in this directory.
"""
def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001):
"""
Creates the Trainer object
Args:
transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta".
loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce".
lr (int, optional): learning rate for the optimizer. Defaults to 0.001.
"""
super().__init__()
self.model_args = model_args
# Find loss function
if loss_func == "ce":
self.criterion = nn.CrossEntropyLoss()
self.weighted_loss = False
elif loss_func == "weighted_bce":
self.criterion = nn.BCEWithLogitsLoss()
self.weighted_loss = True # used to add weights during train time.
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
self.transformer_name = transformer_name
self.lr = lr
def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:
"""
Sets learning rates for each layer of the model.
Currently, the model has the transformer layer and the MLP layer, so these are tweakable.
Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.
Currently unused - could be refactored into the parent class's train method,
or the parent class could call a build_optimizer and this subclass would use the optimizer
"""
transformer_params, mlp_params = [], []
for name, param in self.model.named_parameters():
if 'transformer' in name:
transformer_params.append(param)
elif 'mlp' in name:
mlp_params.append(param)
optimizer = optim.Adam([
{"params": transformer_params, "lr": transformer_lr},
{"params": mlp_params, "lr": mlp_lr}
])
return optimizer
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos)
def main(args=None, predefined_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file")
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file")
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.")
parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
args = parser.parse_args(args) if predefined_args is None else predefined_args
save_name = args.save_name
num_epochs = args.num_epochs
train_file = args.train_file
loss_fn = args.loss_fn
eval_file = args.eval_file
lr = args.lr
args = vars(args)
if args['model_type'] == 'bert':
args['bert_model'] = 'bert-base-uncased'
elif args['model_type'] == 'roberta':
args['bert_model'] = 'roberta-base'
elif args['model_type'] == 'transformer':
if args['bert_model'] is None:
raise ValueError("Need to specify a bert_model for model_type transformer!")
else:
raise ValueError("Unknown model type " + args['model_type'])
if os.path.exists(save_name) and not args.get('force', False):
raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
if not os.path.exists(train_file):
raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
logger.info("Running training script with the following args:")
for arg in args:
logger.info(f"{arg}: {args[arg]}")
logger.info("------------------------------------------------------------")
trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr)
trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file)
return trainer
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemma_classifier/transformer_model.py
================================================
import torch
import torch.nn as nn
import os
import sys
import logging
from transformers import AutoTokenizer, AutoModel
from typing import Mapping, List, Tuple, Any
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType
logger = logging.getLogger('stanza.lemmaclassifier')
class LemmaClassifierWithTransformer(LemmaClassifier):
def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set):
"""
Model architecture:
Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence.
Get the embedding for the word that is to be classified on, and feed the embedding
as input to an MLP classifier that has 2 linear layers, and a prediction head.
Args:
model_args (dict): args for the model
output_dim (int): Dimension of the output from the MLP
transformer_name (str): name of the HF transformer to use
label_decoder (dict): a map of the labels available to the model
target_words (set(str)): a set of the words which might need lemmatization
"""
super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos)
self.model_args = model_args
# Choose transformer
self.transformer_name = transformer_name
self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)
self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name))
config = self.transformer.config
embedding_size = config.hidden_size
# define an MLP layer
self.mlp = nn.Sequential(
nn.Linear(embedding_size, 64),
nn.ReLU(),
nn.Linear(64, output_dim)
)
def get_save_dict(self):
save_dict = {
"params": self.state_dict(),
"label_decoder": self.label_decoder,
"target_words": list(self.target_words),
"target_upos": list(self.target_upos),
"model_type": self.model_type().name,
"args": self.model_args,
}
skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
for k in skipped:
del save_dict["params"][k]
return save_dict
def convert_tags(self, upos_tags: List[List[str]]):
return None
def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
"""
Computes the forward pass of the transformer baselines
Args:
idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
sentences (List[List[str]]): A list of the token-split sentences of the input data.
upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility
Returns:
torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
"""
device = next(self.transformer.parameters()).device
bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
keep_endpoints=False, num_layers=1, detach=True)
embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]
embeddings = torch.stack(embeddings, dim=0)[:, :, 0]
# pass to the MLP
output = self.mlp(embeddings)
return output
def model_type(self):
return ModelType.TRANSFORMER
================================================
FILE: stanza/models/lemma_classifier/utils.py
================================================
from collections import Counter
import json
import logging
import os
import random
from typing import List, Tuple, Any, Mapping
import stanza
import torch
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
logger = logging.getLogger('stanza.lemmaclassifier')
class Dataset:
def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
"""
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
Args:
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
batch_size (int): Size of each batch of examples
get_counts (optional, bool): Whether there should be a map of the label index to counts
Returns:
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
5 (Optional): A mapping of label ID to counts in the dataset.
6. Mapping[str, int]: A map between the labels and their indexes
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
"""
if data_path is None or not os.path.exists(data_path):
raise FileNotFoundError(f"Data file {data_path} could not be found.")
if label_decoder is None:
label_decoder = {}
else:
# if labels in the test set aren't in the original model,
# the model will never predict those labels,
# but we can still use those labels in a confusion matrix
label_decoder = dict(label_decoder)
logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)
# words which we are analyzing
target_words = set()
# all known words in the dataset, not just target words
known_words = set()
with open(data_path, "r+", encoding="utf-8") as fin:
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}
input_json = json.load(fin)
sentences_data = input_json['sentences']
self.target_upos = input_json['upos']
for idx, sentence in enumerate(sentences_data):
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
if None in [words, target_idx, upos_tags, label]:
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
label_id = label_decoder.get(label, None)
if label_id is None:
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label
converted_upos_tags = [] # convert upos tags to upos IDs
for upos_tag in upos_tags:
if upos_tag not in upos_to_id:
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
converted_upos_tags.append(upos_to_id[upos_tag])
sentences.append(words)
indices.append(target_idx)
upos_ids.append(converted_upos_tags)
labels.append(label_decoder[label])
if get_counts:
counts[label_decoder[label]] += 1
target_words.add(words[target_idx])
known_words.update(words)
self.sentences = sentences
self.indices = indices
self.upos_ids = upos_ids
self.labels = labels
self.counts = counts
self.label_decoder = label_decoder
self.upos_to_id = upos_to_id
self.batch_size = batch_size
self.shuffle = shuffle
self.known_words = [x.lower() for x in sorted(known_words)]
self.target_words = set(x.lower() for x in target_words)
def __len__(self):
"""
Number of batches, rounded up to nearest batch
"""
return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)
def __iter__(self):
num_sentences = len(self.sentences)
indices = list(range(num_sentences))
if self.shuffle:
random.shuffle(indices)
for i in range(self.__len__()):
batch_start = self.batch_size * i
batch_end = min(batch_start + self.batch_size, num_sentences)
batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
yield batch_sentences, batch_indices, batch_upos_ids, batch_labels
def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
"""
Extracts the indices within `tokenized_indices` which match `unknown_token_idx`
Args:
tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.
unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.
Returns:
List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`
"""
return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]
def get_device():
"""
Get the device to run computations on
"""
if torch.cuda.is_available:
device = torch.device("cuda")
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
def round_up_to_multiple(number, multiple):
if multiple == 0:
return "Error: The second number (multiple) cannot be zero."
# Calculate the remainder when dividing the number by the multiple
remainder = number % multiple
# If remainder is non-zero, round up to the next multiple
if remainder != 0:
rounded_number = number + (multiple - remainder)
else:
rounded_number = number # No rounding needed
return rounded_number
def main():
default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff
sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/lemmatizer.py
================================================
"""
Entry point for training and evaluating a lemmatizer.
This lemmatizer combines a neural sequence-to-sequence architecture with an `edit` classifier
and two dictionaries to produce robust lemmas from word forms.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
"""
import logging
import sys
import os
import shutil
import time
from datetime import datetime
import argparse
import numpy as np
import random
import torch
from torch import nn, optim
from stanza.models.lemma.data import DataLoader
from stanza.models.lemma.vocab import Vocab
from stanza.models.lemma.trainer import Trainer
from stanza.models.lemma import scorer, edit
from stanza.models.common import utils
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import *
from stanza.utils.conll import CoNLL
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
parser.add_argument('--train_file', type=str, default=None, help='Training input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Evaluation input file for data loader.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use. lang_dataset')
parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default use ensemble.')
parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based lemmatizer.')
parser.add_argument('--hidden_dim', type=int, default=200)
parser.add_argument('--emb_dim', type=int, default=50)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--emb_dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--max_dec_len', type=int, default=50)
parser.add_argument('--beam_size', type=int, default=1)
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
parser.add_argument('--pos_dim', type=int, default=50)
parser.add_argument('--pos_dropout', type=float, default=0.5)
parser.add_argument('--no_edit', dest='edit', action='store_false', help='Do not use edit classifier in lemmatization. By default use an edit classifier.')
parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID))
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.')
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.')
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--lr_decay', type=float, default=0.9)
parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
parser.add_argument('--num_epoch', type=int, default=60)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_lemmatizer.pt", help="File name to save the model")
parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing. This will happen automatically if 100%% of the data is caseless')
parser.add_argument('--skip_blank_lemmas', default=False, action='store_true', help='Skip blank entries in the data files. Useful for training a lemmatizer from a partially annotated dataset')
parser.add_argument('--seed', type=int, default=1234)
utils.add_device_args(parser)
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
if args.wandb_name:
args.wandb = True
args = vars(args)
# when building the vocab, we keep track of the original language name
lang = args['shorthand'].split("_")[0] if args['shorthand'] else ""
args['lang'] = lang
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running lemmatizer in {} mode".format(args['mode']))
if args['mode'] == 'train':
train(args)
else:
evaluate(args)
def all_lowercase(doc):
for sentence in doc.sentences:
for word in sentence.words:
if word.text.lower() != word.text:
return False
return True
def build_model_filename(args):
embedding = "nocharlm"
if args['charlm'] and args['charlm_forward_file']:
embedding = "charlm"
model_file = args['save_name'].format(shorthand=args['shorthand'],
embedding=embedding)
model_dir = os.path.split(model_file)[0]
if not model_dir.startswith(args['save_dir']):
model_file = os.path.join(args['save_dir'], model_file)
return model_file
def train(args):
# load data
logger.info("[Loading data with batch size {}...]".format(args['batch_size']))
train_doc = CoNLL.conll2doc(input_file=args['train_file'])
train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
vocab = train_batch.vocab
args['vocab_size'] = vocab['char'].size
args['pos_vocab_size'] = vocab['pos'].size
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
utils.ensure_dir(args['save_dir'])
model_file = build_model_filename(args)
logger.info("Using full savename: %s", model_file)
# gold path
gold_file = args['eval_file']
utils.print_config(args)
# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
logger.warning("[Skip training because no training data available...]")
return
if not args['caseless'] and all_lowercase(train_doc):
logger.info("Building a caseless model, as all of the training data is caseless")
args['caseless'] = True
# start training
# train a dictionary-based lemmatizer
logger.info("Building lemmatizer in %s", model_file)
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
logger.info("[Training dictionary-based lemmatizer...]")
trainer.train_dict(train_batch.raw_data())
logger.info("Evaluating on dev set...")
dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS]))
dev_batch.doc.set([LEMMA], dev_preds)
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_f = scorer.score(system_pred_file, gold_file)
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
if args.get('dict_only', False):
# save dictionaries
trainer.save(model_file)
else:
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_lemmatizer" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
# train a seq2seq model
logger.info("[Training seq2seq-based lemmatizer...]")
global_step = 0
max_steps = len(train_batch) * args['num_epoch']
dev_score_history = []
best_dev_preds = []
current_lr = args['lr']
global_start_time = time.time()
format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
# start training
for epoch in range(1, args['num_epoch']+1):
train_loss = 0
for i, batch in enumerate(train_batch):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
if global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
dev_edits = []
for i, batch in enumerate(dev_batch):
preds, edits = trainer.predict(batch, args['beam_size'])
dev_preds += preds
if edits is not None:
dev_edits += edits
dev_preds = trainer.postprocess(dev_batch.doc.get([TEXT]), dev_preds, edits=dev_edits)
# try ensembling with dict if necessary
if args.get('ensemble_dict', False):
logger.info("[Ensembling dict with seq2seq model...]")
dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds)
dev_batch.doc.set([LEMMA], dev_preds)
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
# save best model
if epoch == 1 or dev_score > max(dev_score_history):
trainer.save(model_file)
logger.info("new best model saved.")
best_dev_preds = dev_preds
# lr schedule
if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \
args['optim'] in ['sgd', 'adagrad']:
current_lr *= args['lr_decay']
trainer.update_lr(current_lr)
dev_score_history += [dev_score]
logger.info("")
logger.info("Training ended with {} epochs.".format(epoch))
if args['wandb']:
wandb.finish()
best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
def evaluate(args):
# file paths
system_pred_file = args['output_file']
model_file = build_model_filename(args)
# load model
trainer = Trainer(model_file=model_file, device=args['device'], args=args)
loaded_args, vocab = trainer.args, trainer.vocab
for k in args:
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
loaded_args[k] = args[k]
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
doc = CoNLL.conll2doc(input_file=args['eval_file'])
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
# skip eval if dev data does not exist
if len(batch) == 0:
logger.warning("Skip evaluation because no dev data is available...\nLemma score:\n{} ".format(args['shorthand']))
return
dict_preds = trainer.predict_dict(batch.doc.get([TEXT, UPOS]))
if loaded_args.get('dict_only', False):
preds = dict_preds
else:
logger.info("Running the seq2seq model...")
preds = []
edits = []
for i, b in enumerate(batch):
ps, es = trainer.predict(b, args['beam_size'])
preds += ps
if es is not None:
edits += es
preds = trainer.postprocess(batch.doc.get([TEXT]), preds, edits=edits)
if loaded_args.get('ensemble_dict', False):
logger.info("[Ensembling dict with seq2seq lemmatizer...]")
preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)
if trainer.has_contextual_lemmatizers():
preds = trainer.update_contextual_preds(batch.doc, preds)
# write to file and score
batch.doc.set([LEMMA], preds)
if system_pred_file:
CoNLL.write_doc2conll(batch.doc, system_pred_file)
system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/mwt/__init__.py
================================================
================================================
FILE: stanza/models/mwt/character_classifier.py
================================================
"""
Classify characters based on an LSTM with learned character representations
"""
import logging
import torch
from torch import nn
import stanza.models.common.seq2seq_constant as constant
logger = logging.getLogger('stanza')
class CharacterClassifier(nn.Module):
def __init__(self, args):
super().__init__()
self.vocab_size = args['vocab_size']
self.emb_dim = args['emb_dim']
self.hidden_dim = args['hidden_dim']
self.nlayers = args['num_layers'] # lstm encoder layers
self.pad_token = constant.PAD_ID
self.enc_hidden_dim = self.hidden_dim // 2 # since it is bidirectional
self.num_outputs = 2
self.args = args
self.emb_dropout = args.get('emb_dropout', 0.0)
self.emb_drop = nn.Dropout(self.emb_dropout)
self.dropout = args['dropout']
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
self.input_dim = self.emb_dim
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
self.output_layer = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.num_outputs))
def encode(self, enc_inputs, lens):
""" Encode source sequence. """
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
packed_h_in, (hn, cn) = self.encoder(packed_inputs)
return packed_h_in
def embed(self, src, src_mask):
# the input data could have characters outside the known range
# of characters in cases where the vocabulary was temporarily
# expanded (note that this model does nothing with those chars)
embed_src = src.clone()
embed_src[embed_src >= self.vocab_size] = constant.UNK_ID
enc_inputs = self.emb_drop(self.embedding(embed_src))
batch_size = enc_inputs.size(0)
src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1))
return enc_inputs, batch_size, src_lens, src_mask
def forward(self, src, src_mask):
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask)
encoded = self.encode(enc_inputs, src_lens)
encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)
logits = self.output_layer(encoded)
return logits
================================================
FILE: stanza/models/mwt/data.py
================================================
import random
import numpy as np
import os
from collections import Counter, namedtuple
import logging
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader as DL
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.vocab import DeltaVocab
from stanza.models.mwt.vocab import Vocab
from stanza.models.common.doc import Document
logger = logging.getLogger('stanza')
DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text")
DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx")
# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
# about the unknown characters without messing up the predictions for other text
#
# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07
APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''')
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
self.args = args
self.augment_apos = args.get('augment_apos', 0.0)
self.evaluation = evaluation
self.doc = doc
data = self.load_doc(self.doc, evaluation=self.evaluation)
# handle vocab
if vocab is None:
assert self.evaluation == False # for eval vocab must exist
self.vocab = self.init_vocab(data)
if self.augment_apos > 0 and any(x in self.vocab for x in APOS):
for apos in APOS:
self.vocab.add_unit(apos)
elif expand_unk_vocab:
self.vocab = DeltaVocab(data, vocab)
else:
self.vocab = vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.evaluation:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
# shuffle for training
if not self.evaluation:
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.data = data
self.num_examples = len(data)
def init_vocab(self, data):
assert self.evaluation == False # for eval vocab must exist
vocab = Vocab(data, self.args['shorthand'])
return vocab
def maybe_augment_apos(self, datum):
for original in APOS:
if original in datum[0]:
if random.uniform(0,1) < self.augment_apos:
replacement = random.choice(APOS)
datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement))
break
return datum
def process(self, sample):
if not self.evaluation and self.augment_apos > 0:
sample = self.maybe_augment_apos(sample)
src = list(sample[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, sample)
src = self.vocab.map(src)
processed = [src, tgt_in, tgt_out, sample[0]]
return processed
def prepare_target(self, vocab, datum):
if self.evaluation:
tgt = list(datum[0]) # as a placeholder
else:
tgt = list(datum[1])
tgt_in = vocab.map([constant.SOS] + tgt)
tgt_out = vocab.map(tgt + [constant.EOS])
return tgt_in, tgt_out
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
sample = self.data[key]
sample = self.process(sample)
assert len(sample) == 4
src = torch.tensor(sample[0])
tgt_in = torch.tensor(sample[1])
tgt_out = torch.tensor(sample[2])
orig_text = sample[3]
result = DataSample(src, tgt_in, tgt_out, orig_text), key
return result
@staticmethod
def __collate_fn(data):
(data, idx) = zip(*data)
(src, tgt_in, tgt_out, orig_text) = zip(*data)
# collate_fn is given a list of length batch size
batch_size = len(data)
# need to sort by length of src to properly handle
# the batching in the model itself
lens = [len(x) for x in src]
(src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)
lens = [len(x) for x in src]
# convert to tensors
src = pad_sequence(src, True, constant.PAD_ID)
src_mask = torch.eq(src, constant.PAD_ID)
tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)
tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)
assert tgt_in.size(1) == tgt_out.size(1), \
"Target input and output sequence sizes do not match."
return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def to_loader(self):
"""Converts self to a DataLoader """
batch_size = self.batch_size
shuffle = not self.evaluation
return DL(self,
collate_fn=self.__collate_fn,
batch_size=batch_size,
shuffle=shuffle)
def load_doc(self, doc, evaluation=False):
data = doc.get_mwt_expansions(evaluation)
if evaluation: data = [[e] for e in data]
return data
class BinaryDataLoader(DataLoader):
"""
This version of the DataLoader performs the same tasks as the regular DataLoader,
except the targets are arrays of 0/1 indicating if the character is the location
of an MWT split
"""
def prepare_target(self, vocab, datum):
src = datum[0] if self.evaluation else datum[1]
binary = [0]
has_space = False
for char in src:
if char == ' ':
has_space = True
elif has_space:
has_space = False
binary.append(1)
else:
binary.append(0)
binary.append(0)
return binary, binary
================================================
FILE: stanza/models/mwt/scorer.py
================================================
"""
Utils and wrappers for scoring MWT
"""
from stanza.models.common.utils import ud_scores
def score(system_conllu_file, gold_conllu_file):
""" Wrapper for word segmenter scorer. """
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
el = evaluation["Words"]
p, r, f = el.precision, el.recall, el.f1
return p, r, f
================================================
FILE: stanza/models/mwt/trainer.py
================================================
"""
A trainer class to handle training and testing of models.
"""
import sys
import numpy as np
from collections import Counter
import logging
import torch
from torch import nn
import torch.nn.init as init
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common.seq2seq_model import Seq2SeqModel
from stanza.models.common import utils, loss
from stanza.models.mwt.character_classifier import CharacterClassifier
from stanza.models.mwt.vocab import Vocab
logger = logging.getLogger('stanza')
def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [b.to(device) if b is not None else None for b in batch[:4]]
orig_text = batch[4]
orig_idx = batch[5]
return inputs, orig_text, orig_idx
class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):
if model_file is not None:
# load from file
self.load(model_file)
else:
self.args = args
if args['dict_only']:
self.model = None
elif args.get('force_exact_pieces', False):
self.model = CharacterClassifier(args)
else:
self.model = Seq2SeqModel(args, emb_matrix=emb_matrix)
self.vocab = vocab
self.expansion_dict = dict()
if not self.args['dict_only']:
self.model = self.model.to(device)
if self.args.get('force_exact_pieces', False):
self.crit = nn.CrossEntropyLoss()
else:
self.crit = loss.SequenceLoss(self.vocab.size).to(device)
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])
def update(self, batch, eval=False):
device = next(self.model.parameters()).device
# ignore the original text when training
# can try to learn the correct values, even if we eventually
# copy directly from the original text
inputs, _, orig_idx = unpack_batch(batch, device)
src, src_mask, tgt_in, tgt_out = inputs
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
if self.args.get('force_exact_pieces', False):
log_probs = self.model(src, src_mask)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True)
packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True)
loss = self.crit(packed_output.data, packed_tgt.data)
else:
log_probs, _ = self.model(src, src_mask, tgt_in)
loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1))
loss_val = loss.data.item()
if eval:
return loss_val
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
return loss_val
def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None):
if vocab is None:
vocab = self.vocab
device = next(self.model.parameters()).device
inputs, orig_text, orig_idx = unpack_batch(batch, device)
src, src_mask, tgt, tgt_mask = inputs
self.model.eval()
batch_size = src.size(0)
if self.args.get('force_exact_pieces', False):
log_probs = self.model(src, src_mask)
cuts = log_probs[:, :, 1] > log_probs[:, :, 0]
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
pred_tokens = []
for src_ids, cut, src_len in zip(src, cuts, src_lens):
src_chars = vocab.unmap(src_ids)
pred_seq = []
for char_idx in range(1, src_len-1):
if cut[char_idx]:
pred_seq.append(' ')
pred_seq.append(src_chars[char_idx])
pred_seq = "".join(pred_seq).strip()
pred_tokens.append(pred_seq)
else:
preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk)
pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens
pred_seqs = utils.prune_decoded_seqs(pred_seqs)
pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
# if any tokens are predicted to expand to blank,
# that is likely an error. use the original text
# this originally came up with the Spanish model turning 's' into a blank
# furthermore, if there are no spaces predicted by the seq2seq,
# might as well use the original in case the seq2seq went crazy
# this particular error came up training a Hebrew MWT
pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)]
if unsort:
pred_tokens = utils.unsort(pred_tokens, orig_idx)
return pred_tokens
def train_dict(self, pairs):
""" Train a MWT expander given training word-expansion pairs. """
# accumulate counter
ctr = Counter()
ctr.update([(p[0], p[1]) for p in pairs])
seen = set()
# find the most frequent mappings
for p, _ in ctr.most_common():
w, l = p
if w not in seen and w != l:
self.expansion_dict[w] = l
seen.add(w)
return
def dict_expansion(self, word):
"""
Check the expansion dictionary for the word along with a couple common lowercasings of the word
(Leadingcase and UPPERCASE)
"""
expansion = self.expansion_dict.get(word)
if expansion is not None:
return expansion
if word.isupper():
expansion = self.expansion_dict.get(word.lower())
if expansion is not None:
return expansion.upper()
if word[0].isupper() and word[1:].islower():
expansion = self.expansion_dict.get(word.lower())
if expansion is not None:
return expansion[0].upper() + expansion[1:]
# could build a truecasing model of some kind to handle cRaZyCaSe...
# but that's probably too much effort
return None
def predict_dict(self, words):
""" Predict a list of expansions given words. """
expansions = []
for w in words:
expansion = self.dict_expansion(w)
if expansion is not None:
expansions.append(expansion)
else:
expansions.append(w)
return expansions
def ensemble(self, cands, other_preds):
""" Ensemble the dict with statistical model predictions. """
expansions = []
assert len(cands) == len(other_preds)
for c, pred in zip(cands, other_preds):
expansion = self.dict_expansion(c)
if expansion is not None:
expansions.append(expansion)
else:
expansions.append(pred)
return expansions
def save(self, filename):
params = {
'model': self.model.state_dict() if self.model is not None else None,
'dict': self.expansion_dict,
'vocab': self.vocab.state_dict(),
'config': self.args
}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except BaseException:
logger.warning("Saving failed... continuing anyway.")
def load(self, filename):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
self.expansion_dict = checkpoint['dict']
if not self.args['dict_only']:
if self.args.get('force_exact_pieces', False):
self.model = CharacterClassifier(self.args)
else:
self.model = Seq2SeqModel(self.args)
# could remove strict=False after rebuilding all models,
# or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
self.model.load_state_dict(checkpoint['model'], strict=False)
else:
self.model = None
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
================================================
FILE: stanza/models/mwt/utils.py
================================================
import stanza
from stanza.models.common import doc
from stanza.models.tokenization.data import TokenizationDataset
from stanza.models.tokenization.utils import predict, decode_predictions
def mwts_composed_of_words(doc):
"""
Return True/False if the MWTs in the doc are all exactly composed of the text in their words
"""
for sent_idx, sentence in enumerate(doc.sentences):
for token_idx, token in enumerate(sentence.tokens):
if len(token.words) > 1:
expected = "".join(x.text for x in token.words)
if token.text != expected:
return False
return True
def resplit_mwt(tokens, pipeline, keep_tokens=True):
"""
Uses the tokenize processor and the mwt processor in the pipeline to resplit tokens into MWT
tokens: a list of list of string
pipeline: a Stanza pipeline which contains, at a minimum, tokenize and mwt
keep_tokens: if True, enforce the old token boundaries by modify
the results of the tokenize inference.
Otherwise, use whatever new boundaries the model comes up with.
between running the tokenize model and breaking the text into tokens,
we can update all_preds to use the original token boundaries
(if and only if keep_tokens == True)
This method returns a Document with just the tokens and words annotated.
"""
if "tokenize" not in pipeline.processors:
raise ValueError("Need a Pipeline with a valid tokenize processor")
if "mwt" not in pipeline.processors:
raise ValueError("Need a Pipeline with a valid mwt processor")
tokenize_processor = pipeline.processors["tokenize"]
mwt_processor = pipeline.processors["mwt"]
fake_text = "\n\n".join(" ".join(sentence) for sentence in tokens)
# set up batches
batches = TokenizationDataset(tokenize_processor.config,
input_text=fake_text,
vocab=tokenize_processor.vocab,
evaluation=True,
dictionary=tokenize_processor.trainer.dictionary)
all_preds, all_raw = predict(trainer=tokenize_processor.trainer,
data_generator=batches,
batch_size=tokenize_processor.trainer.args['batch_size'],
max_seqlen=tokenize_processor.config.get('max_seqlen', tokenize_processor.MAX_SEQ_LENGTH_DEFAULT),
use_regex_tokens=True,
num_workers=tokenize_processor.config.get('num_workers', 0))
if keep_tokens:
for sentence, pred in zip(tokens, all_preds):
char_idx = 0
for word in sentence:
if len(word) > 0:
pred[char_idx:char_idx+len(word)-1] = 0
if pred[char_idx+len(word)-1] == 0:
pred[char_idx+len(word)-1] = 1
char_idx += len(word) + 1
_, _, document = decode_predictions(vocab=tokenize_processor.vocab,
mwt_dict=None,
orig_text=fake_text,
all_raw=all_raw,
all_preds=all_preds,
no_ssplit=True,
skip_newline=tokenize_processor.trainer.args['skip_newline'],
use_la_ittb_shorthand=tokenize_processor.trainer.args['shorthand'] == 'la_ittb')
document = doc.Document(document, fake_text)
mwt_processor.process(document)
return document
def main():
pipe = stanza.Pipeline("en", processors="tokenize,mwt", package="gum")
tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
doc = resplit_mwt(tokens, pipe)
print(doc)
doc = resplit_mwt(tokens, pipe, keep_tokens=False)
print(doc)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/mwt/vocab.py
================================================
from collections import Counter
from stanza.models.common.vocab import BaseVocab
import stanza.models.common.seq2seq_constant as constant
class Vocab(BaseVocab):
def build_vocab(self):
pairs = self.data
allchars = "".join([src + tgt for src, tgt in pairs])
counter = Counter(allchars)
self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
def add_unit(self, unit):
if unit in self._unit2id:
return
self._unit2id[unit] = len(self._id2unit)
self._id2unit.append(unit)
================================================
FILE: stanza/models/mwt_expander.py
================================================
"""
Entry point for training and evaluating a multi-word token (MWT) expander.
This MWT expander combines a neural sequence-to-sequence architecture with a dictionary
to decode the token into multiple words.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf
In the case of a dataset where all of the MWT exactly split into the words
composing the MWT, a classifier over the characters is used instead of the seq2seq
"""
import io
import sys
import os
import shutil
import time
from datetime import datetime
import argparse
import logging
import math
import numpy as np
import random
import torch
from torch import nn, optim
import copy
from stanza.models.mwt.data import DataLoader, BinaryDataLoader
from stanza.models.mwt.utils import mwts_composed_of_words
from stanza.models.mwt.vocab import Vocab
from stanza.models.mwt.trainer import Trainer
from stanza.models.mwt import scorer
from stanza.models.common import utils
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import Document
from stanza.utils.conll import CoNLL
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/mwt', help='Root dir for saving models.')
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default ensemble a dict.')
parser.add_argument('--ensemble_early_stop', action='store_true', help='Early stopping based on ensemble performance.')
parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based MWT expander.')
parser.add_argument('--hidden_dim', type=int, default=100)
parser.add_argument('--emb_dim', type=int, default=50)
parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier')
parser.add_argument('--emb_dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--max_dec_len', type=int, default=50)
parser.add_argument('--beam_size', type=int, default=1)
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')
parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|')
parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)')
parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--lr_decay', type=float, default=0.9)
parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
parser.add_argument('--num_epoch', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
parser.add_argument('--save_dir', type=str, default='saved_models/mwt', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
parser.add_argument('--seed', type=int, default=1234)
utils.add_device_args(parser)
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
if args.wandb_name:
args.wandb = True
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args.seed)
args = vars(args)
logger.info("Running MWT expander in {} mode".format(args['mode']))
if args['mode'] == 'train':
return train(args)
else:
return evaluate(args)
def train(args):
# load data
logger.debug('max_dec_len: %d' % args['max_dec_len'])
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
train_doc = CoNLL.conll2doc(input_file=args['train_file'])
train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
vocab = train_batch.vocab
args['vocab_size'] = vocab.size
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
utils.ensure_dir(args['save_dir'])
save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
model_file = os.path.join(args['save_dir'], save_name)
save_each_name = None
if args['save_each_name']:
save_each_name = os.path.join(args['save_dir'], args['save_each_name'])
save_each_name = utils.build_save_each_filename(save_each_name)
# pred and gold path
gold_file = args['gold_file']
# skip training if the language does not have training or dev data
if len(train_batch) == 0:
logger.warning("Skip training because no data available...")
return
dev_mwt = dev_doc.get_mwt_expansions(False)
if len(dev_batch) == 0 and args.get('dict_only', False):
logger.warning("Training data available, but dev data has no MWTs. Only training a dict based MWT")
args['dict_only'] = True
if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc):
raise ValueError("Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords")
if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc):
# the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer
# (the training loop here does not need to change)
# in this model, a classifier distinguishes whether or not a location is a split
# and the text is copied exactly from the input rather than created via seq2seq
# this behavior can be turned off at training time with --no_force_exact_pieces
logger.info("Train MWTs entirely composed of their subwords. Training the MWT to match that paradigm as closely as possible")
args['force_exact_pieces'] = True
if args['force_exact_pieces']:
logger.info("Reconverting to BinaryDataLoader")
train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False)
vocab = train_batch.vocab
args['vocab_size'] = vocab.size
dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
if args['num_layers'] is None:
if args['force_exact_pieces']:
args['num_layers'] = 2
else:
args['num_layers'] = 1
# train a dictionary-based MWT expander
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
logger.info("Training dictionary-based MWT expander...")
trainer.train_dict(train_batch.doc.get_mwt_expansions(evaluation=False))
logger.info("Evaluating on dev set...")
dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_f = scorer.score(system_preds, gold_file)
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
if args.get('dict_only', False):
# save dictionaries
trainer.save(model_file)
else:
# train a seq2seq model
logger.info("Training seq2seq-based MWT expander...")
global_step = 0
steps_per_epoch = math.ceil(len(train_batch) / args['batch_size'])
max_steps = steps_per_epoch * args['num_epoch']
dev_score_history = []
best_dev_preds = []
current_lr = args['lr']
global_start_time = time.time()
format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_mwt" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
# start training
for epoch in range(1, args['num_epoch']+1):
train_loss = 0
for i, batch in enumerate(train_batch.to_loader()):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
if global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
if save_each_name:
trainer.save(save_each_name % epoch)
logger.info("Saved epoch %d model to %s" % (epoch, save_each_name % epoch))
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
for i, batch in enumerate(dev_batch.to_loader()):
preds = trainer.predict(batch)
dev_preds += preds
if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):
logger.info("[Ensembling dict with seq2seq model...]")
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_score = scorer.score(system_preds, gold_file)
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
# save best model
if epoch == 1 or dev_score > max(dev_score_history):
trainer.save(model_file)
logger.info("new best model saved.")
best_dev_preds = dev_preds
# lr schedule
if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1]:
current_lr *= args['lr_decay']
trainer.change_lr(current_lr)
dev_score_history += [dev_score]
logger.info("Training ended with {} epochs.".format(epoch))
if args['wandb']:
wandb.finish()
best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
# try ensembling with dict if necessary
if args.get('ensemble_dict', False):
logger.info("[Ensembling dict with seq2seq model...]")
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_score = scorer.score(system_preds, gold_file)
logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100))
best_f = max(best_f, dev_score)
return trainer, _
def evaluate(args):
# file paths
system_pred_file = args['output_file']
gold_file = args['gold_file']
model_file = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
if args['save_dir'] and not model_file.startswith(args['save_dir']) and not os.path.exists(model_file):
model_file = os.path.join(args['save_dir'], model_file)
# load model
trainer = Trainer(model_file=model_file, device=args['device'])
loaded_args, vocab = trainer.args, trainer.vocab
for k in args:
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
loaded_args[k] = args[k]
logger.debug('max_dec_len: %d' % loaded_args['max_dec_len'])
# load data
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
doc = CoNLL.conll2doc(input_file=args['eval_file'])
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
if len(batch) > 0:
dict_preds = trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True))
# decide trainer type and run eval
if loaded_args['dict_only']:
preds = dict_preds
else:
logger.info("Running the seq2seq model...")
preds = []
for i, b in enumerate(batch.to_loader()):
preds += trainer.predict(b)
if loaded_args.get('ensemble_dict', False):
preds = trainer.ensemble(batch.doc.get_mwt_expansions(evaluation=True), preds)
else:
# skip eval if dev data does not exist
preds = []
# write to file and score
doc = copy.deepcopy(batch.doc)
doc.set_mwt_expansions(preds, fake_dependencies=True)
if system_pred_file is not None:
CoNLL.write_doc2conll(doc, system_pred_file)
else:
system_pred_file = "{:C}\n\n".format(doc)
system_pred_file = io.StringIO(system_pred_file)
if gold_file is not None:
_, _, score = scorer.score(system_pred_file, gold_file)
logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100))
return trainer, doc
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/ner/__init__.py
================================================
================================================
FILE: stanza/models/ner/data.py
================================================
import random
import logging
import torch
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
from stanza.models.common.data import map_to_ids, get_long_tensor, sort_all
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX
from stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab
from stanza.models.ner.vocab import MultiVocab
from stanza.models.common.doc import *
from stanza.models.ner.utils import process_tags, normalize_empty_tags
logger = logging.getLogger('stanza')
class DataLoader:
def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=False, preprocess_tags=True, bert_tokenizer=None, scheme=None, max_batch_words=None):
self.max_batch_words = max_batch_words
self.batch_size = batch_size
self.args = args
self.eval = evaluation
self.shuffled = not self.eval
self.doc = doc
self.preprocess_tags = preprocess_tags
data = self._load_doc(self.doc, scheme)
# filter out the long sentences if bert is used
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
self.tags = [[w[1] for w in sent] for sent in data]
# handle vocab
self.pretrain = pretrain
if vocab is None:
self.vocab = self.init_vocab(data)
else:
self.vocab = vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
data = self.preprocess(data, self.vocab, args)
# shuffle for training
if self.shuffled:
random.shuffle(data)
self.num_examples = len(data)
# chunk into batches
self.data = self.chunk_batches(data)
logger.debug("{} batches created.".format(len(self.data)))
def init_vocab(self, data):
def from_model(model_filename):
""" Try loading vocab from charLM model file. """
state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True)
if 'vocab' in state_dict:
return state_dict['vocab']
if 'model' in state_dict and 'vocab' in state_dict['model']:
return state_dict['model']['vocab']
raise ValueError("Cannot find vocab in charLM model file %s" % model_filename)
if self.eval:
raise AssertionError("Vocab must exist for evaluation.")
if self.args['charlm']:
charvocab = CharVocab.load_state_dict(from_model(self.args['charlm_forward_file']))
else:
charvocab = CharVocab(data, self.args['shorthand'])
wordvocab = self.pretrain.vocab if self.pretrain is not None else None
tag_data = [[(x[1],) for x in sentence] for sentence in data]
tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None)
ignore = None
if self.args['emb_finetune_known_only']:
if self.pretrain is None:
raise ValueError("Cannot train emb_finetune_known_only with no pretrain of known words")
if self.args['lowercase']:
ignore = set([w[0].lower() for sent in data for w in sent if w[0] not in wordvocab and w[0].lower() not in wordvocab])
else:
ignore = set([w[0] for sent in data for w in sent if w[0] not in wordvocab])
logger.debug("Ignoring %d in the delta vocab as they did not appear in the original embedding", len(ignore))
deltavocab = WordVocab(data, self.args['shorthand'], cutoff=1, lower=self.args['lowercase'], ignore=ignore)
logger.debug("Creating delta vocab of size %s", len(deltavocab))
vocabs = {'char': charvocab,
'delta': deltavocab,
'tag': tagvocab}
if wordvocab is not None:
vocabs['word'] = wordvocab
vocab = MultiVocab(vocabs)
return vocab
def preprocess(self, data, vocab, args):
processed = []
if args.get('char_lowercase', False): # handle character case
char_case = lambda x: x.lower()
else:
char_case = lambda x: x
for sent_idx, sent in enumerate(data):
processed_sent = [[w[0] for w in sent]]
processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]]
processed_sent += [vocab['tag'].map([w[1] for w in sent])]
processed.append(processed_sent)
return processed
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[List[int]]]
# sort sentences by lens for easy RNN operations
sentlens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, sentlens)
sentlens = [len(x) for x in batch[0]]
# sort chars by lens for easy char-LM operations
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(batch[1])
chars_sorted, char_orig_idx = sort_all([chars_forward, chars_backward, charoffsets_forward, charoffsets_backward], charlens)
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted
charlens = [len(sent) for sent in chars_forward]
# sort words by lens for easy char-RNN operations
batch_words = [w for sent in batch[1] for w in sent]
wordlens = [len(x) for x in batch_words]
batch_words, word_orig_idx = sort_all([batch_words], wordlens)
batch_words = batch_words[0]
wordlens = [len(x) for x in batch_words]
words = batch[0]
wordchars = get_long_tensor(batch_words, len(wordlens))
wordchars_mask = torch.eq(wordchars, PAD_ID)
chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' '))
chars_backward = get_long_tensor(chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' '))
chars = torch.cat([chars_forward.unsqueeze(0), chars_backward.unsqueeze(0)]) # padded forward and backward char idx
charoffsets = [charoffsets_forward, charoffsets_backward] # idx for forward and backward lm to get word representation
tags = get_long_tensor(batch[2], batch_size)
return words, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def _load_doc(self, doc, scheme):
# preferentially load the MULTI_NER in case we are training /
# testing a model with multiple layers of tags
data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True)
data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data]
if self.preprocess_tags: # preprocess tags
if scheme is None:
data = process_tags(data, self.args.get('scheme', 'bio'))
data = normalize_empty_tags(data)
return data
def process_chars(self, sents):
start_id, end_id = self.vocab['char'].unit2id('\n'), self.vocab['char'].unit2id(' ') # special token
start_offset, end_offset = 1, 1
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = [], [], [], []
# get char representation for each sentence
for sent in sents:
chars_forward_sent, chars_backward_sent, charoffsets_forward_sent, charoffsets_backward_sent = [start_id], [start_id], [], []
# forward lm
for word in sent:
chars_forward_sent += word
charoffsets_forward_sent = charoffsets_forward_sent + [len(chars_forward_sent)] # add each token offset in the last for forward lm
chars_forward_sent += [end_id]
# backward lm
for word in sent[::-1]:
chars_backward_sent += word[::-1]
charoffsets_backward_sent = [len(chars_backward_sent)] + charoffsets_backward_sent # add each offset in the first for backward lm
chars_backward_sent += [end_id]
# store each sentence
chars_forward.append(chars_forward_sent)
chars_backward.append(chars_backward_sent)
charoffsets_forward.append(charoffsets_forward_sent)
charoffsets_backward.append(charoffsets_backward_sent)
charlens = [len(sent) for sent in chars_forward] # forward lm and backward lm should have the same lengths
return chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens
def reshuffle(self):
data = [y for x in self.data for y in x]
random.shuffle(data)
self.data = self.chunk_batches(data)
def chunk_batches(self, data):
if self.max_batch_words is None:
return [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
batches = []
next_batch = []
for item in data:
next_batch.append(item)
if len(next_batch) >= self.batch_size:
batches.append(next_batch)
next_batch = []
if sum(len(x[0]) for x in next_batch) >= self.max_batch_words:
batches.append(next_batch)
next_batch = []
if len(next_batch) > 0:
batches.append(next_batch)
return batches
================================================
FILE: stanza/models/ner/model.py
================================================
import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
from stanza.models.common.data import map_to_ids, get_long_tensor
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
from stanza.models.common.packed_lstm import PackedLSTM
from stanza.models.common.dropout import WordDropout, LockedDropout
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
from stanza.models.common.crf import CRFLoss
from stanza.models.common.foundation_cache import load_bert
from stanza.models.common.utils import attach_bert_model
from stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID
from stanza.models.common.bert_embedding import extract_bert_embeddings
logger = logging.getLogger('stanza')
# this gets created in two places in trainer
# in both places, pass in the bert model & tokenizer
class NERTagger(nn.Module):
def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
super().__init__()
self.vocab = vocab
self.args = args
self.unsaved_modules = []
# input layers
input_size = 0
if self.args['word_emb_dim'] > 0:
emb_finetune = self.args.get('emb_finetune', True)
if 'word' in self.vocab:
# load pretrained embeddings if specified
word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID)
# if a model trained with no 'delta' vocab is loaded, and
# emb_finetune is off, any resaving of the model will need
# the updated vectors. this is accounted for in load()
if not emb_finetune or 'delta' in self.vocab:
# if emb_finetune is off
# or if the delta embedding is present
# then we won't fine tune the original embedding
self.add_unsaved_module('word_emb', word_emb)
self.word_emb.weight.detach_()
else:
self.word_emb = word_emb
if emb_matrix is not None:
self.init_emb(emb_matrix)
# TODO: allow for expansion of delta embedding if new
# training data has new words in it?
self.delta_emb = None
if 'delta' in self.vocab:
# zero inits seems to work better
# note that the gradient will flow to the bottom and then adjust the 0 weights
# as opposed to a 0 matrix cutting off the gradient if higher up in the model
self.delta_emb = nn.Embedding(len(self.vocab['delta']), self.args['word_emb_dim'], PAD_ID)
nn.init.zeros_(self.delta_emb.weight)
# if the model was trained with a delta embedding, but emb_finetune is off now,
# then we will detach the delta embedding
if not emb_finetune:
self.delta_emb.weight.detach_()
input_size += self.args['word_emb_dim']
self.peft_name = peft_name
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
if self.args.get('bert_model', None):
# TODO: refactor bert_hidden_layers between the different models
if args.get('bert_hidden_layers', False):
# The average will be offset by 1/N so that the default zeros
# represents an average of the N layers
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
input_size += self.bert_model.config.hidden_size
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args['charlm']:
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
raise ForwardCharlmNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file'])
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
raise BackwardCharlmNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file'])
self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
else:
self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False)
input_size += self.args['char_hidden_dim'] * 2
# optionally add a input transformation layer
if self.args.get('input_transform', False):
self.input_transform = nn.Linear(input_size, input_size)
else:
self.input_transform = None
# recurrent layers
self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \
bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout'])
# self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
self.drop_replacement = None
self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
# tag classifier
tag_lengths = self.vocab['tag'].lens()
self.num_output_layers = len(tag_lengths)
if self.args.get('connect_output_layers'):
tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])]
for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]):
tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length))
self.tag_clfs = nn.ModuleList(tag_clfs)
else:
self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths])
for tag_clf in self.tag_clfs:
tag_clf.bias.data.zero_()
self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths])
self.drop = nn.Dropout(args['dropout'])
self.worddrop = WordDropout(args['word_dropout'])
self.lockeddrop = LockedDropout(args['locked_dropout'])
def init_emb(self, emb_matrix):
if isinstance(emb_matrix, np.ndarray):
emb_matrix = torch.from_numpy(emb_matrix)
vocab_size = len(self.vocab['word'])
dim = self.args['word_emb_dim']
assert emb_matrix.size() == (vocab_size, dim), \
"Input embedding matrix must match size: {} x {}, found {}".format(vocab_size, dim, emb_matrix.size())
self.word_emb.weight.data.copy_(emb_matrix)
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMTERS"]
for name, param in self.named_parameters():
if param.requires_grad and name.split(".")[0] not in ('charmodel_forward', 'charmodel_backward'):
lines.append(" %s %.6g" % (name, torch.norm(param).item()))
logger.info("\n".join(lines))
def forward(self, sentences, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx):
device = next(self.parameters()).device
def pack(x):
return pack_padded_sequence(x, sentlens, batch_first=True)
inputs = []
batch_size = len(sentences)
has_embedding = False
if self.args['word_emb_dim'] > 0:
#extract static embeddings
if 'word' in self.vocab:
static_words, word_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['word'])
word_mask = word_mask.to(device)
static_words = static_words.to(device)
word_static_emb = self.word_emb(static_words)
has_embedding = True
if 'delta' in self.vocab and self.delta_emb is not None:
# masks should be the same
delta_words, delta_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['delta'])
delta_words = delta_words.to(device)
# unclear whether to treat words in the main embedding
# but not in delta as unknown
# simple heuristic though - treating them as not
# unknown keeps existing models the same when
# separating models into the base WV and delta WV
# also, note that at training time, words like this
# did not show up in the training data, but are
# not exactly UNK, so it makes sense
if has_embedding:
delta_unk_mask = torch.eq(delta_words, UNK_ID)
static_unk_mask = torch.not_equal(static_words, UNK_ID)
unk_mask = delta_unk_mask * static_unk_mask
delta_words[unk_mask] = PAD_ID
else:
word_mask = delta_mask.to(device)
delta_emb = self.delta_emb(delta_words)
if has_embedding:
word_static_emb = word_static_emb + delta_emb
else:
has_embedding = True
word_static_emb = delta_emb
if has_embedding:
word_emb = pack(word_static_emb)
inputs += [word_emb]
if self.bert_model is not None:
device = next(self.parameters()).device
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, sentences, device, keep_endpoints=False,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.args.get('bert_finetune', False),
peft_name=self.peft_name)
if self.bert_layer_mix is not None:
# use a linear layer to weighted average the embedding dynamically
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
processed_bert = pad_sequence(processed_bert, batch_first=True)
inputs += [pack(processed_bert)]
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args.get('charlm', None):
char_reps_forward = self.charmodel_forward.get_representation(chars[0], charoffsets[0], charlens, char_orig_idx)
char_reps_forward = PackedSequence(char_reps_forward.data, char_reps_forward.batch_sizes)
char_reps_backward = self.charmodel_backward.get_representation(chars[1], charoffsets[1], charlens, char_orig_idx)
char_reps_backward = PackedSequence(char_reps_backward.data, char_reps_backward.batch_sizes)
inputs += [char_reps_forward, char_reps_backward]
else:
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
char_reps = PackedSequence(char_reps.data, char_reps.batch_sizes)
inputs += [char_reps]
batch_sizes = inputs[0].batch_sizes
def pad(x):
return pad_packed_sequence(PackedSequence(x, batch_sizes), batch_first=True)[0]
lstm_inputs = torch.cat([x.data for x in inputs], 1)
if self.args['word_dropout'] > 0:
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
lstm_inputs = self.drop(lstm_inputs)
lstm_inputs = pad(lstm_inputs)
lstm_inputs = self.lockeddrop(lstm_inputs)
lstm_inputs = pack(lstm_inputs).data
if self.input_transform:
lstm_inputs = self.input_transform(lstm_inputs)
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(\
self.taggerlstm_h_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous(), \
self.taggerlstm_c_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous()))
lstm_outputs = lstm_outputs.data
# prediction layer
lstm_outputs = self.drop(lstm_outputs)
lstm_outputs = pad(lstm_outputs)
lstm_outputs = self.lockeddrop(lstm_outputs)
lstm_outputs = pack(lstm_outputs).data
loss = 0
logits = []
trans = []
for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)):
if not self.args.get('connect_output_layers') or idx == 0:
next_logits = pad(tag_clf(lstm_outputs)).contiguous()
else:
# here we pack the output of the previous round, then append it
packed_logits = pack(next_logits).data
input_logits = torch.cat([lstm_outputs, packed_logits], axis=1)
next_logits = pad(tag_clf(input_logits)).contiguous()
# the tag_mask lets us avoid backprop on a blank tag
tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID)
if has_embedding:
tag_mask = torch.bitwise_or(tag_mask, word_mask)
else:
tag_mask = torch.bitwise_or(tag_mask, torch.eq(tags[:, :, idx], PAD_ID))
next_loss, next_trans = crit(next_logits, tag_mask, tags[:, :, idx])
loss = loss + next_loss
logits.append(next_logits)
trans.append(next_trans)
return loss, logits, trans
@staticmethod
def extract_static_embeddings(args, sents, vocab):
processed = []
if args.get('lowercase', True): # handle word case
case = lambda x: x.lower()
else:
case = lambda x: x
for idx, sent in enumerate(sents):
processed_sent = [vocab.map([case(w) for w in sent])]
processed.append(processed_sent[0])
words = get_long_tensor(processed, len(sents))
words_mask = torch.eq(words, PAD_ID)
return words, words_mask
================================================
FILE: stanza/models/ner/scorer.py
================================================
"""
An NER scorer that calculates F1 score given gold and predicted tags.
"""
import sys
import os
import logging
from collections import Counter, defaultdict
from stanza.models.ner.utils import decode_from_bioes
logger = logging.getLogger('stanza')
def score_by_entity(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None):
""" Score predicted tags at the entity level.
Args:
pred_tags_sequences: a list of list of predicted tags for each word
gold_tags_sequences: a list of list of gold tags for each word
verbose: print log with results
ignore_tags: a list or a string with a comma-separated list of tags to ignore
Returns:
Precision, recall and F1 scores.
"""
assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \
"Number of predicted tag sequences does not match gold sequences."
def decode_all(tag_sequences):
# decode from all sequences, each sequence with a unique id
ents = []
for sent_id, tags in enumerate(tag_sequences):
for ent in decode_from_bioes(tags):
ent['sent_id'] = sent_id
ents += [ent]
return ents
ignore_tag_set = set()
if ignore_tags:
if isinstance(ignore_tags, str):
ignore_tag_set.update(ignore_tags.split(","))
else:
ignore_tag_set.update(ignore_tags)
gold_ents = decode_all(gold_tag_sequences)
gold_ents = [x for x in gold_ents if x['type'] not in ignore_tag_set]
pred_ents = decode_all(pred_tag_sequences)
pred_ents = [x for x in pred_ents if x['type'] not in ignore_tag_set]
# scoring
true_positive_by_type = Counter()
false_positive_by_type = Counter()
false_negative_by_type = Counter()
guessed_by_type = Counter()
gold_by_type = Counter()
for p in pred_ents:
guessed_by_type[p['type']] += 1
if p in gold_ents:
true_positive_by_type[p['type']] += 1
else:
false_positive_by_type[p['type']] += 1
for g in gold_ents:
gold_by_type[g['type']] += 1
if g not in pred_ents:
false_negative_by_type[g['type']] += 1
entities = sorted(set(list(true_positive_by_type.keys()) + list(false_positive_by_type.keys()) + list(false_negative_by_type.keys())))
entity_f1 = {}
for entity in entities:
entity_f1[entity] = 2 * true_positive_by_type[entity] / (2 * true_positive_by_type[entity] + false_positive_by_type[entity] + false_negative_by_type[entity])
prec_micro = 0.0
if sum(guessed_by_type.values()) > 0:
prec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(guessed_by_type.values())
rec_micro = 0.0
if sum(gold_by_type.values()) > 0:
rec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(gold_by_type.values())
f_micro = 0.0
if prec_micro + rec_micro > 0:
f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro)
if verbose:
logger.info("Score by entity:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format(
prec_micro*100, rec_micro*100, f_micro*100))
return prec_micro, rec_micro, f_micro, entity_f1
def score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None):
""" Score predicted tags at the token level.
Args:
pred_tags_sequences: a list of list of predicted tags for each word
gold_tags_sequences: a list of list of gold tags for each word
verbose: print log with results
ignore_tags: a list or a string with a comma-separated list of tags to ignore
Returns:
Precision, recall and F1 scores, along with a confusion matrix
"""
assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \
"Number of predicted tag sequences does not match gold sequences."
ignore_tag_set = set()
if ignore_tags:
if isinstance(ignore_tags, str):
ignore_tag_set.update(ignore_tags.split(","))
else:
ignore_tag_set.update(ignore_tags)
def ignore_tag(tag):
if tag == 'O':
return True
if len(tag) > 2 and (tag[1] == '_' or tag[1] == '-'):
tag = tag[2:]
if tag in ignore_tag_set:
return True
return False
correct_by_tag = Counter()
guessed_by_tag = Counter()
gold_by_tag = Counter()
confusion = defaultdict(lambda: defaultdict(int))
for gold_tags, pred_tags in zip(gold_tag_sequences, pred_tag_sequences):
assert(len(gold_tags) == len(pred_tags)), \
"Number of predicted tags does not match gold."
for g, p in zip(gold_tags, pred_tags):
if ignore_tag(g):
g = 'O'
if ignore_tag(p):
p = 'O'
confusion[g][p] = confusion[g][p] + 1
if g == 'O' and p == 'O':
continue
elif g == 'O' and p != 'O':
guessed_by_tag[p] += 1
elif g != 'O' and p == 'O':
gold_by_tag[g] += 1
else:
guessed_by_tag[p] += 1
gold_by_tag[p] += 1
if g == p:
correct_by_tag[p] += 1
prec_micro = 0.0
if sum(guessed_by_tag.values()) > 0:
prec_micro = sum(correct_by_tag.values()) * 1.0 / sum(guessed_by_tag.values())
rec_micro = 0.0
if sum(gold_by_tag.values()) > 0:
rec_micro = sum(correct_by_tag.values()) * 1.0 / sum(gold_by_tag.values())
f_micro = 0.0
if prec_micro + rec_micro > 0:
f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro)
if verbose:
logger.info("Score by token:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format(
prec_micro*100, rec_micro*100, f_micro*100))
return prec_micro, rec_micro, f_micro, confusion
def test():
pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],
['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]
gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],
['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]
print(score_by_token(pred_sequences, gold_sequences))
print(score_by_entity(pred_sequences, gold_sequences))
if __name__ == '__main__':
test()
================================================
FILE: stanza/models/ner/trainer.py
================================================
"""
A trainer class to handle training and testing of models.
"""
import sys
import logging
import torch
from torch import nn
from stanza.models.common.foundation_cache import NoTransformerFoundationCache, load_bert, load_bert_with_peft
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE
from stanza.models.common import utils, loss
from stanza.models.ner.model import NERTagger
from stanza.models.ner.vocab import MultiVocab
from stanza.models.common.crf import viterbi_decode
logger = logging.getLogger('stanza')
def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [batch[0]]
inputs += [b.to(device) if b is not None else None for b in batch[1:5]]
orig_idx = batch[5]
word_orig_idx = batch[6]
char_orig_idx = batch[7]
sentlens = batch[8]
wordlens = batch[9]
charlens = batch[10]
charoffsets = batch[11]
return inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
def fix_singleton_tags(tags):
"""
If there are any singleton B- or E- tags, convert them to S-
"""
new_tags = list(tags)
# first update all I- tags at the start or end of sequence to B- or E- as appropriate
for idx, tag in enumerate(new_tags):
if (tag.startswith("I-") and
(idx == len(new_tags) - 1 or
(new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))):
new_tags[idx] = "E-" + tag[2:]
if (tag.startswith("I-") and
(idx == 0 or
(new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))):
new_tags[idx] = "B-" + tag[2:]
# now make another pass through the data to update any singleton tags,
# including ones which were turned into singletons by the previous operation
for idx, tag in enumerate(new_tags):
if (tag.startswith("B-") and
(idx == len(new_tags) - 1 or
(new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))):
new_tags[idx] = "S-" + tag[2:]
if (tag.startswith("E-") and
(idx == 0 or
(new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))):
new_tags[idx] = "S-" + tag[2:]
return new_tags
class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None,
train_classifier_only=False, foundation_cache=None, second_optim=False):
if model_file is not None:
# load everything from file
self.load(model_file, pretrain, args, foundation_cache)
else:
assert args is not None
assert vocab is not None
# build model from scratch
self.args = args
self.vocab = vocab
bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
peft_name = None
if self.args['use_peft']:
# fine tune the bert if we're using peft
self.args['bert_finetune'] = True
peft_name = "ner"
# peft the lovely model
bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
emb_matrix=None
if pretrain is not None:
emb_matrix = pretrain.emb
self.model = NERTagger(args, vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
# IMPORTANT: gradient checkpointing BREAKS peft if applied before
# 1. Apply PEFT FIRST (looksie! it's above this line)
# 2. Run gradient checkpointing
# https://github.com/huggingface/peft/issues/742
if self.args.get("gradient_checkpointing", False) and self.args.get("bert_finetune", False):
self.model.bert_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# if this wasn't set anywhere, we use a default of the 0th tagset
# we don't set this as a default in the options so that
# we can distinguish "intentionally set to 0" and "not set at all"
if self.args.get('predict_tagset', None) is None:
self.args['predict_tagset'] = 0
if train_classifier_only:
logger.info('Disabling gradient for non-classifier layers')
exclude = ['tag_clf', 'crit']
for pname, p in self.model.named_parameters():
if pname.split('.')[0] not in exclude:
p.requires_grad = False
self.model = self.model.to(device)
if not second_optim:
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("use_peft"))
else:
self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get("use_peft"))
def update(self, batch, eval=False):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)
word, wordchars, wordchars_mask, chars, tags = inputs
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
loss, _, _ = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)
loss_val = loss.data.item()
if eval:
return loss_val
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
return loss_val
def predict(self, batch, unsort=True):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)
word, wordchars, wordchars_mask, chars, tags = inputs
self.model.eval()
#batch_size = word.size(0)
_, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)
# decode
# TODO: might need to decode multiple columns of output for
# models with multiple layers
trans = [x.data.cpu().numpy() for x in trans]
logits = [x.data.cpu().numpy() for x in logits]
batch_size = logits[0].shape[0]
if any(x.shape[0] != batch_size for x in logits):
raise AssertionError("Expected all of the logits to have the same size")
tag_seqs = []
predict_tagset = self.args['predict_tagset']
for i in range(batch_size):
# for each tag column in the output, decode the tag assignments
tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)]
# TODO: this is to patch that the model can sometimes predict < "O"
tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags]
# that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags
tags = list(zip(*tags))
# now unmap that to the tags in the vocab
tags = self.vocab['tag'].unmap(tags)
# for now, allow either TagVocab or CompositeVocab
# TODO: we might want to return all of the predictions
# rather than a single column
tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags]
tags = fix_singleton_tags(tags)
tag_seqs += [tags]
if unsort:
tag_seqs = utils.unsort(tag_seqs, orig_idx)
return tag_seqs
def save(self, filename, skip_modules=True):
model_state = self.model.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
'config': self.args
}
if self.args["use_peft"]:
from peft import get_peft_model_state_dict
params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except (KeyboardInterrupt, SystemExit):
raise
except:
logger.warning("Saving failed... continuing anyway.")
def load(self, filename, pretrain=None, args=None, foundation_cache=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args: self.args.update(args)
# if predict_tagset was not explicitly set in the args,
# we use the value the model was trained with
for keep_arg in ('predict_tagset', 'train_scheme', 'scheme'):
if self.args.get(keep_arg, None) is None:
self.args[keep_arg] = checkpoint['config'].get(keep_arg, None)
lora_weights = checkpoint.get('bert_lora')
if lora_weights:
logger.debug("Found peft weights for NER; loading a peft adapter")
self.args["use_peft"] = True
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
emb_matrix=None
if pretrain is not None:
emb_matrix = pretrain.emb
force_bert_saved = False
peft_name = None
if self.args.get('use_peft', False):
force_bert_saved = True
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "ner", foundation_cache)
bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
logger.debug("Loaded peft with name %s", peft_name)
else:
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
force_bert_saved = True
bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
if any(x.startswith("crit.") for x in checkpoint['model'].keys()):
logger.debug("Old model format detected. Updating to the new format with one column of tags")
checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions')
checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight')
checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias')
self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
self.model.load_state_dict(checkpoint['model'], strict=False)
# there is a possible issue with the delta embeddings.
# specifically, with older models trained without the delta
# embedding matrix
# if those models have been trained with the embedding
# modifications saved as part of the base embedding,
# we need to resave the model with the updated embedding
# otherwise the resulting model will be broken
if 'delta' not in self.model.vocab and 'word_emb.weight' in checkpoint['model'].keys() and 'word_emb' in self.model.unsaved_modules:
logger.debug("Removing word_emb from unsaved_modules so that resaving %s will keep the saved embedding", filename)
self.model.unsaved_modules.remove('word_emb')
def get_known_tags(self):
"""
Return the tags known by this model
Removes the S-, B-, etc, and does not include O
"""
tags = set()
for tag in self.vocab['tag'].items(0):
if tag in VOCAB_PREFIX:
continue
if tag == 'O':
continue
if len(tag) > 2 and tag[:2] in ('S-', 'B-', 'I-', 'E-'):
tag = tag[2:]
tags.add(tag)
return sorted(tags)
================================================
FILE: stanza/models/ner/utils.py
================================================
"""
Utility functions for dealing with NER tagging.
"""
import logging
from stanza.models.common.vocab import EMPTY
logger = logging.getLogger('stanza')
EMPTY_TAG = ('_', '-', '', None)
EMPTY_OR_O_TAG = tuple(list(EMPTY_TAG) + ['O'])
def is_basic_scheme(all_tags):
"""
Check if a basic tagging scheme is used. Return True if so.
Args:
all_tags: a list of NER tags
Returns:
True if the tagging scheme does not use B-, I-, etc, otherwise False
"""
for tag in all_tags:
if len(tag) > 2 and tag[:2] in ('B-', 'I-', 'S-', 'E-', 'B_', 'I_', 'S_', 'E_'):
return False
return True
def is_bio_scheme(all_tags):
"""
Check if BIO tagging scheme is used. Return True if so.
Args:
all_tags: a list of NER tags
Returns:
True if the tagging scheme is BIO, otherwise False
"""
for tag in all_tags:
if tag in EMPTY_OR_O_TAG:
continue
elif len(tag) > 2 and tag[:2] in ('B-', 'I-', 'B_', 'I_'):
continue
else:
return False
return True
def to_bio2(tags):
"""
Convert the original tag sequence to BIO2 format. If the input is already in BIO2 format,
the original input is returned.
Args:
tags: a list of tags in either BIO or BIO2 format
Returns:
new_tags: a list of tags in BIO2 format
"""
new_tags = []
for i, tag in enumerate(tags):
if tag in EMPTY_OR_O_TAG:
new_tags.append(tag)
elif tag[0] == 'I':
if i == 0 or tags[i-1] == 'O' or tags[i-1][1:] != tag[1:]:
new_tags.append('B' + tag[1:])
else:
new_tags.append(tag)
else:
new_tags.append(tag)
return new_tags
def basic_to_bio(tags):
"""
Convert a basic tag sequence into a BIO sequence.
You can compose this with bio2_to_bioes to convert to bioes
Args:
tags: a list of tags in basic (no B-, I-, etc) format
Returns:
new_tags: a list of tags in BIO format
"""
new_tags = []
for i, tag in enumerate(tags):
if tag in EMPTY_OR_O_TAG:
new_tags.append(tag)
elif i == 0 or tags[i-1] == 'O' or tags[i-1] != tag:
new_tags.append('B-' + tag)
else:
new_tags.append('I-' + tag)
return new_tags
def bio2_to_bioes(tags):
"""
Convert the BIO2 tag sequence into a BIOES sequence.
Args:
tags: a list of tags in BIO2 format
Returns:
new_tags: a list of tags in BIOES format
"""
new_tags = []
for i, tag in enumerate(tags):
if tag in EMPTY_OR_O_TAG:
new_tags.append(tag)
else:
if len(tag) < 2:
raise Exception(f"Invalid BIO2 tag found: {tag}")
else:
if tag[:2] in ('I-', 'I_'): # convert to E- if next tag is not I-
if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'):
new_tags.append('I-' + tag[2:]) # compensate for underscores
else:
new_tags.append('E-' + tag[2:])
elif tag[:2] in ('B-', 'B_'): # convert to S- if next tag is not I-
if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'):
new_tags.append('B-' + tag[2:])
else:
new_tags.append('S-' + tag[2:])
else:
raise Exception(f"Invalid IOB tag found: {tag}")
return new_tags
def normalize_empty_tags(sentences):
"""
If any tags are None, _, -, or blank, turn them into EMPTY
The input should be a list(sentence) of list(word) of tuple(text, list(tag))
which is the typical format for the data at the time data.py is preprocessing the tags
"""
new_sentences = [[(word[0], tuple(EMPTY if x in EMPTY_TAG else x for x in word[1])) for word in sentence]
for sentence in sentences]
return new_sentences
def process_tags(sentences, scheme):
"""
Convert tags in these sentences to bioes
We allow empty tags ('_', '-', None), which will represent tags
that do not get any gradient when training
"""
all_words = []
all_tags = []
converted_tuples = False
for sent_idx, sent in enumerate(sentences):
words, tags = zip(*sent)
all_words.append(words)
# if we got one dimension tags w/o tuples or lists, make them tuples
# but we also check that the format is consistent,
# as otherwise the result being converted might be confusing
if not converted_tuples and any(tag is None or isinstance(tag, str) for tag in tags):
if sent_idx > 0:
raise ValueError("Got a mix of tags and lists of tags. First non-list was in sentence %d" % sent_idx)
converted_tuples = True
if converted_tuples:
if not all(tag is None or isinstance(tag, str) for tag in tags):
raise ValueError("Got a mix of tags and lists of tags. First tag as a list was in sentence %d" % sent_idx)
tags = [(tag,) for tag in tags]
all_tags.append(tags)
max_columns = max(len(x) for tags in all_tags for x in tags)
for sent_idx, tags in enumerate(all_tags):
if any(len(x) < max_columns for x in tags):
raise ValueError("NER tags not uniform in length at sentence %d. TODO: extend those columns with O" % sent_idx)
all_convert_bio_to_bioes = []
all_convert_basic_to_bioes = []
for column_idx in range(max_columns):
# check if tag conversion is needed for each column
# we treat each column separately, although practically
# speaking it would be pretty weird for a dataset to have BIO
# in one column and basic in another, for example
convert_bio_to_bioes = False
convert_basic_to_bioes = False
tag_column = [x[column_idx] for sent in all_tags for x in sent]
is_bio = is_bio_scheme(tag_column)
is_basic = not is_bio and is_basic_scheme(tag_column)
if is_bio and scheme.lower() == 'bioes':
convert_bio_to_bioes = True
logger.debug("BIO tagging scheme found in input at column %d; converting into BIOES scheme..." % column_idx)
elif is_basic and scheme.lower() == 'bioes':
convert_basic_to_bioes = True
logger.debug("Basic tagging scheme found in input at column %d; converting into BIOES scheme..." % column_idx)
all_convert_bio_to_bioes.append(convert_bio_to_bioes)
all_convert_basic_to_bioes.append(convert_basic_to_bioes)
result = []
for words, tags in zip(all_words, all_tags):
# TODO: add a convert_basic_to_bio option as well
# process tags
# tags is a list of each column of tags for each word in this sentence
# copy the tags to a list so we can edit them
tags = [[x for x in sentence_tags] for sentence_tags in tags]
for column_idx, (convert_bio_to_bioes, convert_basic_to_bioes) in enumerate(zip(all_convert_bio_to_bioes, all_convert_basic_to_bioes)):
tag_column = [x[column_idx] for x in tags]
if convert_basic_to_bioes:
# if basic, convert tags -> bio -> bioes
tag_column = bio2_to_bioes(basic_to_bio(tag_column))
else:
# first ensure BIO2 scheme
tag_column = to_bio2(tag_column)
# then convert to BIOES
if convert_bio_to_bioes:
tag_column = bio2_to_bioes(tag_column)
for tag_idx, tag in enumerate(tag_column):
tags[tag_idx][column_idx] = tag
result.append([(w,tuple(t)) for w,t in zip(words, tags)])
if converted_tuples:
result = [[(word[0], word[1][0]) for word in sentence] for sentence in result]
return result
def decode_from_bioes(tags):
"""
Decode from a sequence of BIOES tags, assuming default tag is 'O'.
Args:
tags: a list of BIOES tags
Returns:
A list of dict with start_idx, end_idx, and type values.
"""
res = []
ent_idxs = []
cur_type = None
def flush():
if len(ent_idxs) > 0:
res.append({
'start': ent_idxs[0],
'end': ent_idxs[-1],
'type': cur_type})
for idx, tag in enumerate(tags):
if tag is None:
tag = 'O'
if tag == 'O':
flush()
ent_idxs = []
elif tag.startswith('B-'): # start of new ent
flush()
ent_idxs = [idx]
cur_type = tag[2:]
elif tag.startswith('I-'): # continue last ent
ent_idxs.append(idx)
cur_type = tag[2:]
elif tag.startswith('E-'): # end last ent
ent_idxs.append(idx)
cur_type = tag[2:]
flush()
ent_idxs = []
elif tag.startswith('S-'): # start single word ent
flush()
ent_idxs = [idx]
cur_type = tag[2:]
flush()
ent_idxs = []
# flush after whole sentence
flush()
return res
def merge_tags(*sequences):
"""
Merge multiple sequences of NER tags into one sequence
Only O is replaced, and the earlier tags have precedence
"""
tags = list(sequences[0])
for sequence in sequences[1:]:
idx = 0
while idx < len(sequence):
# skip empty tags in the later sequences
if sequence[idx] == 'O':
idx += 1
continue
# check for singletons. copy if not O in the original
if sequence[idx].startswith("S-"):
if tags[idx] == 'O':
tags[idx] = sequence[idx]
idx += 1
continue
# at this point, we know we have a B-... sequence
if not sequence[idx].startswith("B-"):
raise ValueError("Got unexpected tag sequence at idx {}: {}".format(idx, sequence))
# take the block of tags which are B- through E-
start_idx = idx
end_idx = start_idx + 1
while end_idx < len(sequence):
if sequence[end_idx][2:] != sequence[start_idx][2:]:
raise ValueError("Unexpected tag sequence at idx {}: {}".format(end_idx, sequence))
if sequence[end_idx].startswith("E-"):
break
if not sequence[end_idx].startswith("I-"):
raise ValueError("Unexpected tag sequence at idx {}: {}".format(end_idx, sequence))
end_idx += 1
if end_idx == len(sequence):
raise ValueError("Got a sequence with an unclosed tag: {}".format(sequence))
end_idx = end_idx + 1
# if all tags in the original are O, we can overwrite
# otherwise, keep the originals
if all(x == 'O' for x in tags[start_idx:end_idx]):
tags[start_idx:end_idx] = sequence[start_idx:end_idx]
idx = end_idx
return tags
================================================
FILE: stanza/models/ner/vocab.py
================================================
from collections import Counter, OrderedDict
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.models.common.pretrain import PretrainedWordVocab
from stanza.models.pos.vocab import WordVocab
class TagVocab(BaseVocab):
""" A vocab for the output tag sequence. """
def build_vocab(self):
counter = Counter([w[self.idx] for sent in self.data for w in sent])
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
def convert_tag_vocab(state_dict):
if state_dict['lower']:
raise AssertionError("Did not expect an NER vocab with 'lower' set to True")
items = state_dict['_id2unit'][len(VOCAB_PREFIX):]
# this looks silly, but the vocab builder treats this as words with multiple fields
# (we set it to look for field 0 with idx=0)
# and then the label field is expected to be a list or tuple of items
items = [[[[x]]] for x in items]
vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None)
if len(vocab._id2unit[0]) != len(state_dict['_id2unit']):
raise AssertionError("Failed to construct a new vocab of the same length as the original")
if vocab._id2unit[0] != state_dict['_id2unit']:
raise AssertionError("Failed to construct a new vocab in the same order as the original")
return vocab
class MultiVocab(BaseMultiVocab):
def state_dict(self):
""" Also save a vocab name to class name mapping in state dict. """
state = OrderedDict()
key2class = OrderedDict()
for k, v in self._vocabs.items():
state[k] = v.state_dict()
key2class[k] = type(v).__name__
state['_key2class'] = key2class
return state
@classmethod
def load_state_dict(cls, state_dict):
class_dict = {'CharVocab': CharVocab.load_state_dict,
'PretrainedWordVocab': PretrainedWordVocab.load_state_dict,
'TagVocab': convert_tag_vocab,
'CompositeVocab': CompositeVocab.load_state_dict,
'WordVocab': WordVocab.load_state_dict}
new = cls()
assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
key2class = state_dict['_key2class']
for k,v in state_dict.items():
if k == '_key2class':
continue
classname = key2class[k]
new[k] = class_dict[classname](v)
return new
================================================
FILE: stanza/models/ner_tagger.py
================================================
"""
Entry point for training and evaluating an NER tagger.
This tagger uses BiLSTM layers with character and word-level representations, and a CRF decoding layer
to produce NER predictions.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
"""
import sys
import os
import time
from datetime import datetime
import argparse
import logging
import numpy as np
import random
import re
import json
import torch
from torch import nn, optim
from stanza.models.ner.data import DataLoader
from stanza.models.ner.trainer import Trainer
from stanza.models.ner import scorer
from stanza.models.common import utils
from stanza.models.common.pretrain import Pretrain
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import *
from stanza.models import _training_logging
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
from stanza.utils.confusion import confusion_to_weighted_f1, format_confusion
logger = logging.getLogger('stanza')
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors')
parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_output_file', type=str, default=None, help='Where to write results: text, gold, pred. If None, no results file printed')
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path')
parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning')
parser.add_argument('--train_classifier_only', action='store_true',
help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--char_hidden_dim', type=int, default=100)
parser.add_argument('--word_emb_dim', type=int, default=100)
parser.add_argument('--char_emb_dim', type=int, default=100)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--char_num_layers', type=int, default=1)
parser.add_argument('--pretrain_max_vocab', type=int, default=100000)
parser.add_argument('--word_dropout', type=float, default=0.01, help="How often to remove a word at training time. Set to a small value to train unk when finetuning word embeddings")
parser.add_argument('--locked_dropout', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--rec_dropout', type=float, default=0, help="Word recurrent dropout")
parser.add_argument('--char_rec_dropout', type=float, default=0, help="Character recurrent dropout")
parser.add_argument('--char_dropout', type=float, default=0, help="Character-level language model dropout")
parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off training a character model.")
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help="Use cased word vectors.")
parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help="Turn off finetuning of the embedding matrix.")
parser.add_argument('--emb_finetune_known_only', dest='emb_finetune_known_only', action='store_true', help="Finetune the embedding matrix only for words in the embedding. (Default: finetune words not in the embedding as well) This may be useful for very large datasets where obscure words are only trained once in a while, such as French-WikiNER")
parser.add_argument('--no_input_transform', dest='input_transform', action='store_false', help="Do not use input transformation layer before tagger lstm.")
parser.add_argument('--scheme', type=str, default='bioes', help="The tagging scheme to use: bio or bioes.")
parser.add_argument('--train_scheme', type=str, default=None, help="The tagging scheme to use when training: bio or bioes. Overrides --scheme for the training set")
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps')
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--second_optim', type=str, default=None, help='once first optimizer converged, tune the model again. with: sgd, adagrad, adam or adamax.')
parser.add_argument('--second_bert_learning_rate', default=0, type=float, help='Secondary stage transformer finetuning learning rate scale')
parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='sgd', help='sgd, adagrad, adam or adamax.')
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate.')
parser.add_argument('--min_lr', type=float, default=1e-4, help='Minimum learning rate to stop training.')
parser.add_argument('--second_lr', type=float, default=5e-3, help='Secondary learning rate')
parser.add_argument('--momentum', type=float, default=0, help='Momentum for SGD.')
parser.add_argument('--lr_decay', type=float, default=0.5, help="LR decay rate.")
parser.add_argument('--patience', type=int, default=3, help="Patience for LR decay.")
parser.add_argument('--connect_output_layers', action='store_true', default=False, help='Connect one output layer to the input of the next output layer. By default, those layers are all separate')
parser.add_argument('--predict_tagset', type=int, default=None, help='Which tagset to predict if there are multiple tagsets. Will default to 0. Default of None allows the model to remember the value from training time, but be overridden at test time')
parser.add_argument('--ignore_tag_scores', type=str, default=None, help="Which tags to ignore, if any, when scoring dev & test sets")
parser.add_argument('--max_steps', type=int, default=200000)
parser.add_argument('--max_steps_no_improve', type=int, default=2500, help='if the model doesn\'t improve after this many steps, give up or switch to new optimizer.')
parser.add_argument('--eval_interval', type=int, default=500)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--max_batch_words', type=int, default=800, help='Long sentences can overwhelm even a large GPU when finetuning a transformer on otherwise reasonable batch sizes. This cuts off those batches early')
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
parser.add_argument('--save_dir', type=str, default='saved_models/ner', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_nertagger.pt", help="File name to save the model")
parser.add_argument('--seed', type=int, default=1234)
utils.add_device_args(parser)
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def parse_args(args=None):
parser = build_argparse()
add_peft_args(parser)
args = parser.parse_args(args=args)
resolve_peft_args(args, logger)
if args.wandb_name:
args.wandb = True
args = vars(args)
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running NER tagger in {} mode".format(args['mode']))
if args['mode'] == 'train':
return train(args)
else:
evaluate(args)
def load_pretrain(args):
# load pretrained vectors
if not args['pretrain']:
return None
if args['wordvec_pretrain_file']:
pretrain_file = args['wordvec_pretrain_file']
pretrain = Pretrain(pretrain_file, None, args['pretrain_max_vocab'], save_to_file=False)
else:
if len(args['wordvec_file']) == 0:
vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
else:
vec_file = args['wordvec_file']
# do not save pretrained embeddings individually
pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False)
return pretrain
def model_file_name(args):
return utils.standard_model_file_name(args, "nertagger")
def get_known_tags(tags):
"""
Tags are stored in the dataset as a list of list of tags
This returns a sorted list for each column of tags in the dataset
"""
max_columns = max(len(word) for sent in tags for word in sent)
known_tags = [set() for _ in range(max_columns)]
for sent in tags:
for word in sent:
for tag_idx, tag in enumerate(word):
known_tags[tag_idx].add(tag)
return [sorted(x) for x in known_tags]
def warn_missing_tags(tag_vocab, data_tags, error_msg, bioes_to_bio=False):
"""
Check for tags missing from the tag_vocab.
Given a tag_vocab and the known tags in the format used by
ner.data, go through the tags in the dataset and look for any
which aren't in the tag_vocab.
error_msg is something like "training set" or "eval file" to
indicate where the missing tags came from.
"""
tag_depth = max(max(len(tags) for tags in sentence) for sentence in data_tags)
if tag_depth != len(tag_vocab.lens()):
logger.warning("Test dataset has a different number of tag types compared to the model: %d vs %d", tag_depth, len(tag_vocab.lens()))
for tag_set_idx in range(min(tag_depth, len(tag_vocab.lens()))):
tag_set = tag_vocab.items(tag_set_idx)
if len(tag_vocab.lens()) > 1:
current_error_msg = error_msg + " tag set %d" % tag_set_idx
else:
current_error_msg = error_msg
current_tags = set([word[tag_set_idx] for sentence in data_tags for word in sentence])
if bioes_to_bio:
current_tags = set([re.sub("^E-", "I-", re.sub("^S-", "B-", x)) for x in current_tags])
utils.warn_missing_tags(tag_set, current_tags, current_error_msg)
def train(args):
model_file = model_file_name(args)
save_dir, save_name = os.path.split(model_file)
utils.ensure_dir(save_dir)
if args['save_dir'] is None:
args['save_dir'] = save_dir
args['save_name'] = save_name
utils.log_training_args(args, logger)
pretrain = None
vocab = None
trainer = None
if args['finetune'] and args['finetune_load_name']:
logger.warning('Finetune is ON. Using model from "{}"'.format(args['finetune_load_name']))
_, trainer, vocab = load_model(args, args['finetune_load_name'])
elif args['finetune'] and os.path.exists(model_file):
logger.warning('Finetune is ON. Using model from "{}"'.format(model_file))
_, trainer, vocab = load_model(args, model_file)
else:
if args['finetune']:
raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file))
pretrain = load_pretrain(args)
if pretrain is not None:
word_emb_dim = pretrain.emb.shape[1]
if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim:
logger.warning("Embedding file has a dimension of {}. Model will be built with that size instead of {}".format(word_emb_dim, args['word_emb_dim']))
args['word_emb_dim'] = word_emb_dim
if args['charlm']:
if args['charlm_shorthand'] is None:
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
logger.info('Using pretrained contextualized char embedding')
if not args['charlm_forward_file']:
args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
if not args['charlm_backward_file']:
args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
# load data
logger.info("Loading training data with batch size %d from %s", args['batch_size'], args['train_file'])
with open(args['train_file']) as fin:
train_doc = Document(json.load(fin))
logger.info("Loaded %d sentences of training data", len(train_doc.sentences))
if len(train_doc.sentences) == 0:
raise ValueError("File %s exists but has no usable training data" % args['train_file'])
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])
vocab = train_batch.vocab
logger.info("Loading dev data from %s", args['eval_file'])
with open(args['eval_file']) as fin:
dev_doc = Document(json.load(fin))
logger.info("Loaded %d sentences of dev data", len(dev_doc.sentences))
if len(dev_doc.sentences) == 0:
raise ValueError("File %s exists but has no usable dev data" % args['train_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
train_tags = get_known_tags(train_batch.tags)
logger.info("Training data has %d columns of tags", len(train_tags))
for tag_idx, tags in enumerate(train_tags):
logger.info("Tags present in training set at column %d:\n Tags without BIES markers: %s\n Tags with B-, I-, E-, or S-: %s",
tag_idx,
" ".join(sorted(set(i for i in tags if i[:2] not in ('B-', 'I-', 'E-', 'S-')))),
" ".join(sorted(set(i[2:] for i in tags if i[:2] in ('B-', 'I-', 'E-', 'S-')))))
# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
logger.info("Skip training because no data available...")
return
logger.info("Training tagger...")
if trainer is None: # init if model was not loaded previously from file
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
train_classifier_only=args['train_classifier_only'])
if args['finetune']:
warn_missing_tags(trainer.vocab['tag'], train_batch.tags, "training set")
# the evaluation will coerce the tags to the proper scheme,
# so we won't need to alert for not having S- or E- tags
bioes_to_bio = args['train_scheme'] == 'bio' and args['scheme'] == 'bioes'
warn_missing_tags(trainer.vocab['tag'], dev_batch.tags, "dev set", bioes_to_bio=bioes_to_bio)
# TODO: might still want to add multiple layers of tag evaluation to the scorer
dev_gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in dev_batch.tags]
logger.info(trainer.model)
global_step = 0
max_steps = args['max_steps']
dev_score_history = []
best_dev_preds = []
current_lr = trainer.optimizer.param_groups[0]['lr']
global_start_time = time.time()
format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
# LR scheduling
if args['lr_decay'] > 0:
# learning rate changes on plateau -- no improvement on model for patience number of epochs
# change is made as a factor of the learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'],
patience=args['patience'], min_lr=args['min_lr'])
else:
scheduler = None
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_ner" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
# track gradients!
wandb.watch(trainer.model, log_freq=4, log="gradients")
# start training
last_best_step = 0
train_loss = 0
is_second_optim = False
while True:
should_stop = False
for i, batch in enumerate(train_batch):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
if global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
max_steps, loss, duration, current_lr))
if global_step % args['eval_interval'] == 0:
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
for batch in dev_batch:
preds = trainer.predict(batch)
dev_preds += preds
_, _, dev_score, _ = scorer.score_by_entity(dev_preds, dev_gold_tags, ignore_tags=args['ignore_tag_scores'])
train_loss = train_loss / args['eval_interval'] # avg loss per batch
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
train_loss = 0
# save best model
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
trainer.save(model_file)
last_best_step = global_step
logger.info("New best model saved.")
best_dev_preds = dev_preds
dev_score_history += [dev_score]
logger.info("")
# lr schedule
if scheduler is not None:
scheduler.step(dev_score)
if args['log_norms']:
trainer.model.log_norms()
# check stopping
current_lr = trainer.optimizer.param_groups[0]['lr']
if (global_step - last_best_step) >= args['max_steps_no_improve'] or global_step >= args['max_steps'] or current_lr <= args['min_lr']:
if (global_step - last_best_step) >= args['max_steps_no_improve']:
logger.info("{} steps without improvement...".format((global_step - last_best_step)))
if not is_second_optim and args['second_optim'] is not None:
logger.info("Switching to second optimizer: {}".format(args['second_optim']))
logger.info('Reloading best model to continue from current local optimum')
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
train_classifier_only=args['train_classifier_only'], model_file=model_file, second_optim=True)
is_second_optim = True
last_best_step = global_step
current_lr = trainer.optimizer.param_groups[0]['lr']
else:
logger.info("stopping...")
should_stop = True
break
if should_stop:
break
train_batch.reshuffle()
logger.info("Training ended with {} steps.".format(global_step))
if args['wandb']:
wandb.finish()
if len(dev_score_history) > 0:
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
else:
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)
return trainer
def write_ner_results(filename, batch, preds, predict_tagset):
if len(batch.tags) != len(preds):
raise ValueError("Unexpected batch vs pred lengths: %d vs %d" % (len(batch.tags), len(preds)))
with open(filename, "w", encoding="utf-8") as fout:
tag_idx = 0
for b in batch:
# b[0] is words, b[5] is orig_idx
# a namedtuple would make this cleaner without being much slower
text = utils.unsort(b[0], b[5])
for sentence in text:
# TODO: if we change the predict_tagset mechanism, will have to change this
sentence_gold = [x[predict_tagset] for x in batch.tags[tag_idx]]
sentence_pred = preds[tag_idx]
tag_idx += 1
for word, gold, pred in zip(sentence, sentence_gold, sentence_pred):
fout.write("%s\t%s\t%s\n" % (word, gold, pred))
fout.write("\n")
def evaluate(args):
# file paths
model_file = model_file_name(args)
loaded_args, trainer, vocab = load_model(args, model_file)
return evaluate_model(loaded_args, trainer, vocab, args['eval_file'])
def evaluate_model(loaded_args, trainer, vocab, eval_file):
if loaded_args['log_norms']:
trainer.model.log_norms()
model_file = os.path.join(loaded_args['save_dir'], loaded_args['save_name'])
logger.debug("Loaded model for eval from %s", model_file)
logger.debug("Using the %d tagset for evaluation", loaded_args['predict_tagset'])
# load data
logger.info("Loading data with batch size {}...".format(loaded_args['batch_size']))
with open(eval_file) as fin:
doc = Document(json.load(fin))
batch = DataLoader(doc, loaded_args['batch_size'], loaded_args, vocab=vocab, evaluation=True, bert_tokenizer=trainer.model.bert_tokenizer)
bioes_to_bio = loaded_args['train_scheme'] == 'bio' and loaded_args['scheme'] == 'bioes'
warn_missing_tags(trainer.vocab['tag'], batch.tags, "eval_file", bioes_to_bio=bioes_to_bio)
logger.info("Start evaluation...")
preds = []
for i, b in enumerate(batch):
preds += trainer.predict(b)
gold_tags = batch.tags
# TODO: might still want to add multiple layers of tag evaluation to the scorer
gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in gold_tags]
_, _, score, entity_f1 = scorer.score_by_entity(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
_, _, _, confusion = scorer.score_by_token(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
logger.info("Weighted f1 for non-O tokens: %5f", confusion_to_weighted_f1(confusion, exclude=["O"]))
logger.info("NER tagger score: %s %s %s %.2f", loaded_args['shorthand'], model_file, eval_file, score*100)
entity_f1_lines = ["%s: %.2f" % (x, y*100) for x, y in entity_f1.items()]
logger.info("NER Entity F1 scores:\n %s", "\n ".join(entity_f1_lines))
logger.info("NER token confusion matrix:\n{}".format(format_confusion(confusion)))
if loaded_args['eval_output_file']:
write_ner_results(loaded_args['eval_output_file'], batch, preds, trainer.args['predict_tagset'])
return confusion
def load_model(args, model_file):
# load model
charlm_args = {}
if 'charlm_forward_file' in args:
charlm_args['charlm_forward_file'] = args['charlm_forward_file']
if 'charlm_backward_file' in args:
charlm_args['charlm_backward_file'] = args['charlm_backward_file']
if args['predict_tagset'] is not None:
charlm_args['predict_tagset'] = args['predict_tagset']
pretrain = load_pretrain(args)
trainer = Trainer(args=charlm_args, model_file=model_file, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only'])
loaded_args, vocab = trainer.args, trainer.vocab
# load config
for k in args:
if k.endswith('_dir') or k.endswith('_file') or k in ['batch_size', 'ignore_tag_scores', 'log_norms', 'mode', 'scheme', 'shorthand']:
loaded_args[k] = args[k]
save_dir, save_name = os.path.split(model_file)
loaded_args['save_dir'] = save_dir
loaded_args['save_name'] = save_name
return loaded_args, trainer, vocab
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/parser.py
================================================
"""
Entry point for training and evaluating a dependency parser.
This implementation combines a deep biaffine graph-based parser with linearization and distance features.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
"""
"""
Training and evaluation for the parser.
"""
import io
import sys
import os
import copy
import shutil
import time
import argparse
import logging
import numpy as np
import random
import zipfile
import torch
from torch import nn, optim
import stanza.models.depparse.data as data
from stanza.models.depparse.data import DataLoader
from stanza.models.depparse.trainer import Trainer
from stanza.models.depparse import scorer
from stanza.models.common import utils
from stanza.models.common import pretrain
from stanza.models.common.data import augment_punct
from stanza.models.common.doc import *
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
from stanza.models.common.utils import log_training_args
from stanza.utils.conll import CoNLL
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/depparse', help='Root dir for saving models.')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors.')
parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time")
parser.add_argument('--output_latex', default=False, action='store_true', help='Output the per-relation table in Latex form')
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
parser.add_argument('--hidden_dim', type=int, default=400)
parser.add_argument('--char_hidden_dim', type=int, default=400)
parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)
parser.add_argument('--deep_biaff_output_dim', type=int, default=160)
# As an additional option, we implement arc embeddings
# described in https://arxiv.org/pdf/2501.09451
# Scaling Graph-Based Dependency Parsing with Arc Vectorization and Attention-Based Refinement
# Nicolas Floquet, Joseph Le Roux, Nadi Tomeh, Thierry Charnois
# Unfortunately, the current implementation and hyperparameters do not seem to help
# when combined with a transformer as the input embedding
# LAS Scores on a few dev sets, UD 2.17, averaged over 5 seeds
# This is with a version where the arc -> unlabeled is one layer, arc -> label is two layers
# Using two layers for the arc -> unlabeled hurts scores a bit more
# treebank w/ w/o
# en_ewt 93.46 93.47
# de_gsd 89.02 89.12
# it_vit 90.15 90.19
# However, this is without the transformer over the arcs, which is
# an important component of making the arcs more useful
parser.add_argument('--use_arc_embedding', action='store_true', default=False, help='Use arc embeddings, as per Scaling Graph-Based Dependency Parsing')
parser.add_argument('--no_use_arc_embedding', dest='use_arc_embedding', action='store_false', help="Don't use arc embeddings")
parser.add_argument('--word_emb_dim', type=int, default=75)
parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding. If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF)
parser.add_argument('--char_emb_dim', type=int, default=100)
parser.add_argument('--tag_emb_dim', type=int, default=50)
parser.add_argument('--no_upos', dest='use_upos', action='store_false', default=True, help="Don't use upos tags as part of the tag embedding")
parser.add_argument('--no_xpos', dest='use_xpos', action='store_false', default=True, help="Don't use xpos tags as part of the tag embedding")
parser.add_argument('--no_ufeats', dest='use_ufeats', action='store_false', default=True, help="Don't use ufeats as part of the tag embedding")
parser.add_argument('--transformed_dim', type=int, default=125)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--char_num_layers', type=int, default=1)
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
parser.add_argument('--word_dropout', type=float, default=0.33)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout")
parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout")
parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.")
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--second_bert_learning_rate', default=1e-3, type=float, help='Secondary stage transformer finetuning learning rate scale')
parser.add_argument('--bert_start_finetuning', default=200, type=int, help='When to start finetuning the transformer')
parser.add_argument('--bert_warmup_steps', default=200, type=int, help='How many steps for a linear warmup when finetuning the transformer')
parser.add_argument('--bert_weight_decay', default=0.0, type=float, help='Weight decay bert parameters by this much')
parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
parser.add_argument('--no_linearization', dest='linearization', action='store_false', help="Turn off linearization term.")
parser.add_argument('--no_distance', dest='distance', action='store_false', help="Turn off distance term.")
# Originally, we used a single adam optimizer, stopping after 1000 stalled iterations,
# with a couple other hyperparameters corresponding to: TODO
# --max_steps_before_stop 1000
# --beta2 0.95
# --lr 3e-3
# --weight_decay 0.0
# --optim adam
# --no_second_optim
# Later experiments found the current defaults helped the results
# on several different datasets (using a transformer as the input embedding)
# These experiements are averaged across 5 models,
# with multiple early stopping values as well
# 5 model dev avg LAS 1 stage 1 stage 2k 1 stage 4k 2 stage
# de_gsd 89.03 89.50 89.71 89.83
# en_ewt 93.47 93.69 93.74 93.89
# fi_tdt 92.16 92.56 92.69 93.15
# it_vit 90.12 90.37 90.44 90.60
# ta_ttb 71.26 71.39 71.45 72.19
# zh-hans_gsdsimp 85.47 85.69 85.76 85.89
#
# 5 model test avg LAS 1 stage 1 stage 2k 1 stage 4k 2 stage
# de_gsd 86.60 86.96 87.04 87.09
# en_ewt 93.37 93.51 93.55 93.72
# fi_tdt 92.56 92.92 93.10 93.47
# it_vit 90.51 90.74 90.75 90.88
# ta_ttb 68.22 68.27 68.42 69.06
# zh-hans_gsdsimp 85.66 85.92 86.04 86.34
#
# In addition to these experiments, we ran multiple alternate optimizer combinations, none of which
# were a clear improvement over AdaDelta+Adam
#
# rmsprop --weight_decay 1e-5 --lr 0.0001
# adamw --second_lr 0.0001
# madgrad --second_lr 0.00008
# 5 model dev avg LAS ada+adam rms+adam ada+adamw ada+madgrad
# de_gsd 89.83 89.80 89.67 89.55
# en_ewt 93.89 93.97 93.92 93.90
# fi_tdt 93.15 92.95 93.03 93.08
# it_vit 90.60 90.64 90.58 90.54
# ta_ttb 72.19 71.86 72.18 72.24
# zh-hans_gsdsimp 85.89 85.60 85.97 85.92
#
# 5 model test avg LAS ada+adam rms+adam ada+adamw ada+madgrad
# de_gsd 87.09 87.26 87.06 87.08
# en_ewt 93.72 93.73 93.75 93.73
# fi_tdt 93.47 93.30 93.43 93.44
# it_vit 90.88 90.95 90.90 90.85
# ta_ttb 69.06 68.45 69.05 69.26
# zh-hans_gsdsimp 86.34 85.86 86.27 86.23
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adadelta', help='sgd, adagrad, adam or adamax.')
parser.add_argument('--second_optim', type=str, default="adam", help='sgd, adagrad, adam or adamax.')
parser.add_argument('--no_second_optim', dest='second_optim', action='store_const', const=None, help="Don't use the second optimizer")
parser.add_argument('--lr', type=float, default=2.0, help='Learning rate')
parser.add_argument('--second_lr', type=float, default=0.0002, help='Secondary stage learning rate')
parser.add_argument('--weight_decay', type=float, default=0.00001, help='Weight decay for the first optimizer')
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--second_optim_start_step', type=int, default=10000, help='If set, switch to the second optimizer when stalled or at this step regardless of performance. Normally, the optimizer only switches when the dev scores have stalled for --max_steps_before_stop steps')
parser.add_argument('--second_warmup_steps', type=int, default=200, help="If set, give the 2nd optimizer a linear warmup. Idea being that the optimizer won't have a good grasp on the initial gradients and square gradients when it first starts")
parser.add_argument('--max_steps', type=int, default=50000)
parser.add_argument('--eval_interval', type=int, default=100)
parser.add_argument('--checkpoint_interval', type=int, default=500)
parser.add_argument('--max_steps_before_stop', type=int, default=2000)
parser.add_argument('--batch_size', type=int, default=5000)
parser.add_argument('--second_batch_size', type=int, default=None, help='Use a different batch size for the second optimizer. Can be relevant for models with different transformer finetuning settings between optimizers, for example, where the larger batch size is impossible for FT the transformer"')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
parser.add_argument('--save_dir', type=str, default='saved_models/depparse', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model")
parser.add_argument('--continue_from', type=str, default=None, help="File name to preload the model to continue training from")
parser.add_argument('--seed', type=int, default=1234)
add_peft_args(parser)
utils.add_device_args(parser)
parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 10%%')
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
parser.add_argument('--train_size', type=int, default=None, help='If specified, randomly select this many sentences from the training data')
return parser
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
resolve_peft_args(args, logger)
if args.wandb_name:
args.wandb = True
args = vars(args)
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running parser in {} mode".format(args['mode']))
if args['mode'] == 'train':
return train(args)
else:
return evaluate(args)
def model_file_name(args):
return utils.standard_model_file_name(args, "parser")
# TODO: refactor with everywhere
def load_pretrain(args):
pt = None
if args['pretrain']:
pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
if os.path.exists(pretrain_file):
vec_file = None
else:
vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
return pt
def predict_dataset(trainer, dev_batch):
dev_preds = []
if len(dev_batch) > 0:
for batch in dev_batch:
preds = trainer.predict(batch)
dev_preds += preds
dev_preds = utils.unsort(dev_preds, dev_batch.data_orig_idx)
return dev_preds
def train(args):
model_file = model_file_name(args)
utils.ensure_dir(os.path.split(model_file)[0])
# load pretrained vectors if needed
pretrain = load_pretrain(args)
args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff'])
# TODO: refactor. the exact same thing is done in the tagger
if args['charlm']:
if args['charlm_shorthand'] is None:
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
logger.info('Using pretrained contextualized char embedding')
if not args['charlm_forward_file']:
args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
if not args['charlm_backward_file']:
args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
utils.log_training_args(args, logger)
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
train_file = args['train_file']
if zipfile.is_zipfile(train_file):
logger.info("Decompressing %s" % train_file)
train_data = []
with zipfile.ZipFile(train_file) as zin:
for zipped_train_file in zin.namelist():
with zin.open(zipped_train_file) as fin:
logger.info("Reading %s from %s" % (zipped_train_file, train_file))
train_str = fin.read()
train_str = train_str.decode("utf-8")
train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)
logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data)))
train_data.extend(train_file_data)
else:
train_data, _, _ = CoNLL.conll2dict(input_file=args['train_file'])
logger.info("Train File {}, Data Size: {}".format(train_file, len(train_data)))
# possibly augment the training data with some amount of fake data
# based on the options chosen
logger.info("Original data size: {}".format(len(train_data)))
if args['train_size']:
if len(train_data) < args['train_size']:
random.shuffle(train_data)
train_data = train_data[:args['train_size']]
logger.info("Limiting training data to %d entries", len(train_data))
else:
logger.info("Train data less than %d already, not limiting train data", args['train_size'])
# build the training data once, before augmentation, so that random variation
# (which might be different based on the random seed)
# doesn't have an effect on the vocab being cut off at the word limit
# otherwise different models will have different vocabs
# based on how often the words were duplicated in the augmentation
# TODO: put the augmentation into the dataloader,
# such as is done with the POS or the tokenizer
train_doc = Document(train_data)
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, evaluation=False)
vocab = train_batch.vocab
train_data.extend(augment_punct(train_data, args['augment_nopunct'],
keep_original_sentences=False))
logger.info("Augmented data size: {}".format(len(train_data)))
train_doc = Document(train_data)
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False)
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
logger.info("Skip training because no data available...")
sys.exit(0)
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_depparse" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
logger.info("Training parser...")
checkpoint_file = None
if args.get("checkpoint"):
# calculate checkpoint file name from the save filename
checkpoint_file = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
args["checkpoint_save_name"] = checkpoint_file
if args.get("checkpoint") and os.path.exists(args["checkpoint_save_name"]):
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["checkpoint_save_name"], device=args['device'], ignore_model_config=True)
if len(trainer.dev_score_history) > 0:
logger.info("Continuing from checkpoint %s Model was previously trained for %d steps, with a best dev score of %.4f", args["checkpoint_save_name"], trainer.global_step, max(trainer.dev_score_history))
elif args["continue_from"]:
if not os.path.exists(args["continue_from"]):
raise FileNotFoundError("--continue_from specified, but the file %s does not exist" % args["continue_from"])
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["continue_from"], device=args['device'], ignore_model_config=True, reset_history=True)
else:
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])
max_steps = args['max_steps']
current_lr = args['lr']
global_start_time = time.time()
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
is_second_stage = False
# start training
train_loss = 0
if args['log_norms']:
trainer.model.log_norms()
while True:
do_break = False
for i, batch in enumerate(train_batch):
start_time = time.time()
trainer.global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
# will checkpoint if we switch optimizers or score a new best score
force_checkpoint = False
if trainer.global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(trainer.global_step, max_steps, loss, duration, current_lr))
if trainer.global_step % args['eval_interval'] == 0:
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = predict_dataset(trainer, dev_batch)
dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'])
train_loss = train_loss / args['eval_interval'] # avg loss per batch
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(trainer.global_step, train_loss, dev_score))
if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
train_loss = 0
# save best model
trainer.dev_score_history += [dev_score]
if dev_score >= max(trainer.dev_score_history):
trainer.last_best_step = trainer.global_step
trainer.save(model_file)
logger.info("new best model saved.")
force_checkpoint = True
for scheduler_name, scheduler in trainer.scheduler.items():
logger.info('scheduler %s learning rate: %s', scheduler_name, scheduler.get_last_lr())
if args['log_norms']:
trainer.model.log_norms()
if not is_second_stage and args.get('second_optim', None) is not None:
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and trainer.global_step >= args['second_optim_start_step']):
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
global_step = trainer.global_step
args["second_stage"] = True
# if the loader gets a model file, it uses secondary optimizer
# (because of the second_stage = True argument)
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain,
model_file=model_file, device=args['device'])
logger.info('Reloading best model to continue from current local optimum')
dev_preds = predict_dataset(trainer, dev_batch)
dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Reloaded model with dev score %.4f", dev_score)
is_second_stage = True
trainer.global_step = global_step
trainer.last_best_step = global_step
if args['second_batch_size'] is not None:
train_batch.set_batch_size(args['second_batch_size'])
force_checkpoint = True
else:
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop']:
do_break = True
break
if trainer.global_step % args['eval_interval'] == 0 or force_checkpoint:
# if we need to save checkpoint, do so
# (save after switching the optimizer, if applicable, so that
# the new optimizer is the optimizer used if a restart happens)
if checkpoint_file is not None:
trainer.save(checkpoint_file, save_optimizer=True)
logger.info("new model checkpoint saved.")
if trainer.global_step >= args['max_steps']:
do_break = True
break
if do_break: break
train_batch.reshuffle()
logger.info("Training ended with {} steps.".format(trainer.global_step))
if args['wandb']:
wandb.finish()
if len(trainer.dev_score_history) > 0:
# TODO: technically the iteration position will be wrong if
# the eval_interval changed when running from a checkpoint
# could fix this by saving step & score instead of just score
best_f, best_eval = max(trainer.dev_score_history)*100, np.argmax(trainer.dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
else:
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)
return trainer, _
def evaluate(args):
model_file = model_file_name(args)
# load pretrained vectors if needed
pretrain = load_pretrain(args)
load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),
'charlm_backward_file': args.get('charlm_backward_file', None)}
# load model
logger.info("Loading model from: {}".format(model_file))
trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
if args['log_norms']:
trainer.model.log_norms()
return trainer, evaluate_trainer(args, trainer, pretrain)
def evaluate_trainer(args, trainer, pretrain):
system_pred_file = args['output_file']
loaded_args, vocab = trainer.args, trainer.vocab
# load config
for k in args:
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':
loaded_args[k] = args[k]
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
doc = CoNLL.conll2doc(input_file=args['eval_file'])
batch = DataLoader(doc, args['batch_size'], loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
preds = predict_dataset(trainer, batch)
# write to file and score
batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x])
if system_pred_file:
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if args['gold_labels']:
gold_doc = CoNLL.conll2doc(input_file=args['eval_file'])
# Check for None ... otherwise an inscrutable error occurs later in the scorer
for sent_idx, sentence in enumerate(gold_doc.sentences):
for word_idx, word in enumerate(sentence.words):
if word.deprel is None:
raise ValueError("Gold document {} has a None at sentence {} word {}\n{:C}".format(args['eval_file'], sent_idx, word_idx, sentence))
scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex'])
system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Parser score on %s file %s: %.2f", args['shorthand'], args['eval_file'], score*100)
return batch.doc
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/pos/__init__.py
================================================
================================================
FILE: stanza/models/pos/build_xpos_vocab_factory.py
================================================
import argparse
from collections import defaultdict
import logging
import os
import re
import sys
from zipfile import ZipFile
from stanza.models.common.constant import treebank_to_short_name
from stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType
from stanza.models.common.doc import *
from stanza.utils.conll import CoNLL
from stanza.utils import default_paths
SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
DATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR']
logger = logging.getLogger('stanza')
def get_xpos_factory(shorthand, fn):
logger.info('Resolving vocab option for {}...'.format(shorthand))
doc = None
train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand))
if os.path.exists(train_file):
doc = CoNLL.conll2doc(input_file=train_file)
else:
zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand))
if os.path.exists(zip_file):
with ZipFile(zip_file) as zin:
for train_file in zin.namelist():
doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file)
if any(word.xpos for sentence in doc.sentences for word in sentence.words):
break
else:
raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file))
if doc is None:
raise FileNotFoundError('Training data for {} not found. To generate the XPOS vocabulary '
'for this treebank properly, please run the following command first:\n'
' python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))
# without the training file, there's not much we can do
key = DEFAULT_KEY
return key
data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
return choose_simplest_factory(data, shorthand)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--treebanks', type=str, default=DATA_DIR, help="Treebanks to process - directory with processed datasets or a file with a list")
parser.add_argument('--output_file', type=str, default="stanza/models/pos/xpos_vocab_factory.py", help="Where to write the results")
args = parser.parse_args()
output_file = args.output_file
if os.path.isdir(args.treebanks):
# if the path is a directory of datasets (which is the default if --treebanks is not set)
# we use those datasets to prepare the xpos factories
treebanks = os.listdir(args.treebanks)
treebanks = [x.split(".", maxsplit=1)[0] for x in treebanks]
treebanks = sorted(set(treebanks))
elif os.path.exists(args.treebanks):
# maybe it's a file with a list of names
with open(args.treebanks) as fin:
treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()]))
else:
raise ValueError("Cannot figure out which treebanks to use. Please set the --treebanks parameter")
logger.info("Processing the following treebanks: %s" % " ".join(treebanks))
shorthands = []
fullnames = []
for treebank in treebanks:
fullnames.append(treebank)
if SHORTNAME_RE.match(treebank):
shorthands.append(treebank)
else:
shorthands.append(treebank_to_short_name(treebank))
# For each treebank, we would like to find the XPOS Vocab configuration that minimizes
# the number of total classes needed to predict by all tagger classifiers. This is
# achieved by enumerating different options of separators that different treebanks might
# use, and comparing that to treating the XPOS tags as separate categories (using a
# WordVocab).
mapping = defaultdict(list)
for sh, fn in zip(shorthands, fullnames):
factory = get_xpos_factory(sh, fn)
mapping[factory].append(sh)
if sh == 'zh-hans_gsdsimp':
mapping[factory].append('zh_gsdsimp')
elif sh == 'no_bokmaal':
mapping[factory].append('nb_bokmaal')
mapping[DEFAULT_KEY].append('en_test')
# Generate code. This takes the XPOS vocabulary classes selected above, and generates the
# actual factory class as seen in models.pos.xpos_vocab_factory.
first = True
with open(output_file, 'w') as f:
max_len = max(max(len(x) for x in mapping[key]) for key in mapping)
print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
# Please don't edit it!
import logging
from stanza.models.pos.vocab import WordVocab, XPOSVocab
from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
# using a sublogger makes it easier to test in the unittests
logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
XPOS_DESCRIPTIONS = {''', file=f)
for key_idx, key in enumerate(mapping):
if key_idx > 0:
print(file=f)
for shorthand in sorted(mapping[key]):
# +2 to max_len for the ''
# this format string is left justified (either would be okay, probably)
if key.sep is None:
sep = 'None'
else:
sep = "'%s'" % key.sep
print((" {:%ds}: XPOSDescription({}, {})," % (max_len+2)).format("'%s'" % shorthand, key.xpos_type, sep), file=f)
print('''}
def xpos_vocab_factory(data, shorthand):
if shorthand not in XPOS_DESCRIPTIONS:
logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
desc = choose_simplest_factory(data, shorthand)
if shorthand in XPOS_DESCRIPTIONS:
if XPOS_DESCRIPTIONS[shorthand] != desc:
# log instead of throw
# otherwise, updating datasets would be unpleasant
logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
else:
logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
return build_xpos_vocab(desc, data, shorthand)
''', file=f)
logger.info('Done!')
if __name__ == "__main__":
main()
================================================
FILE: stanza/models/pos/data.py
================================================
import random
import logging
import copy
import torch
from collections import namedtuple
from torch.utils.data import DataLoader as DL
from torch.utils.data.sampler import Sampler
from torch.nn.utils.rnn import pad_sequence
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab
from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
from stanza.models.common.doc import *
logger = logging.getLogger('stanza')
DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")
class Dataset:
def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
self.args = args
self.eval = evaluation
self.shuffled = not self.eval
self.sort_during_eval = sort_during_eval
self.doc = doc
if vocab is None:
self.vocab = Dataset.init_vocab([doc], args)
else:
self.vocab = vocab
self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False))
self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False))
self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False))
data = self.load_doc(self.doc)
# filter out the long sentences if bert is used
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
# handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
self.pretrain_vocab = None
if pretrain is not None and args['pretrain']:
self.pretrain_vocab = pretrain.vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
self.data = data
self.num_examples = len(data)
self.__punct_tags = self.vocab["upos"].map(["PUNCT"])
self.augment_nopunct = self.args.get("augment_nopunct", 0.0)
@staticmethod
def init_vocab(docs, args):
cutoff = args['word_cutoff'] if args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF
data = [x for doc in docs for x in Dataset.load_doc(doc)]
charvocab = CharVocab(data, args['shorthand'])
wordvocab = WordVocab(data, args['shorthand'], cutoff=cutoff, lower=True)
uposvocab = WordVocab(data, args['shorthand'], idx=1)
xposvocab = xpos_vocab_factory(data, args['shorthand'])
try:
featsvocab = FeatureVocab(data, args['shorthand'], idx=3)
except ValueError as e:
raise ValueError("Unable to build features vocab. Please check the Features column of your data for an error which may match the following description.") from e
vocab = MultiVocab({'char': charvocab,
'word': wordvocab,
'upos': uposvocab,
'xpos': xposvocab,
'feats': featsvocab})
return vocab
def preprocess(self, data, vocab, pretrain_vocab, args):
processed = []
for sent in data:
processed_sent = DataSample(
word = [vocab['word'].map([w[0] for w in sent])],
char = [[vocab['char'].map([x for x in w[0]]) for w in sent]],
upos = [vocab['upos'].map([w[1] for w in sent])],
xpos = [vocab['xpos'].map([w[2] for w in sent])],
feats = [vocab['feats'].map([w[3] for w in sent])],
pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
if pretrain_vocab is not None
else [[PAD_ID] * len(sent)]),
text = [w[0] for w in sent]
)
processed.append(processed_sent)
return processed
def __len__(self):
return len(self.data)
def __mask(self, upos):
"""Returns a torch boolean about which elements should be masked out"""
# creates all false mask
mask = torch.zeros_like(upos, dtype=torch.bool)
### augmentation 1: punctuation augmentation ###
# tags that needs to be checked, currently only PUNCT
if random.uniform(0,1) < self.augment_nopunct:
for i in self.__punct_tags:
# generate a mask for the last element
last_element = torch.zeros_like(upos, dtype=torch.bool)
last_element[..., -1] = True
# we or the bitmask against the existing mask
# if it satisfies, we remove the word by masking it
# to true
#
# if your input is just a lone punctuation, we perform
# no masking
if not torch.all(upos.eq(torch.tensor([[i]]))):
mask |= ((upos == i) & (last_element))
return mask
def __getitem__(self, key):
"""Retrieves a sample from the dataset.
Retrieves a sample from the dataset. This function, for the
most part, is spent performing ad-hoc data augmentation and
restoration. It receives a DataSample object from the storage,
and returns an almost-identical DataSample object that may
have been augmented with /possibly/ (depending on augment_punct
settings) PUNCT chopped.
**Important Note**
------------------
If you would like to load the data into a model, please convert
this Dataset object into a DataLoader via self.to_loader(). Then,
you can use the resulting object like any other PyTorch data
loader. As masks are calculated ad-hoc given the batch, the samples
returned from this object doesn't have the appropriate masking.
Motivation
----------
Why is this here? Every time you call next(iter(dataloader)), it calls
this function. Therefore, if we augmented each sample on each iteration,
the model will see dynamically generated augmentation.
Furthermore, PyTorch dataloader handles shuffling natively.
Parameters
----------
key : int
the integer ID to from which to retrieve the key.
Returns
-------
DataSample
The sample of data you requested, with augmentation.
"""
# get a sample of the input data
sample = self.data[key]
# some data augmentation requires constructing a mask based on upos.
# For instance, sometimes we'd like to mask out ending sentence punctuation.
# We copy the other items here so that any edits made because
# of the mask don't clobber the version owned by the Dataset
# convert to tensors
# TODO: only store single lists per data entry?
words = torch.tensor(sample.word[0])
# convert the rest to tensors
upos = torch.tensor(sample.upos[0]) if self.has_upos else None
xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
pretrained = torch.tensor(sample.pretrain[0])
# and deal with char & raw_text
char = sample.char[0]
raw_text = sample.text
# some data augmentation requires constructing a mask based on
# which upos. For instance, sometimes we'd like to mask out ending
# sentence punctuation. The mask is True if we want to remove the element
if self.has_upos and upos is not None and not self.eval:
# perform actual masking
mask = self.__mask(upos)
else:
# dummy mask that's all false
mask = None
if mask is not None:
mask_index = mask.nonzero()
# mask out the elements that we need to mask out
for mask in mask_index:
mask = mask.item()
words[mask] = PAD_ID
if upos is not None:
upos[mask] = PAD_ID
if xpos is not None:
# TODO: test the multi-dimension xpos
xpos[mask, ...] = PAD_ID
if ufeats is not None:
ufeats[mask, ...] = PAD_ID
pretrained[mask] = PAD_ID
char = char[:mask] + char[mask+1:]
raw_text = raw_text[:mask] + raw_text[mask+1:]
# get each character from the input sentnece
# chars = [w for sent in char for w in sent]
return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def to_loader(self, **kwargs):
"""Converts self to a DataLoader """
return DL(self,
collate_fn=Dataset.__collate_fn,
**kwargs)
def to_length_limited_loader(self, batch_size, maximum_tokens):
sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens)
return DL(self,
collate_fn=Dataset.__collate_fn,
batch_sampler = sampler)
@staticmethod
def __collate_fn(data):
"""Function used by DataLoader to pack data"""
(data, idx) = zip(*data)
(words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)
# collate_fn is given a list of length batch size
batch_size = len(data)
# sort sentences by lens for easy RNN operations
lens = [torch.sum(x != PAD_ID) for x in words]
(words, wordchars, upos, xpos,
ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
ufeats, pretrained, text), lens)
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN
# combine all words into one large list, and sort for easy charRNN ops
wordchars = [w for sent in wordchars for w in sent]
word_lens = [len(x) for x in wordchars]
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN
# We now pad everything
words = pad_sequence(words, True, PAD_ID)
if None not in upos:
upos = pad_sequence(upos, True, PAD_ID)
else:
upos = None
if None not in xpos:
xpos = pad_sequence(xpos, True, PAD_ID)
else:
xpos = None
if None not in ufeats:
ufeats = pad_sequence(ufeats, True, PAD_ID)
else:
ufeats = None
pretrained = pad_sequence(pretrained, True, PAD_ID)
wordchars = get_long_tensor(wordchars, len(word_lens))
# and finally create masks for the padding indices
words_mask = torch.eq(words, PAD_ID)
wordchars_mask = torch.eq(wordchars, PAD_ID)
return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
@staticmethod
def load_doc(doc):
data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
data = Dataset.resolve_none(data)
data = simplify_punct(data)
return data
@staticmethod
def resolve_none(data):
# replace None to '_'
for sent_idx in range(len(data)):
for tok_idx in range(len(data[sent_idx])):
for feat_idx in range(len(data[sent_idx][tok_idx])):
if data[sent_idx][tok_idx][feat_idx] is None:
data[sent_idx][tok_idx][feat_idx] = '_'
return data
class LengthLimitedBatchSampler(Sampler):
"""
Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens
Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected,
leaving a batch too large to fit in the GPU
Sentences which are longer than maximum_tokens by themselves are put in their own batches
"""
def __init__(self, data, batch_size, maximum_tokens):
"""
Precalculate the batches, making it so len and iter just read off the precalculated batches
"""
self.data = data
self.batch_size = batch_size
self.maximum_tokens = maximum_tokens
self.batches = []
current_batch = []
current_length = 0
for item, item_idx in data:
item_len = len(item.word)
if maximum_tokens and item_len > maximum_tokens:
if len(current_batch) > 0:
self.batches.append(current_batch)
current_batch = []
current_length = 0
self.batches.append([item_idx])
continue
if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens):
self.batches.append(current_batch)
current_batch = []
current_length = 0
current_batch.append(item_idx)
current_length += item_len
if len(current_batch) > 0:
self.batches.append(current_batch)
def __len__(self):
return len(self.batches)
def __iter__(self):
for batch in self.batches:
current_batch = []
for idx in batch:
current_batch.append(idx)
yield current_batch
class ShuffledDataset:
"""A wrapper around one or more datasets which shuffles the data in batch_size chunks
This means that if multiple datasets are passed in, the batches
from each dataset are shuffled together, with one batch being
entirely members of the same dataset.
The main use case of this is that in the tagger, there are cases
where batches from different datasets will have different
properties, such as having or not having UPOS tags. We found that
it is actually somewhat tricky to make the model's loss function
(in model.py) properly represent batches with mixed w/ and w/o
property, whereas keeping one entire batch together makes it a lot
easier to process.
The mechanism for the shuffling is that the iterator first makes a
list long enough to represent each batch from each dataset,
tracking the index of the dataset it is coming from, then shuffles
that list. Another alternative would be to use a weighted
randomization approach, but this is very simple and the memory
requirements are not too onerous.
Note that the batch indices are wasteful in the case of only one
underlying dataset, which is actually the most common use case,
but the overhead is small enough that it probably isn't worth
special casing the one dataset version.
"""
def __init__(self, datasets, batch_size):
self.batch_size = batch_size
self.datasets = datasets
self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]
def __iter__(self):
iterators = [iter(x) for x in self.loaders]
lengths = [len(x) for x in self.loaders]
indices = [[x] * y for x, y in enumerate(lengths)]
indices = [idx for inner in indices for idx in inner]
random.shuffle(indices)
for idx in indices:
yield(next(iterators[idx]))
def __len__(self):
return sum(len(x) for x in self.datasets)
================================================
FILE: stanza/models/pos/model.py
================================================
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.common.biaffine import BiaffineScorer
from stanza.models.common.foundation_cache import load_bert, load_charlm
from stanza.models.common.hlstm import HighwayLSTM
from stanza.models.common.dropout import WordDropout
from stanza.models.common.utils import attach_bert_model
from stanza.models.common.vocab import CompositeVocab
from stanza.models.common.char_model import CharacterModel
from stanza.models.common import utils
logger = logging.getLogger('stanza')
class Tagger(nn.Module):
def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
super().__init__()
self.vocab = vocab
self.args = args
self.share_hid = share_hid
self.unsaved_modules = []
# input layers
input_size = 0
if self.args['word_emb_dim'] > 0:
# frequent word embeddings
self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
input_size += self.args['word_emb_dim']
if not share_hid:
# upos embeddings
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args.get('charlm', None):
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
logger.debug("POS model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
# optionally add a input transformation layer
if self.args.get('charlm_transform_dim', 0):
self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
input_size += self.args['charlm_transform_dim'] * 2
else:
self.charmodel_forward_transform = None
self.charmodel_backward_transform = None
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
else:
bidirectional = args.get('char_bidirectional', False)
self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional)
if bidirectional:
self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False)
else:
self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
input_size += self.args['transformed_dim']
self.peft_name = peft_name
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
if self.args.get('bert_model', None):
# TODO: refactor bert_hidden_layers between the different models
if args.get('bert_hidden_layers', False):
# The average will be offset by 1/N so that the default zeros
# represents an average of the N layers
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
input_size += self.bert_model.config.hidden_size
if self.args['pretrain']:
# pretrained embeddings, by default this won't be saved into model file
self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
input_size += self.args['transformed_dim']
# recurrent layers
self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
# classifiers
self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'])
self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos']))
self.upos_clf.weight.data.zero_()
self.upos_clf.bias.data.zero_()
if share_hid:
clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize)
else:
self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim'])
self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim'])
clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize)
if isinstance(vocab['xpos'], CompositeVocab):
self.xpos_clf = nn.ModuleList()
for l in vocab['xpos'].lens():
self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
else:
self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos']))
if share_hid:
self.xpos_clf.weight.data.zero_()
self.xpos_clf.bias.data.zero_()
self.ufeats_clf = nn.ModuleList()
for l in vocab['feats'].lens():
if share_hid:
self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l))
self.ufeats_clf[-1].weight.data.zero_()
self.ufeats_clf[-1].bias.data.zero_()
else:
self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
# criterion
self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding
self.drop = nn.Dropout(args['dropout'])
self.worddrop = WordDropout(args['word_dropout'])
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def log_norms(self):
utils.log_norms(self)
def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text):
def pack(x):
return pack_padded_sequence(x, sentlens, batch_first=True)
inputs = []
if self.args['word_emb_dim'] > 0:
word_emb = self.word_emb(word)
word_emb = pack(word_emb)
inputs += [word_emb]
if self.args['pretrain']:
pretrained_emb = self.pretrained_emb(pretrained)
pretrained_emb = self.trans_pretrained(pretrained_emb)
pretrained_emb = pack(pretrained_emb)
inputs += [pretrained_emb]
def pad(x):
return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0]
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args.get('charlm', None):
all_forward_chars = self.charmodel_forward.build_char_representation(text)
assert isinstance(all_forward_chars, list)
if self.charmodel_forward_transform is not None:
all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars]
all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
all_backward_chars = self.charmodel_backward.build_char_representation(text)
if self.charmodel_backward_transform is not None:
all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars]
all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
inputs += [all_forward_chars, all_backward_chars]
else:
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
inputs += [char_reps]
if self.bert_model is not None:
device = next(self.parameters()).device
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.args.get('bert_finetune', False) or not self.training,
peft_name=self.peft_name)
if self.bert_layer_mix is not None:
# add the average so that the default behavior is to
# take an average of the N layers, and anything else
# other than that needs to be learned
# TODO: refactor this
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
processed_bert = pad_sequence(processed_bert, batch_first=True)
inputs += [pack(processed_bert)]
lstm_inputs = torch.cat([x.data for x in inputs], 1)
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
lstm_inputs = self.drop(lstm_inputs)
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
lstm_outputs = lstm_outputs.data
upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs)))
upos_pred = self.upos_clf(self.drop(upos_hid))
preds = [pad(upos_pred).max(2)[1]]
if upos is not None:
upos = pack(upos).data
loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))
else:
loss = 0.0
if self.share_hid:
xpos_hid = upos_hid
ufeats_hid = upos_hid
clffunc = lambda clf, hid: clf(self.drop(hid))
else:
xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs)))
ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs)))
if self.training and upos is not None:
upos_emb = self.upos_emb(upos)
else:
upos_emb = self.upos_emb(upos_pred.max(1)[1])
clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb))
if xpos is not None: xpos = pack(xpos).data
if isinstance(self.vocab['xpos'], CompositeVocab):
xpos_preds = []
for i in range(len(self.vocab['xpos'])):
xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)
if xpos is not None:
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))
xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])
preds.append(torch.cat(xpos_preds, 2))
else:
xpos_pred = clffunc(self.xpos_clf, xpos_hid)
if xpos is not None:
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))
preds.append(pad(xpos_pred).max(2)[1])
ufeats_preds = []
if ufeats is not None: ufeats = pack(ufeats).data
for i in range(len(self.vocab['feats'])):
ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)
if ufeats is not None:
loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))
ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])
preds.append(torch.cat(ufeats_preds, 2))
return loss, preds
================================================
FILE: stanza/models/pos/scorer.py
================================================
"""
Utils and wrappers for scoring taggers.
"""
import logging
from stanza.models.common.utils import ud_scores
logger = logging.getLogger('stanza')
def score(system_conllu_file, gold_conllu_file, verbose=True, eval_type='AllTags'):
""" Wrapper for tagger scorer. """
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
el = evaluation[eval_type]
p = el.precision
r = el.recall
f = el.f1
if verbose:
scores = [evaluation[k].f1 * 100 for k in ['UPOS', 'XPOS', 'UFeats', 'AllTags']]
logger.info("UPOS\tXPOS\tUFeats\tAllTags")
logger.info("{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}".format(*scores))
return p, r, f
================================================
FILE: stanza/models/pos/trainer.py
================================================
"""
A trainer class to handle training and testing of models.
"""
import sys
import logging
import torch
from torch import nn
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common import utils, loss
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.pos.model import Tagger
from stanza.models.pos.vocab import MultiVocab
logger = logging.getLogger('stanza')
def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [b.to(device) if b is not None else None for b in batch[:8]]
orig_idx = batch[8]
word_orig_idx = batch[9]
sentlens = batch[10]
wordlens = batch[11]
text = batch[12]
return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):
if model_file is not None:
# load everything from file
self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache)
else:
# build model from scratch
self.args = args
self.vocab = vocab
bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
peft_name = None
if self.args['use_peft']:
# fine tune the bert if we're using peft
self.args['bert_finetune'] = True
peft_name = "pos"
bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
self.model = self.model.to(device)
self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("peft", False))
self.schedulers = {}
if self.args.get('bert_finetune', None):
import transformers
warmup_scheduler = transformers.get_linear_schedule_with_warmup(
self.optimizers["bert_optimizer"],
# todo late starting?
0, self.args["max_steps"])
self.schedulers["bert_scheduler"] = warmup_scheduler
def update(self, batch, eval=False):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
if eval:
self.model.eval()
else:
self.model.train()
for optimizer in self.optimizers.values():
optimizer.zero_grad()
loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
if loss == 0.0:
return loss
loss_val = loss.data.item()
if eval:
return loss_val
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
for optimizer in self.optimizers.values():
optimizer.step()
for scheduler in self.schedulers.values():
scheduler.step()
return loss_val
def predict(self, batch, unsort=True):
device = next(self.model.parameters()).device
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
self.model.eval()
batch_size = word.size(0)
_, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()]
xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()]
feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()]
pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]
if unsort:
pred_tokens = utils.unsort(pred_tokens, orig_idx)
return pred_tokens
def save(self, filename, skip_modules=True):
model_state = self.model.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
'config': self.args
}
if self.args.get('use_peft', False):
# Hide import so that peft dependency is optional
from peft import get_peft_model_state_dict
params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
logger.warning(f"Saving failed... {e} continuing anyway.")
def load(self, filename, pretrain, args=None, foundation_cache=None):
"""
Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
"""
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None: self.args.update(args)
# preserve old models which were created before transformers were added
if 'bert_model' not in self.args:
self.args['bert_model'] = None
lora_weights = checkpoint.get('bert_lora')
if lora_weights:
logger.debug("Found peft weights for POS; loading a peft adapter")
self.args["use_peft"] = True
# TODO: refactor this common block of code with NER
force_bert_saved = False
peft_name = None
if self.args.get('use_peft', False):
force_bert_saved = True
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "pos", foundation_cache)
bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
logger.debug("Loaded peft with name %s", peft_name)
else:
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
force_bert_saved = True
bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
# load model
emb_matrix = None
if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
emb_matrix = pretrain.emb
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
self.model.load_state_dict(checkpoint['model'], strict=False)
================================================
FILE: stanza/models/pos/vocab.py
================================================
from collections import Counter, OrderedDict
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab
from stanza.models.common.vocab import CompositeVocab, VOCAB_PREFIX, EMPTY, EMPTY_ID
class WordVocab(BaseVocab):
def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False, ignore=None):
self.ignore = ignore if ignore is not None else []
super().__init__(data, lang=lang, idx=idx, cutoff=cutoff, lower=lower)
self.state_attrs += ['ignore']
def id2unit(self, id):
if len(self.ignore) > 0 and id == EMPTY_ID:
return '_'
else:
return super().id2unit(id)
def unit2id(self, unit):
if len(self.ignore) > 0 and unit in self.ignore:
return self._unit2id[EMPTY]
else:
return super().unit2id(unit)
def build_vocab(self):
if self.lower:
counter = Counter([w[self.idx].lower() for sent in self.data for w in sent])
else:
counter = Counter([w[self.idx] for sent in self.data for w in sent])
for k in list(counter.keys()):
if counter[k] < self.cutoff or k in self.ignore:
del counter[k]
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
def __iter__(self):
# the EMPTY shenanigans above make list() look really weird
# when using the __len__ / __getitem__ paradigm,
# but yielding items like this works fine
for x in self._id2unit:
yield x
def __str__(self):
return "<{}: {}>".format(type(self), ",".join("|%s|" % x for x in self._id2unit))
class XPOSVocab(CompositeVocab):
def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
class FeatureVocab(CompositeVocab):
def __init__(self, data=None, lang="", idx=0, sep="|", keyed=True):
super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
class MultiVocab(BaseMultiVocab):
def state_dict(self):
""" Also save a vocab name to class name mapping in state dict. """
state = OrderedDict()
key2class = OrderedDict()
for k, v in self._vocabs.items():
state[k] = v.state_dict()
key2class[k] = type(v).__name__
state['_key2class'] = key2class
return state
@classmethod
def load_state_dict(cls, state_dict):
class_dict = {'CharVocab': CharVocab,
'WordVocab': WordVocab,
'XPOSVocab': XPOSVocab,
'FeatureVocab': FeatureVocab}
new = cls()
assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
key2class = state_dict['_key2class']
for k,v in state_dict.items():
if k == '_key2class':
continue
classname = key2class[k]
new[k] = class_dict[classname].load_state_dict(v)
return new
================================================
FILE: stanza/models/pos/xpos_vocab_factory.py
================================================
# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
# Please don't edit it!
import logging
from stanza.models.pos.vocab import WordVocab, XPOSVocab
from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
# using a sublogger makes it easier to test in the unittests
logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
XPOS_DESCRIPTIONS = {
'af_afribooms' : XPOSDescription(XPOSType.XPOS, ''),
'ar_padt' : XPOSDescription(XPOSType.XPOS, ''),
'bg_btb' : XPOSDescription(XPOSType.XPOS, ''),
'ca_ancora' : XPOSDescription(XPOSType.XPOS, ''),
'cs_cac' : XPOSDescription(XPOSType.XPOS, ''),
'cs_cltt' : XPOSDescription(XPOSType.XPOS, ''),
'cs_fictree' : XPOSDescription(XPOSType.XPOS, ''),
'cs_pdt' : XPOSDescription(XPOSType.XPOS, ''),
'en_partut' : XPOSDescription(XPOSType.XPOS, ''),
'es_ancora' : XPOSDescription(XPOSType.XPOS, ''),
'es_combined' : XPOSDescription(XPOSType.XPOS, ''),
'fr_partut' : XPOSDescription(XPOSType.XPOS, ''),
'gd_arcosg' : XPOSDescription(XPOSType.XPOS, ''),
'gl_ctg' : XPOSDescription(XPOSType.XPOS, ''),
'gl_treegal' : XPOSDescription(XPOSType.XPOS, ''),
'grc_perseus' : XPOSDescription(XPOSType.XPOS, ''),
'hr_set' : XPOSDescription(XPOSType.XPOS, ''),
'is_gc' : XPOSDescription(XPOSType.XPOS, ''),
'is_icepahc' : XPOSDescription(XPOSType.XPOS, ''),
'is_modern' : XPOSDescription(XPOSType.XPOS, ''),
'it_combined' : XPOSDescription(XPOSType.XPOS, ''),
'it_isdt' : XPOSDescription(XPOSType.XPOS, ''),
'it_markit' : XPOSDescription(XPOSType.XPOS, ''),
'it_parlamint' : XPOSDescription(XPOSType.XPOS, ''),
'it_partut' : XPOSDescription(XPOSType.XPOS, ''),
'it_postwita' : XPOSDescription(XPOSType.XPOS, ''),
'it_twittiro' : XPOSDescription(XPOSType.XPOS, ''),
'it_vit' : XPOSDescription(XPOSType.XPOS, ''),
'la_perseus' : XPOSDescription(XPOSType.XPOS, ''),
'la_udante' : XPOSDescription(XPOSType.XPOS, ''),
'lt_alksnis' : XPOSDescription(XPOSType.XPOS, ''),
'lv_lvtb' : XPOSDescription(XPOSType.XPOS, ''),
'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''),
'ro_rrt' : XPOSDescription(XPOSType.XPOS, ''),
'ro_simonero' : XPOSDescription(XPOSType.XPOS, ''),
'sk_snk' : XPOSDescription(XPOSType.XPOS, ''),
'sl_ssj' : XPOSDescription(XPOSType.XPOS, ''),
'sl_sst' : XPOSDescription(XPOSType.XPOS, ''),
'sr_set' : XPOSDescription(XPOSType.XPOS, ''),
'ta_ttb' : XPOSDescription(XPOSType.XPOS, ''),
'uk_iu' : XPOSDescription(XPOSType.XPOS, ''),
'be_hse' : XPOSDescription(XPOSType.WORD, None),
'bxr_bdt' : XPOSDescription(XPOSType.WORD, None),
'cop_scriptorium': XPOSDescription(XPOSType.WORD, None),
'cu_proiel' : XPOSDescription(XPOSType.WORD, None),
'cy_ccg' : XPOSDescription(XPOSType.WORD, None),
'da_ddt' : XPOSDescription(XPOSType.WORD, None),
'de_gsd' : XPOSDescription(XPOSType.WORD, None),
'de_hdt' : XPOSDescription(XPOSType.WORD, None),
'el_gdt' : XPOSDescription(XPOSType.WORD, None),
'el_gud' : XPOSDescription(XPOSType.WORD, None),
'en_atis' : XPOSDescription(XPOSType.WORD, None),
'en_combined' : XPOSDescription(XPOSType.WORD, None),
'en_craft' : XPOSDescription(XPOSType.WORD, None),
'en_eslspok' : XPOSDescription(XPOSType.WORD, None),
'en_ewt' : XPOSDescription(XPOSType.WORD, None),
'en_genia' : XPOSDescription(XPOSType.WORD, None),
'en_gum' : XPOSDescription(XPOSType.WORD, None),
'en_gumreddit' : XPOSDescription(XPOSType.WORD, None),
'en_mimic' : XPOSDescription(XPOSType.WORD, None),
'en_test' : XPOSDescription(XPOSType.WORD, None),
'es_gsd' : XPOSDescription(XPOSType.WORD, None),
'et_edt' : XPOSDescription(XPOSType.WORD, None),
'et_ewt' : XPOSDescription(XPOSType.WORD, None),
'eu_bdt' : XPOSDescription(XPOSType.WORD, None),
'fa_perdt' : XPOSDescription(XPOSType.WORD, None),
'fa_seraji' : XPOSDescription(XPOSType.WORD, None),
'fi_tdt' : XPOSDescription(XPOSType.WORD, None),
'fr_combined' : XPOSDescription(XPOSType.WORD, None),
'fr_gsd' : XPOSDescription(XPOSType.WORD, None),
'fr_parisstories': XPOSDescription(XPOSType.WORD, None),
'fr_rhapsodie' : XPOSDescription(XPOSType.WORD, None),
'fr_sequoia' : XPOSDescription(XPOSType.WORD, None),
'fro_profiterole': XPOSDescription(XPOSType.WORD, None),
'ga_idt' : XPOSDescription(XPOSType.WORD, None),
'ga_twittirish' : XPOSDescription(XPOSType.WORD, None),
'got_proiel' : XPOSDescription(XPOSType.WORD, None),
'grc_proiel' : XPOSDescription(XPOSType.WORD, None),
'grc_ptnk' : XPOSDescription(XPOSType.WORD, None),
'gv_cadhan' : XPOSDescription(XPOSType.WORD, None),
'hbo_ptnk' : XPOSDescription(XPOSType.WORD, None),
'he_combined' : XPOSDescription(XPOSType.WORD, None),
'he_htb' : XPOSDescription(XPOSType.WORD, None),
'he_iahltknesset': XPOSDescription(XPOSType.WORD, None),
'he_iahltwiki' : XPOSDescription(XPOSType.WORD, None),
'hi_hdtb' : XPOSDescription(XPOSType.WORD, None),
'hsb_ufal' : XPOSDescription(XPOSType.WORD, None),
'hu_szeged' : XPOSDescription(XPOSType.WORD, None),
'hy_armtdp' : XPOSDescription(XPOSType.WORD, None),
'hy_bsut' : XPOSDescription(XPOSType.WORD, None),
'hyw_armtdp' : XPOSDescription(XPOSType.WORD, None),
'id_csui' : XPOSDescription(XPOSType.WORD, None),
'it_old' : XPOSDescription(XPOSType.WORD, None),
'ka_glc' : XPOSDescription(XPOSType.WORD, None),
'kk_ktb' : XPOSDescription(XPOSType.WORD, None),
'kmr_mg' : XPOSDescription(XPOSType.WORD, None),
'kpv_lattice' : XPOSDescription(XPOSType.WORD, None),
'ky_ktmu' : XPOSDescription(XPOSType.WORD, None),
'la_proiel' : XPOSDescription(XPOSType.WORD, None),
'lij_glt' : XPOSDescription(XPOSType.WORD, None),
'lt_hse' : XPOSDescription(XPOSType.WORD, None),
'lzh_kyoto' : XPOSDescription(XPOSType.WORD, None),
'mr_ufal' : XPOSDescription(XPOSType.WORD, None),
'mt_mudt' : XPOSDescription(XPOSType.WORD, None),
'myv_jr' : XPOSDescription(XPOSType.WORD, None),
'nb_bokmaal' : XPOSDescription(XPOSType.WORD, None),
'nds_lsdc' : XPOSDescription(XPOSType.WORD, None),
'nn_nynorsk' : XPOSDescription(XPOSType.WORD, None),
'nn_nynorsklia' : XPOSDescription(XPOSType.WORD, None),
'no_bokmaal' : XPOSDescription(XPOSType.WORD, None),
'orv_birchbark' : XPOSDescription(XPOSType.WORD, None),
'orv_rnc' : XPOSDescription(XPOSType.WORD, None),
'orv_torot' : XPOSDescription(XPOSType.WORD, None),
'ota_boun' : XPOSDescription(XPOSType.WORD, None),
'pcm_nsc' : XPOSDescription(XPOSType.WORD, None),
'pt_bosque' : XPOSDescription(XPOSType.WORD, None),
'pt_cintil' : XPOSDescription(XPOSType.WORD, None),
'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None),
'pt_gsd' : XPOSDescription(XPOSType.WORD, None),
'pt_petrogold' : XPOSDescription(XPOSType.WORD, None),
'pt_porttinari' : XPOSDescription(XPOSType.WORD, None),
'qpm_philotis' : XPOSDescription(XPOSType.WORD, None),
'qtd_sagt' : XPOSDescription(XPOSType.WORD, None),
'ru_gsd' : XPOSDescription(XPOSType.WORD, None),
'ru_poetry' : XPOSDescription(XPOSType.WORD, None),
'ru_syntagrus' : XPOSDescription(XPOSType.WORD, None),
'ru_taiga' : XPOSDescription(XPOSType.WORD, None),
'sa_vedic' : XPOSDescription(XPOSType.WORD, None),
'sme_giella' : XPOSDescription(XPOSType.WORD, None),
'swl_sslc' : XPOSDescription(XPOSType.WORD, None),
'sq_staf' : XPOSDescription(XPOSType.WORD, None),
'te_mtg' : XPOSDescription(XPOSType.WORD, None),
'tr_atis' : XPOSDescription(XPOSType.WORD, None),
'tr_boun' : XPOSDescription(XPOSType.WORD, None),
'tr_framenet' : XPOSDescription(XPOSType.WORD, None),
'tr_imst' : XPOSDescription(XPOSType.WORD, None),
'tr_kenet' : XPOSDescription(XPOSType.WORD, None),
'tr_penn' : XPOSDescription(XPOSType.WORD, None),
'tr_tourism' : XPOSDescription(XPOSType.WORD, None),
'ug_udt' : XPOSDescription(XPOSType.WORD, None),
'uk_parlamint' : XPOSDescription(XPOSType.WORD, None),
'vi_vtb' : XPOSDescription(XPOSType.WORD, None),
'wo_wtb' : XPOSDescription(XPOSType.WORD, None),
'xcl_caval' : XPOSDescription(XPOSType.WORD, None),
'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None),
'zh-hant_gsd' : XPOSDescription(XPOSType.WORD, None),
'zh_gsdsimp' : XPOSDescription(XPOSType.WORD, None),
'en_lines' : XPOSDescription(XPOSType.XPOS, '-'),
'fo_farpahc' : XPOSDescription(XPOSType.XPOS, '-'),
'ja_gsd' : XPOSDescription(XPOSType.XPOS, '-'),
'ja_gsdluw' : XPOSDescription(XPOSType.XPOS, '-'),
'sv_lines' : XPOSDescription(XPOSType.XPOS, '-'),
'ur_udtb' : XPOSDescription(XPOSType.XPOS, '-'),
'fi_ftb' : XPOSDescription(XPOSType.XPOS, ','),
'orv_ruthenian' : XPOSDescription(XPOSType.XPOS, ','),
'id_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
'ko_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
'ko_kaist' : XPOSDescription(XPOSType.XPOS, '+'),
'ko_ksl' : XPOSDescription(XPOSType.XPOS, '+'),
'qaf_arabizi' : XPOSDescription(XPOSType.XPOS, '+'),
'la_ittb' : XPOSDescription(XPOSType.XPOS, '|'),
'la_llct' : XPOSDescription(XPOSType.XPOS, '|'),
'nl_alpino' : XPOSDescription(XPOSType.XPOS, '|'),
'nl_lassysmall' : XPOSDescription(XPOSType.XPOS, '|'),
'sv_talbanken' : XPOSDescription(XPOSType.XPOS, '|'),
'pl_lfg' : XPOSDescription(XPOSType.XPOS, ':'),
'pl_pdb' : XPOSDescription(XPOSType.XPOS, ':'),
}
def xpos_vocab_factory(data, shorthand):
if shorthand not in XPOS_DESCRIPTIONS:
logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
desc = choose_simplest_factory(data, shorthand)
if shorthand in XPOS_DESCRIPTIONS:
if XPOS_DESCRIPTIONS[shorthand] != desc:
# log instead of throw
# otherwise, updating datasets would be unpleasant
logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
else:
logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
return build_xpos_vocab(desc, data, shorthand)
================================================
FILE: stanza/models/pos/xpos_vocab_utils.py
================================================
from collections import namedtuple
from enum import Enum
import logging
import os
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.models.pos.vocab import XPOSVocab, WordVocab
class XPOSType(Enum):
XPOS = 1
WORD = 2
XPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep'])
DEFAULT_KEY = XPOSDescription(XPOSType.WORD, None)
logger = logging.getLogger('stanza')
def filter_data(data, idx):
data_filtered = []
for sentence in data:
flag = True
for token in sentence:
if token[idx] is None:
flag = False
if flag: data_filtered.append(sentence)
return data_filtered
def choose_simplest_factory(data, shorthand):
logger.info(f'Original length = {len(data)}')
data = filter_data(data, idx=2)
logger.info(f'Filtered length = {len(data)}')
vocab = WordVocab(data, shorthand, idx=2, ignore=["_"])
key = DEFAULT_KEY
best_size = len(vocab) - len(VOCAB_PREFIX)
if best_size > 20:
for sep in ['', '-', '+', '|', ',', ':']: # separators
vocab = XPOSVocab(data, shorthand, idx=2, sep=sep)
length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())
if length < best_size:
key = XPOSDescription(XPOSType.XPOS, sep)
best_size = length
return key
def build_xpos_vocab(description, data, shorthand):
if description.xpos_type is XPOSType.WORD:
return WordVocab(data, shorthand, idx=2, ignore=["_"])
return XPOSVocab(data, shorthand, idx=2, sep=description.sep)
================================================
FILE: stanza/models/tagger.py
================================================
"""
Entry point for training and evaluating a POS/morphological features tagger.
This tagger uses highway BiLSTM layers with character and word-level representations, and biaffine classifiers
to produce consistent POS and UFeats predictions.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
"""
import argparse
import logging
import io
import os
import time
import zipfile
import numpy as np
import torch
from torch import nn, optim
from stanza.models.pos.data import Dataset, ShuffledDataset
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
from stanza.models.common import pretrain
from stanza.models.common.doc import *
from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
from stanza.models import _training_logging
from stanza.utils.conll import CoNLL
logger = logging.getLogger('stanza')
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/pos', help='Root dir for saving models.')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.')
parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--train_file', type=str, default=None, help='Input file for training.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for scoring.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time")
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
parser.add_argument('--hidden_dim', type=int, default=200)
parser.add_argument('--char_hidden_dim', type=int, default=400)
parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)
parser.add_argument('--composite_deep_biaff_hidden_dim', type=int, default=100)
parser.add_argument('--word_emb_dim', type=int, default=75, help='Dimension of the finetuned word embedding. Set to 0 to turn off')
parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding. If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF)
parser.add_argument('--char_emb_dim', type=int, default=100)
parser.add_argument('--tag_emb_dim', type=int, default=50)
parser.add_argument('--charlm_transform_dim', type=int, default=None, help='Transform the pretrained charlm to this dimension. If not set, no transform is used')
parser.add_argument('--transformed_dim', type=int, default=125)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--char_num_layers', type=int, default=1)
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
parser.add_argument('--word_dropout', type=float, default=0.33)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout")
parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout")
# TODO: refactor charlm arguments for models which use it?
parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.")
parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help="Use a bidirectional version of the non-pretrained charlm. Doesn't help much, makes the models larger")
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
parser.add_argument('--share_hid', action='store_true', help="Share hidden representations for UPOS, XPOS and UFeats.")
parser.set_defaults(share_hid=False)
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw, adamax, or adadelta. madgrad as an optional dependency')
parser.add_argument('--second_optim', type=str, default='amsgrad', help='Optimizer for the second half of training. Default is Adam with AMSGrad')
parser.add_argument('--second_optim_reload', default=False, action='store_true', help='Reload the best model instead of continuing from current model if the first optimizer stalls out. This does not seem to help, but might be useful for further experiments')
parser.add_argument('--no_second_optim', action='store_const', const=None, dest='second_optim', help="Don't use a second optimizer - only use the first optimizer")
parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
parser.add_argument('--second_lr', type=float, default=None, help='Alternate learning rate for the second optimizer')
parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer')
parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer')
parser.add_argument('--beta2', type=float, default=0.95)
parser.add_argument('--max_steps', type=int, default=50000)
parser.add_argument('--eval_interval', type=int, default=100)
parser.add_argument('--fix_eval_interval', dest='adapt_eval_interval', action='store_false', \
help="Use fixed evaluation interval for all treebanks, otherwise by default the interval will be increased for larger treebanks.")
parser.add_argument('--max_steps_before_stop', type=int, default=3000, help='Changes learning method or early terminates after this many steps if the dev scores are not improving')
parser.add_argument('--batch_size', type=int, default=250)
parser.add_argument('--batch_maximum_tokens', type=int, default=5000, help='When run in a Pipeline, limit a batch to this many tokens to help avoid OOM for long sentences')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
parser.add_argument('--save_dir', type=str, default='saved_models/pos', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tagger.pt", help="File name to save the model")
parser.add_argument('--save_each', default=False, action='store_true', help="Save each checkpoint to its own model. Will take up a bunch of space")
parser.add_argument('--seed', type=int, default=1234)
add_peft_args(parser)
utils.add_device_args(parser)
parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 50%%')
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
resolve_peft_args(args, logger)
if args.augment_nopunct is None:
args.augment_nopunct = 0.25
if args.wandb_name:
args.wandb = True
if not args.share_hid and args.tag_emb_dim == 0:
raise ValueError("Cannot have tag_emb_dim==0 with share_hid==False, as the tags will be embedded for the next layer")
args = vars(args)
return args
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running tagger in {} mode".format(args['mode']))
if args['mode'] == 'train':
return train(args)
else:
return evaluate(args)
def model_file_name(args):
return utils.standard_model_file_name(args, "tagger")
def save_each_file_name(args):
model_file = model_file_name(args)
pieces = os.path.splitext(model_file)
return pieces[0] + "_%05d" + pieces[1]
def load_pretrain(args):
pt = None
if args['pretrain']:
pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
if os.path.exists(pretrain_file):
vec_file = None
else:
vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
return pt
def get_eval_type(dev_batch):
"""
If there is only one column to score in the dev set, use that instead of AllTags
"""
if dev_batch.has_xpos and not dev_batch.has_upos and not dev_batch.has_feats:
return "XPOS"
elif dev_batch.has_upos and not dev_batch.has_xpos and not dev_batch.has_feats:
return "UPOS"
else:
return "AllTags"
def load_training_data(args, pretrain):
train_docs = []
raw_train_files = args['train_file'].split(";")
train_files = []
for train_file in raw_train_files:
if zipfile.is_zipfile(train_file):
logger.info("Decompressing %s" % train_file)
with zipfile.ZipFile(train_file) as zin:
for zipped_train_file in zin.namelist():
with zin.open(zipped_train_file) as fin:
logger.info("Reading %s from %s" % (zipped_train_file, train_file))
train_str = fin.read()
train_str = train_str.decode("utf-8")
train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)
logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data)))
train_docs.append(Document(train_file_data))
train_files.append("%s %s" % (train_file, zipped_train_file))
else:
logger.info("Reading %s" % train_file)
# train_data is now a list of sentences, where each sentence is a
# list of words, in which each word is a dict of conll attributes
train_file_data, _, _ = CoNLL.conll2dict(input_file=train_file)
logger.info("Train File {}, Data Size: {}".format(train_file, len(train_file_data)))
train_docs.append(Document(train_file_data))
train_files.append(train_file)
if sum(len(x.sentences) for x in train_docs) == 0:
raise RuntimeError("Training data for the tagger is empty: %s" % args['train_file'])
# we want to ensure that the model is able te output _ for empty columns,
# but create batches whereby if a doc has upos/xpos tags we include them all.
# therefore, we create separate datasets and loaders for each input training file,
# which will ensure the system be able to see batches with both upos available
# and upos unavailable depending on what the availability in the file is.
vocab = Dataset.init_vocab(train_docs, args)
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)
for i in train_docs]
for train_file, td in zip(train_files, train_data):
if not td.has_upos:
logger.info("No UPOS in %s" % train_file)
if not td.has_xpos:
logger.info("No XPOS in %s" % train_file)
if not td.has_feats:
logger.info("No feats in %s" % train_file)
# reject partially tagged upos or xpos documents
# otherwise, the model will learn to output blanks for some words,
# which is probably a confusing result
# (and definitely throws off the depparse)
# another option would be to treat those as masked out
for td_idx, td in enumerate(train_data):
if td.has_upos:
upos_data = td.doc.get(UPOS, as_sentences=True)
for sentence_idx, sentence in enumerate(upos_data):
for word_idx, upos in enumerate(sentence):
if upos == '_' or upos is None:
conll = "{:C}".format(td.doc.sentences[sentence_idx])
raise RuntimeError("Found a blank tag in the UPOS at sentence %d word %d of %s.\n%s" % ((sentence_idx+1), (word_idx+1), train_files[td_idx], conll))
# here we make sure the model will learn to output _ for empty columns
# if *any* dataset has data for the upos, xpos, or feature column,
# we consider that data enough to train the model on that column
# otherwise, we want to train the model to always output blanks
if not any(td.has_upos for td in train_data):
for td in train_data:
td.has_upos = True
if not any(td.has_xpos for td in train_data):
for td in train_data:
td.has_xpos = True
if not any(td.has_feats for td in train_data):
for td in train_data:
td.has_feats = True
# calculate the batches
train_batches = ShuffledDataset(train_data, args["batch_size"])
return vocab, train_data, train_batches
def train(args):
model_file = model_file_name(args)
utils.ensure_dir(os.path.split(model_file)[0])
if args['save_each']:
# so models.pt -> models_0001.pt, etc
model_save_each_file = save_each_file_name(args)
logger.info("Saving each checkpoint to %s" % model_save_each_file)
# load pretrained vectors if needed
pretrain = load_pretrain(args)
args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff'])
if args['charlm']:
if args['charlm_shorthand'] is None:
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
logger.info('Using pretrained contextualized char embedding')
if not args['charlm_forward_file']:
args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
if not args['charlm_backward_file']:
args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
vocab, train_data, train_batches = load_training_data(args, pretrain)
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
dev_batch = dev_data.to_loader(batch_size=args["batch_size"])
eval_type = get_eval_type(dev_data)
# skip training if the language does not have training or dev data
# sum(...) to check if all of the training files are empty
if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0:
logger.info("Skip training because no data available...")
return None, None
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tagger" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
logger.info("Training tagger...")
foundation_cache = FoundationCache()
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], foundation_cache=foundation_cache)
global_step = 0
max_steps = args['max_steps']
dev_score_history = []
best_dev_preds = []
current_lr = args['lr']
global_start_time = time.time()
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
logger.debug("Training model on device %s", next(trainer.model.parameters()).device)
if args['adapt_eval_interval']:
args['eval_interval'] = utils.get_adaptive_eval_interval(dev_data.num_examples, 2000, args['eval_interval'])
logger.info("Evaluating the model every {} steps...".format(args['eval_interval']))
if args['save_each']:
logger.info("Saving initial checkpoint to %s" % (model_save_each_file % global_step))
trainer.save(model_save_each_file % global_step)
using_amsgrad = False
last_best_step = 0
# start training
train_loss = 0
if args['log_norms']:
trainer.model.log_norms()
while True:
do_break = False
for i, batch in enumerate(train_batches):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
if global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))
if args['log_norms']:
trainer.model.log_norms()
if global_step % args['eval_interval'] == 0:
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
indices = []
for batch in dev_batch:
preds = trainer.predict(batch)
dev_preds += preds
indices.extend(batch[-1])
dev_preds = utils.unsort(dev_preds, indices)
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x])
system_pred_file = "{:C}\n\n".format(dev_data.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
train_loss = train_loss / args['eval_interval'] # avg loss per batch
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
train_loss = 0
if args['save_each']:
logger.info("Saving checkpoint to %s" % (model_save_each_file % global_step))
trainer.save(model_save_each_file % global_step)
# save best model
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
last_best_step = global_step
trainer.save(model_file)
logger.info("new best model saved.")
best_dev_preds = dev_preds
dev_score_history += [dev_score]
if global_step - last_best_step >= args['max_steps_before_stop']:
if not using_amsgrad and args['second_optim'] is not None:
logger.info("Switching to second optimizer: {}".format(args['second_optim']))
if args['second_optim_reload']:
logger.info('Reloading best model to continue from current local optimum')
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device'], foundation_cache=foundation_cache)
last_best_step = global_step
using_amsgrad = True
lr = args['second_lr']
if lr is None:
lr = args['lr']
trainer.optimizer = utils.get_optimizer(args['second_optim'], trainer.model, lr=lr, betas=(.9, args['beta2']), eps=1e-6, weight_decay=args['second_weight_decay'])
else:
logger.info("Early termination: have not improved in {} steps".format(args['max_steps_before_stop']))
do_break = True
break
if global_step >= args['max_steps']:
do_break = True
break
if do_break: break
logger.info("Training ended with {} steps.".format(global_step))
if args['wandb']:
wandb.finish()
if len(dev_score_history) > 0:
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
else:
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)
return trainer, _
def evaluate(args):
# file paths
model_file = model_file_name(args)
pretrain = load_pretrain(args)
load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),
'charlm_backward_file': args.get('charlm_backward_file', None)}
# load model
logger.info("Loading model from: {}".format(model_file))
trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
result_doc = evaluate_trainer(args, trainer, pretrain)
return trainer, result_doc
def evaluate_trainer(args, trainer, pretrain):
system_pred_file = args['output_file']
loaded_args, vocab = trainer.args, trainer.vocab
# load config
for k in args:
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':
loaded_args[k] = args[k]
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_data = Dataset(doc, loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
dev_batch = dev_data.to_loader(batch_size=args['batch_size'])
eval_type = get_eval_type(dev_data)
if len(dev_batch) > 0:
logger.info("Start evaluation...")
preds = []
indices = []
with torch.no_grad():
for b in dev_batch:
preds += trainer.predict(b)
indices.extend(b[-1])
else:
# skip eval if dev data does not exist
preds = []
preds = utils.unsort(preds, indices)
# write to file and score
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x])
if system_pred_file:
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
if args['gold_labels']:
system_pred_file = "{:C}\n\n".format(dev_data.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100)
return dev_data.doc
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/tokenization/__init__.py
================================================
================================================
FILE: stanza/models/tokenization/data.py
================================================
from bisect import bisect_right
from collections import defaultdict
from copy import copy
import numpy as np
import random
import logging
import re
import torch
from torch.utils.data import Dataset
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.models.tokenization.vocab import Vocab
logger = logging.getLogger('stanza')
def filter_consecutive_whitespaces(para):
filtered = []
for i, (char, label) in enumerate(para):
if i > 0:
if char == ' ' and para[i-1][0] == ' ':
continue
filtered.append((char, label))
return filtered
NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
# this was (r'^([\d]+[,\.]*)+$')
# but the runtime on that can explode exponentially
# for example, on 111111111111111111111111a
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
WHITESPACE_RE = re.compile(r'\s')
class TokenizationDataset:
def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
super().__init__(*args, **kwargs) # forwards all unused arguments
self.args = tokenizer_args
self.eval = evaluation
self.dictionary = dictionary
self.vocab = vocab
# get input files
txt_file = input_files['txt']
label_file = input_files['label']
# Load data and process it
# set up text from file or input string
assert txt_file is not None or input_text is not None
if input_text is None:
with open(txt_file, encoding="utf-8") as f:
text = ''.join(f.readlines()).rstrip()
else:
text = input_text
text_chunks = NEWLINE_WHITESPACE_RE.split(text)
text_chunks = [pt.rstrip() for pt in text_chunks]
text_chunks = [pt for pt in text_chunks if pt]
if label_file is not None:
with open(label_file, encoding="utf-8") as f:
labels = ''.join(f.readlines()).rstrip()
labels = NEWLINE_WHITESPACE_RE.split(labels)
labels = [pt.rstrip() for pt in labels]
labels = [map(int, pt) for pt in labels if pt]
else:
labels = [[0 for _ in pt] for pt in text_chunks]
skip_newline = self.args.get('skip_newline', False)
self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
for pt, pc in zip(text_chunks, labels)]
# remove consecutive whitespaces
self.data = [filter_consecutive_whitespaces(x) for x in self.data]
def labels(self):
"""
Returns a list of the labels for all of the sentences in this DataLoader
Used at eval time to compare to the results, for example
"""
return [np.array(list(x[1] for x in sent)) for sent in self.data]
def extract_dict_feat(self, para, idx):
"""
This function is to extract dictionary features for each character
"""
length = len(para)
dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
forward_word = para[idx][0]
backward_word = para[idx][0]
prefix = True
suffix = True
for window in range(1,self.args['num_dict_feat']+1):
# concatenate each character and check if words found in dict not, stop if prefix not found
#check if idx+t is out of bound and if the prefix is already not found
if (idx + window) <= length-1 and prefix:
forward_word += para[idx+window][0].lower()
#check in json file if the word is present as prefix or word or None.
feat = 1 if forward_word in self.dictionary["words"] else 0
#if the return value is not 2 or 3 then the checking word is not a valid word in dict.
dict_forward_feats[window-1] = feat
#if the dict return 0 means no prefixes found, thus, stop looking for forward.
if forward_word not in self.dictionary["prefixes"]:
prefix = False
#backward check: similar to forward
if (idx - window) >= 0 and suffix:
backward_word = para[idx-window][0].lower() + backward_word
feat = 1 if backward_word in self.dictionary["words"] else 0
dict_backward_feats[window-1] = feat
if backward_word not in self.dictionary["suffixes"]:
suffix = False
#if cannot find both prefix and suffix, then exit the loop
if not prefix and not suffix:
break
return dict_forward_feats + dict_backward_feats
def para_to_sentences(self, para):
""" Convert a paragraph to a list of processed sentences. """
res = []
funcs = []
for feat_func in self.args['feat_funcs']:
if feat_func == 'end_of_para' or feat_func == 'start_of_para':
# skip for position-dependent features
continue
if feat_func == 'space_before':
func = lambda x: 1 if x.startswith(' ') else 0
elif feat_func == 'capitalized':
func = lambda x: 1 if x[0].isupper() else 0
elif feat_func == 'numeric':
func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
else:
raise ValueError('Feature function "{}" is undefined.'.format(feat_func))
funcs.append(func)
# stacking all featurize functions
composite_func = lambda x: [f(x) for f in funcs]
def process_sentence(sent_units, sent_labels, sent_feats):
return (np.array([self.vocab.unit2id(y) for y in sent_units]),
np.array(sent_labels),
np.array(sent_feats),
list(sent_units))
use_end_of_para = 'end_of_para' in self.args['feat_funcs']
use_start_of_para = 'start_of_para' in self.args['feat_funcs']
use_dictionary = self.args['use_dictionary']
current_units = []
current_labels = []
current_feats = []
for i, (unit, label) in enumerate(para):
feats = composite_func(unit)
# position-dependent features
if use_end_of_para:
f = 1 if i == len(para)-1 else 0
feats.append(f)
if use_start_of_para:
f = 1 if i == 0 else 0
feats.append(f)
#if dictionary feature is selected
if use_dictionary:
dict_feats = self.extract_dict_feat(para, i)
feats = feats + dict_feats
current_units.append(unit)
current_labels.append(label)
current_feats.append(feats)
if not self.eval and (label == 2 or label == 4): # end of sentence
if len(current_units) <= self.args['max_seqlen']:
# get rid of sentences that are too long during training of the tokenizer
res.append(process_sentence(current_units, current_labels, current_feats))
current_units.clear()
current_labels.clear()
current_feats.clear()
if len(current_units) > 0:
if self.eval or len(current_units) <= self.args['max_seqlen']:
res.append(process_sentence(current_units, current_labels, current_feats))
return res
def advance_old_batch(self, eval_offsets, old_batch):
"""
Advance to a new position in a batch where we have partially processed the batch
If we have previously built a batch of data and made predictions on them, then when we are trying to make
prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
In this case, eval_offsets index within the old_batch to advance the strings to process.
"""
unkid = self.vocab.unit2id('')
padid = self.vocab.unit2id('')
ounits, olabels, ofeatures, oraw = old_batch
feat_size = ofeatures.shape[-1]
lens = (ounits != padid).sum(1).tolist()
pad_len = max(l-i for i, l in zip(eval_offsets, lens))
units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
raw_units = []
for i in range(len(ounits)):
eval_offsets[i] = min(eval_offsets[i], lens[i])
units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + [''] * (pad_len - lens[i] + eval_offsets[i]))
return units, labels, features, raw_units
def build_move_punct_set(data, move_back_prob):
move_punct = {',', ':', '!', '.', '?', '"', '(', ')'}
for chunk in data:
# ignore positions at the start and end of a chunk
for idx in range(1, len(chunk)-1):
if chunk[idx][0] not in move_punct:
continue
if chunk[idx][1] == 0:
if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit():
# this check removes punct which isn't ending a word...
# honestly that's a rather unusual situation
# VI has |3, 5| as a complete token
# so we also eliminate isdigit()
move_punct.remove(chunk[idx][0])
continue
# we skip isdigit() because we will intentionally not
# create things that look like decimal numbers
if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit():
# this check eliminates things like '.' after 'Mr.'
move_punct.remove(chunk[idx][0])
continue
return move_punct
def build_known_mwt(data, mwt_expansions):
known_mwts = set()
for chunk in data:
for idx, unit in enumerate(chunk):
if unit[1] != 3:
continue
# found an MWT
prev_idx = idx - 1
while prev_idx >= 0 and chunk[prev_idx][1] == 0:
prev_idx -= 1
prev_idx += 1
while chunk[prev_idx][0].isspace():
prev_idx += 1
if prev_idx == idx:
continue
mwt = "".join(x[0] for x in chunk[prev_idx:idx+1])
if mwt not in mwt_expansions:
continue
if len(mwt_expansions[mwt]) > 2:
# TODO: could split 3 word tokens as well
continue
known_mwts.add(mwt)
return known_mwts
class DataLoader(TokenizationDataset):
"""
This is the training version of the dataset.
"""
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None):
super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)
self.vocab = vocab if vocab is not None else self.init_vocab()
# data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
# At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
# sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
# the last predicted sentence break to start afresh.
self.sentences = [self.para_to_sentences(para) for para in self.data]
self.init_sent_ids()
logger.debug(f"{len(self.sentence_ids)} sentences loaded.")
punct_move_back_prob = args.get('punct_move_back_prob', 0.0)
if punct_move_back_prob > 0.0:
self.move_punct = build_move_punct_set(self.data, punct_move_back_prob)
if len(self.move_punct) > 0:
logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct))
else:
logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace')
split_mwt_prob = args.get('split_mwt_prob', 0.0)
if split_mwt_prob > 0.0 and not evaluation:
self.mwt_expansions = mwt_expansions
self.known_mwt = build_known_mwt(self.data, mwt_expansions)
if len(self.known_mwt) > 0:
logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt))
else:
logger.debug('Based on the training data, there are NO MWT to split at training time')
augment_final_punct_prob = 0.0 if evaluation else args.get('augment_final_punct_prob', 0.0)
if augment_final_punct_prob > 0:
self.augmentations = defaultdict(list)
AUGMENT_PAIRS = [("?", "?"),
("?", "︖"),
("?", "﹖"),
("?", "⁇"),
("!", "!"),
("!", "︕"),
("!", "﹗"),
("!", "‼"),]
for orig, target in AUGMENT_PAIRS:
if self.augment_vocab(self.vocab, self.data, orig, target):
logger.debug('Based on the training data, augmenting |%s| to |%s|' % (orig, target))
self.augmentations[orig].append(target)
if self.augment_vocab(self.vocab, self.data, target, orig):
logger.debug('Based on the training data, augmenting |%s| to |%s|' % (target, orig))
self.augmentations[target].append(orig)
def __len__(self):
return len(self.sentence_ids)
def init_vocab(self):
vocab = Vocab(self.data, self.args['lang'])
return vocab
@staticmethod
def augment_vocab(vocab, data, existing_unit, new_unit):
if existing_unit not in vocab:
return False
new_unit_count = 0
existing_unit_count = 0
for sentence in data:
unit = sentence[-1][0]
if unit == new_unit:
new_unit_count += 1
elif unit == existing_unit:
existing_unit_count += 1
if existing_unit_count == 0:
return False
if new_unit_count > 0:
return False
if new_unit not in vocab:
vocab.append(new_unit)
logger.debug("Found %d |%s| and %d |%s|", new_unit_count, new_unit, existing_unit_count, existing_unit)
return True
def init_sent_ids(self):
self.sentence_ids = []
self.cumlen = [0]
for i, para in enumerate(self.sentences):
for j in range(len(para)):
self.sentence_ids += [(i, j)]
self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]
def has_mwt(self):
# presumably this only needs to be called either 0 or 1 times,
# 1 when training and 0 any other time, so no effort is put
# into caching the result
for sentence in self.data:
for word in sentence:
if word[1] > 2:
return True
return False
def shuffle(self):
for para in self.sentences:
random.shuffle(para)
self.init_sent_ids()
def move_last_char(self, sentence):
if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0:
new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]
new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))])
encoded = self.para_to_sentences(new_units)
return encoded
return None
def split_mwt(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None
# if we find a token in the sentence which ends with label 3,
# eg it is an MWT,
# with some probability we split it into two tokens
# and treat the split tokens as both label 1 instead of 3
# in this manner, we teach the tokenizer not to treat the
# entire sequence of characters with added spaces as an MWT,
# which weirdly can happen in some corner cases
mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3]
if len(mwt_ends) == 0:
return None
random_end = random.randint(0, len(mwt_ends)-1)
mwt_end = mwt_ends[random_end]
mwt_start = mwt_end - 1
while mwt_start >= 0 and sentence[1][mwt_start] == 0:
mwt_start -= 1
mwt_start += 1
while sentence[3][mwt_start].isspace():
mwt_start += 1
if mwt_start == mwt_end:
return None
mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1])
if mwt not in self.mwt_expansions:
return None
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]]
w0_units[-1] = (w0_units[-1][0], 1)
w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]]
w1_units[-1] = (w1_units[-1][0], 1)
split_units = w0_units + [(' ', 0)] + w1_units
new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:]
encoded = self.para_to_sentences(new_units)
return encoded
def move_punct_back(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None
# check that we are not accidentally creating decimal numbers
# idx == 1 or not sentence[3][idx-2].isdigit()
# one disadvantage of checking for sentence[1][idx] == 0
# would be that tokens of all punct, such as '...',
# should move but would not move if this is eliminated
commas = [idx for idx, c in enumerate(sentence[3])
if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())]
if len(commas) == 0:
return None
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
new_units = []
span_start = 0
for span_end in commas:
new_units.extend(all_units[span_start:span_end-1])
span_start = span_end
if span_end < len(sentence[3]):
new_units.extend(all_units[span_end:])
encoded = self.para_to_sentences(new_units)
return encoded
def augment_final_punct(self, sentence):
if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen']:
if sentence[3][-1] in self.augmentations:
augmented = random.choice(self.augmentations[sentence[3][-1]])
new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]
new_units.append((augmented, sentence[1][-1]))
else:
return None
encoded = self.para_to_sentences(new_units)
return encoded
return None
def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
feat_size = len(self.sentences[0][0][2][0])
unkid = self.vocab.unit2id('')
padid = self.vocab.unit2id('')
def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
# At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed
# by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
# from the entire dataset until we reach max_seqlen.
drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0)
move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0)
split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0)
augment_final_punct_prob = 0.0 if self.eval else self.args.get('augment_final_punct_prob', 0.0)
pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
total_len = len(sentences[0][0])
assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
if self.eval:
for sid1 in range(sid+1, len(self.sentences[pid])):
total_len += len(self.sentences[pid][sid1][0])
sentences.append(self.sentences[pid][sid1])
if total_len >= self.args['max_seqlen']:
break
else:
while True:
pid1, sid1 = random.choice(self.sentence_ids)
total_len += len(self.sentences[pid1][sid1][0])
sentences.append(self.sentences[pid1][sid1])
if total_len >= self.args['max_seqlen']:
break
if move_last_char_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < move_last_char_prob:
# the sentence might not be eligible, such as
# already having a space or not having a sentence final punct,
# so we need to do a two step checking process here
new_sentence = self.move_last_char(sentence)
if new_sentence is not None:
sentences[sentence_idx] = new_sentence[0]
total_len += 1
if move_punct_back_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < move_punct_back_prob:
# the sentence might not be eligible, such as
# not having a space separated punct,
# so we need to do a two step checking process here
new_sentence = self.move_punct_back(sentence)
if new_sentence is not None:
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]
if split_mwt_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < split_mwt_prob:
new_sentence = self.split_mwt(sentence)
if new_sentence is not None:
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]
if augment_final_punct_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < split_mwt_prob:
new_sentence = self.augment_final_punct(sentence)
if new_sentence is not None:
sentences[sentence_idx] = new_sentence[0]
if drop_sents and len(sentences) > 1:
if total_len > self.args['max_seqlen']:
sentences = sentences[:-1]
if len(sentences) > 1:
p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
sentences = sentences[:cutoff+1]
units = np.concatenate([s[0] for s in sentences])
labels = np.concatenate([s[1] for s in sentences])
feats = np.concatenate([s[2] for s in sentences])
raw_units = [x for s in sentences for x in s[3]]
if not self.eval:
cutoff = self.args['max_seqlen']
units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]
if drop_last_char: # can only happen in non-eval mode
if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
# training text ended with a sentence end position
# and that word was a single character
# and the previous character ended the word
units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
# word end -> sentence end, mwt end -> sentence mwt end
labels[-1] = labels[-1] + 1
return units, labels, feats, raw_units
if eval_offsets is not None:
# find max padding length
pad_len = 0
for eval_offset in eval_offsets:
if eval_offset < self.cumlen[-1]:
pair_id = bisect_right(self.cumlen, eval_offset) - 1
pair = self.sentence_ids[pair_id]
pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))
pad_len += 1
id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]
offsets_pairs = list(zip(offsets, pairs))
else:
id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
offsets_pairs = [(0, x) for x in id_pairs]
pad_len = self.args['max_seqlen']
# put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
raw_units = []
for i, (offset, pair) in enumerate(offsets_pairs):
u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
units[i, :len(u_)] = u_
labels[i, :len(l_)] = l_
features[i, :len(f_), :] = f_
raw_units.append(r_ + [''] * (pad_len - len(r_)))
if unit_dropout > 0 and not self.eval:
# dropout characters/units at training time and replace them with UNKs
mask = np.random.random_sample(units.shape) < unit_dropout
mask[units == padid] = 0
units[mask] = unkid
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
if mask[i, j]:
raw_units[i][j] = ''
# dropout unit feature vector in addition to only torch.dropout in the model.
# experiments showed that only torch.dropout hurts the model
# we believe it is because the dict feature vector is mostly scarse so it makes
# more sense to drop out the whole vector instead of only single element.
if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
mask_feat[units == padid] = 0
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
if mask_feat[i,j]:
features[i,j,:] = 0
units = torch.from_numpy(units)
labels = torch.from_numpy(labels)
features = torch.from_numpy(features)
return units, labels, features, raw_units
class SortedDataset(Dataset):
"""
Holds a TokenizationDataset for use in a torch DataLoader
The torch DataLoader is different from the DataLoader defined here
and allows for cpu & gpu parallelism. Updating output_predictions
to use this class as a wrapper to a TokenizationDataset means the
calculation of features can happen in parallel, saving quite a
bit of time.
"""
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# This will return a single sample
# np: index in character map
# np: tokenization label
# np: features
# list: original text as one length strings
return self.dataset.para_to_sentences(self.data[index])
def unsort(self, arr):
return unsort(arr, self.indices)
def collate(self, samples):
if any(len(x) > 1 for x in samples):
raise ValueError("Expected all paragraphs to have no preset sentence splits!")
feat_size = samples[0][0][2].shape[-1]
padid = self.dataset.vocab.unit2id('')
# +1 so that all samples end with at least one pad
pad_len = max(len(x[0][3]) for x in samples) + 1
units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
raw_units = []
for i, sample in enumerate(samples):
u_, l_, f_, r_ = sample[0]
units[i, :len(u_)] = torch.from_numpy(u_)
labels[i, :len(l_)] = torch.from_numpy(l_)
features[i, :len(f_), :] = torch.from_numpy(f_)
raw_units.append(r_ + [''])
return units, labels, features, raw_units
================================================
FILE: stanza/models/tokenization/model.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
from stanza.models.common.foundation_cache import load_charlm
class Tokenizer(nn.Module):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None):
super().__init__()
self.unsaved_modules = []
self.args = args
feat_dim = args['feat_dim']
self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
self.input_dim = emb_dim + feat_dim
charmodel = None
if args is not None and args.get('charlm_forward_file', None):
charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)
charmodels = nn.ModuleList([charmodel_forward])
charmodel = CharacterLanguageModelWordAdapter(charmodels)
self.input_dim += charmodel.hidden_dim()
self.add_unsaved_module("charmodel", charmodel)
self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
if self.args['conv_res'] is not None:
self.conv_res = nn.ModuleList()
self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]
for si, size in enumerate(self.conv_sizes):
l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
self.conv_res.append(l)
if self.args.get('hier_conv_res', False):
self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)
self.tok_clf = nn.Linear(hidden_dim * 2, 1)
self.sent_clf = nn.Linear(hidden_dim * 2, 1)
if self.args['use_mwt']:
self.mwt_clf = nn.Linear(hidden_dim * 2, 1)
if args['hierarchical']:
in_dim = hidden_dim * 2
self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
if self.args['use_mwt']:
self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
self.dropout = nn.Dropout(dropout)
self.dropout_feat = nn.Dropout(feat_dropout)
self.toknoise = nn.Dropout(self.args['tok_noise'])
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def forward(self, x, feats, lengths, raw=None):
emb = self.embeddings(x)
if self.charmodel is not None and raw is not None:
char_emb = self.charmodel(raw, wrap=False)
emb = torch.cat([emb, char_emb], axis=2)
emb = self.dropout(emb)
feats = self.dropout_feat(feats)
emb = torch.cat([emb, feats], 2)
emb = pack_padded_sequence(emb, lengths, batch_first=True)
inp, _ = self.rnn(emb)
inp, _ = pad_packed_sequence(inp, batch_first=True)
if self.args['conv_res'] is not None:
conv_input = emb.transpose(1, 2).contiguous()
if not self.args.get('hier_conv_res', False):
for l in self.conv_res:
inp = inp + l(conv_input).transpose(1, 2).contiguous()
else:
hid = []
for l in self.conv_res:
hid += [l(conv_input)]
hid = torch.cat(hid, 1)
hid = F.relu(hid)
hid = self.dropout(hid)
inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()
inp = self.dropout(inp)
tok0 = self.tok_clf(inp)
sent0 = self.sent_clf(inp)
if self.args['use_mwt']:
mwt0 = self.mwt_clf(inp)
if self.args['hierarchical']:
inp2 = inp
if self.args['hier_invtemp'] > 0:
inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp'])))
inp2 = pack_padded_sequence(inp2, lengths, batch_first=True)
inp2, _ = self.rnn2(inp2)
inp2, _ = pad_packed_sequence(inp2, batch_first=True)
inp2 = self.dropout(inp2)
tok0 = tok0 + self.tok_clf2(inp2)
sent0 = sent0 + self.sent_clf2(inp2)
if self.args['use_mwt']:
mwt0 = mwt0 + self.mwt_clf2(inp2)
nontok = F.logsigmoid(-tok0)
tok = F.logsigmoid(tok0)
nonsent = F.logsigmoid(-sent0)
sent = F.logsigmoid(sent0)
if self.args['use_mwt']:
nonmwt = F.logsigmoid(-mwt0)
mwt = F.logsigmoid(mwt0)
pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
else:
pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)
return pred
================================================
FILE: stanza/models/tokenization/tokenize_files.py
================================================
"""Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line
For example, the output of this script is suitable for Glove
Currently this *only* supports tokenization, no MWT splitting.
It also would be beneficial to have an option to convert spaces into
NBSP, underscore, or some other marker to make it easier to process
languages such as VI which have spaces in them
"""
import argparse
import io
import os
import time
import re
import zipfile
import torch
import stanza
from stanza.models.common.utils import open_read_text, default_device
from stanza.models.tokenization.data import TokenizationDataset
from stanza.models.tokenization.utils import output_predictions
from stanza.pipeline.tokenize_processor import TokenizeProcessor
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
NEWLINE_SPLIT_RE = re.compile(r"\n\s*\n")
def tokenize_to_file(tokenizer, fin, fout, chunk_size=500):
raw_text = fin.read()
documents = NEWLINE_SPLIT_RE.split(raw_text)
for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False):
chunk_end = min(chunk_start + chunk_size, len(documents))
chunk = documents[chunk_start:chunk_end]
in_docs = [stanza.Document([], text=d) for d in chunk]
out_docs = tokenizer.bulk_process(in_docs)
for document in out_docs:
for sent_idx, sentence in enumerate(document.sentences):
if sent_idx > 0:
fout.write(" ")
fout.write(" ".join(x.text for x in sentence.tokens))
fout.write("\n")
def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="sd", help="Which language to use for tokenization")
parser.add_argument("--tokenize_model_path", type=str, default=None, help="Specific tokenizer model to use")
parser.add_argument("input_files", type=str, nargs="+", help="Which input files to tokenize")
parser.add_argument("--output_file", type=str, default="glove.txt", help="Where to write the tokenized output")
parser.add_argument("--model_dir", type=str, default=None, help="Where to get models for a Pipeline (None => default models dir)")
parser.add_argument("--chunk_size", type=int, default=500, help="How many 'documents' to use in a chunk when tokenizing. This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once")
args = parser.parse_args(args=args)
if os.path.exists(args.output_file):
print("Cowardly refusing to overwrite existing output file %s" % args.output_file)
return
if args.tokenize_model_path:
config = { "model_path": args.tokenize_model_path,
"check_requirements": False }
tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device())
else:
pipe = stanza.Pipeline(lang=args.lang, processors="tokenize", model_dir=args.model_dir)
tokenizer = pipe.processors["tokenize"]
with open(args.output_file, "w", encoding="utf-8") as fout:
for filename in tqdm(args.input_files):
if filename.endswith(".zip"):
with zipfile.ZipFile(filename) as zin:
input_names = zin.namelist()
for input_name in tqdm(input_names, leave=False):
with zin.open(input_names[0]) as fin:
fin = io.TextIOWrapper(fin, encoding='utf-8')
tokenize_to_file(tokenizer, fin, fout)
else:
with open_read_text(filename, encoding="utf-8") as fin:
tokenize_to_file(tokenizer, fin, fout)
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/tokenization/trainer.py
================================================
import sys
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from stanza.models.common import utils
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.tokenization.utils import create_dictionary
from .model import Tokenizer
from .vocab import Vocab
logger = logging.getLogger('stanza')
class Trainer(BaseTrainer):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None):
# TODO: make a test of the training w/ and w/o charlm
if model_file is not None:
# load everything from file
self.load(model_file, args, foundation_cache)
else:
# build model from scratch
self.args = args
self.vocab = vocab
self.lexicon = list(lexicon) if lexicon is not None else None
self.dictionary = dictionary
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model = self.model.to(device)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
self.feat_funcs = self.args.get('feat_funcs', None)
self.lang = self.args['lang'] # language determines how token normalization is done
def update(self, inputs):
self.model.train()
units, labels, features, text = inputs
lengths = [len(x) for x in text]
device = next(self.model.parameters()).device
units = units.to(device)
labels = labels.to(device)
features = features.to(device)
pred = self.model(units, features, lengths, text)
self.optimizer.zero_grad()
classes = pred.size(2)
loss = self.criterion(pred.view(-1, classes), labels.view(-1))
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
return loss.item()
def predict(self, inputs):
self.model.eval()
units, _, features, text = inputs
lengths = [len(x) for x in text]
device = next(self.model.parameters()).device
units = units.to(device)
features = features.to(device)
pred = self.model(units, features, lengths, text)
return pred.data.cpu().numpy()
def save(self, filename, skip_modules=True):
model_state = None
if self.model is not None:
model_state = self.model.state_dict()
# skip saving modules like the pretrained charlm
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
# save and load lexicon as list instead of set so
# we can use weights_only=True
'lexicon': list(self.lexicon) if self.lexicon is not None else None,
'config': self.args
}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except BaseException:
logger.warning("Saving failed... continuing anyway.")
def load(self, filename, args, foundation_cache):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None and args.get('charlm_forward_file', None) is not None:
if checkpoint['config'].get('charlm_forward_file') is None:
# if the saved model didn't use a charlm, we skip the charlm here
# otherwise the loaded model weights won't fit in the newly created model
self.args['charlm_forward_file'] = None
else:
self.args['charlm_forward_file'] = args['charlm_forward_file']
if self.args.get('use_mwt', None) is None:
# Default to True as many currently saved models
# were built with mwt layers
self.args['use_mwt'] = True
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model.load_state_dict(checkpoint['model'], strict=False)
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
self.lexicon = checkpoint['lexicon']
if self.lexicon is not None:
self.lexicon = set(self.lexicon)
self.dictionary = create_dictionary(self.lexicon)
else:
self.dictionary = None
================================================
FILE: stanza/models/tokenization/utils.py
================================================
from collections import Counter
from copy import copy
import json
import numpy as np
import re
import logging
import os
from torch.utils.data import DataLoader as TorchDataLoader
import stanza.utils.default_paths as default_paths
from stanza.models.common.utils import ud_scores, harmonic_mean
from stanza.models.common.doc import Document
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import *
from stanza.models.tokenization.data import SortedDataset
logger = logging.getLogger('stanza')
paths = default_paths.get_default_paths()
def create_dictionary(lexicon):
"""
This function is to create a new dictionary used for improving tokenization model for multi-syllable words languages
such as vi, zh or th. This function takes the lexicon as input and output a dictionary that contains three set:
words, prefixes and suffixes where prefixes set should contains all the prefixes in the lexicon and similar for suffixes.
The point of having prefixes/suffixes sets in the dictionary is just to make it easier to check during data preparation.
:param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp
:param lexicon - set of words used to create dictionary
:return a dictionary object that contains words and their prefixes and suffixes.
"""
dictionary = {"words":set(), "prefixes":set(), "suffixes":set()}
def add_word(word):
if word not in dictionary["words"]:
dictionary["words"].add(word)
prefix = ""
suffix = ""
for i in range(0,len(word)-1):
prefix = prefix + word[i]
suffix = word[len(word) - i - 1] + suffix
dictionary["prefixes"].add(prefix)
dictionary["suffixes"].add(suffix)
for word in lexicon:
if len(word)>1:
add_word(word)
return dictionary
def create_lexicon(shorthand=None, train_path=None, external_path=None):
"""
This function is to create a lexicon to store all the words from the training set and external dictionary.
This lexicon will be saved with the model and will be used to create dictionary when the model is loaded.
The idea of separating lexicon and dictionary in two different phases is a good tradeoff between time and space.
Note that we eliminate all the long words but less frequently appeared in the lexicon by only taking 95-percentile
list of words.
:param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp
:param train_path - path to conllu train file
:param external_path - path to extenral dict, expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt
:return a set lexicon object that contains all distinct words
"""
lexicon = set()
length_freq = []
#this regex is to check if a character is an actual Thai character as seems .isalpha() python method doesn't pick up Thai accent characters..
pattern_thai = re.compile(r"(?:[^\d\W]+)|\s")
def check_valid_word(shorthand, word):
"""
This function is to check if the word are multi-syllable words and not numbers.
For vi, whitespaces are syllabe-separator.
"""
if shorthand.startswith("vi_"):
return True if len(word.split(" ")) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False
elif shorthand.startswith("th_"):
return True if len(word) > 1 and any(map(pattern_thai.match, word)) and not any(map(str.isdigit, word)) else False
else:
return True if len(word) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False
#checking for words in the training set to add them to lexicon.
if train_path is not None:
if not os.path.isfile(train_path):
raise FileNotFoundError(f"Cannot open train set at {train_path}")
train_doc = CoNLL.conll2doc(input_file=train_path)
for train_sent in train_doc.sentences:
train_words = [x.text for x in train_sent.tokens if x.is_mwt()] + [x.text for x in train_sent.words]
for word in train_words:
word = word.lower()
if check_valid_word(shorthand, word) and word not in lexicon:
lexicon.add(word)
length_freq.append(len(word))
count_word = len(lexicon)
logger.info(f"Added {count_word} words from the training data to the lexicon.")
#checking for external dictionary and add them to lexicon.
if external_path is not None:
if not os.path.isfile(external_path):
raise FileNotFoundError(f"Cannot open external dictionary at {external_path}")
with open(external_path, "r", encoding="utf-8") as external_file:
lines = external_file.readlines()
for line in lines:
word = line.lower()
word = word.replace("\n","")
if check_valid_word(shorthand, word) and word not in lexicon:
lexicon.add(word)
length_freq.append(len(word))
logger.info(f"Added another {len(lexicon) - count_word} words from the external dict to dictionary.")
#automatically calculate the number of dictionary features (window size to look for words) based on the frequency of word length
#take the length at 95-percentile to eliminate all the longest (maybe) compounds words in the lexicon
num_dict_feat = int(np.percentile(length_freq, 95))
lexicon = {word for word in lexicon if len(word) <= num_dict_feat }
logger.info(f"Final lexicon consists of {len(lexicon)} words after getting rid of long words.")
return lexicon, num_dict_feat
def load_lexicon(args):
"""
This function is to create a new dictionary and load it to training.
The external dictionary is expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt
For example, vi_vlsp-externaldict.txt
"""
shorthand = args["shorthand"]
tokenize_dir = paths["TOKENIZE_DATA_DIR"]
train_path = f"{tokenize_dir}/{shorthand}.train.gold.conllu"
external_dict_path = f"{tokenize_dir}/{shorthand}-externaldict.txt"
if not os.path.exists(external_dict_path):
logger.info(f"External dictionary not found! Looked in {external_dict_path} Checking training data...")
external_dict_path = None
if not os.path.exists(train_path):
logger.info(f"Training dataset does not exist, thus cannot create dictionary {shorthand}")
train_path = None
if train_path is None and external_dict_path is None:
raise FileNotFoundError(f"Cannot find training set / external dictionary at {train_path} and {external_dict_path}")
return create_lexicon(shorthand, train_path, external_dict_path)
def load_mwt_dict(filename):
"""
Returns a dict from an MWT to its most common expansion and count.
Other less common expansions are discarded.
"""
if filename is None:
return None
with open(filename, 'r') as f:
mwt_dict0 = json.load(f)
mwt_dict = dict()
for item in mwt_dict0:
(key, expansion), count = item
if key not in mwt_dict or mwt_dict[key][1] < count:
mwt_dict[key] = (expansion, count)
return mwt_dict
def process_sentence(sentence, mwt_dict=None):
sent = []
i = 0
for tok, p, position_info in sentence:
expansion = None
if (p == 3 or p == 4) and mwt_dict is not None:
# MWT found, (attempt to) expand it!
if tok in mwt_dict:
expansion = mwt_dict[tok][0]
elif tok.lower() in mwt_dict:
expansion = mwt_dict[tok.lower()][0]
if expansion is not None:
sent.append({ID: (i+1, i+len(expansion)), TEXT: tok})
if position_info is not None:
sent[-1][START_CHAR] = position_info[0]
sent[-1][END_CHAR] = position_info[1]
for etok in expansion:
sent.append({ID: (i+1, ), TEXT: etok})
i += 1
else:
if len(tok) <= 0:
continue
sent.append({ID: (i+1, ), TEXT: tok})
if position_info is not None:
sent[-1][START_CHAR] = position_info[0]
sent[-1][END_CHAR] = position_info[1]
if p == 3 or p == 4:# MARK
sent[-1][MISC] = 'MWT=Yes'
i += 1
return sent
# https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression
EMAIL_RAW_RE = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])"""
# https://stackoverflow.com/questions/3809401/what-is-a-good-regular-expression-to-match-a-url
# modification: disallow " as opposed to all ^\s
URL_RAW_RE = r"""(?:https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s"]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s"]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s"]{2,}|www\.[a-zA-Z0-9]+\.[^\s"]{2,})|[a-zA-Z0-9]+\.(?:gov|org|edu|net|com|co)(?:\.[^\s"]{2,})"""
MASK_RE = re.compile(f"(?:{EMAIL_RAW_RE}|{URL_RAW_RE})")
def find_spans(raw):
"""
Return spans of text which don't contain and are split by
"""
pads = [idx for idx, char in enumerate(raw) if char == '']
if len(pads) == 0:
spans = [(0, len(raw))]
else:
prev = 0
spans = []
for pad in pads:
if pad != prev:
spans.append( (prev, pad) )
prev = pad + 1
if prev < len(raw):
spans.append( (prev, len(raw)) )
return spans
def update_pred_regex(raw, pred):
"""
Update the results of a tokenization batch by checking the raw text against a couple regular expressions
Currently, emails and urls are handled
TODO: this might work better as a constraint on the inference
for efficiency pred is modified in place
"""
spans = find_spans(raw)
for span_begin, span_end in spans:
text = "".join(raw[span_begin:span_end])
for match in MASK_RE.finditer(text):
match_begin, match_end = match.span()
# first, update all characters touched by the regex to not split
# with the exception of the last character...
for char in range(match_begin+span_begin, match_end+span_begin-1):
pred[char] = 0
# if the last character is not currently a split, make it a word split
if pred[match_end+span_begin-1] == 0:
pred[match_end+span_begin-1] = 1
return pred
SPACE_RE = re.compile(r'\s')
SPACE_SPLIT_RE = re.compile(r'( *[^ ]+)')
def predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers):
"""
The guts of the prediction method
Calls trainer.predict() over and over until we have predictions for all of the text
"""
all_preds = []
all_raw = []
sorted_data = SortedDataset(data_generator)
dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers)
for batch_idx, batch in enumerate(dataloader):
num_sentences = len(batch[3])
# being sorted by descending length, we need to use 0 as the longest sentence
N = len(batch[3][0])
for paragraph in batch[3]:
all_raw.append(list(paragraph))
if N <= max_seqlen:
pred = np.argmax(trainer.predict(batch), axis=2)
else:
# TODO: we could shortcircuit some processing of
# long strings of PAD by tracking which rows are finished
idx = [0] * num_sentences
adv = [0] * num_sentences
para_lengths = [x.index('') for x in batch[3]]
pred = [[] for _ in range(num_sentences)]
while True:
ens = [min(N - idx1, max_seqlen) for idx1, N in zip(idx, para_lengths)]
en = max(ens)
batch1 = batch[0][:, :en], batch[1][:, :en], batch[2][:, :en], [x[:en] for x in batch[3]]
pred1 = np.argmax(trainer.predict(batch1), axis=2)
for j in range(num_sentences):
sentbreaks = np.where((pred1[j] == 2) + (pred1[j] == 4))[0]
if len(sentbreaks) <= 0 or idx[j] >= para_lengths[j] - max_seqlen:
advance = ens[j]
else:
advance = np.max(sentbreaks) + 1
pred[j] += [pred1[j, :advance]]
idx[j] += advance
adv[j] = advance
if all([idx1 >= N for idx1, N in zip(idx, para_lengths)]):
break
# once we've made predictions on a certain number of characters for each paragraph (recorded in `adv`),
# we skip the first `adv` characters to make the updated batch
batch = data_generator.advance_old_batch(adv, batch)
pred = [np.concatenate(p, 0) for p in pred]
for par_idx in range(num_sentences):
offset = batch_idx * batch_size + par_idx
raw = all_raw[offset]
par_len = raw.index('')
raw = raw[:par_len]
all_raw[offset] = raw
if pred[par_idx][par_len-1] < 2:
pred[par_idx][par_len-1] = 2
elif pred[par_idx][par_len-1] > 2:
pred[par_idx][par_len-1] = 4
if use_regex_tokens:
all_preds.append(update_pred_regex(raw, pred[par_idx][:par_len]))
else:
all_preds.append(pred[par_idx][:par_len])
all_preds = sorted_data.unsort(all_preds)
all_raw = sorted_data.unsort(all_raw)
return all_preds, all_raw
def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, max_seqlen=1000, orig_text=None, no_ssplit=False, use_regex_tokens=True, num_workers=0, postprocessor=None):
batch_size = trainer.args['batch_size']
max_seqlen = max(1000, max_seqlen)
all_preds, all_raw = predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers)
use_la_ittb_shorthand = trainer.args['shorthand'] == 'la_ittb'
skip_newline = trainer.args['skip_newline']
oov_count, offset, doc = decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand)
# If we are provided a postprocessor, we prepare a list of pre-tokenized words and mwt flags and
# call the postprocessor for analysis.
if postprocessor:
doc = postprocess_doc(doc, postprocessor, orig_text)
if output_file: CoNLL.dict2conll(doc, output_file)
return oov_count, offset, all_preds, doc
def postprocess_doc(doc, postprocessor, orig_text=None):
"""Applies a postprocessor on the doc"""
# get a list of all the words in the "draft" document to pass to the postprocessor
# the words array looks like [["words, "words", "words"], ["words, ("i_am_a_mwt", True), "I_am_not"]]
# and the postprocessor is expected to return in the same format
words = [[((word["text"], True)
if word.get("misc") == "MWT=Yes"
else word["text"]) for word in sentence]
for sentence in doc]
if not orig_text:
raw_text = "".join("".join(i) for i in all_raw) # template to compare the stitched text against
else:
raw_text = orig_text
# perform correction with the postprocessor
postprocessor_return = postprocessor(words)
# collect the words and MWTs separately
corrected_words = []
corrected_mwts = []
corrected_expansions = []
# for each word, if its just a string (without the ("word", mwt_bool) format)
# we default that the word is not a MWT.
for sent in postprocessor_return:
sent_words = []
sent_mwts = []
sent_expansions = []
for word in sent:
if isinstance(word, str):
sent_words.append(word)
sent_mwts.append(False)
sent_expansions.append(None)
else:
if isinstance(word[1], bool):
sent_words.append(word[0])
sent_mwts.append(word[1])
sent_expansions.append(None)
else:
sent_words.append(word[0])
sent_mwts.append(True)
# expansions are marked in a space-separated list, which
# `stanza.common.doc.set_mwt_expansions` reads and splits again
# by splitting by spaces. Therefore, to serialize the users' supplied MWT
# information, we join them by spaces to be split later by
# `set_mwt_expansions`.
sent_expansions.append(" ".join(word[1]))
corrected_words.append(sent_words)
corrected_mwts.append(sent_mwts)
corrected_expansions.append(sent_expansions)
# check postprocessor output
token_lens = [len(i) for i in corrected_words]
mwt_lens = [len(i) for i in corrected_mwts]
assert token_lens == mwt_lens, "Postprocessor returned token and MWT lists of different length! Token list lengths %s, MWT list lengths %s" % (token_lens, mwt_lens)
# reassemble document. offsets and oov shouldn't change
doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts,
corrected_expansions, raw_text)
return doc
def reassemble_doc_from_tokens(tokens, mwts, expansions, raw_text):
"""Assemble a Stanza document list format from a list of string tokens, calculating offsets as needed.
Parameters
----------
tokens : List[List[str]]
A list of sentences, which includes string tokens.
mwts : List[List[bool]]
Whether or not each of the tokens are MWTs to be analyzed by the MWT system.
expansions : List[List[Optional[List[str]]]]
A list of possible expansions for MWTs, or None if no user-defined expansion
is given.
parser_text : str
The raw text off of which we can compare offsets.
Returns
-------
List[List[Dict]]
List of words and their offsets, used as `doc`.
"""
# oov count and offset stays the same; doc gets regenerated
new_offset = 0
corrected_doc = []
for sent_words, sent_mwts, sent_expansions in zip(tokens, mwts, expansions):
sentence_doc = []
for indx, (word, mwt, expansion) in enumerate(zip(sent_words, sent_mwts, sent_expansions)):
try:
offset_index = raw_text.index(word, new_offset)
except ValueError as e:
sub_start = max(0, new_offset - 20)
sub_end = min(len(raw_text), new_offset + 20)
sub = raw_text[sub_start:sub_end]
raise ValueError("Could not find word |%s| starting from char_offset %d. Surrounding text: |%s|. \n Hint: did you accidentally add/subtract a symbol/character such as a space when combining tokens?" % (word, new_offset, sub)) from e
wd = {
"id": (indx+1,), "text": word,
"start_char": offset_index,
"end_char": offset_index+len(word)
}
if expansion:
wd["manual_expansion"] = True
elif mwt:
wd["misc"] = "MWT=Yes"
sentence_doc.append(wd)
# start the next search after the previous word ended
new_offset = offset_index+len(word)
corrected_doc.append(sentence_doc)
# use the built in MWT system to expand MWTs
doc = Document(corrected_doc, raw_text)
doc.set_mwt_expansions([j
for i in expansions
for j in i if j],
process_manual_expanded=True)
return doc.to_dict()
def decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand):
"""
Decode the predictions into a document of words
Once everything is fed through the tokenizer model, it's time to decode the predictions
into actual tokens and sentences that the rest of the pipeline uses
"""
offset = 0
oov_count = 0
doc = []
text = SPACE_RE.sub(' ', orig_text) if orig_text is not None else None
char_offset = 0
if vocab is not None:
UNK_ID = vocab.unit2id('')
for raw, pred in zip(all_raw, all_preds):
current_tok = ''
current_sent = []
for t, p in zip(raw, pred):
if t == '':
break
# hack la_ittb
if use_la_ittb_shorthand and t in (":", ";"):
p = 2
offset += 1
if vocab is not None and vocab.unit2id(t) == UNK_ID:
oov_count += 1
current_tok += t
if p >= 1:
if vocab is not None:
tok = vocab.normalize_token(current_tok)
else:
tok = current_tok
assert '\t' not in tok, tok
if len(tok) <= 0:
current_tok = ''
continue
if orig_text is not None:
st = -1
tok_len = 0
for part in SPACE_SPLIT_RE.split(current_tok):
if len(part) == 0: continue
if skip_newline:
part_pattern = re.compile(r'\s*'.join(re.escape(c) for c in part))
match = part_pattern.search(text, char_offset)
st0 = match.start(0) - char_offset
partlen = match.end(0) - match.start(0)
lstripped = match.group(0).lstrip()
else:
try:
st0 = text.index(part, char_offset) - char_offset
except ValueError as e:
sub_start = max(0, char_offset - 20)
sub_end = min(len(text), char_offset + 20)
sub = text[sub_start:sub_end]
raise ValueError("Could not find |%s| starting from char_offset %d. Surrounding text: |%s|" % (part, char_offset, sub)) from e
partlen = len(part)
lstripped = part.lstrip()
if st < 0:
st = char_offset + st0 + (partlen - len(lstripped))
char_offset += st0 + partlen
position_info = (st, char_offset)
else:
position_info = None
current_sent.append((tok, p, position_info))
current_tok = ''
if (p == 2 or p == 4) and not no_ssplit:
doc.append(process_sentence(current_sent, mwt_dict))
current_sent = []
if len(current_tok) > 0:
raise ValueError("Finished processing tokens, but there is still text left!")
if len(current_sent):
doc.append(process_sentence(current_sent, mwt_dict))
return oov_count, offset, doc
def match_tokens_with_text(sentences, orig_text):
"""
Turns pretokenized text and the original text into a Doc object
sentences: list of list of string
orig_text: string, where the text must be exactly the sentences
concatenated with 0 or more whitespace characters
if orig_text deviates in any way, a ValueError will be thrown
"""
text = "".join(["".join(x) for x in sentences])
all_raw = list(text)
all_preds = [0] * len(all_raw)
offset = 0
for sentence in sentences:
for word in sentence:
offset += len(word)
all_preds[offset-1] = 1
all_preds[offset-1] = 2
_, _, doc = decode_predictions(None, None, orig_text, [all_raw], [all_preds], False, False, False)
doc = Document(doc, orig_text)
# check that all the orig_text was used up by the tokens
offset = doc.sentences[-1].tokens[-1].end_char
remainder = orig_text[offset:].strip()
if len(remainder) > 0:
raise ValueError("Finished processing tokens, but there is still text left!")
return doc
def eval_model(args, trainer, batches, vocab, mwt_dict):
oov_count, N, all_preds, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])
all_preds = np.concatenate(all_preds, 0)
labels = np.concatenate(batches.labels())
counter = Counter(zip(all_preds, labels))
def f1(pred, gold, mapping):
pred = [mapping[p] for p in pred]
gold = [mapping[g] for g in gold]
lastp = -1; lastg = -1
tp = 0; fp = 0; fn = 0
for i, (p, g) in enumerate(zip(pred, gold)):
if p == g > 0 and lastp == lastg:
lastp = i
lastg = i
tp += 1
elif p > 0 and g > 0:
lastp = i
lastg = i
fp += 1
fn += 1
elif p > 0:
# and g == 0
lastp = i
fp += 1
elif g > 0:
lastg = i
fn += 1
if tp == 0:
return 0
else:
return 2 * tp / (2 * tp + fp + fn)
f1tok = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:1, 4:1})
f1sent = f1(all_preds, labels, {0:0, 1:0, 2:1, 3:0, 4:1})
f1mwt = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:2, 4:2})
logger.info(f"{args['shorthand']}: token F1 = {f1tok*100:.2f}, sentence F1 = {f1sent*100:.2f}, mwt F1 = {f1mwt*100:.2f}")
return harmonic_mean([f1tok, f1sent, f1mwt], [1, 1, .01])
================================================
FILE: stanza/models/tokenization/vocab.py
================================================
from collections import Counter
import re
from stanza.models.common.vocab import BaseVocab
from stanza.models.common.vocab import UNK, PAD
SPACE_RE = re.compile(r'\s')
class Vocab(BaseVocab):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lang_replaces_spaces = any([self.lang.startswith(x) for x in ['zh', 'ja', 'ko']])
def build_vocab(self):
paras = self.data
counter = Counter()
for para in paras:
for unit in para:
normalized = self.normalize_unit(unit[0])
counter[normalized] += 1
self._id2unit = [PAD, UNK] + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
def append(self, unit):
self._id2unit.append(unit)
idx = len(self._id2unit) - 1
self._unit2id[unit] = idx
def normalize_unit(self, unit):
# Normalize minimal units used by the tokenizer
return unit
def normalize_token(self, token):
token = SPACE_RE.sub(' ', token.lstrip())
if self.lang_replaces_spaces:
token = token.replace(' ', '')
return token
================================================
FILE: stanza/models/tokenizer.py
================================================
"""
Entry point for training and evaluating a neural tokenizer.
This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of
recurrent and convolutional architectures.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that
have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in
training dataset and external lexicon (if any) is created during training and saved alongside the model after training.
Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation,
dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed
found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing
words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes
and suffixes are used to stop early during the window-dictionary checking process.
"""
import argparse
from copy import copy
import logging
import random
import numpy as np
import os
import torch
import json
from stanza.models.common import utils
from stanza.models.tokenization.trainer import Trainer
from stanza.models.tokenization.data import DataLoader, TokenizationDataset
from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
def build_argparse():
"""
If args == None, the system args are used.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--txt_file', type=str, help="Input plaintext file")
parser.add_argument('--label_file', type=str, default=None, help="Character-level label file")
parser.add_argument('--mwt_json_file', type=str, default=None, help="JSON file for MWT expansions")
parser.add_argument('--conll_file', type=str, default=None, help="CoNLL file for output")
parser.add_argument('--dev_txt_file', type=str, help="(Train only) Input plaintext file for the dev set")
parser.add_argument('--dev_label_file', type=str, default=None, help="(Train only) Character-level label file for the dev set")
parser.add_argument('--dev_conll_gold', type=str, default=None, help="(Train only) CoNLL-U file for the dev set for early stopping")
parser.add_argument('--lang', type=str, help="Language")
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.")
parser.add_argument('--emb_dim', type=int, default=32, help="Dimension of unit embeddings")
parser.add_argument('--hidden_dim', type=int, default=64, help="Dimension of hidden units")
parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.")
parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections")
parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer")
parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers")
parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well")
parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN")
parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer")
parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt")
parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to")
parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate")
parser.add_argument('--anneal_after', type=int, default=2000, help="Anneal the learning rate no earlier than this step")
parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate")
parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability")
parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability")
parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector")
parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability")
parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN")
parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.")
parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help="Probability to drop the last char of a block of text during training, uniformly at random. Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period")
parser.add_argument('--last_char_move_prob', type=float, default=0.02, help="Probability to move the sentence final punctuation of a sentence during training, uniformly at random. Idea is to teach the tokenizer that a space separated sentence final punct still ends the sentence")
parser.add_argument('--punct_move_back_prob', type=float, default=0.02, help="Probability to move a comma in the sentence one over, removing the previous space, during training. Idea is to teach the tokenizer that commas can appear next to words even in languages where the dataset doesn't allow it, such as Vietnamese")
parser.add_argument('--split_mwt_prob', type=float, default=0.01, help="Probably to split an MWT into its component pieces and turn it into separate words")
parser.add_argument('--augment_final_punct_prob', type=float, default=0.05, help="Probability to replace a ? with a ? or other similar augmentations")
parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time")
parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use")
parser.add_argument('--epochs', type=int, default=10, help="Total epochs to train the model for")
parser.add_argument('--steps', type=int, default=50000, help="Steps to train the model for, if unspecified use epochs")
parser.add_argument('--report_steps', type=int, default=20, help="Update step interval to report loss")
parser.add_argument('--shuffle_steps', type=int, default=100, help="Step interval to shuffle each paragraph in the generator")
parser.add_argument('--eval_steps', type=int, default=200, help="Step interval to evaluate the model on the dev set for early stopping")
parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving')
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tokenizer.pt", help="File name to save the model")
parser.add_argument('--load_name', type=str, default=None, help="File name to load a saved model")
parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help="Directory to save models in")
utils.add_device_args(parser)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers. If set to None, this will be determined by examining the training data for MWTs')
parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers')
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
return parser
def parse_args(args=None):
parser = build_argparse()
args = parser.parse_args(args=args)
if args.wandb_name:
args.wandb = True
args = vars(args)
return args
def model_file_name(args):
embedding = "nocharlm"
if args['charlm'] and args['charlm_forward_file']:
embedding = "charlm"
save_name = args['save_name'].format(shorthand=args['shorthand'],
embedding=embedding)
logger.info("Saving to: %s", save_name)
if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name):
return save_name
return os.path.join(args['save_dir'], save_name)
def main(args=None):
args = parse_args(args=args)
utils.set_random_seed(args['seed'])
logger.info("Running tokenizer in {} mode".format(args['mode']))
args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para']
args['feat_dim'] = len(args['feat_funcs'])
args['save_name'] = model_file_name(args)
utils.ensure_dir(os.path.split(args['save_name'])[0])
if args['mode'] == 'train':
return train(args)
else:
return evaluate(args)
def train(args):
if args['use_dictionary']:
#load lexicon
lexicon, args['num_dict_feat'] = load_lexicon(args)
#create the dictionary
dictionary = create_dictionary(lexicon)
#adjust the feat_dim
args['feat_dim'] += args['num_dict_feat']*2
else:
args['num_dict_feat'] = 0
lexicon=None
dictionary=None
mwt_dict = load_mwt_dict(args['mwt_json_file'])
mwt_expansions = {x: y[0] for x, y in mwt_dict.items()}
train_input_files = {
'txt': args['txt_file'],
'label': args['label_file']
}
train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary, mwt_expansions=mwt_expansions)
vocab = train_batches.vocab
args['vocab_size'] = len(vocab)
dev_input_files = {
'txt': args['dev_txt_file'],
'label': args['dev_label_file']
}
dev_batches = TokenizationDataset(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary)
if args['use_mwt'] is None:
args['use_mwt'] = train_batches.has_mwt()
logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt']))
trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], foundation_cache=None)
if args['load_name'] is not None:
load_name = os.path.join(args['save_dir'], args['load_name'])
trainer.load(load_name)
trainer.change_lr(args['lr0'])
N = len(train_batches)
steps = args['steps'] if args['steps'] is not None else int(N * args['epochs'] / args['batch_size'] + .5)
lr = args['lr0']
prev_dev_score = -1
best_dev_score = -1
best_dev_step = -1
if args['wandb']:
import wandb
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tokenizer" % args['shorthand']
wandb.init(name=wandb_name, config=args)
wandb.run.define_metric('train_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
for step in range(1, steps+1):
batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])
loss = trainer.update(batch)
if step % args['report_steps'] == 0:
logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss))
if args['wandb']:
wandb.log({'train_loss': loss}, step=step)
if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0:
train_batches.shuffle()
if step % args['eval_steps'] == 0:
dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict)
if args['wandb']:
wandb.log({'dev_score': dev_score}, step=step)
reports = ['Dev score: {:6.3f}'.format(dev_score * 100)]
if step >= args['anneal_after'] and dev_score < prev_dev_score:
reports += ['lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])]
lr *= args['anneal']
trainer.change_lr(lr)
prev_dev_score = dev_score
if dev_score > best_dev_score:
reports += ['New best dev score!']
best_dev_score = dev_score
best_dev_step = step
trainer.save(args['save_name'])
elif best_dev_step > 0 and step - best_dev_step > args['max_steps_before_stop']:
reports += ['Stopping training after {} steps with no improvement'.format(step - best_dev_step)]
logger.info('\t'.join(reports))
break
logger.info('\t'.join(reports))
if args['wandb']:
wandb.finish()
if best_dev_step > -1:
logger.info('Best dev score={} at step {}'.format(best_dev_score, best_dev_step))
else:
logger.info('Dev set never evaluated. Saving final model')
trainer.save(args['save_name'])
return trainer, None
def evaluate(args):
mwt_dict = load_mwt_dict(args['mwt_json_file'])
trainer = Trainer(args=args, model_file=args['load_name'] or args['save_name'], device=args['device'], foundation_cache=None)
loaded_args, vocab = trainer.args, trainer.vocab
for k in loaded_args:
if not k.endswith('_file') and k not in ['device', 'mode', 'save_dir', 'load_name', 'save_name']:
args[k] = loaded_args[k]
eval_input_files = {
'txt': args['txt_file'],
'label': args['label_file']
}
batches = TokenizationDataset(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary)
oov_count, N, _, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])
logger.info("OOV rate: {:6.3f}% ({:6d}/{:6d})".format(oov_count / N * 100, oov_count, N))
return trainer, doc
if __name__ == '__main__':
main()
================================================
FILE: stanza/models/wl_coref.py
================================================
"""
Runs experiments with CorefModel.
Try 'python wl_coref.py -h' for more details.
Code based on
https://github.com/KarelDO/wl-coref/tree/master
https://arxiv.org/abs/2310.06165
This was a fork of
https://github.com/vdobrovolskii/wl-coref
https://aclanthology.org/2021.emnlp-main.605/
If you use Stanza's coref module in your work, please cite the following:
@misc{doosterlinck2023cawcoref,
title={CAW-coref: Conjunction-Aware Word-level Coreference Resolution},
author={Karel D'Oosterlinck and Semere Kiros Bitew and Brandon Papineau and Christopher Potts and Thomas Demeester and Chris Develder},
year={2023},
eprint={2310.06165},
archivePrefix={arXiv},
primaryClass={cs.CL},
url = "https://arxiv.org/abs/2310.06165",
}
@inproceedings{dobrovolskii-2021-word,
title = "Word-Level Coreference Resolution",
author = "Dobrovolskii, Vladimir",
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
month = nov,
year = "2021",
address = "Online and Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.emnlp-main.605",
pages = "7670--7675"
}
"""
import argparse
from contextlib import contextmanager
import datetime
import logging
import os
import random
import sys
import dataclasses
import time
import numpy as np # type: ignore
import torch # type: ignore
from stanza.models.common.utils import set_random_seed
from stanza.models.coref.model import CorefModel
logger = logging.getLogger('stanza')
@contextmanager
def output_running_time():
""" Prints the time elapsed in the context """
start = int(time.time())
try:
yield
finally:
end = int(time.time())
delta = datetime.timedelta(seconds=end - start)
logger.info(f"Total running time: {delta}")
def deterministic() -> None:
torch.backends.cudnn.deterministic = True # type: ignore
torch.backends.cudnn.benchmark = False # type: ignore
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("mode", choices=("train", "eval"))
argparser.add_argument("experiment")
argparser.add_argument("--config_file", default="config.toml")
argparser.add_argument("--data_split", choices=("train", "dev", "test"),
default="test",
help="Data split to be used for evaluation."
" Defaults to 'test'."
" Ignored in 'train' mode.")
argparser.add_argument("--batch_size", type=int,
help="Adjust to override the config value of anaphoricity "
"batch size if you are experiencing out-of-memory "
"issues")
argparser.add_argument("--disable_singletons", action="store_true",
help="don't predict singletons")
argparser.add_argument("--full_pairwise", action="store_true",
help="use speaker and document embeddings")
argparser.add_argument("--hidden_size", type=int,
help="Adjust the anaphoricity scorer hidden size")
argparser.add_argument("--rough_k", type=int,
help="Adjust the number of dummies to keep")
argparser.add_argument("--n_hidden_layers", type=int,
help="Adjust the anaphoricity scorer hidden layers")
argparser.add_argument("--dummy_mix", type=float,
help="Adjust the dummy mix")
argparser.add_argument("--bert_finetune_begin_epoch", type=float,
help="Adjust the bert finetune begin epoch")
argparser.add_argument("--bert_model", type=str,
help="Use this transformer for the given experiment")
argparser.add_argument("--warm_start", action="store_true",
help="If set, the training will resume from the"
" last checkpoint saved if any. Ignored in"
" evaluation modes."
" Incompatible with '--weights'.")
argparser.add_argument("--weights",
help="Path to file with weights to load."
" If not supplied, in 'eval' mode the latest"
" weights of the experiment will be loaded;"
" in 'train' mode no weights will be loaded.")
argparser.add_argument("--word_level", action="store_true",
help="If set, output word-level conll-formatted"
" files in evaluation modes. Ignored in"
" 'train' mode.")
argparser.add_argument("--learning_rate", default=None, type=float,
help="If set, update the learning rate for the model")
argparser.add_argument("--bert_learning_rate", default=None, type=float,
help="If set, update the learning rate for the transformer")
argparser.add_argument("--save_dir", default=None,
help="If set, update the save directory for writing models")
argparser.add_argument("--save_name", default=None,
help="If set, update the save name for writing models (otherwise, section name)")
argparser.add_argument("--score_lang", default=None,
help="only score a particular language for eval")
argparser.add_argument("--log_norms", action="store_true", default=None,
help="If set, log all of the trainable norms each epoch. Very noisy!")
argparser.add_argument("--seed", type=int, default=2020,
help="Random seed to set")
argparser.add_argument("--lang_lr_attenuation", type=str, default=None,
help="A comma-separated list of languages where the LR will be scaled by 1/epoch, such as --lang_lr_attenuation=es,en,de,...")
argparser.add_argument("--lang_lr_weights", type=str, default=None,
help="A comma-separated list of languages and their weights of LR scaling for different languages, such as es=0.5,en=1.0,...")
argparser.add_argument("--max_train_len", type=int, default=5000,
help="Skip any documents longer than this maximum length")
argparser.add_argument("--no_max_train_len", action="store_const", const=float("inf"), dest="max_train_len",
help="Do not skip any documents for being too long")
argparser.add_argument("--train_epochs", type=int, default=None,
help="Train this many epochs")
argparser.add_argument("--plateau_epochs", type=int, default=None,
help="Stop training if plateaued for this many epochs (only applies if positive)")
argparser.add_argument("--train_data", default=None, help="File to use for train data")
argparser.add_argument("--dev_data", default=None, help="File to use for dev data")
argparser.add_argument("--test_data", default=None, help="File to use for test data")
argparser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name', default=False)
argparser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
args = argparser.parse_args()
if args.warm_start and args.weights is not None:
raise ValueError("The following options are incompatible: '--warm_start' and '--weights'")
set_random_seed(args.seed)
deterministic()
config = CorefModel._load_config(args.config_file, args.experiment)
if args.batch_size:
config.a_scoring_batch_size = args.batch_size
if args.hidden_size:
config.hidden_size = args.hidden_size
if args.n_hidden_layers:
config.n_hidden_layers = args.n_hidden_layers
if args.learning_rate is not None:
config.learning_rate = args.learning_rate
if args.bert_model is not None:
config.bert_model = args.bert_model
if args.bert_learning_rate is not None:
config.bert_learning_rate = args.bert_learning_rate
if args.bert_finetune_begin_epoch is not None:
config.bert_finetune_begin_epoch = args.bert_finetune_begin_epoch
if args.dummy_mix is not None:
config.dummy_mix = args.dummy_mix
if args.save_dir is not None:
config.save_dir = args.save_dir
if args.save_name:
config.save_name = args.save_name
else:
config.save_name = args.experiment
if args.rough_k is not None:
config.rough_k = args.rough_k
if args.log_norms is not None:
config.log_norms = args.log_norms
if args.full_pairwise:
config.full_pairwise = args.full_pairwise
if args.disable_singletons:
config.singletons = False
if args.train_data:
config.train_data = args.train_data
if args.dev_data:
config.dev_data = args.dev_data
if args.test_data:
config.test_data = args.test_data
if args.max_train_len:
config.max_train_len = args.max_train_len
if args.train_epochs:
config.train_epochs = args.train_epochs
if args.plateau_epochs:
config.plateau_epochs = args.plateau_epochs
if args.lang_lr_attenuation:
config.lang_lr_attenuation = args.lang_lr_attenuation
if args.lang_lr_weights:
config.lang_lr_weights = args.lang_lr_weights
# if wandb, generate wandb configuration
if args.mode == "train":
if args.wandb:
import wandb
wandb_name = args.wandb_name if args.wandb_name else f"wl_coref_{args.experiment}"
wandb.init(name=wandb_name, config=dataclasses.asdict(config), project="stanza")
wandb.run.define_metric('train_c_loss', summary='min')
wandb.run.define_metric('train_s_loss', summary='min')
wandb.run.define_metric('dev_score', summary='max')
model = CorefModel(config=config)
if args.weights is not None or args.warm_start:
model.load_weights(path=args.weights, map_location="cpu",
noexception=args.warm_start)
with output_running_time():
model.train(args.wandb)
else:
config_update = {
'log_norms': args.log_norms if args.log_norms is not None else False
}
if args.test_data:
config_update['test_data'] = args.test_data
if args.weights is None and config.save_name is not None:
args.weights = config.save_name
if not os.path.exists(args.weights) and os.path.exists(args.weights + ".pt"):
args.weights = args.weights + ".pt"
elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights)):
args.weights = os.path.join(config.save_dir, args.weights)
elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights + ".pt")):
args.weights = os.path.join(config.save_dir, args.weights + ".pt")
model = CorefModel.load_model(path=args.weights, map_location="cpu",
ignore={"bert_optimizer", "general_optimizer",
"bert_scheduler", "general_scheduler"},
config_update=config_update)
results = model.evaluate(data_split=args.data_split,
word_level_conll=args.word_level,
eval_lang=args.score_lang)
# logger.info(("mean loss", "))
print("\t".join([str(round(i, 3)) for i in results]))
================================================
FILE: stanza/pipeline/__init__.py
================================================
================================================
FILE: stanza/pipeline/_constants.py
================================================
""" Module defining constants """
# string constants for processor names
LANGID = 'langid'
TOKENIZE = 'tokenize'
MWT = 'mwt'
POS = 'pos'
LEMMA = 'lemma'
DEPPARSE = 'depparse'
NER = 'ner'
SENTIMENT = 'sentiment'
CONSTITUENCY = 'constituency'
COREF = 'coref'
MORPHSEG = 'morphseg'
================================================
FILE: stanza/pipeline/constituency_processor.py
================================================
"""
Processor that attaches a constituency tree to a sentence
"""
from stanza.models.constituency.trainer import Trainer
from stanza.models.common import doc
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.utils.get_tqdm import get_tqdm
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
tqdm = get_tqdm()
@register_processor(CONSTITUENCY)
class ConstituencyProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([CONSTITUENCY])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE, POS])
# default batch size, measured in sentences
DEFAULT_BATCH_SIZE = 50
def _set_up_requires(self):
self._pretagged = self._config.get('pretagged')
if self._pretagged:
self._requires = set()
else:
self._requires = self.__class__.REQUIRES_DEFAULT
def _set_up_model(self, config, pipeline, device):
# set up model
# pretrain and charlm paths are args from the config
# bert (if used) will be chosen from the model save file
args = {
"wordvec_pretrain_file": config.get('pretrain_path', None),
"charlm_forward_file": config.get('forward_charlm_path', None),
"charlm_backward_file": config.get('backward_charlm_path', None),
"device": device,
}
trainer = Trainer.load(filename=config['model_path'],
args=args,
foundation_cache=pipeline.foundation_cache)
self._trainer = trainer
self._model = trainer.model
self._model.eval()
# batch size counted as sentences
self._batch_size = int(config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE))
self._tqdm = 'tqdm' in config and config['tqdm']
def _set_up_final_config(self, config):
loaded_args = self._model.args
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
loaded_args.update(config)
self._config = loaded_args
def process(self, document):
sentences = document.sentences
if self._model.uses_xpos():
words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
else:
words = [[(w.text, w.upos) for w in s.words] for s in sentences]
words, original_indices = sort_with_indices(words, key=len, reverse=True)
if self._tqdm:
words = tqdm(words)
trees = self._model.parse_tagged_words(words, self._batch_size)
trees = unsort(trees, original_indices)
document.set(CONSTITUENCY, trees, to_sentence=True)
return document
def get_constituents(self):
"""
Return a set of the constituents known by this model
For a pipeline, this can be queried with
pipeline.processors["constituency"].get_constituents()
"""
return set(self._model.constituents)
================================================
FILE: stanza/pipeline/core.py
================================================
"""
Pipeline that runs tokenize,mwt,pos,lemma,depparse
"""
import argparse
import collections
from enum import Enum
import io
import itertools
import sys
import logging
import json
import os
from stanza.pipeline._constants import *
from stanza.models.common.constant import langcode_to_lang
from stanza.models.common.doc import Document
from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.common.utils import default_device
from stanza.pipeline.processor import Processor, ProcessorRequirementsException
from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
from stanza.pipeline.langid_processor import LangIDProcessor
from stanza.pipeline.tokenize_processor import TokenizeProcessor
from stanza.pipeline.mwt_processor import MWTProcessor
from stanza.pipeline.pos_processor import POSProcessor
from stanza.pipeline.lemma_processor import LemmaProcessor
from stanza.pipeline.constituency_processor import ConstituencyProcessor
from stanza.pipeline.coref_processor import CorefProcessor
from stanza.pipeline.depparse_processor import DepparseProcessor
from stanza.pipeline.sentiment_processor import SentimentProcessor
from stanza.pipeline.ner_processor import NERProcessor
from stanza.resources.common import DEFAULT_MODEL_DIR, DEFAULT_RESOURCES_URL, DEFAULT_RESOURCES_VERSION, ModelSpecification, add_dependencies, add_mwt, download_models, download_resources_json, flatten_processor_list, load_resources_json, maintain_processor_list, process_pipeline_parameters, set_logging_level, sort_processors
from stanza.resources.default_packages import PACKAGES
from stanza.utils.conll import CoNLL, CoNLLError
from stanza.utils.helper_func import make_table
logger = logging.getLogger('stanza')
class DownloadMethod(Enum):
"""
Determines a couple options on how to download resources for the pipeline.
NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place.
REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models.
DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models.
"""
NONE = 1
REUSE_RESOURCES = 2
DOWNLOAD_RESOURCES = 3
class LanguageNotDownloadedError(FileNotFoundError):
def __init__(self, lang, lang_dir, model_path):
super().__init__(f'Could not find the model file {model_path}. The expected model directory {lang_dir} is missing. Perhaps you need to run stanza.download("{lang}")')
self.lang = lang
self.lang_dir = lang_dir
self.model_path = model_path
class UnsupportedProcessorError(FileNotFoundError):
def __init__(self, processor, lang):
super().__init__(f'Processor {processor} is not known for language {lang}. If you have created your own model, please specify the {processor}_model_path parameter when creating the pipeline.')
self.processor = processor
self.lang = lang
class IllegalPackageError(ValueError):
def __init__(self, msg):
super().__init__(msg)
class PipelineRequirementsException(Exception):
"""
Exception indicating one or more requirements failures while attempting to build a pipeline.
Contains a ProcessorRequirementsException list.
"""
def __init__(self, processor_req_fails):
self._processor_req_fails = processor_req_fails
self.build_message()
@property
def processor_req_fails(self):
return self._processor_req_fails
def build_message(self):
err_msg = io.StringIO()
print(*[req_fail.message for req_fail in self.processor_req_fails], sep='\n', file=err_msg)
self.message = '\n\n' + err_msg.getvalue()
def __str__(self):
return self.message
def build_default_config_option(model_specs):
"""
Build a config option for a couple situations: lemma=identity, processor is a variant
Returns the option name and value
Refactored from build_default_config so that we can reuse it when
downloading all models
"""
# handle case when processor variants are used
if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs):
if len(model_specs) > 1:
raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec.processor))
return f"{model_specs[0].processor}_with_{model_specs[0].package}", True
# handle case when identity is specified as lemmatizer
elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs):
if len(model_specs) > 1:
raise IllegalPackageError("Identity processor selected for lemma, but multiple packages requested")
return f"{LEMMA}_use_identity", True
return None
def filter_variants(model_specs):
return [(key, value) for (key, value) in model_specs if build_default_config_option(value) is None]
# given a language and models path, build a default configuration
def build_default_config(resources, lang, model_dir, load_list):
default_config = {}
for processor, model_specs in load_list:
option = build_default_config_option(model_specs)
if option is not None:
# if an option is set for the model_specs, keep that option and ignore
# the rest of the model spec
default_config[option[0]] = option[1]
continue
model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs]
dependencies = [model_spec.dependencies for model_spec in model_specs]
# Special case for NER: load multiple models at once
# The pattern will be:
# a list of ner_model_path
# a list of ner_dependencies
# where each item in ner_dependencies is a map
# the map may contain forward_charlm_path, backward_charlm_path, or any other deps
# The user will be able to override the defaults using a semicolon separated string
# TODO: at least use the same config pattern for all other models
if processor == NER:
default_config[f"{processor}_model_path"] = model_paths
dependency_paths = []
for dependency_block in dependencies:
if not dependency_block:
dependency_paths.append({})
continue
dependency_paths.append({})
for dependency in dependency_block:
dep_processor, dep_model = dependency
dependency_paths[-1][f"{dep_processor}_path"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt')
default_config[f"{processor}_dependencies"] = dependency_paths
continue
if len(model_specs) > 1:
raise IllegalPackageError("Specified multiple packages for {}, which currently only handles one package".format(processor))
default_config[f"{processor}_model_path"] = model_paths[0]
if not dependencies[0]: continue
for dependency in dependencies[0]:
dep_processor, dep_model = dependency
default_config[f"{processor}_{dep_processor}_path"] = os.path.join(
model_dir, lang, dep_processor, dep_model + '.pt'
)
return default_config
def normalize_download_method(download_method):
"""
Turn None -> DownloadMethod.NONE, strings to the corresponding enum
"""
if download_method is None:
return DownloadMethod.NONE
elif isinstance(download_method, str):
try:
return DownloadMethod[download_method.upper()]
except KeyError as e:
raise ValueError("Unknown download method %s" % download_method) from e
return download_method
class Pipeline:
def __init__(self,
lang='en',
dir=DEFAULT_MODEL_DIR,
package='default',
processors={},
logging_level=None,
verbose=None,
use_gpu=None,
model_dir=None,
download_method=DownloadMethod.DOWNLOAD_RESOURCES,
resources_url=DEFAULT_RESOURCES_URL,
resources_branch=None,
resources_version=DEFAULT_RESOURCES_VERSION,
resources_filepath=None,
proxies=None,
foundation_cache=None,
device=None,
allow_unknown_language=False,
**kwargs):
self.lang, self.dir, self.kwargs = lang, dir, kwargs
if model_dir is not None and dir == DEFAULT_MODEL_DIR:
self.dir = model_dir
# set global logging level
set_logging_level(logging_level, verbose)
self.download_method = normalize_download_method(download_method)
if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or
(self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))):
logger.info("Checking for updates to resources.json in case models have been updated. Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES")
download_resources_json(self.dir,
resources_url=resources_url,
resources_branch=resources_branch,
resources_version=resources_version,
resources_filepath=resources_filepath,
proxies=proxies)
# processors can use this to save on the effort of loading
# large sub-models, such as pretrained embeddings, bert, etc
if foundation_cache is None:
self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE))
else:
self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE))
# process different pipeline parameters
lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors)
# Load resources.json to obtain latest packages.
logger.debug('Loading resource file...')
resources = load_resources_json(self.dir, resources_filepath)
if lang in resources:
if 'alias' in resources[lang]:
logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
lang = resources[lang]['alias']
lang_name = resources[lang]['lang_name'] if 'lang_name' in resources[lang] else ''
elif allow_unknown_language:
logger.warning("Trying to create pipeline for unsupported language: %s", lang)
lang_name = langcode_to_lang(lang)
else:
logger.warning("Unsupported language: %s If trying to add a new language, consider using allow_unknown_language=True", lang)
lang_name = langcode_to_lang(lang)
# Maintain load list
if lang in resources:
self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get("tokenize_pretokenized")))
self.load_list = add_dependencies(resources, lang, self.load_list)
if self.download_method is not DownloadMethod.NONE:
# skip processors which aren't downloaded from our collection
download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})]
# skip variants
download_list = filter_variants(download_list)
# gather up the model list...
download_list = flatten_processor_list(download_list)
# download_models will skip models we already have
download_models(download_list,
resources=resources,
lang=lang,
model_dir=self.dir,
resources_version=resources_version,
proxies=proxies,
log_info=False)
elif allow_unknown_language:
self.load_list = [(proc, [ModelSpecification(processor=proc, package='default', dependencies=None)])
for proc in list(processors.keys())]
else:
self.load_list = []
self.load_list = self.update_kwargs(kwargs, self.load_list)
if len(self.load_list) == 0:
if lang not in resources or PACKAGES not in resources[lang]:
raise ValueError(f'No processors to load for language {lang}. Language {lang} is currently unsupported')
else:
raise ValueError('No processors to load for language {}. Please check if your language or package is correctly set.'.format(lang))
load_table = make_table(['Processor', 'Package'], [(row[0], ";".join(model_spec.package for model_spec in row[1])) for row in self.load_list])
logger.info(f'Loading these models for language: {lang} ({lang_name}):\n{load_table}')
self.config = build_default_config(resources, lang, self.dir, self.load_list)
self.config.update(kwargs)
# Load processors
self.processors = {}
# configs that are the same for all processors
pipeline_level_configs = {'lang': lang, 'mode': 'predict'}
if device is None:
if use_gpu is None or use_gpu == True:
device = default_device()
else:
device = 'cpu'
if use_gpu == True and device == 'cpu':
logger.warning("GPU requested, but is not available!")
self.device = device
logger.info("Using device: {}".format(self.device))
# set up processors
pipeline_reqs_exceptions = []
for item in self.load_list:
processor_name, _ = item
logger.info('Loading: ' + processor_name)
curr_processor_config = self.filter_config(processor_name, self.config)
curr_processor_config.update(pipeline_level_configs)
# TODO: this is obviously a hack
# a better solution overall would be to make a pretagged version of the pos annotator
# and then subsequent modules can use those tags without knowing where those tags came from
if "pretagged" in self.config and "pretagged" not in curr_processor_config:
curr_processor_config["pretagged"] = self.config["pretagged"]
logger.debug('With settings: ')
logger.debug(curr_processor_config)
try:
# try to build processor, throw an exception if there is a requirements issue
self.processors[processor_name] = NAME_TO_PROCESSOR_CLASS[processor_name](config=curr_processor_config,
pipeline=self,
device=self.device)
except ProcessorRequirementsException as e:
# if there was a requirements issue, add it to list which will be printed at end
pipeline_reqs_exceptions.append(e)
# add the broken processor to the loaded processors for the sake of analyzing the validity of the
# entire proposed pipeline, but at this point the pipeline will not be built successfully
self.processors[processor_name] = e.err_processor
except FileNotFoundError as e:
# For a FileNotFoundError, we try to guess if there's
# a missing model directory or file. If so, we
# suggest the user try to download the models
if 'model_path' in curr_processor_config:
model_path = curr_processor_config['model_path']
if e.filename == model_path or (isinstance(model_path, (tuple, list)) and e.filename in model_path):
model_path = e.filename
model_dir, model_name = os.path.split(model_path)
lang_dir = os.path.dirname(model_dir)
if lang_dir and not os.path.exists(lang_dir):
# model files for this language can't be found in the expected directory
raise LanguageNotDownloadedError(lang, lang_dir, model_path) from e
if processor_name not in resources[lang]:
# user asked for a model which doesn't exist for this language?
raise UnsupportedProcessorError(processor_name, lang) from e
if not os.path.exists(model_path):
model_name, _ = os.path.splitext(model_name)
# TODO: before recommending this, check that such a thing exists in resources.json.
# currently that case is handled by ignoring the model, anyway
raise FileNotFoundError('Could not find model file %s, although there are other models downloaded for language %s. Perhaps you need to download a specific model. Try: stanza.download(lang="%s",package=None,processors={"%s":"%s"})' % (model_path, lang, lang, processor_name, model_name)) from e
# if we couldn't find a more suitable description of the
# FileNotFoundError, just raise the old error
raise
# if there are any processor exceptions, throw an exception to indicate pipeline build failure
if pipeline_reqs_exceptions:
logger.info('\n')
raise PipelineRequirementsException(pipeline_reqs_exceptions)
logger.info("Done loading processors!")
@staticmethod
def update_kwargs(kwargs, processor_list):
processor_dict = {processor: [{'package': model_spec.package, 'dependencies': model_spec.dependencies} for model_spec in model_specs]
for (processor, model_specs) in processor_list}
for key, value in kwargs.items():
pieces = key.split('_', 1)
if len(pieces) == 1:
continue
k, v = pieces
if v == 'model_path':
package = value if len(value) < 25 else value[:10]+ '...' + value[-10:]
original_spec = processor_dict.get(k, [])
if len(original_spec) > 0:
dependencies = original_spec[0].get('dependencies')
else:
dependencies = None
processor_dict[k] = [{'package': package, 'dependencies': dependencies}]
processor_list = [(processor, [ModelSpecification(processor=processor, package=model_spec['package'], dependencies=model_spec['dependencies']) for model_spec in processor_dict[processor]]) for processor in processor_dict]
processor_list = sort_processors(processor_list)
return processor_list
@staticmethod
def filter_config(prefix, config_dict):
filtered_dict = {}
for key in config_dict.keys():
pieces = key.split('_', 1) # split tokenize_pretokenize to tokenize+pretokenize
if len(pieces) == 1:
continue
k, v = pieces
if k == prefix:
filtered_dict[v] = config_dict[key]
return filtered_dict
@property
def loaded_processors(self):
"""
Return all currently loaded processors in execution order.
:return: list of Processor instances
"""
return [self.processors[processor_name] for processor_name in PIPELINE_NAMES if self.processors.get(processor_name)]
def process(self, doc, processors=None):
"""
Run the pipeline
processors: allow for a list of processors used by this pipeline action
can be list, tuple, set, or comma separated string
if None, use all the processors this pipeline knows about
MWT is added if necessary
otherwise, no care is taken to make sure prerequisites are followed...
some of the annotators, such as depparse, will check, but others
will fail in some unusual manner or just have really bad results
"""
assert any([isinstance(doc, str), isinstance(doc, list),
isinstance(doc, Document)]), 'input should be either str, list or Document'
# empty bulk process
if isinstance(doc, list) and len(doc) == 0:
return []
# determine whether we are in bulk processing mode for multiple documents
bulk=(isinstance(doc, list) and len(doc) > 0 and isinstance(doc[0], Document))
# various options to limit the processors used by this pipeline action
if processors is None:
processors = PIPELINE_NAMES
elif not isinstance(processors, (str, list, tuple, set)):
raise ValueError("Cannot process {} as a list of processors to run".format(type(processors)))
else:
if isinstance(processors, str):
processors = {x for x in processors.split(",")}
else:
processors = set(processors)
if TOKENIZE in processors and MWT in self.processors and MWT not in processors:
logger.debug("Requested processors for pipeline did not have mwt, but pipeline needs mwt, so mwt is added")
processors.add(MWT)
processors = [x for x in PIPELINE_NAMES if x in processors]
for processor_name in processors:
if self.processors.get(processor_name):
process = self.processors[processor_name].bulk_process if bulk else self.processors[processor_name].process
doc = process(doc)
return doc
def process_conllu(self, doc, ignore_gapping=True, processors=None):
""" Convenience method: treat the doc as a conllu text, convert it, and process it accordingly """
if processors is None:
processors = set(self.processors.keys())
if TOKENIZE in processors:
processors.remove(TOKENIZE)
if MWT in processors:
processors.remove(MWT)
doc = CoNLL.conll2doc(input_str=doc, ignore_gapping=ignore_gapping)
return self.process(doc, processors=processors)
def bulk_process(self, docs, *args, **kwargs):
"""
Run the pipeline in bulk processing mode
Expects a list of str or a list of Docs
"""
# Wrap each text as a Document unless it is already such a document
docs = [doc if isinstance(doc, Document) else Document([], text=doc) for doc in docs]
return self.process(docs, *args, **kwargs)
def stream(self, docs, batch_size=50, *args, **kwargs):
"""
Go through an iterator of documents in batches, yield processed documents
sentence indices will be counted across the entire iterator
"""
if not isinstance(docs, collections.abc.Iterator):
docs = iter(docs)
def next_batch():
batch = []
for _ in range(batch_size):
try:
next_doc = next(docs)
batch.append(next_doc)
except StopIteration:
return batch
return batch
sentence_start_index = 0
batch = next_batch()
while batch:
batch = self.bulk_process(batch, *args, **kwargs)
for doc in batch:
doc.reindex_sentences(sentence_start_index)
sentence_start_index += len(doc.sentences)
yield doc
batch = next_batch()
def __str__(self):
"""
Assemble the processors in order to make a simple description of the pipeline
"""
processors = ["%s=%s" % (x, str(self.processors[x])) for x in PIPELINE_NAMES if x in self.processors]
return "" % ", ".join(processors)
def __call__(self, doc, processors=None):
return self.process(doc, processors)
def main():
# TODO: can add a bunch more arguments
parser = argparse.ArgumentParser()
parser.add_argument('--lang', type=str, default='en', help='Language of the pipeline to use')
parser.add_argument('--input_file', type=str, required=True, help='Input file to read')
parser.add_argument('--processors', type=str, default='tokenize,pos,lemma,depparse', help='Processors to use')
parser.add_argument('--package', type=str, default='default', help='Which package to use')
parser.add_argument('--tokenize_no_ssplit', default=False, action='store_true', help="Don't ssplit")
parser.add_argument('--tokenize_pretokenized', default=False, action='store_true', help="Text is pretokenized")
args, extra_args = parser.parse_known_args()
try:
doc = CoNLL.conll2doc(args.input_file)
extra_args = {
"tokenize_pretokenized": True
}
except CoNLLError:
logger.debug("Input file %s does not appear to be a conllu file. Will read it as a text file")
with open(args.input_file, encoding="utf-8") as fin:
doc = fin.read()
extra_args = {}
extra_args['package'] = args.package
if args.tokenize_no_ssplit:
extra_args['tokenize_no_ssplit'] = True
if args.tokenize_pretokenized:
extra_args['tokenize_pretokenized'] = True
pipe = Pipeline(args.lang, processors=args.processors, **extra_args)
doc = pipe(doc)
print("{:C}".format(doc))
if __name__ == '__main__':
main()
================================================
FILE: stanza/pipeline/coref_processor.py
================================================
"""
Processor that attaches coref annotations to a document
"""
from stanza.models.common.utils import misc_to_space_after
from stanza.models.coref.coref_chain import CorefMention, CorefChain
from stanza.models.common.doc import Word
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
import torch
def extract_text(document, sent_id, start_word, end_word):
sentence = document.sentences[sent_id]
tokens = []
# the coref model indexes the words from 0,
# whereas the ids we are looking at on the tokens start from 1
# here we will switch to ID space
start_word = start_word + 1
end_word = end_word + 1
# For each position between start and end word:
# If a word is part of an MWT, and the entire token
# is inside the range, we use that Token's text for that span
# This will let us easily handle words which are split into pieces
# Otherwise, we only take the text of the word itself
next_idx = start_word
while next_idx < end_word:
word = sentence.words[next_idx-1]
parent_token = word.parent
if isinstance(parent_token.id, int) or len(parent_token.id) == 1:
tokens.append(parent_token)
next_idx += 1
elif parent_token.id[0] >= start_word and parent_token.id[1] < end_word:
tokens.append(parent_token)
next_idx = parent_token.id[1] + 1
else:
tokens.append(word)
next_idx += 1
# We use the SpaceAfter or SpacesAfter attribute on each Word or Token
# we chose in the above loop to separate the text pieces
text = []
for token in tokens:
text.append(token.text)
text.append(misc_to_space_after(token.misc))
# the last token space_after will be discarded
# so that we don't have stray WS at the end of the mention text
text = text[:-1]
return "".join(text)
@register_processor(COREF)
class CorefProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([COREF])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE])
def _set_up_model(self, config, pipeline, device):
try:
from stanza.models.coref.model import CorefModel
except ImportError:
raise ImportError("Please install the transformers and peft libraries before using coref! Try `pip install -e .[transformers]`.")
# set up model
# currently, the model has everything packaged in it
# (except its config)
# TODO: separate any pretrains if possible
# TODO: add device parameter to the load mechanism
config_update = {'log_norms': config.get('log_norms', False),
'device': device}
model = CorefModel.load_model(path=config['model_path'],
ignore={"bert_optimizer", "general_optimizer",
"bert_scheduler", "general_scheduler"},
config_update=config_update,
foundation_cache=pipeline.foundation_cache)
if config.get('batch_size', None):
model.config.a_scoring_batch_size = int(config['batch_size'])
model.training = False
self._model = model
# coref_use_zeros=False will turn off creating new nodes and attaching mentions to those zero nodes
self._use_zeros = config.get('use_zeros', True)
if isinstance(self._use_zeros, str):
self._use_zeros = self._use_zeros.lower() != 'false'
def process(self, document):
sentences = document.sentences
cased_words = []
sent_ids = []
word_pos = []
speaker = []
for sent_idx, sentence in enumerate(sentences):
for word_idx, word in enumerate(sentence.words):
cased_words.append(word.text)
sent_ids.append(sent_idx)
word_pos.append(word_idx)
if sentence.speaker:
speaker.append(sentence.speaker)
else:
speaker.append("_")
coref_input = {
"document_id": "wb_doc_1",
"cased_words": cased_words,
"sent_id": sent_ids,
"speaker": speaker,
}
coref_input = self._model.build_doc(coref_input)
results = self._model.run(coref_input)
# Handle zero anaphora - zero_scores is always predicted
zero_nodes_created = self._handle_zero_anaphora(document, results, sent_ids, word_pos)
clusters = []
for cluster_idx, span_cluster in enumerate(results.span_clusters):
if len(span_cluster) == 0:
continue
span_cluster = sorted(span_cluster)
for span in span_cluster:
# check there are no sentence crossings before
# manipulating the spans, since we will expect it to
# be this way for multiple usages of the spans
sent_id = sent_ids[span[0]]
if sent_ids[span[1]-1] != sent_id:
raise ValueError("The coref model predicted a span that crossed two sentences! Please send this example to us on our github")
# treat the longest span as the representative
# break ties using the first one
# IF there is the POS processor, and it adds upos tags
# to the sentence, ties are broken first by maximum
# number of UPOS and then earliest in the document
max_len = 0
best_span = None
max_propn = 0
for span_idx, span in enumerate(span_cluster):
word_idx = results.word_clusters[cluster_idx][span_idx]
is_zero = zero_nodes_created.get((cluster_idx, word_idx))
if is_zero:
continue
sent_id = sent_ids[span[0]]
sentence = sentences[sent_id]
start_word = word_pos[span[0]]
# fiddle -1 / +1 so as to avoid problems with coref
# clusters that end at exactly the end of a document
end_word = word_pos[span[1]-1] + 1
# very UD specific test for most number of proper nouns in a mention
# will do nothing if POS is not active (they will all be None)
num_propn = sum(word.pos == 'PROPN' for word in sentence.words[start_word:end_word])
if ((span[1] - span[0] > max_len) or
span[1] - span[0] == max_len and num_propn > max_propn):
max_len = span[1] - span[0]
best_span = span_idx
max_propn = num_propn
mentions = []
for span_idx, span in enumerate(span_cluster):
word_idx = results.word_clusters[cluster_idx][span_idx]
is_zero = zero_nodes_created.get((cluster_idx, word_idx))
if is_zero:
(sent_id, zero_word_id) = is_zero
# if the word id is a tuple, it will be attached
# to the zero
mentions.append(
CorefMention(
sent_id,
zero_word_id,
zero_word_id
)
)
else:
sent_id = sent_ids[span[0]]
start_word = word_pos[span[0]]
end_word = word_pos[span[1]-1] + 1
mentions.append(CorefMention(sent_id, start_word, end_word))
# if we ended up with no best span, then our "representative text"
# is just underscore
if best_span is not None:
representative = mentions[best_span]
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
else:
representative_text = "_"
chain = CorefChain(len(clusters), mentions, representative_text, best_span)
clusters.append(chain)
document.coref = clusters
return document
def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
"""Handle zero anaphora by creating zero nodes and updating coreference clusters."""
if results.zero_scores is None or results.word_clusters is None:
return {}
if not self._use_zeros:
return {}
zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores
is_zero = []
# Flatten word_clusters to get the word indices that correspond to zero_scores
cluster_word_ids = []
cluster_mapping = {}
counter = 0
for indx, cluster in enumerate(results.word_clusters):
for _ in range(len(cluster)):
cluster_mapping[counter] = indx
counter += 1
cluster_word_ids.extend(cluster)
# Find indices where zero_scores > 0
zero_indices = (zero_scores > 0.0).nonzero()
# this dict maps (cluster_id, word_id) to (cluster_id, start, end)
# which overrides span_clusters
zero_to_coref = {}
for zero_idx in zero_indices:
zero_idx = zero_idx.item()
if zero_idx >= len(cluster_word_ids):
continue
word_idx = cluster_word_ids[zero_idx]
sent_id = sent_ids[word_idx]
word_id = word_pos[word_idx]
# Create zero node - attach BEFORE the current word
# This means the zero node comes after word_id-1 but before word_id
zero_word_id = (
word_id,
len(document.sentences[sent_id]._empty_words)+1
) # attach after word_id-1, before word_id
zero_word = Word(document.sentences[sent_id], {
"text": "_",
"lemma": "_",
"id": zero_word_id
})
document.sentences[sent_id]._empty_words.append(zero_word)
# Track this zero node for adding to coreference clusters
cluster_idx = cluster_mapping[zero_idx]
zero_to_coref[(cluster_idx, word_idx)] = (
sent_id, zero_word_id
)
return zero_to_coref
================================================
FILE: stanza/pipeline/demo/README.md
================================================
## Interactive Demo for Stanza
### Requirements
stanza, flask
### Run the demo locally
1. Make sure you know how to disable your browser's CORS rule. For Chrome, [this extension](https://mybrowseraddon.com/access-control-allow-origin.html) works pretty well.
2. From this directory, start the Stanza demo server
```bash
export FLASK_APP=demo_server.py
flask run
```
3. In `stanza-brat.js`, uncomment the line at the top that declares `serverAddress` and point it to where your flask is serving the demo server (usually `http://localhost:5000`)
4. Open `stanza-brat.html` in your browser (with CORS disabled) and enjoy!
### Common issues
Make sure you have the models corresponding to the language you want to test out locally before submitting requests to the server! (Models can be obtained by `import stanza; stanza.download()`.
================================================
FILE: stanza/pipeline/demo/__init__.py
================================================
================================================
FILE: stanza/pipeline/demo/demo_server.py
================================================
from flask import Flask, request, abort
import json
import stanza
import os
app = Flask(__name__, static_url_path='', static_folder=os.path.abspath(os.path.dirname(__file__)))
pipelineCache = dict()
def get_file(path):
res = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
print(res)
return res
@app.route('/')
@app.route('/static/fonts/')
def static_file(path):
if path in ['stanza-brat.css', 'stanza-brat.js', 'stanza-parseviewer.js', 'loading.gif',
'favicon.png', 'stanza-logo.png',
'Astloch-Bold.ttf', 'Liberation_Sans-Regular.ttf', 'PT_Sans-Caption-Web-Regular.ttf']:
return app.send_static_file(path)
elif path in 'index.html':
return app.send_static_file('stanza-brat.html')
else:
abort(403)
@app.route('/', methods=['GET'])
def index():
return static_file('index.html')
@app.route('/', methods=['POST'])
def annotate():
global pipelineCache
properties = request.args.get('properties', '')
lang = request.args.get('pipelineLanguage', '')
text = list(request.form.keys())[0]
if lang not in pipelineCache:
pipelineCache[lang] = stanza.Pipeline(lang=lang, use_gpu=False)
res = pipelineCache[lang](text)
annotated_sentences = []
for sentence in res.sentences:
tokens = []
deps = []
for word in sentence.words:
tokens.append({'index': word.id, 'word': word.text, 'lemma': word.lemma, 'pos': word.xpos, 'upos': word.upos, 'feats': word.feats, 'ner': word.parent.ner if word.parent.ner is None or word.parent.ner == 'O' else word.parent.ner[2:]})
deps.append({'dep': word.deprel, 'governor': word.head, 'governorGloss': sentence.words[word.head-1].text,
'dependent': word.id, 'dependentGloss': word.text})
annotated_sentences.append({'basicDependencies': deps, 'tokens': tokens})
if hasattr(sentence, 'constituency') and sentence.constituency is not None:
annotated_sentences[-1]['parse'] = str(sentence.constituency)
return json.dumps({'sentences': annotated_sentences})
def create_app():
return app
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8080)
================================================
FILE: stanza/pipeline/demo/stanza-brat.css
================================================
.red {
color:#990000
}
#wrap {
min-height: 100%;
height: auto;
margin: 0 auto -6ex;
padding: 0 0 6ex;
}
.pattern_tab {
margin: 1ex;
}
.pattern_brat {
margin-top: 1ex;
}
.label {
color: #777777;
font-size: small;
}
.footer {
bottom: 0;
width: 100%;
/* Set the fixed height of the footer here */
height: 5ex;
padding-top: 1ex;
margin-top: 1ex;
background-color: #f5f5f5;
}
.corenlp_error {
margin-top: 2ex;
}
/* Styling for parse graph */
.node rect {
stroke: #333;
fill: #fff;
}
.parse-RULE rect {
fill: #C0D9AF;
}
.parse-TERMINAL rect {
stroke: #333;
fill: #EEE8AA;
}
.node.highlighted {
stroke: #ffff00;
}
.edgePath path {
stroke: #333;
fill: #333;
stroke-width: 1.5px;
}
.parse-EDGE path {
stroke: DarkGray;
fill: DarkGray;
stroke-width: 1.5px;
}
.logo {
font-family: "Lato", "Gill Sans MT", "Gill Sans", "Helvetica", "Arial", sans-serif;
font-style: italic;
}
================================================
FILE: stanza/pipeline/demo/stanza-brat.html
================================================
================================================
FILE: stanza/pipeline/demo/stanza-brat.js
================================================
// Takes Stanford CoreNLP JSON output (var data = ... in data.js)
// and uses brat to render everything.
//var serverAddress = 'http://localhost:5000';
// Load Brat libraries
var bratLocation = 'https://nlp.stanford.edu/js/brat/';
head.js(
// External libraries
bratLocation + '/client/lib/jquery.svg.min.js',
bratLocation + '/client/lib/jquery.svgdom.min.js',
// brat helper modules
bratLocation + '/client/src/configuration.js',
bratLocation + '/client/src/util.js',
bratLocation + '/client/src/annotation_log.js',
bratLocation + '/client/lib/webfont.js',
// brat modules
bratLocation + '/client/src/dispatcher.js',
bratLocation + '/client/src/url_monitor.js',
bratLocation + '/client/src/visualizer.js',
// parse viewer
'./stanza-parseviewer.js'
);
// Uses Dagre (https://github.com/cpettitt/dagre) for constinuency parse
// visualization. It works better than the brat visualization.
var useDagre = true;
var currentQuery = 'The quick brown fox jumped over the lazy dog.';
var currentSentences = '';
var currentText = '';
// ----------------------------------------------------------------------------
// HELPERS
// ----------------------------------------------------------------------------
/**
* Add the startsWith function to the String class
*/
if (typeof String.prototype.startsWith !== 'function') {
// see below for better implementation!
String.prototype.startsWith = function (str){
return this.indexOf(str) === 0;
};
}
function isInt(value) {
return !isNaN(value) && (function(x) { return (x | 0) === x; })(parseFloat(value))
}
/**
* A reverse map of PTB tokens to their original gloss
*/
var tokensMap = {
'-LRB-': '(',
'-RRB-': ')',
'-LSB-': '[',
'-RSB-': ']',
'-LCB-': '{',
'-RCB-': '}',
'``': '"',
'\'\'': '"',
};
/**
* A mapping from part of speech tag to the associated
* visualization color
*/
function posColor(posTag) {
if (posTag === null) {
return '#E3E3E3';
} else if (posTag.startsWith('N')) {
return '#A4BCED';
} else if (posTag.startsWith('V') || posTag.startsWith('M')) {
return '#ADF6A2';
} else if (posTag.startsWith('P')) {
return '#CCDAF6';
} else if (posTag.startsWith('I')) {
return '#FFE8BE';
} else if (posTag.startsWith('R') || posTag.startsWith('W')) {
return '#FFFDA8';
} else if (posTag.startsWith('D') || posTag === 'CD') {
return '#CCADF6';
} else if (posTag.startsWith('J')) {
return '#FFFDA8';
} else if (posTag.startsWith('T')) {
return '#FFE8BE';
} else if (posTag.startsWith('E') || posTag.startsWith('S')) {
return '#E4CBF6';
} else if (posTag.startsWith('CC')) {
return '#FFFFFF';
} else if (posTag === 'LS' || posTag === 'FW') {
return '#FFFFFF';
} else {
return '#E3E3E3';
}
}
/**
* A mapping from part of speech tag to the associated
* visualization color
*/
function uposColor(posTag) {
if (posTag === null) {
return '#E3E3E3';
} else if (posTag === 'NOUN' || posTag === 'PROPN') {
return '#A4BCED';
} else if (posTag.startsWith('V') || posTag === 'AUX') {
return '#ADF6A2';
} else if (posTag === 'PART') {
return '#CCDAF6';
} else if (posTag === 'ADP') {
return '#FFE8BE';
} else if (posTag === 'ADV' || posTag.startsWith('PRON')) {
return '#FFFDA8';
} else if (posTag === 'NUM' || posTag === 'DET') {
return '#CCADF6';
} else if (posTag === 'ADJ') {
return '#FFFDA8';
} else if (posTag.startsWith('E') || posTag.startsWith('S')) {
return '#E4CBF6';
} else if (posTag.startsWith('CC')) {
return '#FFFFFF';
} else if (posTag === 'X' || posTag === 'FW') {
return '#FFFFFF';
} else {
return '#E3E3E3';
}
}
/**
* A mapping from named entity tag to the associated
* visualization color
*/
function nerColor(nerTag) {
if (nerTag === null) {
return '#E3E3E3';
} else if (nerTag === 'PERSON' || nerTag === 'PER') {
return '#FFCCAA';
} else if (nerTag === 'ORGANIZATION' || nerTag === 'ORG') {
return '#8FB2FF';
} else if (nerTag === 'MISC') {
return '#F1F447';
} else if (nerTag === 'LOCATION' || nerTag == 'LOC') {
return '#95DFFF';
} else if (nerTag === 'DATE' || nerTag === 'TIME' || nerTag === 'SET') {
return '#9AFFE6';
} else if (nerTag === 'MONEY') {
return '#FFFFFF';
} else if (nerTag === 'PERCENT') {
return '#FFA22B';
} else {
return '#E3E3E3';
}
}
/**
* A mapping from sentiment value to the associated
* visualization color
*/
function sentimentColor(sentiment) {
if (sentiment === "VERY POSITIVE") {
return '#00FF00';
} else if (sentiment === "POSITIVE") {
return '#7FFF00';
} else if (sentiment === "NEUTRAL") {
return '#FFFF00';
} else if (sentiment === "NEGATIVE") {
return '#FF7F00';
} else if (sentiment === "VERY NEGATIVE") {
return '#FF0000';
} else {
return '#E3E3E3';
}
}
/**
* Get a list of annotators, from the annotator option input.
*/
function annotators() {
var annotators = "tokenize,ssplit";
$('#annotators').find('option:selected').each(function () {
annotators += "," + $(this).val();
});
return annotators;
}
/**
* Get the input date
*/
function date() {
function f(n) {
return n < 10 ? '0' + n : n;
}
var date = new Date();
var M = date.getMonth() + 1;
var D = date.getDate();
var Y = date.getFullYear();
var h = date.getHours();
var m = date.getMinutes();
var s = date.getSeconds();
return "" + Y + "-" + f(M) + "-" + f(D) + "T" + f(h) + ':' + f(m) + ':' + f(s);
}
//-----------------------------------------------------------------------------
// Constituency parser
//-----------------------------------------------------------------------------
function ConstituencyParseProcessor() {
var parenthesize = function (input, list) {
if (list === undefined) {
return parenthesize(input, []);
} else {
var token = input.shift();
if (token === undefined) {
return list.pop();
} else if (token === "(") {
list.push(parenthesize(input, []));
return parenthesize(input, list);
} else if (token === ")") {
return list;
} else {
return parenthesize(input, list.concat(token));
}
}
};
var toTree = function (list) {
if (list.length === 2 && typeof list[1] === 'string') {
return {label: list[0], text: list[1], isTerminal: true};
} else if (list.length >= 2) {
var label = list.shift();
var node = {label: label};
var rest = list.map(function (x) {
var t = toTree(x);
if (typeof t === 'object') {
t.parent = node;
}
return t;
});
node.children = rest;
return node;
} else {
return list;
}
};
var indexTree = function (tree, tokens, index) {
index = index || 0;
if (tree.isTerminal) {
tree.token = tokens[index];
tree.tokenIndex = index;
tree.tokenStart = index;
tree.tokenEnd = index + 1;
return index + 1;
} else if (tree.children) {
tree.tokenStart = index;
for (var i = 0; i < tree.children.length; i++) {
var child = tree.children[i];
index = indexTree(child, tokens, index);
}
tree.tokenEnd = index;
}
return index;
};
var tokenize = function (input) {
return input.split('"')
.map(function (x, i) {
if (i % 2 === 0) { // not in string
return x.replace(/\(/g, ' ( ')
.replace(/\)/g, ' ) ');
} else { // in string
return x.replace(/ /g, "!whitespace!");
}
})
.join('"')
.trim()
.split(/\s+/)
.map(function (x) {
return x.replace(/!whitespace!/g, " ");
});
};
var convertParseStringToTree = function (input, tokens) {
var p = parenthesize(tokenize(input));
if (Array.isArray(p)) {
var tree = toTree(p);
// Correlate tree with tokens
indexTree(tree, tokens);
return tree;
}
};
this.process = function(annotation) {
for (var i = 0; i < annotation.sentences.length; i++) {
var s = annotation.sentences[i];
if (s.parse) {
s.parseTree = convertParseStringToTree(s.parse, s.tokens);
}
}
}
}
// ----------------------------------------------------------------------------
// RENDER
// ----------------------------------------------------------------------------
/**
* Render a given JSON data structure
*/
function render(data, reverse) {
// Tweak arguments
if (typeof reverse !== 'boolean') {
reverse = false;
}
// Error checks
if (typeof data.sentences === 'undefined') { return; }
/**
* Register an entity type (a tag) for Brat
*/
var entityTypesSet = {};
var entityTypes = [];
function addEntityType(name, type, coarseType) {
if (typeof coarseType === "undefined") {
coarseType = type;
}
// Don't add duplicates
if (entityTypesSet[type]) return;
entityTypesSet[type] = true;
// Get the color of the entity type
color = '#ffccaa';
if (name === 'POS') {
color = posColor(type);
} else if (name === 'UPOS') {
color = uposColor(type);
} else if (name === 'NER') {
color = nerColor(coarseType);
} else if (name === 'NNER') {
color = nerColor(coarseType);
} else if (name === 'COREF') {
color = '#FFE000';
} else if (name === 'ENTITY') {
color = posColor('NN');
} else if (name === 'RELATION') {
color = posColor('VB');
} else if (name === 'LEMMA') {
color = '#FFFFFF';
} else if (name === 'SENTIMENT') {
color = sentimentColor(type);
} else if (name === 'LINK') {
color = '#FFFFFF';
} else if (name === 'KBP_ENTITY') {
color = '#FFFFFF';
}
// Register the type
entityTypes.push({
type: type,
labels : [type],
bgColor: color,
borderColor: 'darken'
});
}
/**
* Register a relation type (an arc) for Brat
*/
var relationTypesSet = {};
var relationTypes = [];
function addRelationType(type, symmetricEdge) {
// Prevent adding duplicates
if (relationTypesSet[type]) return;
relationTypesSet[type] = true;
// Default arguments
if (typeof symmetricEdge === 'undefined') { symmetricEdge = false; }
// Add the type
relationTypes.push({
type: type,
labels: [type],
dashArray: (symmetricEdge ? '3,3' : undefined),
arrowHead: (symmetricEdge ? 'none' : undefined),
});
}
//
// Construct text of annotation
//
currentText = []; // GLOBAL
currentSentences = data.sentences; // GLOBAL
data.sentences.forEach(function(sentence) {
for (var i = 0; i < sentence.tokens.length; ++i) {
var token = sentence.tokens[i];
var word = token.word;
if (!(typeof tokensMap[word] === "undefined")) {
word = tokensMap[word];
}
if (i > 0) { currentText.push(' '); }
token.characterOffsetBegin = currentText.length;
for (var j = 0; j < word.length; ++j) {
currentText.push(word[j]);
}
token.characterOffsetEnd = currentText.length;
}
currentText.push('\n');
});
currentText = currentText.join('');
//
// Shared variables
// These are what we'll render in BRAT
//
// (pos)
var posEntities = [];
// (upos)
var uposEntities = [];
// (lemma)
var lemmaEntities = [];
// (ner)
var nerEntities = [];
var nerEntitiesNormalized = [];
// (sentiment)
var sentimentEntities = [];
// (entitylinking)
var linkEntities = [];
// (dependencies)
var depsRelations = [];
var deps2Relations = [];
// (openie)
var openieEntities = [];
var openieEntitiesSet = {};
var openieRelations = [];
var openieRelationsSet = {};
// (kbp)
var kbpEntities = [];
var kbpEntitiesSet = [];
var kbpRelations = [];
var kbpRelationsSet = [];
var cparseEntities = [];
var cparseRelations = [];
//
// Loop over sentences.
// This fills in the variables above.
//
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
var sentence = data.sentences[sentI];
var index = sentence.index;
var tokens = sentence.tokens;
var deps = sentence['basicDependencies'];
var deps2 = sentence['enhancedPlusPlusDependencies'];
var parseTree = sentence['parseTree'];
// POS tags
/**
* Generate a POS tagged token id
*/
function posID(i) {
return 'POS_' + sentI + '_' + i;
}
var noXPOS = true;
if (tokens.length > 0 && typeof tokens[0].pos !== 'undefined' && tokens[0].pos !== null) {
noXPOS = false;
for (var i = 0; i < tokens.length; i++) {
var token = tokens[i];
var pos = token.pos;
var begin = parseInt(token.characterOffsetBegin);
var end = parseInt(token.characterOffsetEnd);
addEntityType('POS', pos);
posEntities.push([posID(i), pos, [[begin, end]]]);
}
}
// Universal POS tags
/**
* Generate a POS tagged token id
*/
function uposID(i) {
return 'UPOS_' + sentI + '_' + i;
}
if (tokens.length > 0 && typeof tokens[0].upos !== 'undefined') {
for (var i = 0; i < tokens.length; i++) {
var token = tokens[i];
var upos = token.upos;
var begin = parseInt(token.characterOffsetBegin);
var end = parseInt(token.characterOffsetEnd);
addEntityType('UPOS', upos);
uposEntities.push([uposID(i), upos, [[begin, end]]]);
}
}
// Constituency parse
// Carries the same assumption as NER
if (parseTree && !useDagre) {
var parseEntities = [];
var parseRels = [];
function processParseTree(tree, index) {
tree.visitIndex = index;
index++;
if (tree.isTerminal) {
parseEntities[tree.visitIndex] = uposEntities[tree.tokenIndex];
return index;
} else if (tree.children) {
addEntityType('PARSENODE', tree.label);
parseEntities[tree.visitIndex] =
['PARSENODE_' + sentI + '_' + tree.visitIndex, tree.label,
[[tokens[tree.tokenStart].characterOffsetBegin, tokens[tree.tokenEnd-1].characterOffsetEnd]]];
var parentEnt = parseEntities[tree.visitIndex];
for (var i = 0; i < tree.children.length; i++) {
var child = tree.children[i];
index = processParseTree(child, index);
var childEnt = parseEntities[child.visitIndex];
addRelationType('pc');
parseRels.push(['PARSEEDGE_' + sentI + '_' + parseRels.length, 'pc', [['parent', parentEnt[0]], ['child', childEnt[0]]]]);
}
}
return index;
}
processParseTree(parseTree, 0);
cparseEntities = cparseEntities.concat(cparseEntities, parseEntities);
cparseRelations = cparseRelations.concat(parseRels);
}
// Dependency parsing
/**
* Process a dependency tree from JSON to Brat relations
*/
function processDeps(name, deps) {
var relations = [];
// Format: [${ID}, ${TYPE}, [[${ARGNAME}, ${TARGET}], [${ARGNAME}, ${TARGET}]]]
for (var i = 0; i < deps.length; i++) {
var dep = deps[i];
var governor = dep.governor - 1;
var dependent = dep.dependent - 1;
if (governor == -1) continue;
addRelationType(dep.dep);
relations.push([name + '_' + sentI + '_' + i, dep.dep, [['governor', uposID(governor)], ['dependent', uposID(dependent)]]]);
}
return relations;
}
// Actually add the dependencies
if (typeof deps !== 'undefined') {
depsRelations = depsRelations.concat(processDeps('dep', deps));
}
if (typeof deps2 !== 'undefined') {
deps2Relations = deps2Relations.concat(processDeps('dep2', deps2));
}
// Lemmas
if (tokens.length > 0 && typeof tokens[0].lemma !== 'undefined') {
for (var i = 0; i < tokens.length; i++) {
var token = tokens[i];
var lemma = token.lemma;
var begin = parseInt(token.characterOffsetBegin);
var end = parseInt(token.characterOffsetEnd);
addEntityType('LEMMA', lemma);
lemmaEntities.push(['LEMMA_' + sentI + '_' + i, lemma, [[begin, end]]]);
}
}
// NER tags
// Assumption: contiguous occurrence of one non-O is a single entity
var noNER = true;
if (tokens.some(function(token) { return token.ner; })) {
noNER = false;
for (var i = 0; i < tokens.length; i++) {
var ner = tokens[i].ner || 'O';
var normalizedNER = tokens[i].normalizedNER;
if (typeof normalizedNER === "undefined") {
normalizedNER = ner;
}
if (ner == 'O') continue;
var j = i;
while (j < tokens.length - 1 && tokens[j+1].ner == ner) j++;
addEntityType('NER', ner, ner);
nerEntities.push(['NER_' + sentI + '_' + i, ner, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
if (ner != normalizedNER) {
addEntityType('NNER', normalizedNER, ner);
nerEntities.push(['NNER_' + sentI + '_' + i, normalizedNER, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
}
i = j;
}
}
// Sentiment
if (typeof sentence.sentiment !== "undefined") {
var sentiment = sentence.sentiment.toUpperCase().replace("VERY", "VERY ");
addEntityType('SENTIMENT', sentiment);
sentimentEntities.push(['SENTIMENT_' + sentI, sentiment,
[[tokens[0].characterOffsetBegin, tokens[tokens.length - 1].characterOffsetEnd]]]);
}
// Entity Links
// Carries the same assumption as NER
if (tokens.length > 0) {
for (var i = 0; i < tokens.length; i++) {
var link = tokens[i].entitylink;
if (link == 'O' || typeof link === 'undefined') continue;
var j = i;
while (j < tokens.length - 1 && tokens[j+1].entitylink == link) j++;
addEntityType('LINK', link);
linkEntities.push(['LINK_' + sentI + '_' + i, link, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
i = j;
}
}
// Open IE
// Helper Functions
function openieID(span) {
return 'OPENIEENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
}
function addEntity(span, role) {
// Don't add duplicate entities
if (openieEntitiesSet[[sentI, span, role]]) return;
openieEntitiesSet[[sentI, span, role]] = true;
// Add the entity
openieEntities.push([openieID(span), role,
[[tokens[span[0]].characterOffsetBegin,
tokens[span[1] - 1].characterOffsetEnd ]] ]);
}
function addRelation(gov, dep, role) {
// Don't add duplicate relations
if (openieRelationsSet[[sentI, gov, dep, role]]) return;
openieRelationsSet[[sentI, gov, dep, role]] = true;
// Add the relation
openieRelations.push(['OPENIESUBJREL_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
role,
[['governor', openieID(gov)],
['dependent', openieID(dep)] ] ]);
}
// Render OpenIE
if (typeof sentence.openie !== 'undefined') {
// Register the entities + relations we'll need
addEntityType('ENTITY', 'Entity');
addEntityType('RELATION', 'Relation');
addRelationType('subject');
addRelationType('object');
// Loop over triples
for (var i = 0; i < sentence.openie.length; ++i) {
var subjectSpan = sentence.openie[i].subjectSpan;
var relationSpan = sentence.openie[i].relationSpan;
var objectSpan = sentence.openie[i].objectSpan;
if (parseInt(relationSpan[0]) < 0 || parseInt(relationSpan[1]) < 0) {
continue; // This is a phantom relation
}
var begin = parseInt(token.characterOffsetBegin);
// Add the entities
addEntity(subjectSpan, 'Entity');
addEntity(relationSpan, 'Relation');
addEntity(objectSpan, 'Entity');
// Add the relations
addRelation(relationSpan, subjectSpan, 'subject');
addRelation(relationSpan, objectSpan, 'object');
}
} // End OpenIE block
//
// KBP
//
// Helper Functions
function kbpEntity(span) {
return 'KBPENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
}
function addKBPEntity(span, role) {
// Don't add duplicate entities
if (kbpEntitiesSet[[sentI, span, role]]) return;
kbpEntitiesSet[[sentI, span, role]] = true;
// Add the entity
kbpEntities.push([kbpEntity(span), role,
[[tokens[span[0]].characterOffsetBegin,
tokens[span[1] - 1].characterOffsetEnd ]] ]);
}
function addKBPRelation(gov, dep, role) {
// Don't add duplicate relations
if (kbpRelationsSet[[sentI, gov, dep, role]]) return;
kbpRelationsSet[[sentI, gov, dep, role]] = true;
// Add the relation
kbpRelations.push(['KBPRELATION_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
role,
[['governor', kbpEntity(gov)],
['dependent', kbpEntity(dep)] ] ]);
}
if (typeof sentence.kbp !== 'undefined') {
// Register the entities + relations we'll need
addRelationType('subject');
addRelationType('object');
// Loop over triples
for (var i = 0; i < sentence.kbp.length; ++i) {
var subjectSpan = sentence.kbp[i].subjectSpan;
var subjectLink = 'Entity';
for (var k = subjectSpan[0]; k < subjectSpan[1]; ++k) {
if (subjectLink == 'Entity' &&
typeof tokens[k] !== 'undefined' &&
tokens[k].entitylink != 'O' &&
typeof tokens[k].entitylink !== 'undefined') {
subjectLink = tokens[k].entitylink
}
}
addEntityType('KBP_ENTITY', subjectLink);
var objectSpan = sentence.kbp[i].objectSpan;
var objectLink = 'Entity';
for (var k = objectSpan[0]; k < objectSpan[1]; ++k) {
if (objectLink == 'Entity' &&
typeof tokens[k] !== 'undefined' &&
tokens[k].entitylink != 'O' &&
typeof tokens[k].entitylink !== 'undefined') {
objectLink = tokens[k].entitylink
}
}
addEntityType('KBP_ENTITY', objectLink);
var relation = sentence.kbp[i].relation;
var begin = parseInt(token.characterOffsetBegin);
// Add the entities
addKBPEntity(subjectSpan, subjectLink);
addKBPEntity(objectSpan, objectLink);
// Add the relations
addKBPRelation(subjectSpan, objectSpan, relation);
}
} // End KBP block
} // End sentence loop
//
// Coreference
//
var corefEntities = [];
var corefRelations = [];
if (typeof data.corefs !== 'undefined') {
addRelationType('coref', true);
addEntityType('COREF', 'Mention');
var clusters = Object.keys(data.corefs);
clusters.forEach( function (clusterId) {
var chain = data.corefs[clusterId];
if (chain.length > 1) {
for (var i = 0; i < chain.length; ++i) {
var mention = chain[i];
var id = 'COREF' + mention.id;
var tokens = data.sentences[mention.sentNum - 1].tokens;
corefEntities.push([id, 'Mention',
[[tokens[mention.startIndex - 1].characterOffsetBegin,
tokens[mention.endIndex - 2].characterOffsetEnd ]] ]);
if (i > 0) {
var lastId = 'COREF' + chain[i - 1].id;
corefRelations.push(['COREF' + chain[i-1].id + '_' + chain[i].id,
'coref',
[['governor', lastId],
['dependent', id] ] ]);
}
}
}
});
} // End coreference block
//
// Actually render the elements
//
/**
* Helper function to render a given set of entities / relations
* to a Div, if it exists.
*/
function embed(container, entities, relations, reverse) {
var text = currentText;
if (reverse) {
var length = currentText.length;
for (var i = 0; i < entities.length; ++i) {
var offsets = entities[i][2][0];
var tmp = length - offsets[0];
offsets[0] = length - offsets[1];
offsets[1] = tmp;
}
text = text.split("").reverse().join("");
}
if ($('#' + container).length > 0) {
Util.embed(container,
{entity_types: entityTypes, relation_types: relationTypes},
{text: text, entities: entities, relations: relations}
);
}
}
function reportna(container, text) {
$('#' + container).text(text);
}
// Render each annotation
head.ready(function() {
if (!noXPOS) {
embed('pos', posEntities);
} else {
reportna('pos', 'XPOS is not available for this language at this time.')
}
embed('upos', uposEntities);
embed('lemma', lemmaEntities);
if (!noNER) {
embed('ner', nerEntities);
} else {
reportna('ner', 'NER is not available for this language at this time.')
}
embed('entities', linkEntities);
if (!useDagre) {
embed('parse', cparseEntities, cparseRelations);
}
embed('deps', uposEntities, depsRelations);
embed('deps2', posEntities, deps2Relations);
embed('coref', corefEntities, corefRelations);
embed('openie', openieEntities, openieRelations);
embed('kbp', kbpEntities, kbpRelations);
embed('sentiment', sentimentEntities);
// Constituency parse
// Uses d3 and dagre-d3 (not brat)
if ($('#parse').length > 0 && useDagre) {
var parseViewer = new ParseViewer({ selector: '#parse' });
parseViewer.showAnnotation(data);
$('#parse').addClass('svg').css('display', 'block');
}
});
} // End render function
/**
* Render a TokensRegex response
*/
function renderTokensregex(data) {
/**
* Register an entity type (a tag) for Brat
*/
var entityTypesSet = {};
var entityTypes = [];
function addEntityType(type, color) {
// Don't add duplicates
if (entityTypesSet[type]) return;
entityTypesSet[type] = true;
// Set the color
if (typeof color === 'undefined') {
color = '#ADF6A2';
}
// Register the type
entityTypes.push({
type: type,
labels : [type],
bgColor: color,
borderColor: 'darken'
});
}
var entities = [];
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
var tokens = currentSentences[sentI].tokens;
for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
var match = data.sentences[sentI][matchI];
// Add groups
for (groupName in match) {
if (groupName.startsWith("$") || isInt(groupName)) {
addEntityType(groupName, '#FFFDA8');
var begin = parseInt(tokens[match[groupName].begin].characterOffsetBegin);
var end = parseInt(tokens[match[groupName].end - 1].characterOffsetEnd);
entities.push(['TOK_' + sentI + '_' + matchI + '_' + groupName,
groupName,
[[begin, end]]]);
}
}
// Add match
addEntityType('match', '#ADF6A2');
var begin = parseInt(tokens[match.begin].characterOffsetBegin);
var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
entities.push(['TOK_' + sentI + '_' + matchI + '_match',
'match',
[[begin, end]]]);
}
}
Util.embed('tokensregex',
{entity_types: entityTypes, relation_types: []},
{text: currentText, entities: entities, relations: []}
);
} // END renderTokensregex()
/**
* Render a Semgrex response
*/
function renderSemgrex(data) {
/**
* Register an entity type (a tag) for Brat
*/
var entityTypesSet = {};
var entityTypes = [];
function addEntityType(type, color) {
// Don't add duplicates
if (entityTypesSet[type]) return;
entityTypesSet[type] = true;
// Set the color
if (typeof color === 'undefined') {
color = '#ADF6A2';
}
// Register the type
entityTypes.push({
type: type,
labels : [type],
bgColor: color,
borderColor: 'darken'
});
}
relationTypes = [{
type: 'semgrex',
labels: ['-'],
dashArray: '3,3',
arrowHead: 'none',
}];
var entities = [];
var relations = [];
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
var tokens = currentSentences[sentI].tokens;
for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
var match = data.sentences[sentI][matchI];
// Add match
addEntityType('match', '#ADF6A2');
var begin = parseInt(tokens[match.begin].characterOffsetBegin);
var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
entities.push(['SEM_' + sentI + '_' + matchI + '_match',
'match',
[[begin, end]]]);
// Add groups
for (groupName in match) {
if (groupName.startsWith("$") || isInt(groupName)) {
// (add node)
group = match[groupName];
groupName = groupName.substring(1);
addEntityType(groupName, '#FFFDA8');
var begin = parseInt(tokens[group.begin].characterOffsetBegin);
var end = parseInt(tokens[group.end - 1].characterOffsetEnd);
entities.push(['SEM_' + sentI + '_' + matchI + '_' + groupName,
groupName,
[[begin, end]]]);
// (add relation)
relations.push(['SEMGREX_' + sentI + '_' + matchI + '_' + groupName,
'semgrex',
[['governor', 'SEM_' + sentI + '_' + matchI + '_match'],
['dependent', 'SEM_' + sentI + '_' + matchI + '_' + groupName] ] ]);
}
}
}
}
Util.embed('semgrex',
{entity_types: entityTypes, relation_types: relationTypes},
{text: currentText, entities: entities, relations: relations}
);
} // END renderSemgrex
/**
* Render a Tregex response
*/
function renderTregex(data) {
$('#tregex').empty();
$('#tregex').append('' + JSON.stringify(data, null, 4) + ' ');
} // END renderTregex
// ----------------------------------------------------------------------------
// MAIN
// ----------------------------------------------------------------------------
/**
* MAIN()
*
* The entry point of the page
*/
$(document).ready(function() {
// Some initial styling
$('.chosen-select').chosen();
$('.chosen-container').css('width', '100%');
// Language-specific changes
$('#language').on('change', function() {
$('#text').attr('dir', '');
if ($('#language').val() === 'ar' ||
$('#language').val() === 'fa' ||
$('#language').val() === 'he' ||
$('#language').val() === 'ur') {
$('#text').attr('dir', 'rtl');
}
if ($('#language').val() === 'ar') {
$('#text').attr('placeholder', 'على سبيل المثال، قفز الثعلب البني السريع فوق الكلب الكسول.');
} else if ($('#language').val() === 'en') {
$('#text').attr('placeholder', 'e.g., The quick brown fox jumped over the lazy dog.');
} else if ($('#language').val() === 'zh') {
$('#text').attr('placeholder', '例如,快速的棕色狐狸跳过了懒惰的狗。');
} else if ($('#language').val() === 'zh-Hant') {
$('#text').attr('placeholder', '例如,快速的棕色狐狸跳過了懶惰的狗。');
} else if ($('#language').val() === 'fr') {
$('#text').attr('placeholder', 'Par exemple, le renard brun rapide a sauté sur le chien paresseux.');
} else if ($('#language').val() === 'de') {
$('#text').attr('placeholder', 'Z. B. sprang der schnelle braune Fuchs über den faulen Hund.');
} else if ($('#language').val() === 'es') {
$('#text').attr('placeholder', 'Por ejemplo, el rápido zorro marrón saltó sobre el perro perezoso.');
} else if ($('#language').val() === 'ur') {
$('#text').attr('placeholder', 'میرا نام علی ہے');
} else {
$('#text').attr('placeholder', 'Unknown language for placeholder query: ' + $('#language').val());
}
});
// Submit on shift-enter
$('#text').keydown(function (event) {
if (event.keyCode == 13) {
if(event.shiftKey){
event.preventDefault(); // don't register the enter key when pressed
return false;
}
}
});
$('#text').keyup(function (event) {
if (event.keyCode == 13) {
if(event.shiftKey){
$('#submit').click(); // submit the form when the enter key is released
event.stopPropagation();
return false;
}
}
});
// Submit on clicking the 'submit' button
$('#submit').click(function() {
// Get the text to annotate
currentQuery = $('#text').val();
if (currentQuery.trim() == '') {
if ($('#language').val() === 'ar') {
currentQuery = 'قفز الثعلب البني السريع فوق الكلب الكسول.';
} else if ($('#language').val() === 'en') {
currentQuery = 'The quick brown fox jumped over the lazy dog.';
} else if ($('#language').val() === 'zh') {
currentQuery = '快速的棕色狐狸跳过了懒惰的狗。';
} else if ($('#language').val() === 'zh-Hant') {
currentQuery = '快速的棕色狐狸跳過了懶惰的狗。';
} else if ($('#language').val() === 'fr') {
currentQuery = 'Le renard brun rapide a sauté sur le chien paresseux.';
} else if ($('#language').val() === 'de') {
currentQuery = 'Sprang der schnelle braune Fuchs über den faulen Hund.';
} else if ($('#language').val() === 'es') {
currentQuery = 'El rápido zorro marrón saltó sobre el perro perezoso.';
} else if ($('#language').val() === 'ur') {
currentQuery = 'میرا نام علی ہے';
} else {
currentQuery = 'Unknown language for default query: ' + $('#language').val();
}
$('#text').val(currentQuery);
}
// Update the UI
$('#submit').prop('disabled', true);
$('#annotations').hide();
$('#patterns_row').hide();
$('#loading').show();
// Run query
$.ajax({
type: 'POST',
url: serverAddress + '?properties=' + encodeURIComponent(
'{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
data: encodeURIComponent(currentQuery), //jQuery doesn't automatically URI encode strings
dataType: 'json',
contentType: "application/x-www-form-urlencoded;charset=UTF-8",
responseType: "application/json",
success: function(data) {
$('#submit').prop('disabled', false);
if (typeof data === 'undefined' || data.sentences == undefined) {
alert("Failed to reach server!");
} else {
// Process constituency parse
var constituencyParseProcessor = new ConstituencyParseProcessor();
constituencyParseProcessor.process(data);
// Empty divs
$('#annotations').empty();
// Re-render divs
function createAnnotationDiv(id, annotator, selector, label) {
// (make sure we requested that element)
if (annotators().split(",").indexOf(annotator) < 0) {
return;
}
// (make sure the data contains that element)
ok = false;
if (typeof data[selector] !== 'undefined') {
ok = true;
} else if (typeof data.sentences !== 'undefined' && data.sentences.length > 0) {
if (typeof data.sentences[0][selector] !== 'undefined') {
ok = true;
} else if (typeof data.sentences[0].tokens != 'undefined' && data.sentences[0].tokens.length > 0) {
// (make sure the annotator select is in at least one of the tokens of any sentence)
ok = data.sentences.some(function(sentence) {
return sentence.tokens.some(function(token) {
return typeof token[selector] !== 'undefined';
});
});
}
}
// (render the element)
if (ok) {
$('#annotations').append('' + label + ':
');
}
}
// (create the divs)
// div id annotator field_in_data label
createAnnotationDiv('pos', 'pos', 'pos', 'Part-of-Speech (XPOS)' );
createAnnotationDiv('upos', 'upos', 'upos', 'Universal Part-of-Speech');
createAnnotationDiv('lemma', 'lemma', 'lemma', 'Lemmas' );
createAnnotationDiv('ner', 'ner', 'ner', 'Named Entity Recognition');
createAnnotationDiv('deps', 'depparse', 'basicDependencies', 'Universal Dependencies' );
createAnnotationDiv('parse', 'parse', 'parseTree', 'Constituency Parse' );
//createAnnotationDiv('deps2', 'depparse', 'enhancedPlusPlusDependencies', 'Enhanced++ Dependencies' );
//createAnnotationDiv('openie', 'openie', 'openie', 'Open IE' );
//createAnnotationDiv('coref', 'coref', 'corefs', 'Coreference' );
//createAnnotationDiv('entities', 'entitylink', 'entitylink', 'Wikidict Entities' );
//createAnnotationDiv('kbp', 'kbp', 'kbp', 'KBP Relations' );
//createAnnotationDiv('sentiment','sentiment', 'sentiment', 'Sentiment' );
// Update UI
$('#loading').hide();
$('.corenlp_error').remove(); // Clear error messages
$('#annotations').show();
// Render
var reverse = ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur');
render(data, reverse);
// Render patterns
//$('#annotations').append('CoreNLP Tools: '); // TODO(gabor) a strange place to add this header to
//$('#patterns_row').show();
}
},
error: function(data) {
DATA = data;
var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('corenlp_error').attr('role', 'alert')
var button = $('× ');
var message = $(' ').text(data.responseText);
button.appendTo(alertDiv);
message.appendTo(alertDiv);
$('#loading').hide();
alertDiv.appendTo($('#errors'));
$('#submit').prop('disabled', false);
}
});
event.preventDefault();
event.stopPropagation();
return false;
});
// Support passing parameters on page launch, via window.location.hash parameters.
// Example: http://localhost:9000/#text=foo%20bar&annotators=pos,lemma,ner
(function() {
var rawParams = window.location.hash.slice(1).split("&");
var params = {};
rawParams.forEach(function(paramKV) {
paramKV = paramKV.split("=");
if (paramKV.length === 2) {
var key = paramKV[0];
var value = paramKV[1];
params[key] = value;
}
});
if (params.text) {
var text = decodeURIComponent(params.text);
$('#text').val(text);
}
if (params.annotators) {
var annotators = params.annotators.split(",");
// De-select everything
$('#annotators').find('option').each(function() {
$(this).prop('selected', false);
});
// Select the specified ones.
annotators.forEach(function(a) {
$('#annotators').find('option[value="'+a+'"]').prop('selected', true);
});
// Refresh Chosen
$('#annotators').trigger('chosen:updated');
}
if (params.text || params.annotators) {
// Finally, let's auto-submit.
$('#submit').click();
}
})();
$('#form_tokensregex').submit( function (e) {
// Don't actually submit the form
e.preventDefault();
// Get text
if ($('#tokensregex_search').val().trim() == '') {
$('#tokensregex_search').val('(?$foxtype [{pos:JJ}]+ ) fox');
}
var pattern = $('#tokensregex_search').val();
// Remove existing annotation
$('#tokensregex').remove();
// Make ajax call
$.ajax({
type: 'POST',
url: serverAddress + '/tokensregex?pattern=' + encodeURIComponent(
pattern.replace("&", "\\&").replace('+', '\\+')) +
'&properties=' + encodeURIComponent(
'{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
data: encodeURIComponent(currentQuery),
success: function(data) {
$('.tokensregex_error').remove(); // Clear error messages
$('
').appendTo($('#div_tokensregex'));
renderTokensregex(data);
},
error: function(data) {
var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tokensregex_error').attr('role', 'alert')
var button = $('× ');
var message = $(' ').text(data.responseText);
button.appendTo(alertDiv);
message.appendTo(alertDiv);
alertDiv.appendTo($('#div_tokensregex'));
}
});
});
$('#form_semgrex').submit( function (e) {
// Don't actually submit the form
e.preventDefault();
// Get text
if ($('#semgrex_search').val().trim() == '') {
$('#semgrex_search').val('{pos:/VB.*/} >nsubj {}=subject >/nmod:.*/ {}=prep_phrase');
}
var pattern = $('#semgrex_search').val();
// Remove existing annotation
$('#semgrex').remove();
// Add missing required annotators
var requiredAnnotators = annotators().split(',');
if (requiredAnnotators.indexOf('depparse') < 0) {
requiredAnnotators.push('depparse');
}
// Make ajax call
$.ajax({
type: 'POST',
url: serverAddress + '/semgrex?pattern=' + encodeURIComponent(
pattern.replace("&", "\\&").replace('+', '\\+')) +
'&properties=' + encodeURIComponent(
'{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
data: encodeURIComponent(currentQuery),
success: function(data) {
$('.semgrex_error').remove(); // Clear error messages
$('
').appendTo($('#div_semgrex'));
renderSemgrex(data);
},
error: function(data) {
var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('semgrex_error').attr('role', 'alert')
var button = $('× ');
var message = $(' ').text(data.responseText);
button.appendTo(alertDiv);
message.appendTo(alertDiv);
alertDiv.appendTo($('#div_semgrex'));
}
});
});
$('#form_tregex').submit( function (e) {
// Don't actually submit the form
e.preventDefault();
// Get text
if ($('#tregex_search').val().trim() == '') {
$('#tregex_search').val('NP < NN=animal');
}
var pattern = $('#tregex_search').val();
// Remove existing annotation
$('#tregex').remove();
// Add missing required annotators
var requiredAnnotators = annotators().split(',');
if (requiredAnnotators.indexOf('parse') < 0) {
requiredAnnotators.push('parse');
}
// Make ajax call
$.ajax({
type: 'POST',
url: serverAddress + '/tregex?pattern=' + encodeURIComponent(
pattern.replace("&", "\\&").replace('+', '\\+')) +
'&properties=' + encodeURIComponent(
'{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
data: encodeURIComponent(currentQuery),
success: function(data) {
$('.tregex_error').remove(); // Clear error messages
$('
').appendTo($('#div_tregex'));
renderTregex(data);
},
error: function(data) {
var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tregex_error').attr('role', 'alert')
var button = $('× ');
var message = $(' ').text(data.responseText);
button.appendTo(alertDiv);
message.appendTo(alertDiv);
alertDiv.appendTo($('#div_tregex'));
}
});
});
});
================================================
FILE: stanza/pipeline/demo/stanza-parseviewer.js
================================================
//'use strict';
//d3 || require('d3');
//var dagreD3 = require('dagre-d3');
//var jquery = require('jquery');
//var $ = jquery;
var ParseViewer = function(params) {
// Container in which the scene template is displayed
this.selector = params.selector;
this.container = $(this.selector);
this.fitToGraph = true;
this.onClickNodeCallback = params.onClickNodeCallback;
this.onHoverNodeCallback = params.onHoverNodeCallback;
this.init();
return this;
};
ParseViewer.MIN_WIDTH = 100;
ParseViewer.MIN_HEIGHT = 100;
ParseViewer.prototype.constructor = ParseViewer;
ParseViewer.prototype.getAutoWidth = function () {
return Math.max(ParseViewer.MIN_WIDTH, this.container.width());
};
ParseViewer.prototype.getAutoHeight = function () {
return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20);
};
ParseViewer.prototype.init = function () {
var canvasWidth = this.getAutoWidth();
var canvasHeight = this.getAutoHeight();
this.parseElem = d3.select(this.selector)
.append('svg')
.attr({'width': canvasWidth, 'height': canvasHeight})
.style({'width': canvasWidth, 'height': canvasHeight});
console.log(this.parseElem);
this.graph = null;
this.graphRendered = false;
this.controls = $('
');
this.container.append(this.controls);
};
var GraphBuilder = function(roots) {
// Create the input graph
this.graph = new dagreD3.graphlib.Graph()
.setGraph({})
.setDefaultEdgeLabel(function () {
return {};
});
this.visitIndex = 0;
//console.log('building graph', roots);
for (var i = 0; i < roots.length; i++) {
this.build(roots[i]);
}
};
GraphBuilder.prototype.build = function(node) {
console.log(node);
// Track my visit index
this.visitIndex++;
node.visitIndex = this.visitIndex;
// Add a node
var nodeData = node; // TODO: replace with semantic data
var nodeLabel = node.label;
var nodeIndex = node.visitIndex;
var nodeClass = 'parse-RULE';
this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
if (node.parent) {
this.graph.setEdge(node.parent.visitIndex, nodeIndex, {
class: 'parse-EDGE'
});
}
if (node.isTerminal) {
this.visitIndex++;
nodeIndex = this.visitIndex;
nodeLabel = node.text;
nodeClass = 'parse-TERMINAL';
this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
this.graph.setEdge(node.visitIndex, nodeIndex, {
class: 'parse-EDGE'
});
} else if (node.children) {
for (var i = 0; i < node.children.length; i++) {
this.build(node.children[i]);
}
}
};
ParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) {
if (this.fitToGraph) {
minWidth = g.graph().width;
minHeight = this.getAutoHeight();
}
adjustGraphPositioning(svg, g, minWidth, minHeight);
};
function adjustGraphPositioning(svg, g, minWidth, minHeight) {
// Resize svg
var newWidth = Math.max(minWidth, g.graph().width);
var newHeight = Math.max(minHeight, g.graph().height + 40);
svg.attr({'width': newWidth, 'height': newHeight});
svg.style({'width': newWidth, 'height': newHeight});
// Center the graph
var svgGroup = svg.select('g');
var xCenterOffset = (svg.attr('width') - g.graph().width) / 2;
svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)');
svg.attr('height', g.graph().height + 40);
svg.style('height', g.graph().height + 40);
}
ParseViewer.prototype.renderGraph = function (svg, g, parse) {
// Create the renderer
var render = new dagreD3.render();
// Run the renderer. This is what draws the final graph.
var svgGroup = svg.select('g');
render(svgGroup, g);
var scope = this;
var nodes = svgGroup.selectAll('g.node');
nodes.on('click',
function (d) {
var v = d;
var node = g.node(v);
if (scope.onClickNodeCallback) {
scope.onClickNodeCallback(node.data);
}
console.log(g.node(v));
}
);
nodes.on('mouseover',
function (d) {
var v = d;
var node = g.node(v);
if (scope.onHoverNodeCallback) {
scope.onHoverNodeCallback(node.data);
}
}
);
this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height'));
this.graphRendered = true;
};
ParseViewer.prototype.showParse = function (root) {
this.showParses([root]);
};
ParseViewer.prototype.showParses = function (roots) {
// Take parse and create a graph
var gb = new GraphBuilder(roots);
var g = gb.graph;
g.nodes().forEach(function (v) {
var node = g.node(v);
// Round the corners of the nodes
node.rx = node.ry = 5;
});
var svg = this.parseElem;
svg.selectAll('*').remove();
var svgGroup = svg.append('g');
this.graph = g;
this.parse = roots;
if (this.container.is(':visible')) {
if (roots.length > 0) {
this.renderGraph(svg, this.graph, this.parse);
}
} else {
this.graphRendered = false;
}
};
ParseViewer.prototype.showAnnotation = function (annotation) {
var parses = [];
for (var i = 0; i < annotation.sentences.length; i++) {
var s = annotation.sentences[i];
if (s && s.parseTree) {
parses.push(s.parseTree);
}
}
this.showParses(parses);
};
ParseViewer.prototype.onResize = function () {
var canvasWidth = this.getAutoWidth();
var canvasHeight = this.getAutoHeight();
var svg = this.parseElem;
// Center the graph
var svgGroup = svg.select('g');
if (svgGroup && this.graph) {
if (!this.graphRendered) {
svg.attr({'width': canvasWidth, 'height': canvasHeight});
svg.style({'width': canvasWidth, 'height': canvasHeight});
this.renderGraph(svg, this.graph, this.parse);
} else {
this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight);
}
} else {
svg.attr({'width': canvasWidth, 'height': canvasHeight});
svg.style({'width': canvasWidth, 'height': canvasHeight});
}
};
// Exports
//module.exports = ParseViewer;
================================================
FILE: stanza/pipeline/depparse_processor.py
================================================
"""
Processor for performing dependency parsing
"""
import torch
from stanza.models.common import doc
from stanza.models.common.utils import unsort
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.models.depparse.data import DataLoader
from stanza.models.depparse.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
# these imports trigger the "register_variant" decorations
from stanza.pipeline.external.corenlp_converter_depparse import ConverterDepparse
DEFAULT_SEPARATE_BATCH=150
@register_processor(name=DEPPARSE)
class DepparseProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([DEPPARSE])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE, POS, LEMMA])
def __init__(self, config, pipeline, device):
self._pretagged = None
super().__init__(config, pipeline, device)
def _set_up_requires(self):
self._pretagged = self._config.get('pretagged')
if self._pretagged:
self._requires = set()
else:
self._requires = self.__class__.REQUIRES_DEFAULT
def _set_up_model(self, config, pipeline, device):
self._trainer = config.get('trainer')
if self._trainer is not None:
return
self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
self._trainer = Trainer(args=args, pretrain=self.pretrain, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
def get_known_relations(self):
"""
Return a list of relations which this processor can produce
"""
keys = [k for k in self.vocab['deprel']._unit2id.keys() if k not in VOCAB_PREFIX]
return keys
def process(self, document):
if hasattr(self, '_variant'):
return self._variant.process(document)
if any(word.upos is None and word.xpos is None for sentence in document.sentences for word in sentence.words):
raise ValueError("POS not run before depparse!")
try:
batch = DataLoader(document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,
sort_during_eval=self.config.get('sort_during_eval', True),
min_length_to_batch_separately=self.config.get('min_length_to_batch_separately', DEFAULT_SEPARATE_BATCH))
with torch.no_grad():
preds = []
for i, b in enumerate(batch):
preds += self.trainer.predict(b)
if batch.data_orig_idx is not None:
preds = unsort(preds, batch.data_orig_idx)
batch.doc.set((doc.HEAD, doc.DEPREL), [y for x in preds for y in x])
# build dependencies based on predictions
for sentence in batch.doc.sentences:
sentence.build_dependencies()
return batch.doc
except RuntimeError as e:
if str(e).startswith("CUDA out of memory. Tried to allocate"):
new_message = str(e) + " ... You may be able to compensate for this by separating long sentences into their own batch with a parameter such as depparse_min_length_to_batch_separately=150 or by limiting the overall batch size with depparse_batch_size=400."
raise RuntimeError(new_message) from e
else:
raise
================================================
FILE: stanza/pipeline/external/__init__.py
================================================
================================================
FILE: stanza/pipeline/external/corenlp_converter_depparse.py
================================================
"""
A depparse processor which converts constituency trees using CoreNLP
"""
from stanza.pipeline._constants import TOKENIZE, CONSTITUENCY, DEPPARSE
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
from stanza.server.dependency_converter import DependencyConverter
@register_processor_variant(DEPPARSE, 'converter')
class ConverterDepparse(ProcessorVariant):
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE, CONSTITUENCY])
def __init__(self, config):
if config['lang'] != 'en':
raise ValueError("Constituency to dependency converter only works for English")
# TODO: get classpath from config
# TODO: close this when finished?
# a more involved approach would be to turn the Pipeline into
# a context with __enter__ and __exit__
# __exit__ would try to free all resources, although some
# might linger such as GPU allocations
# maybe it isn't worth even trying to clean things up on account of that
self.converter = DependencyConverter(classpath="$CLASSPATH")
self.converter.open_pipe()
def process(self, document):
return self.converter.process(document)
================================================
FILE: stanza/pipeline/external/jieba.py
================================================
"""
Processors related to Jieba in the pipeline.
"""
import re
import warnings
from stanza.models.common import doc
from stanza.pipeline._constants import TOKENIZE
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
def check_jieba():
"""
Import necessary components from Jieba to perform tokenization.
"""
try:
import jieba
except ImportError:
raise ImportError(
"Jieba is used but not installed on your machine. Go to https://pypi.org/project/jieba/ for installation instructions."
)
return True
@register_processor_variant(TOKENIZE, 'jieba')
class JiebaTokenizer(ProcessorVariant):
def __init__(self, config):
""" Construct a Jieba-based tokenizer by loading the Jieba pipeline.
Note that this tokenizer uses regex for sentence segmentation.
"""
if config['lang'] not in ['zh', 'zh-hans', 'zh-hant']:
raise Exception("Jieba tokenizer is currently only allowed in Chinese (simplified or traditional) pipelines.")
# Surpress a DeprecationWarning about pkg_resource from jieba.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning, module="jieba")
check_jieba()
import jieba
self.nlp = jieba
self.no_ssplit = config.get('no_ssplit', False)
def process(self, document):
""" Tokenize a document with the Jieba tokenizer and wrap the results into a Doc object.
"""
if isinstance(document, doc.Document):
text = document.text
else:
text = document
if not isinstance(text, str):
raise Exception("Must supply a string or Stanza Document object to the Jieba tokenizer.")
tokens = self.nlp.cut(text, cut_all=False)
sentences = []
current_sentence = []
offset = 0
for token in tokens:
if re.match(r'\s+', token):
offset += len(token)
continue
token_entry = {
doc.TEXT: token,
doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token)}"
}
current_sentence.append(token_entry)
offset += len(token)
if not self.no_ssplit and token in ['。', '!', '?', '!', '?']:
sentences.append(current_sentence)
current_sentence = []
if len(current_sentence) > 0:
sentences.append(current_sentence)
return doc.Document(sentences, text)
================================================
FILE: stanza/pipeline/external/pythainlp.py
================================================
"""
Processors related to PyThaiNLP in the pipeline.
GitHub Home: https://github.com/PyThaiNLP/pythainlp
"""
from stanza.models.common import doc
from stanza.pipeline._constants import TOKENIZE
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
def check_pythainlp():
"""
Import necessary components from pythainlp to perform tokenization.
"""
try:
import pythainlp
except ImportError:
raise ImportError(
"The pythainlp library is required. "
"Try to install it with `pip install pythainlp`. "
"Go to https://github.com/PyThaiNLP/pythainlp for more information."
)
return True
@register_processor_variant(TOKENIZE, 'pythainlp')
class PyThaiNLPTokenizer(ProcessorVariant):
def __init__(self, config):
""" Construct a PyThaiNLP-based tokenizer.
Note that we always uses the default tokenizer of PyThaiNLP for sentence and word segmentation.
Currently this is a CRF model for sentence segmentation and a dictionary-based model (newmm) for word segmentation.
"""
if config['lang'] != 'th':
raise Exception("PyThaiNLP tokenizer is only allowed in Thai pipeline.")
check_pythainlp()
from pythainlp.tokenize import sent_tokenize as pythai_sent_tokenize
from pythainlp.tokenize import word_tokenize as pythai_word_tokenize
self.pythai_sent_tokenize = pythai_sent_tokenize
self.pythai_word_tokenize = pythai_word_tokenize
self.no_ssplit = config.get('no_ssplit', False)
def process(self, document):
""" Tokenize a document with the PyThaiNLP tokenizer and wrap the results into a Doc object.
"""
if isinstance(document, doc.Document):
text = document.text
else:
text = document
if not isinstance(text, str):
raise Exception("Must supply a string or Stanza Document object to the PyThaiNLP tokenizer.")
sentences = []
current_sentence = []
offset = 0
if self.no_ssplit:
# skip sentence segmentation
sent_strs = [text]
else:
sent_strs = self.pythai_sent_tokenize(text, engine='crfcut')
for sent_str in sent_strs:
for token_str in self.pythai_word_tokenize(sent_str, engine='newmm'):
# by default pythainlp will output whitespace as a token
# we need to skip these tokens to be consistent with other tokenizers
if token_str.isspace():
offset += len(token_str)
continue
# create token entry
token_entry = {
doc.TEXT: token_str,
doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token_str)}"
}
current_sentence.append(token_entry)
offset += len(token_str)
# finish sentence
sentences.append(current_sentence)
current_sentence = []
if len(current_sentence) > 0:
sentences.append(current_sentence)
return doc.Document(sentences, text)
================================================
FILE: stanza/pipeline/external/spacy.py
================================================
"""
Processors related to spaCy in the pipeline.
"""
from stanza.models.common import doc
from stanza.pipeline._constants import TOKENIZE
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
def check_spacy():
"""
Import necessary components from spaCy to perform tokenization.
"""
try:
import spacy
except ImportError:
raise ImportError(
"spaCy is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
)
return True
@register_processor_variant(TOKENIZE, 'spacy')
class SpacyTokenizer(ProcessorVariant):
def __init__(self, config):
""" Construct a spaCy-based tokenizer by loading the spaCy pipeline.
"""
if config['lang'] != 'en':
raise Exception("spaCy tokenizer is currently only allowed in English pipeline.")
try:
import spacy
from spacy.lang.en import English
except ImportError:
raise ImportError(
"spaCy 2.0+ is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
)
# Create a Tokenizer with the default settings for English
# including punctuation rules and exceptions
self.nlp = English()
# by default spacy uses dependency parser to do ssplit
# we need to add a sentencizer for fast rule-based ssplit
if spacy.__version__.startswith("2."):
self.nlp.add_pipe(self.nlp.create_pipe("sentencizer"))
else:
self.nlp.add_pipe("sentencizer")
self.no_ssplit = config.get('no_ssplit', False)
def process(self, document):
""" Tokenize a document with the spaCy tokenizer and wrap the results into a Doc object.
"""
if isinstance(document, doc.Document):
text = document.text
else:
text = document
if not isinstance(text, str):
raise Exception("Must supply a string or Stanza Document object to the spaCy tokenizer.")
spacy_doc = self.nlp(text)
sentences = []
for sent in spacy_doc.sents:
tokens = []
for tok in sent:
token_entry = {
doc.TEXT: tok.text,
doc.MISC: f"{doc.START_CHAR}={tok.idx}|{doc.END_CHAR}={tok.idx+len(tok.text)}"
}
tokens.append(token_entry)
sentences.append(tokens)
# if no_ssplit is set, flatten all the sentences into one sentence
if self.no_ssplit:
sentences = [[t for s in sentences for t in s]]
return doc.Document(sentences, text)
================================================
FILE: stanza/pipeline/external/sudachipy.py
================================================
"""
Processors related to SudachiPy in the pipeline.
GitHub Home: https://github.com/WorksApplications/SudachiPy
"""
import re
from stanza.models.common import doc
from stanza.pipeline._constants import TOKENIZE
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
def check_sudachipy():
"""
Import necessary components from SudachiPy to perform tokenization.
"""
try:
import sudachipy
import sudachidict_core
except ImportError:
raise ImportError(
"Both sudachipy and sudachidict_core libraries are required. "
"Try install them with `pip install sudachipy sudachidict_core`. "
"Go to https://github.com/WorksApplications/SudachiPy for more information."
)
return True
@register_processor_variant(TOKENIZE, 'sudachipy')
class SudachiPyTokenizer(ProcessorVariant):
def __init__(self, config):
""" Construct a SudachiPy-based tokenizer.
Note that this tokenizer uses regex for sentence segmentation.
"""
if config['lang'] != 'ja':
raise Exception("SudachiPy tokenizer is only allowed in Japanese pipelines.")
check_sudachipy()
from sudachipy import tokenizer
from sudachipy import dictionary
self.tokenizer = dictionary.Dictionary().create()
self.no_ssplit = config.get('no_ssplit', False)
def process(self, document):
""" Tokenize a document with the SudachiPy tokenizer and wrap the results into a Doc object.
"""
if isinstance(document, doc.Document):
text = document.text
else:
text = document
if not isinstance(text, str):
raise Exception("Must supply a string or Stanza Document object to the SudachiPy tokenizer.")
# we use the default sudachipy tokenization mode (i.e., mode C)
# more config needs to be added to support other modes
tokens = self.tokenizer.tokenize(text)
sentences = []
current_sentence = []
for token in tokens:
token_text = token.surface()
# by default sudachipy will output whitespace as a token
# we need to skip these tokens to be consistent with other tokenizers
if token_text.isspace():
continue
start = token.begin()
end = token.end()
token_entry = {
doc.TEXT: token_text,
doc.MISC: f"{doc.START_CHAR}={start}|{doc.END_CHAR}={end}"
}
current_sentence.append(token_entry)
if not self.no_ssplit and token_text in ['。', '!', '?', '!', '?']:
sentences.append(current_sentence)
current_sentence = []
if len(current_sentence) > 0:
sentences.append(current_sentence)
return doc.Document(sentences, text)
================================================
FILE: stanza/pipeline/langid_processor.py
================================================
"""
Processor for determining language of text.
"""
import emoji
import re
import stanza
import torch
from stanza.models.common.doc import Document
from stanza.models.langid.model import LangIDBiLSTM
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
@register_processor(name=LANGID)
class LangIDProcessor(UDProcessor):
"""
Class for detecting language of text.
"""
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([LANGID])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([])
# default max sequence length
MAX_SEQ_LENGTH_DEFAULT = 1000
def _set_up_model(self, config, pipeline, device):
batch_size = config.get("batch_size", 64)
self._model = LangIDBiLSTM.load(path=config["model_path"], device=device,
batch_size=batch_size, lang_subset=config.get("lang_subset"))
self._char_index = self._model.char_to_idx
self._clean_text = config.get("clean_text")
def _text_to_tensor(self, docs):
"""
Map list of strings to batch tensor. Assumed all docs are same length.
"""
device = next(self._model.parameters()).device
all_docs = []
for doc in docs:
doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)]
all_docs.append(doc_chars)
return torch.tensor(all_docs, device=device, dtype=torch.long)
def _id_langs(self, batch_tensor):
"""
Identify languages for each sequence in a batch tensor
"""
predictions = self._model.prediction_scores(batch_tensor)
prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions]
return prediction_labels
# regexes for cleaning text
http_regex = re.compile(r"https?:\/\/t\.co/[a-zA-Z0-9]+")
handle_regex = re.compile("@[a-zA-Z0-9_]+")
hashtag_regex = re.compile("#[a-zA-Z]+")
punctuation_regex = re.compile("[!.]+")
all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex]
@staticmethod
def clean_text(text):
"""
Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened
urls, hashtags, handles, and punctuation and emoji.
"""
for regex in LangIDProcessor.all_regexes:
text = regex.sub(" ", text)
text = emoji.emojize(text)
text = emoji.replace_emoji(text, replace=' ')
if text.strip():
text = text.strip()
return text
def _process_list(self, docs):
"""
Identify language of list of strings or Documents
"""
if len(docs) == 0:
# TO DO: what standard do we want for bad input, such as empty list?
# TO DO: more handling of bad input
return
if isinstance(docs[0], str):
docs = [Document([], text) for text in docs]
docs_by_length = {}
for doc in docs:
text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text
doc_length = len(text)
if doc_length not in docs_by_length:
docs_by_length[doc_length] = []
docs_by_length[doc_length].append((doc, text))
for doc_length in docs_by_length:
inputs = [doc[1] for doc in docs_by_length[doc_length]]
predictions = self._id_langs(self._text_to_tensor(inputs))
for doc, lang in zip(docs_by_length[doc_length], predictions):
doc[0].lang = lang
return docs
def process(self, doc):
"""
Handle single str or Document
"""
wrapped_doc = [doc]
return self._process_list(wrapped_doc)[0]
def bulk_process(self, docs):
"""
Handle list of strings or Documents
"""
return self._process_list(docs)
================================================
FILE: stanza/pipeline/lemma_processor.py
================================================
"""
Processor for performing lemmatization
"""
from itertools import compress
import torch
from stanza.models.common import doc
from stanza.models.lemma.data import DataLoader
from stanza.models.lemma.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
WORD_TAGS = [doc.TEXT, doc.UPOS]
@register_processor(name=LEMMA)
class LemmaProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([LEMMA])
# set of processor requirements for this processor
# pos will be added later for non-identity lemmatizerx
REQUIRES_DEFAULT = set([TOKENIZE])
# default batch size
DEFAULT_BATCH_SIZE = 5000
def __init__(self, config, pipeline, device):
# run lemmatizer in identity mode
self._use_identity = None
self._pretagged = None
super().__init__(config, pipeline, device)
@property
def use_identity(self):
return self._use_identity
def _set_up_model(self, config, pipeline, device):
if config.get('use_identity') in ['True', True]:
self._use_identity = True
self._config = config
self.config['batch_size'] = LemmaProcessor.DEFAULT_BATCH_SIZE
else:
# the lemmatizer only looks at one word when making
# decisions, not the surrounding context
# therefore, we can save some time by remembering what
# we did the last time we saw any given word,pos
# since a long running program will remember everything
# (unless we go back and make it smarter)
# we make this an option, not the default
# TODO: need to update the cache to skip the contextual lemmatizer
self.store_results = config.get('store_results', False)
self._use_identity = False
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
lemma_classifier_args = dict(args)
lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)
def _set_up_requires(self):
self._pretagged = self._config.get('pretagged', None)
if self._pretagged:
self._requires = set()
elif self.config.get('pos') and not self.use_identity:
self._requires = LemmaProcessor.REQUIRES_DEFAULT.union(set([POS]))
else:
self._requires = LemmaProcessor.REQUIRES_DEFAULT
def process(self, document):
if not self.use_identity:
batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
else:
batch = DataLoader(document, self.config['batch_size'], self.config, evaluation=True, conll_only=True)
if self.use_identity:
preds = [word.text for sent in batch.doc.sentences for word in sent.words]
elif self.config.get('dict_only', False):
preds = self.trainer.predict_dict(batch.doc.get([doc.TEXT, doc.UPOS]))
else:
if self.config.get('ensemble_dict', False):
# skip the seq2seq model when we can
skip = self.trainer.skip_seq2seq(batch.doc.get([doc.TEXT, doc.UPOS]))
# although there is no explicit use of caseless or lemma_caseless in this processor,
# it shows up in the config which gets passed to the DataLoader,
# possibly affecting its results
seq2seq_batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab,
evaluation=True, skip=skip, expand_unk_vocab=True)
else:
seq2seq_batch = batch
with torch.no_grad():
preds = []
edits = []
for i, b in enumerate(seq2seq_batch):
ps, es = self.trainer.predict(b, self.config['beam_size'], seq2seq_batch.vocab)
preds += ps
if es is not None:
edits += es
if self.config.get('ensemble_dict', False):
word_tags = batch.doc.get(WORD_TAGS)
words = [x[0] for x in word_tags]
preds = self.trainer.postprocess([x for x, y in zip(words, skip) if not y], preds, edits=edits)
if self.store_results:
new_word_tags = compress(word_tags, map(lambda x: not x, skip))
new_predictions = [(x[0], x[1], y) for x, y in zip(new_word_tags, preds)]
self.trainer.train_dict(new_predictions, update_word_dict=False)
# expand seq2seq predictions to the same size as all words
i = 0
preds1 = []
for s in skip:
if s:
preds1.append('')
else:
preds1.append(preds[i])
i += 1
preds = self.trainer.ensemble(word_tags, preds1)
else:
preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)
if self.trainer.has_contextual_lemmatizers():
preds = self.trainer.update_contextual_preds(batch.doc, preds)
# map empty string lemmas to '_'
preds = [max([(len(x), x), (0, '_')])[1] for x in preds]
batch.doc.set([doc.LEMMA], preds)
return batch.doc
================================================
FILE: stanza/pipeline/morphseg_processor.py
================================================
from stanza.pipeline.core import UnsupportedProcessorError
from stanza.pipeline.processor import UDProcessor, register_processor
from stanza.pipeline._constants import MORPHSEG, TOKENIZE
@register_processor(name=MORPHSEG)
class MorphSegProcessor(UDProcessor):
PROVIDES_DEFAULT = {MORPHSEG}
REQUIRES_DEFAULT = {TOKENIZE}
def __init__(self, config, pipeline, device):
self._config = config
self._pipeline = pipeline
self._set_up_requires()
self._set_up_provides()
self._set_up_model(config, pipeline, device)
def _set_up_model(self, config, pipeline, device):
try:
from morphseg import MorphemeSegmenter
except ImportError:
raise ImportError(
"morphseg is required for morpheme segmentation. "
"Install it with: pip install morphseg"
)
lang = config.get('lang', 'en')
model_path = config.get('morphseg_model_path', None)
if model_path:
self._segmenter = MorphemeSegmenter(
lang=lang,
load_pretrained=False,
model_filepath=model_path,
is_local=True
)
else:
self._segmenter = MorphemeSegmenter(
lang=lang,
load_pretrained=True
)
if self._segmenter.sequence_labeller is None:
raise UnsupportedProcessorError("morphseg", lang)
def process(self, document):
# Collect all words from all sentences
all_words = []
word_mapping = [] # Track which sentence and word index each prediction belongs to
for sent_idx, sent in enumerate(document.sentences):
if not sent.words:
continue
for word_idx, word in enumerate(sent.words):
all_words.append(word.text)
word_mapping.append((sent_idx, word_idx))
if not all_words:
return document
# Prepare input for morphseg (it expects normalized, lowercased character lists)
word_char_lists = [
list(self._segmenter.normalize_for_morphology(word))
for word in all_words
]
# Batch predict using the internal sequence_labeller
predictions = self._segmenter.sequence_labeller.predict(sources=word_char_lists)
# Extract segmentations from predictions
from morphseg.training.oracle import rules2sent
segmentations = [
rules2sent(
source=[align_pos.symbol for align_pos in pred.alignment],
actions=pred.prediction
).split(' @@') # Split by morphseg's default delimiter
for pred in predictions
]
# Assign segmentations back to words
for (sent_idx, word_idx), seg in zip(word_mapping, segmentations):
document.sentences[sent_idx].words[word_idx].morphemes = seg
return document
================================================
FILE: stanza/pipeline/multilingual.py
================================================
"""
Class for running multilingual pipelines
"""
from collections import OrderedDict
import copy
import logging
from typing import Union
from stanza.models.common.doc import Document
from stanza.models.common.utils import default_device
from stanza.pipeline.core import Pipeline, DownloadMethod
from stanza.pipeline._constants import *
from stanza.resources.common import DEFAULT_MODEL_DIR, get_language_resources, load_resources_json
logger = logging.getLogger('stanza')
class MultilingualPipeline:
"""
Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that
language.
You can specify options to individual language pipelines with the lang_configs field.
For example, if you want English pipelines to have NER, but want to turn that off for French, you can do:
lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse,ner"},
"fr": {"processors": "tokenize,pos,lemma,depparse"}}
pipeline = MultilingualPipeline(lang_configs=lang_configs)
You can also pass in a defaultdict created in such a way that it provides default parameters for each language.
For example, in order to only get tokenization for each language:
(remembering that the Pipeline will automagically add MWT to a language which uses MWT):
from collections import defaultdict
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
pipeline = MultilingualPipeline(lang_configs=lang_configs)
download_method can be set as in Pipeline to turn off downloading
of the .json config or turn off downloading of everything
"""
def __init__(self,
model_dir: str = DEFAULT_MODEL_DIR,
lang_id_config: dict = None,
lang_configs: dict = None,
ld_batch_size: int = 64,
max_cache_size: int = 10,
use_gpu: bool = None,
restrict: bool = False,
device: str = None,
download_method: DownloadMethod = DownloadMethod.DOWNLOAD_RESOURCES,
# python 3.6 compatibility - maybe want to update to 3.7 at some point
processors: Union[str, list] = None,
):
# set up configs and cache for various language pipelines
self.model_dir = model_dir
self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config)
self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs)
self.max_cache_size = max_cache_size
# OrderedDict so we can use it as a LRU cache
# most recent Pipeline goes to the end, pop the oldest one
# when we run out of space
self.pipeline_cache = OrderedDict()
if processors is None:
self.default_processors = None
elif isinstance(processors, str):
self.default_processors = [x.strip() for x in processors.split(",")]
else:
self.default_processors = list(processors)
self.download_method = download_method
if 'download_method' not in self.lang_id_config:
self.lang_id_config['download_method'] = self.download_method
# if lang is not in any of the lang_configs, update them to
# include the lang parameter. otherwise, the default language
# will always be used...
for lang in self.lang_configs:
if 'lang' not in self.lang_configs[lang]:
self.lang_configs[lang]['lang'] = lang
if restrict and 'langid_lang_subset' not in self.lang_id_config:
known_langs = sorted(self.lang_configs.keys())
if known_langs == 0:
logger.warning("MultilingualPipeline asked to restrict to lang_configs, but lang_configs was empty. Ignoring...")
else:
logger.debug("Restricting MultilingualPipeline to %s", known_langs)
self.lang_id_config['langid_lang_subset'] = known_langs
# set use_gpu
if device is None:
if use_gpu is None or use_gpu == True:
device = default_device()
else:
device = 'cpu'
self.device = device
# build language id pipeline
self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid",
device=self.device, **self.lang_id_config)
# load the resources so that we can refer to it later when building a new pipeline
# note that it was either downloaded or not based on download_method when building the lang_id_pipeline
self.resources = load_resources_json(self.model_dir)
def _update_pipeline_cache(self, lang):
"""
Do any necessary updates to the pipeline cache for this language. This includes building a new
pipeline for the lang, and possibly clearing out a language with the old last access date.
"""
# update request history
if lang in self.pipeline_cache:
self.pipeline_cache.move_to_end(lang, last=True)
# update language configs
# try/except to allow for a defaultdict
try:
lang_config = self.lang_configs[lang]
except KeyError:
lang_config = {'lang': lang}
self.lang_configs[lang] = lang_config
# if a defaultdict is passed in, the defaultdict might not contain 'lang'
# so even though we tried adding 'lang' in the constructor, we'll check again here
if 'lang' not in lang_config:
lang_config['lang'] = lang
if 'download_method' not in lang_config:
lang_config['download_method'] = self.download_method
if 'processors' not in lang_config:
if self.default_processors:
lang_resources = get_language_resources(self.resources, lang)
lang_processors = [x for x in self.default_processors if x in lang_resources]
if lang_processors != self.default_processors:
logger.info("Not all requested processors %s available for %s. Loading %s instead", self.default_processors, lang, lang_processors)
lang_config['processors'] = ",".join(lang_processors)
if 'device' not in lang_config:
lang_config['device'] = self.device
# update pipeline cache
if lang not in self.pipeline_cache:
logger.debug("Loading unknown language in MultilingualPipeline: %s", lang)
# clear least recently used lang from pipeline cache
if len(self.pipeline_cache) == self.max_cache_size:
self.pipeline_cache.popitem(last=False)
self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang])
def process(self, doc):
"""
Run language detection on a string, a Document, or a list of either, route to language specific pipeline
"""
# only return a list if given a list
singleton_input = not isinstance(doc, list)
if singleton_input:
docs = [doc]
else:
docs = doc
if docs and isinstance(docs[0], str):
docs = [Document([], text=text) for text in docs]
# run language identification
docs_w_langid = self.lang_id_pipeline.process(docs)
# create language specific batches, store global idx with each doc
lang_batches = {}
for doc_idx, doc in enumerate(docs_w_langid):
logger.debug("Language for document %d: %s", doc_idx, doc.lang)
if doc.lang not in lang_batches:
lang_batches[doc.lang] = []
lang_batches[doc.lang].append(doc)
# run through each language, submit a batch to the language specific pipeline
for lang in lang_batches.keys():
self._update_pipeline_cache(lang)
self.pipeline_cache[lang](lang_batches[lang])
# only return a list if given a list
if singleton_input:
return docs_w_langid[0]
else:
return docs_w_langid
def __call__(self, doc):
doc = self.process(doc)
return doc
================================================
FILE: stanza/pipeline/mwt_processor.py
================================================
"""
Processor for performing multi-word-token expansion
"""
import io
import torch
from stanza.models.mwt.data import DataLoader
from stanza.models.mwt.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
@register_processor(MWT)
class MWTProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([MWT])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE])
def _set_up_model(self, config, pipeline, device):
self._trainer = Trainer(model_file=config['model_path'], device=device)
def build_batch(self, document):
return DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
def process(self, document):
batch = self.build_batch(document)
# process the rest
expansions = batch.doc.get_mwt_expansions(evaluation=True)
if len(batch) > 0:
# decide trainer type and run eval
if self.config['dict_only']:
preds = self.trainer.predict_dict(expansions)
else:
with torch.no_grad():
preds = []
for i, b in enumerate(batch.to_loader()):
preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)
if self.config.get('ensemble_dict', False):
preds = self.trainer.ensemble(expansions, preds)
else:
# skip eval if dev data does not exist
preds = []
batch.doc.set_mwt_expansions(preds, process_manual_expanded=False)
return batch.doc
def bulk_process(self, docs):
"""
MWT processor counts some statistics on the individual docs, so we need to separately redo those stats
"""
docs = super().bulk_process(docs)
for doc in docs:
doc._count_words()
return docs
================================================
FILE: stanza/pipeline/ner_processor.py
================================================
"""
Processor for performing named entity tagging.
"""
import torch
import logging
from stanza.models.common import doc
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
from stanza.models.common.utils import unsort
from stanza.models.ner.data import DataLoader
from stanza.models.ner.trainer import Trainer
from stanza.models.ner.utils import merge_tags
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
logger = logging.getLogger('stanza')
@register_processor(name=NER)
class NERProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([NER])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE])
def _get_dependencies(self, config, dep_name):
dependencies = config.get(dep_name, None)
if dependencies is not None:
dependencies = dependencies.split(";")
dependencies = [x if x else None for x in dependencies]
else:
dependencies = [x.get(dep_name) for x in config.get('dependencies', [])]
return dependencies
def _set_up_model(self, config, pipeline, device):
# set up trainer
model_paths = config.get('model_path')
if isinstance(model_paths, str):
model_paths = model_paths.split(";")
charlm_forward_files = self._get_dependencies(config, 'forward_charlm_path')
charlm_backward_files = self._get_dependencies(config, 'backward_charlm_path')
pretrain_files = self._get_dependencies(config, 'pretrain_path')
# allow predict_tagset to be specified as an int
# (which only applies to the first model)
# or as a string ";" separated list of ints
self._predict_tagset = {}
predict_tagset = config.get('predict_tagset', None)
if predict_tagset:
if isinstance(predict_tagset, int):
self._predict_tagset[0] = predict_tagset
else:
predict_tagset = predict_tagset.split(";")
for piece_idx, piece in enumerate(predict_tagset):
if piece:
self._predict_tagset[piece_idx] = int(piece)
self.trainers = []
for (model_path, pretrain_path, charlm_forward, charlm_backward) in zip(model_paths, pretrain_files, charlm_forward_files, charlm_backward_files):
logger.debug("Loading %s with pretrain %s, forward charlm %s, backward charlm %s", model_path, pretrain_path, charlm_forward, charlm_backward)
pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None
args = {'charlm_forward_file': charlm_forward,
'charlm_backward_file': charlm_backward}
predict_tagset = self._predict_tagset.get(len(self.trainers), None)
if predict_tagset is not None:
args['predict_tagset'] = predict_tagset
try:
trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)
except ForwardCharlmNotFoundError as e:
raise ForwardCharlmNotFoundError("Could not find the forward charlm %s. Please specify the correct path with ner_forward_charlm_path" % e.filename, e.filename) from None
except BackwardCharlmNotFoundError as e:
raise BackwardCharlmNotFoundError("Could not find the backward charlm %s. Please specify the correct path with ner_backward_charlm_path" % e.filename, e.filename) from None
self.trainers.append(trainer)
self._trainer = self.trainers[0]
self.model_paths = model_paths
def _set_up_final_config(self, config):
""" Finalize the configurations for this processor, based off of values from a UD model. """
# set configurations from loaded model
if len(self.trainers) == 0:
raise RuntimeError("Somehow there are no models loaded!")
self._vocab = self.trainers[0].vocab
self.configs = []
for trainer in self.trainers:
loaded_args = trainer.args
# filter out unneeded args from model
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
loaded_args.update(config)
self.configs.append(loaded_args)
self._config = self.configs[0]
def __str__(self):
return "NERProcessor(%s)" % ";".join(self.model_paths)
def mark_inactive(self):
""" Drop memory intensive resources if keeping this processor around for reasons other than running it. """
super().mark_inactive()
self.trainers = None
def process(self, document):
with torch.no_grad():
all_preds = []
for trainer, config in zip(self.trainers, self.configs):
# set up a eval-only data loader and skip tag preprocessing
batch = DataLoader(document, config['batch_size'], config, vocab=trainer.vocab, evaluation=True, preprocess_tags=False, bert_tokenizer=trainer.model.bert_tokenizer)
preds = []
for i, b in enumerate(batch):
preds += trainer.predict(b)
all_preds.append(preds)
# for each sentence, gather a list of predictions
# merge those predictions into a single list
# earlier models will have precedence
preds = [merge_tags(*x) for x in zip(*all_preds)]
batch.doc.set([doc.NER], [y for x in preds for y in x], to_token=True)
batch.doc.set([doc.MULTI_NER], [tuple(y) for x in zip(*all_preds) for y in zip(*x)], to_token=True)
# collect entities into document attribute
total = len(batch.doc.build_ents())
logger.debug(f'{total} entities found in document.')
return batch.doc
def bulk_process(self, docs):
"""
NER processor has a collation step after running inference
"""
docs = super().bulk_process(docs)
for doc in docs:
doc.build_ents()
return docs
def get_known_tags(self, model_idx=0):
"""
Return the tags known by this model
Removes the S-, B-, etc, and does not include O
Specify model_idx if the processor has more than one model
"""
return self.trainers[model_idx].get_known_tags()
================================================
FILE: stanza/pipeline/pos_processor.py
================================================
"""
Processor for performing part-of-speech tagging
"""
import torch
from stanza.models.common import doc
from stanza.models.common.utils import unsort
from stanza.models.common.vocab import VOCAB_PREFIX, CompositeVocab
from stanza.models.pos.data import Dataset
from stanza.models.pos.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
@register_processor(name=POS)
class POSProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([POS])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE])
def _set_up_model(self, config, pipeline, device):
# get pretrained word vectors
self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
# set up trainer
self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], device=device, args=args, foundation_cache=pipeline.foundation_cache)
self._tqdm = 'tqdm' in config and config['tqdm']
def __str__(self):
return "POSProcessor(%s)" % self.config['model_path']
def get_known_xpos(self):
"""
Returns the xpos tags known by this model
"""
if isinstance(self.vocab['xpos'], CompositeVocab):
if len(self.vocab['xpos']) == 1:
return [k for k in self.vocab['xpos'][0]._unit2id.keys() if k not in VOCAB_PREFIX]
else:
return {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['xpos']._unit2id.items()}
return [k for k in self.vocab['xpos']._unit2id.keys() if k not in VOCAB_PREFIX]
def is_composite_xpos(self):
"""
Returns if the xpos tags are part of a composite vocab
"""
return isinstance(self.vocab['xpos'], CompositeVocab)
def get_known_upos(self):
"""
Returns the upos tags known by this model
"""
keys = [k for k in self.vocab['upos']._unit2id.keys() if k not in VOCAB_PREFIX]
return keys
def get_known_feats(self):
"""
Returns the features known by this model
"""
values = {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['feats']._unit2id.items()}
return values
def process(self, document):
# currently, POS models are saved w/o the batch_maximum_tokens flag
maximum_tokens = self.config.get('batch_maximum_tokens', 5000)
dataset = Dataset(
document, self.config, self.pretrain, vocab=self.vocab, evaluation=True,
sort_during_eval=True)
batch = iter(dataset.to_length_limited_loader(batch_size=self.config['batch_size'], maximum_tokens=maximum_tokens))
preds = []
idx = []
with torch.no_grad():
if self._tqdm:
batch = tqdm(batch)
for i, b in enumerate(batch):
idx.extend(b[-1])
preds += self.trainer.predict(b)
preds = unsort(preds, idx)
dataset.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])
return dataset.doc
================================================
FILE: stanza/pipeline/processor.py
================================================
"""
Base classes for processors
"""
from abc import ABC, abstractmethod
from stanza.models.common.doc import Document
from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
class ProcessorRequirementsException(Exception):
""" Exception indicating a processor's requirements will not be met """
def __init__(self, processors_list, err_processor, provided_reqs):
self._err_processor = err_processor
# mark the broken processor as inactive, drop resources
self.err_processor.mark_inactive()
self._processors_list = processors_list
self._provided_reqs = provided_reqs
self.build_message()
@property
def err_processor(self):
""" The processor that raised the exception """
return self._err_processor
@property
def processor_type(self):
return type(self.err_processor).__name__
@property
def processors_list(self):
return self._processors_list
@property
def provided_reqs(self):
return self._provided_reqs
def build_message(self):
self.message = (f"---\nPipeline Requirements Error!\n"
f"\tProcessor: {self.processor_type}\n"
f"\tPipeline processors list: {','.join(self.processors_list)}\n"
f"\tProcessor Requirements: {self.err_processor.requires}\n"
f"\t\t- fulfilled: {self.err_processor.requires.intersection(self.provided_reqs)}\n"
f"\t\t- missing: {self.err_processor.requires - self.provided_reqs}\n"
f"\nThe processors list provided for this pipeline is invalid. Please make sure all "
f"prerequisites are met for every processor.\n\n")
def __str__(self):
return self.message
class Processor(ABC):
""" Base class for all processors """
def __init__(self, config, pipeline, device):
# overall config for the processor
self._config = config
# pipeline building this processor (presently processors are only meant to exist in one pipeline)
self._pipeline = pipeline
self._set_up_variants(config, device)
# run set up process
# set up what annotations are required based on config
if not self._set_up_variant_requires():
self._set_up_requires()
# set up what annotations are provided based on config
self._set_up_provides()
# given pipeline constructing this processor, check if requirements are met, throw exception if not
self._check_requirements()
if hasattr(self, '_variant') and self._variant.OVERRIDE:
self.process = self._variant.process
def __str__(self):
"""
Simple description of the processor: name(model)
"""
name = self.__class__.__name__
model = None
if self._config is not None:
model = self._config.get('model_path')
if model is None:
return name
else:
return "{}({})".format(name, model)
@abstractmethod
def process(self, doc):
""" Process a Document. This is the main method of a processor. """
pass
def bulk_process(self, docs):
""" Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
if hasattr(self, '_variant'):
return self._variant.bulk_process(docs)
return [self.process(doc) for doc in docs]
def _set_up_provides(self):
""" Set up what processor requirements this processor fulfills. Default is to use a class defined list. """
self._provides = self.__class__.PROVIDES_DEFAULT
def _set_up_requires(self):
""" Set up requirements for this processor. Default is to use a class defined list. """
self._requires = self.__class__.REQUIRES_DEFAULT
def _set_up_variant_requires(self):
"""
If this has a variant with its own requirements, use those instead
Returns True iff the _requires is set from the _variant
"""
if not hasattr(self, '_variant'):
return False
if hasattr(self._variant, '_set_up_requires'):
self._variant._set_up_requires()
self._requires = self._variant._requires
return True
if hasattr(self._variant.__class__, 'REQUIRES_DEFAULT'):
self._requires = self._variant.__class__.REQUIRES_DEFAULT
return True
return False
def _set_up_variants(self, config, device):
processor_name = list(self.__class__.PROVIDES_DEFAULT)[0]
if any(config.get(f'with_{variant}', False) for variant in PROCESSOR_VARIANTS[processor_name]):
self._trainer = None
variant_name = [variant for variant in PROCESSOR_VARIANTS[processor_name] if config.get(f'with_{variant}', False)][0]
self._variant = PROCESSOR_VARIANTS[processor_name][variant_name](config)
@property
def config(self):
""" Configurations for the processor """
return self._config
@property
def pipeline(self):
""" The pipeline that this processor belongs to """
return self._pipeline
@property
def provides(self):
return self._provides
@property
def requires(self):
return self._requires
def _check_requirements(self):
""" Given a list of fulfilled requirements, check if all of this processor's requirements are met or not. """
if not self.config.get("check_requirements", True):
return
provided_reqs = set.union(*[processor.provides for processor in self.pipeline.loaded_processors]+[set([])])
if self.requires - provided_reqs:
load_names = [item[0] for item in self.pipeline.load_list]
raise ProcessorRequirementsException(load_names, self, provided_reqs)
class ProcessorVariant(ABC):
""" Base class for all processor variants """
OVERRIDE = False # Set to true to override all the processing from the processor
@abstractmethod
def process(self, doc):
"""
Process a document that is potentially preprocessed by the processor.
This is the main method of a processor variant.
If `OVERRIDE` is set to True, all preprocessing by the processor would be bypassed, and the processor variant
would serve as a drop-in replacement of the entire processor, and has to be able to interpret all the configs
that are typically handled by the processor it replaces.
"""
pass
def bulk_process(self, docs):
""" Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
return [self.process(doc) for doc in docs]
class UDProcessor(Processor):
""" Base class for the neural UD Processors (tokenize,mwt,pos,lemma,depparse,sentiment,constituency) """
def __init__(self, config, pipeline, device):
super().__init__(config, pipeline, device)
# UD model resources, set up is processor specific
self._pretrain = None
self._trainer = None
self._vocab = None
if not hasattr(self, '_variant'):
self._set_up_model(config, pipeline, device)
# build the final config for the processor
self._set_up_final_config(config)
@abstractmethod
def _set_up_model(self, config, pipeline, device):
pass
def _set_up_final_config(self, config):
""" Finalize the configurations for this processor, based off of values from a UD model. """
# set configurations from loaded model
if self._trainer is not None:
loaded_args, self._vocab = self._trainer.args, self._trainer.vocab
# filter out unneeded args from model
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
else:
loaded_args = {}
loaded_args.update(config)
self._config = loaded_args
def mark_inactive(self):
""" Drop memory intensive resources if keeping this processor around for reasons other than running it. """
self._trainer = None
self._vocab = None
@property
def pretrain(self):
return self._pretrain
@property
def trainer(self):
return self._trainer
@property
def vocab(self):
return self._vocab
@staticmethod
def filter_out_option(option):
""" Filter out non-processor configurations """
options_to_filter = ['device', 'cpu', 'cuda', 'dev_conll_gold', 'epochs', 'lang', 'mode', 'save_name', 'shorthand']
if option.endswith('_file') or option.endswith('_dir'):
return True
elif option in options_to_filter:
return True
else:
return False
def bulk_process(self, docs):
"""
Most processors operate on the sentence level, where each sentence is processed independently and processors can benefit
a lot from the ability to combine sentences from multiple documents for faster batched processing. This is a transparent
implementation that allows these processors to batch process a list of Documents as if they were from a single Document.
"""
if hasattr(self, '_variant'):
return self._variant.bulk_process(docs)
combined_sents = [sent for doc in docs for sent in doc.sentences]
combined_doc = Document([])
combined_doc.sentences = combined_sents
combined_doc.num_tokens = sum(doc.num_tokens for doc in docs)
combined_doc.num_words = sum(doc.num_words for doc in docs)
self.process(combined_doc) # annotations are attached to sentence objects
return docs
class ProcessorRegisterException(Exception):
""" Exception indicating processor or processor registration failure """
def __init__(self, processor_class, expected_parent):
self._processor_class = processor_class
self._expected_parent = expected_parent
self.build_message()
def build_message(self):
self.message = f"Failed to register '{self._processor_class}'. It must be a subclass of '{self._expected_parent}'."
def __str__(self):
return self.message
def register_processor(name):
def wrapper(Cls):
if not issubclass(Cls, Processor):
raise ProcessorRegisterException(Cls, Processor)
NAME_TO_PROCESSOR_CLASS[name] = Cls
PIPELINE_NAMES.append(name)
return Cls
return wrapper
def register_processor_variant(name, variant):
def wrapper(Cls):
if not issubclass(Cls, ProcessorVariant):
raise ProcessorRegisterException(Cls, ProcessorVariant)
PROCESSOR_VARIANTS[name][variant] = Cls
return Cls
return wrapper
================================================
FILE: stanza/pipeline/registry.py
================================================
from collections import defaultdict
# these two get filled by register_processor
NAME_TO_PROCESSOR_CLASS = dict()
PIPELINE_NAMES = []
# this gets filled by register_processor_variant
PROCESSOR_VARIANTS = defaultdict(dict)
================================================
FILE: stanza/pipeline/sentiment_processor.py
================================================
"""Processor that attaches a sentiment score to a sentence
The model used is a generally a model trained on the Stanford
Sentiment Treebank or some similar dataset. When run, this processor
attaches a score in the form of a string to each sentence in the
document.
TODO: a possible way to generalize this would be to make it a
ClassifierProcessor and have "sentiment" be an option.
"""
import dataclasses
import torch
from types import SimpleNamespace
from stanza.models.classifiers.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
@register_processor(SENTIMENT)
class SentimentProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([SENTIMENT])
# set of processor requirements for this processor
# TODO: a constituency based model needs CONSTITUENCY as well
# issue: by the time we load the model in Processor.__init__,
# the requirements are already prepared
REQUIRES_DEFAULT = set([TOKENIZE])
# default batch size, measured in words per batch
DEFAULT_BATCH_SIZE = 5000
def _set_up_model(self, config, pipeline, device):
# get pretrained word vectors
pretrain_path = config.get('pretrain_path', None)
forward_charlm_path = config.get('forward_charlm_path', None)
backward_charlm_path = config.get('backward_charlm_path', None)
# elmo does not have a convenient way to download intermediate
# models the way stanza downloads charlms & pretrains or
# transformers downloads bert etc
# however, elmo in general is not as good as using a
# transformer, so it is unlikely we will ever fix this
args = SimpleNamespace(device = device,
charlm_forward_file = forward_charlm_path,
charlm_backward_file = backward_charlm_path,
wordvec_pretrain_file = pretrain_path,
elmo_model = None,
use_elmo = False,
save_dir = None)
filename = config['model_path']
if filename is None:
raise FileNotFoundError("No model specified for the sentiment processor. Perhaps it is not supported for the language. {}".format(config))
# set up model
trainer = Trainer.load(filename=filename,
args=args,
foundation_cache=pipeline.foundation_cache)
self._trainer = trainer
self._model = trainer.model
self._model_type = self._model.config.model_type
# batch size counted as words
self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)
def _set_up_final_config(self, config):
loaded_args = dataclasses.asdict(self._model.config)
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
loaded_args.update(config)
self._config = loaded_args
def process(self, document):
sentences = self._model.extract_sentences(document)
with torch.no_grad():
labels = self._model.label_sentences(sentences, batch_size=self._batch_size)
# TODO: allow a classifier processor for any attribute, not just sentiment
document.set(SENTIMENT, labels, to_sentence=True)
return document
================================================
FILE: stanza/pipeline/tokenize_processor.py
================================================
"""
Processor for performing tokenization
"""
import copy
import io
import logging
import torch
from stanza.models.tokenization.data import TokenizationDataset
from stanza.models.tokenization.trainer import Trainer
from stanza.models.tokenization.utils import output_predictions
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
from stanza.pipeline.registry import PROCESSOR_VARIANTS
from stanza.models.common import doc
# these imports trigger the "register_variant" decorations
from stanza.pipeline.external.jieba import JiebaTokenizer
from stanza.pipeline.external.spacy import SpacyTokenizer
from stanza.pipeline.external.sudachipy import SudachiPyTokenizer
from stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer
logger = logging.getLogger('stanza')
TOKEN_TOO_LONG_REPLACEMENT = ""
# class for running the tokenizer
@register_processor(name=TOKENIZE)
class TokenizeProcessor(UDProcessor):
# set of processor requirements this processor fulfills
PROVIDES_DEFAULT = set([TOKENIZE])
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([])
# default max sequence length
MAX_SEQ_LENGTH_DEFAULT = 1000
def _set_up_model(self, config, pipeline, device):
# set up trainer
if config.get('pretokenized'):
self._trainer = None
else:
args = {'charlm_forward_file': config.get('forward_charlm_path', None)}
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
# get and typecheck the postprocessor
postprocessor = config.get('postprocessor')
if postprocessor and callable(postprocessor):
self._postprocessor = postprocessor
elif not postprocessor:
self._postprocessor = None
else:
raise ValueError("Tokenizer received 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s" % postprocessor)
def process_pre_tokenized_text(self, input_src):
"""
Pretokenized text can be provided in 2 manners:
1.) str, tokenized by whitespace, sentence split by newline
2.) list of token lists, each token list represents a sentence
generate dictionary data structure
"""
document = []
if isinstance(input_src, str):
sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0]
elif isinstance(input_src, list):
sentences = input_src
idx = 0
for sentence in sentences:
sent = []
for token_id, token in enumerate(sentence):
sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'})
idx += len(token) + 1
document.append(sent)
raw_text = ' '.join([' '.join(sentence) for sentence in sentences])
return raw_text, document
def process(self, document):
if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))):
raise ValueError("If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object. Got %s" % str(type(document)))
if isinstance(document, doc.Document):
if self.config.get('pretokenized'):
return document
document = document.text
if self.config.get('pretokenized'):
raw_text, document = self.process_pre_tokenized_text(document)
return doc.Document(document, raw_text)
if hasattr(self, '_variant'):
return self._variant.process(document)
raw_text = '\n\n'.join(document) if isinstance(document, list) else document
max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT)
# set up batches
batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)
# get dict data
with torch.no_grad():
_, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,
max_seq_len,
orig_text=raw_text,
no_ssplit=self.config.get('no_ssplit', False),
num_workers = self.config.get('num_workers', 0),
postprocessor = self._postprocessor)
# replace excessively long tokens with to avoid downstream GPU memory issues in POS
for sentence in document:
for token in sentence:
if len(token['text']) > max_seq_len:
token['text'] = TOKEN_TOO_LONG_REPLACEMENT
return doc.Document(document, raw_text)
def bulk_process(self, docs):
"""
The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling.
Essentially, this method concatenates the text of multiple documents with "\n\n", tokenizes it with the neural tokenizer,
then splits the result into the original Documents and recovers the original character offsets.
"""
if hasattr(self, '_variant'):
return self._variant.bulk_process(docs)
if self.config.get('pretokenized'):
res = []
for document in docs:
if len(document.sentences) > 0:
# perhaps this is a document already tokenized,
# being sent back in for more analysis / reparsing / etc?
# in that case, no need to try to tokenize it
# based on whitespace tokenizing the document text
# which, interestingly, may not even exist depending on
# how the document was created)
# by making a whole deepcopy, the original Document is unchanged
res.append(copy.deepcopy(document))
else:
raw_text, document = self.process_pre_tokenized_text(document.text)
res.append(doc.Document(document, raw_text))
return res
combined_text = '\n\n'.join([thisdoc.text for thisdoc in docs])
processed_combined = self.process(doc.Document([], text=combined_text))
# postprocess sentences and tokens to reset back pointers and char offsets
charoffset = 0
sentst = senten = 0
for thisdoc in docs:
while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text):
senten += 1
sentences = processed_combined.sentences[sentst:senten]
thisdoc.sentences = sentences
for sent in sentences:
# fix doc back pointers for sentences
sent._doc = thisdoc
# fix char offsets for tokens and words
for token in sent.tokens:
token._start_char -= charoffset
token._end_char -= charoffset
if token.words: # not-yet-processed MWT can leave empty tokens
for word in token.words:
word._start_char -= charoffset
word._end_char -= charoffset
# Here we need to fix up the SpacesAfter for the very last token
# and the SpacesBefore for the first token of the next doc
# After all, we had connected the text with \n\n
# Need to be careful about this - in a case such as
# " -text one- "
# " -text two- "
# We want the SpacesBefore for the second document to reflect
# the extra space at the start of its text
# and the SpacesAfter for the first document to reflect
# the whitespace after its text
if len(sentences) > 0:
last_token = sentences[-1].tokens[-1]
last_whitespace = thisdoc.text[last_token.end_char:]
last_token.spaces_after = last_whitespace
first_token = sentences[0].tokens[0]
first_whitespace = thisdoc.text[:first_token.start_char]
first_token.spaces_before = first_whitespace
thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences)
thisdoc.num_words = sum(len(sent.words) for sent in sentences)
sentst = senten
charoffset += len(thisdoc.text) + 2
return docs
================================================
FILE: stanza/protobuf/CoreNLP_pb2.py
================================================
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: CoreNLP.proto
# Protobuf Python Version: 4.25.5
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xf6\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x0f\n\x07mwtMisc\x18N \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r\x12\r\n\x05index\x18O \x01(\r\x12\x12\n\nemptyIndex\x18P \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x9b\x04\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x12/\n\x05token\x18\x04 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x14\n\x08rootNode\x18\x05 \x03(\rB\x02\x10\x01\x1aX\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x12\x12\n\nemptyIndex\x18\x04 \x01(\r\x1a\xd6\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12\x13\n\x0bsourceEmpty\x18\x08 \x01(\r\x12\x13\n\x0btargetEmpty\x18\t \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDependency\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xfb\x06\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\x80\x01\n\tNamedEdge\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0e\n\x06source\x18\x02 \x02(\x05\x12\x0e\n\x06target\x18\x03 \x02(\x05\x12\x0c\n\x04reln\x18\x04 \x01(\t\x12\x0f\n\x07isExtra\x18\x05 \x01(\x08\x12\x12\n\nsourceCopy\x18\x06 \x01(\r\x12\x12\n\ntargetCopy\x18\x07 \x01(\r\x1a-\n\x0eVariableString\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\r\n\x05value\x18\x02 \x02(\t\x1a\xe6\x02\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x12\x42\n\x04\x65\x64ge\x18\x06 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge\x12L\n\tvarstring\x18\x07 \x03(\x0b\x32\x39.edu.stanford.nlp.pipeline.SemgrexResponse.VariableString\x12\x15\n\rsentenceIndex\x18\x04 \x01(\x05\x12\x14\n\x0csemgrexIndex\x18\x05 \x01(\x05\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"\xf0\x01\n\x0fSsurgeonRequest\x12\x45\n\x08ssurgeon\x18\x01 \x03(\x0b\x32\x33.edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon\x12\x39\n\x05graph\x18\x02 \x03(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x1a[\n\x08Ssurgeon\x12\x0f\n\x07semgrex\x18\x01 \x01(\t\x12\x11\n\toperation\x18\x02 \x03(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\r\n\x05notes\x18\x04 \x01(\t\x12\x10\n\x08language\x18\x05 \x01(\t\"\xbc\x01\n\x10SsurgeonResponse\x12J\n\x06result\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult\x1a\\\n\x0eSsurgeonResult\x12\x39\n\x05graph\x18\x01 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0f\n\x07\x63hanged\x18\x02 \x01(\x08\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"E\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01\x12\x0f\n\x07kbestF1\x18\x02 \x01(\x01\x12\x0e\n\x06treeF1\x18\x03 \x03(\x01\"\xc8\x01\n\x0fTsurgeonRequest\x12H\n\noperations\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TsurgeonRequest.Operation\x12<\n\x05trees\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x1a-\n\tOperation\x12\x0e\n\x06tregex\x18\x01 \x02(\t\x12\x10\n\x08tsurgeon\x18\x02 \x03(\t\"P\n\x10TsurgeonResponse\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x85\x01\n\x11MorphologyRequest\x12\x46\n\x05words\x18\x01 \x03(\x0b\x32\x37.edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord\x1a(\n\nTaggedWord\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\"\x9a\x01\n\x12MorphologyResponse\x12I\n\x05words\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma\x1a\x39\n\x0cWordTagLemma\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\x12\r\n\x05lemma\x18\x03 \x02(\t\"Z\n\x1a\x44\x65pendencyConverterRequest\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x90\x02\n\x1b\x44\x65pendencyConverterResponse\x12`\n\x0b\x63onversions\x18\x01 \x03(\x0b\x32K.edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion\x1a\x8e\x01\n\x14\x44\x65pendencyConversion\x12\x39\n\x05graph\x18\x01 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12;\n\x04tree\x18\x02 \x01(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'CoreNLP_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos'
_globals['_DEPENDENCYGRAPH'].fields_by_name['root']._options = None
_globals['_DEPENDENCYGRAPH'].fields_by_name['root']._serialized_options = b'\020\001'
_globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._options = None
_globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._serialized_options = b'\020\001'
_globals['_LANGUAGE']._serialized_start=13585
_globals['_LANGUAGE']._serialized_end=13748
_globals['_SENTIMENT']._serialized_start=13750
_globals['_SENTIMENT']._serialized_end=13854
_globals['_NATURALLOGICRELATION']._serialized_start=13857
_globals['_NATURALLOGICRELATION']._serialized_end=14004
_globals['_DOCUMENT']._serialized_start=45
_globals['_DOCUMENT']._serialized_end=782
_globals['_SENTENCE']._serialized_start=785
_globals['_SENTENCE']._serialized_end=2820
_globals['_TOKEN']._serialized_start=2823
_globals['_TOKEN']._serialized_end=4477
_globals['_QUOTE']._serialized_start=4480
_globals['_QUOTE']._serialized_end=4964
_globals['_PARSETREE']._serialized_start=4967
_globals['_PARSETREE']._serialized_end=5166
_globals['_DEPENDENCYGRAPH']._serialized_start=5169
_globals['_DEPENDENCYGRAPH']._serialized_end=5708
_globals['_DEPENDENCYGRAPH_NODE']._serialized_start=5403
_globals['_DEPENDENCYGRAPH_NODE']._serialized_end=5491
_globals['_DEPENDENCYGRAPH_EDGE']._serialized_start=5494
_globals['_DEPENDENCYGRAPH_EDGE']._serialized_end=5708
_globals['_COREFCHAIN']._serialized_start=5711
_globals['_COREFCHAIN']._serialized_end=6037
_globals['_COREFCHAIN_COREFMENTION']._serialized_start=5836
_globals['_COREFCHAIN_COREFMENTION']._serialized_end=6037
_globals['_MENTION']._serialized_start=6040
_globals['_MENTION']._serialized_end=7175
_globals['_INDEXEDWORD']._serialized_start=7177
_globals['_INDEXEDWORD']._serialized_end=7265
_globals['_SPEAKERINFO']._serialized_start=7267
_globals['_SPEAKERINFO']._serialized_end=7319
_globals['_SPAN']._serialized_start=7321
_globals['_SPAN']._serialized_end=7355
_globals['_TIMEX']._serialized_start=7357
_globals['_TIMEX']._serialized_end=7476
_globals['_ENTITY']._serialized_start=7479
_globals['_ENTITY']._serialized_end=7698
_globals['_RELATION']._serialized_start=7701
_globals['_RELATION']._serialized_end=7884
_globals['_OPERATOR']._serialized_start=7887
_globals['_OPERATOR']._serialized_end=8065
_globals['_POLARITY']._serialized_start=8068
_globals['_POLARITY']._serialized_end=8621
_globals['_NERMENTION']._serialized_start=8624
_globals['_NERMENTION']._serialized_end=8973
_globals['_SENTENCEFRAGMENT']._serialized_start=8975
_globals['_SENTENCEFRAGMENT']._serialized_end=9064
_globals['_TOKENLOCATION']._serialized_start=9066
_globals['_TOKENLOCATION']._serialized_end=9124
_globals['_RELATIONTRIPLE']._serialized_start=9127
_globals['_RELATIONTRIPLE']._serialized_end=9537
_globals['_MAPSTRINGSTRING']._serialized_start=9539
_globals['_MAPSTRINGSTRING']._serialized_end=9584
_globals['_MAPINTSTRING']._serialized_start=9586
_globals['_MAPINTSTRING']._serialized_end=9628
_globals['_SECTION']._serialized_start=9631
_globals['_SECTION']._serialized_end=9883
_globals['_SEMGREXREQUEST']._serialized_start=9886
_globals['_SEMGREXREQUEST']._serialized_end=10114
_globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_start=9992
_globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_end=10114
_globals['_SEMGREXRESPONSE']._serialized_start=10117
_globals['_SEMGREXRESPONSE']._serialized_end=11008
_globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_start=10208
_globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_end=10253
_globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_start=10255
_globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_end=10298
_globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_start=10301
_globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_end=10429
_globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_start=10431
_globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_end=10476
_globals['_SEMGREXRESPONSE_MATCH']._serialized_start=10479
_globals['_SEMGREXRESPONSE_MATCH']._serialized_end=10837
_globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_start=10839
_globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_end=10919
_globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_start=10921
_globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_end=11008
_globals['_SSURGEONREQUEST']._serialized_start=11011
_globals['_SSURGEONREQUEST']._serialized_end=11251
_globals['_SSURGEONREQUEST_SSURGEON']._serialized_start=11160
_globals['_SSURGEONREQUEST_SSURGEON']._serialized_end=11251
_globals['_SSURGEONRESPONSE']._serialized_start=11254
_globals['_SSURGEONRESPONSE']._serialized_end=11442
_globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_start=11350
_globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_end=11442
_globals['_TOKENSREGEXREQUEST']._serialized_start=11444
_globals['_TOKENSREGEXREQUEST']._serialized_end=11531
_globals['_TOKENSREGEXRESPONSE']._serialized_start=11534
_globals['_TOKENSREGEXRESPONSE']._serialized_end=11957
_globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_start=11633
_globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_end=11690
_globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_start=11693
_globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_end=11872
_globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_start=11874
_globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_end=11957
_globals['_DEPENDENCYENHANCERREQUEST']._serialized_start=11960
_globals['_DEPENDENCYENHANCERREQUEST']._serialized_end=12134
_globals['_FLATTENEDPARSETREE']._serialized_start=12137
_globals['_FLATTENEDPARSETREE']._serialized_end=12317
_globals['_FLATTENEDPARSETREE_NODE']._serialized_start=12226
_globals['_FLATTENEDPARSETREE_NODE']._serialized_end=12317
_globals['_EVALUATEPARSERREQUEST']._serialized_start=12320
_globals['_EVALUATEPARSERREQUEST']._serialized_end=12566
_globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_start=12426
_globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_end=12566
_globals['_EVALUATEPARSERRESPONSE']._serialized_start=12568
_globals['_EVALUATEPARSERRESPONSE']._serialized_end=12637
_globals['_TSURGEONREQUEST']._serialized_start=12640
_globals['_TSURGEONREQUEST']._serialized_end=12840
_globals['_TSURGEONREQUEST_OPERATION']._serialized_start=12795
_globals['_TSURGEONREQUEST_OPERATION']._serialized_end=12840
_globals['_TSURGEONRESPONSE']._serialized_start=12842
_globals['_TSURGEONRESPONSE']._serialized_end=12922
_globals['_MORPHOLOGYREQUEST']._serialized_start=12925
_globals['_MORPHOLOGYREQUEST']._serialized_end=13058
_globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_start=13018
_globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_end=13058
_globals['_MORPHOLOGYRESPONSE']._serialized_start=13061
_globals['_MORPHOLOGYRESPONSE']._serialized_end=13215
_globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_start=13158
_globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_end=13215
_globals['_DEPENDENCYCONVERTERREQUEST']._serialized_start=13217
_globals['_DEPENDENCYCONVERTERREQUEST']._serialized_end=13307
_globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_start=13310
_globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_end=13582
_globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_start=13440
_globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_end=13582
# @@protoc_insertion_point(module_scope)
================================================
FILE: stanza/protobuf/__init__.py
================================================
from __future__ import absolute_import
from io import BytesIO
import warnings
from google.protobuf.internal.encoder import _EncodeVarint
from google.protobuf.internal.decoder import _DecodeVarint
from google.protobuf.message import DecodeError
from .CoreNLP_pb2 import *
def parseFromDelimitedString(obj, buf, offset=0):
"""
Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
writes the size (and offset) of the buffer before writing the object.
This function handles parsing this message starting from offset 0.
@returns how many bytes of @buf were consumed.
"""
size, pos = _DecodeVarint(buf, offset)
try:
obj.ParseFromString(buf[offset+pos:offset+pos+size])
except DecodeError as e:
warnings.warn("Failed to decode a serialized output from CoreNLP server. An incomplete or empty object will be returned.", \
RuntimeWarning)
return pos+size
def writeToDelimitedString(obj, stream=None):
"""
Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
writes the size (and offset) of the buffer before writing the object.
This function handles parsing this message starting from offset 0.
@returns how many bytes of @buf were consumed.
"""
if stream is None:
stream = BytesIO()
_EncodeVarint(stream.write, obj.ByteSize(), True)
stream.write(obj.SerializeToString())
return stream
def to_text(sentence):
"""
Helper routine that converts a Sentence protobuf to a string from
its tokens.
"""
text = ""
for i, tok in enumerate(sentence.token):
if i != 0:
text += tok.before
text += tok.word
return text
================================================
FILE: stanza/resources/__init__.py
================================================
================================================
FILE: stanza/resources/common.py
================================================
"""
Common utilities for Stanza resources.
"""
from collections import defaultdict, namedtuple
import errno
import hashlib
import json
import logging
import os
from pathlib import Path
import requests
import shutil
import tempfile
import zipfile
from platformdirs import user_cache_dir
from tqdm.auto import tqdm
from stanza.utils.helper_func import make_table
from stanza.pipeline._constants import TOKENIZE, MWT, POS, LEMMA, DEPPARSE, NER, SENTIMENT
from stanza.pipeline.registry import PIPELINE_NAMES, PROCESSOR_VARIANTS
from stanza.resources.default_packages import PACKAGES
from stanza._version import __resources_version__
logger = logging.getLogger('stanza')
# set home dir for default
USER_CACHE_DIR = user_cache_dir('stanza', 'StanfordNLP', __resources_version__)
STANFORDNLP_RESOURCES_URL = 'https://nlp.stanford.edu/software/stanza/stanza-resources/'
STANZA_RESOURCES_GITHUB = 'https://raw.githubusercontent.com/stanfordnlp/stanza-resources/'
DEFAULT_RESOURCES_URL = os.getenv('STANZA_RESOURCES_URL', STANZA_RESOURCES_GITHUB + 'main')
DEFAULT_RESOURCES_VERSION = os.getenv(
'STANZA_RESOURCES_VERSION',
__resources_version__
)
DEFAULT_MODEL_URL = os.getenv('STANZA_MODEL_URL', 'default')
DEFAULT_MODEL_DIR = os.getenv(
'STANZA_RESOURCES_DIR',
os.path.join(USER_CACHE_DIR, 'resources')
)
PRETRAIN_NAMES = ("pretrain", "forward_charlm", "backward_charlm")
class ResourcesFileNotFoundError(FileNotFoundError):
def __init__(self, resources_filepath):
super().__init__(f"Resources file not found at: {resources_filepath} Try to download the model again.")
self.resources_filepath = resources_filepath
class UnknownLanguageError(ValueError):
def __init__(self, unknown):
super().__init__(f"Unknown language requested: {unknown}")
self.unknown_language = unknown
class UnknownProcessorError(ValueError):
def __init__(self, unknown):
super().__init__(f"Unknown processor type requested: {unknown}")
self.unknown_processor = unknown
ModelSpecification = namedtuple('ModelSpecification', ['processor', 'package', 'dependencies'])
def ensure_dir(path):
"""
Create dir in case it does not exist.
"""
Path(path).mkdir(parents=True, exist_ok=True)
def get_md5(path):
"""
Get the MD5 value of a path.
"""
try:
with open(path, 'rb') as fin:
data = fin.read()
except OSError as e:
if not e.filename:
e.filename = path
raise
return hashlib.md5(data).hexdigest()
def unzip(path, filename):
"""
Fully unzip a file `filename` that's in a directory `dir`.
"""
logger.debug(f'Unzip: {path}/{filename}...')
with zipfile.ZipFile(os.path.join(path, filename)) as f:
f.extractall(path)
def get_root_from_zipfile(filename):
"""
Get the root directory from a archived zip file.
"""
zf = zipfile.ZipFile(filename, "r")
assert len(zf.filelist) > 0, \
f"Zip file at f{filename} seems to be corrupted. Please check it."
return os.path.dirname(zf.filelist[0].filename)
def file_exists(path, md5):
"""
Check if the file at `path` exists and match the provided md5 value.
"""
return os.path.exists(path) and get_md5(path) == md5
def assert_file_exists(path, md5=None, alternate_md5=None):
if not os.path.exists(path):
raise FileNotFoundError(errno.ENOENT, "Cannot find expected file", path)
if md5:
file_md5 = get_md5(path)
if file_md5 != md5:
if file_md5 == alternate_md5:
logger.debug("Found a possibly older version of file %s, md5 %s instead of %s", path, alternate_md5, md5)
else:
raise ValueError("md5 for %s is %s, expected %s" % (path, file_md5, md5))
def download_file(url, path, proxies, raise_for_status=False):
"""
Download a URL into a file as specified by `path`.
"""
verbose = logger.level in [0, 10, 20]
r = requests.get(url, stream=True, proxies=proxies)
if raise_for_status:
r.raise_for_status()
with open(path, 'wb') as f:
file_size = r.headers.get('content-length', None)
if file_size:
file_size = int(file_size)
default_chunk_size = 131072
desc = 'Downloading ' + url
with tqdm(total=file_size, unit='B', unit_scale=True, \
disable=not verbose, desc=desc) as pbar:
for chunk in r.iter_content(chunk_size=default_chunk_size):
if chunk:
f.write(chunk)
f.flush()
pbar.update(len(chunk))
return r.status_code
def request_file(url, path, proxies=None, md5=None, raise_for_status=False, log_info=True, alternate_md5=None):
"""
A complete wrapper over download_file() that also make sure the directory of
`path` exists, and that a file matching the md5 value does not exist.
alternate_md5 allows for an alternate md5 that is acceptable (such as if an older version of a file is okay)
"""
basedir = Path(path).parent
ensure_dir(basedir)
if file_exists(path, md5):
if log_info:
logger.info(f'File exists: {path}')
else:
logger.debug(f'File exists: {path}')
return
# We write data first to a temporary directory,
# then use os.replace() so that multiple processes
# running at the same time don't clobber each other
# with partially downloaded files
# This was especially common with resources.json
with tempfile.TemporaryDirectory(dir=basedir) as temp:
temppath = os.path.join(temp, os.path.split(path)[-1])
download_file(url, temppath, proxies, raise_for_status)
os.replace(temppath, path)
assert_file_exists(path, md5, alternate_md5)
if log_info:
logger.info(f'Downloaded file to {path}')
else:
logger.debug(f'Downloaded file to {path}')
def sort_processors(processor_list):
sorted_list = []
for processor in PIPELINE_NAMES:
for item in processor_list:
if item[0] == processor:
sorted_list.append(item)
# going just by processors in PIPELINE_NAMES, this drops any names
# which are not an official processor but might still be useful
# check the list and append them to the end
# this is especially useful when downloading pretrain or charlm models
for processor in processor_list:
for item in sorted_list:
if processor[0] == item[0]:
break
else:
sorted_list.append(item)
return sorted_list
def add_mwt(processors, resources, lang):
"""Add mwt if tokenize is passed without mwt.
If tokenize is in the list, but mwt is not, and there is a corresponding
tokenize and mwt pair in the resources file, mwt is added so no missing
mwt errors are raised.
"""
value = processors[TOKENIZE]
if value in resources[lang][PACKAGES] and MWT in resources[lang][PACKAGES][value]:
logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
processors[MWT] = value
elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and value in resources[lang][MWT]):
logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
processors[MWT] = value
def maintain_processor_list(resources, lang, package, processors, allow_pretrain=False, maybe_add_mwt=True):
"""
Given a parsed resources file, language, and possible package
and/or processors, expands the package to the list of processors
Returns a list of processors
Each item in the list of processors is a pair:
name, then a list of ModelSpecification
so, for example:
[['pos', [ModelSpecification(processor='pos', package='gsd', dependencies=None)]],
['depparse', [ModelSpecification(processor='depparse', package='gsd', dependencies=None)]]]
"""
processor_list = defaultdict(list)
# resolve processor models
if processors:
logger.debug(f'Processing parameter "processors"...')
if maybe_add_mwt and TOKENIZE in processors and MWT not in processors:
add_mwt(processors, resources, lang)
for key, plist in processors.items():
if not isinstance(key, str):
raise ValueError("Processor names must be strings")
if not isinstance(plist, (tuple, list, str)):
raise ValueError("Processor values must be strings")
if isinstance(plist, str):
plist = [plist]
if key not in PIPELINE_NAMES:
if not allow_pretrain or key not in PRETRAIN_NAMES:
raise UnknownProcessorError(key)
for value in plist:
# check if keys and values can be found
if key in resources[lang] and value in resources[lang][key]:
logger.debug(f'Found {key}: {value}.')
processor_list[key].append(value)
# allow values to be default in some cases
elif value in resources[lang][PACKAGES] and key in resources[lang][PACKAGES][value]:
logger.debug(
f'Found {key}: {resources[lang][PACKAGES][value][key]}.'
)
processor_list[key].append(resources[lang][PACKAGES][value][key])
# optional defaults will be activated if specifically turned on
elif value in resources[lang][PACKAGES] and 'optional' in resources[lang][PACKAGES][value] and key in resources[lang][PACKAGES][value]['optional']:
logger.debug(
f"Found {key}: {resources[lang][PACKAGES][value]['optional'][key]}."
)
processor_list[key].append(resources[lang][PACKAGES][value]['optional'][key])
# allow processors to be set to variants that we didn't implement
elif value in PROCESSOR_VARIANTS[key]:
logger.debug(
f'Found {key}: {value}. '
f'Using external {value} variant for the {key} processor.'
)
processor_list[key].append(value)
# allow lemma to be set to "identity"
elif key == LEMMA and value == 'identity':
logger.debug(
f'Found {key}: {value}. Using identity lemmatizer.'
)
processor_list[key].append(value)
# not a processor in the officially supported processor list
elif key not in resources[lang]:
logger.debug(
f'{key}: {value} is not officially supported by Stanza, '
f'loading it anyway.'
)
processor_list[key].append(value)
# cannot find the package for a processor and warn user
else:
logger.warning(
f'Can not find {key}: {value} from official model list. '
f'Ignoring it.'
)
# resolve package
if package:
logger.debug(f'Processing parameter "package"...')
if PACKAGES in resources[lang] and package in resources[lang][PACKAGES]:
for key, value in resources[lang][PACKAGES][package].items():
if key != 'optional' and key not in processor_list:
logger.debug(f'Found {key}: {value}.')
processor_list[key].append(value)
else:
flag = False
for key in PIPELINE_NAMES:
if key not in resources[lang]: continue
if package in resources[lang][key]:
flag = True
if key not in processor_list:
logger.debug(f'Found {key}: {package}.')
processor_list[key].append(package)
else:
logger.debug(
f'{key}: {package} is overwritten by '
f'{key}: {processors[key]}.'
)
if not flag: logger.warning((f'Can not find package: {package}.'))
processor_list = [[key, [ModelSpecification(processor=key, package=value, dependencies=None) for value in plist]] for key, plist in processor_list.items()]
processor_list = sort_processors(processor_list)
return processor_list
def add_dependencies(resources, lang, processor_list):
"""
Expand the processor_list as given in maintain_processor_list to have the dependencies
Still a list of model types to ModelSpecifications
the dependencies are tuples: name and package
for example:
[['pos', (ModelSpecification(processor='pos', package='gsd', dependencies=(('pretrain', 'gsd'),)),)],
['depparse', (ModelSpecification(processor='depparse', package='gsd', dependencies=(('pretrain', 'gsd'),)),)]]
"""
lang_resources = resources[lang]
for item in processor_list:
processor, model_specs = item
new_model_specs = []
for model_spec in model_specs:
# skip dependency checking for external variants of processors and identity lemmatizer
if not any([
model_spec.package in PROCESSOR_VARIANTS[processor],
processor == LEMMA and model_spec.package == 'identity'
]):
dependencies = lang_resources.get(processor, {}).get(model_spec.package, {}).get('dependencies', [])
dependencies = [(dependency['model'], dependency['package']) for dependency in dependencies]
model_spec = model_spec._replace(dependencies=tuple(dependencies))
logger.debug("Found dependencies %s for processor %s model %s", dependencies, processor, model_spec.package)
new_model_specs.append(model_spec)
item[1] = tuple(new_model_specs)
return processor_list
def flatten_processor_list(processor_list):
"""
The flattened processor list is just a list of types & packages
For example:
[['pos', 'gsd'], ['depparse', 'gsd'], ['pretrain', 'gsd']]
"""
flattened_processor_list = []
dependencies_list = []
for item in processor_list:
processor, model_specs = item
for model_spec in model_specs:
package = model_spec.package
dependencies = model_spec.dependencies
flattened_processor_list.append([processor, package])
if dependencies:
dependencies_list += [tuple(dependency) for dependency in dependencies]
dependencies_list = [list(item) for item in set(dependencies_list)]
for processor, package in dependencies_list:
logger.debug(f'Find dependency {processor}: {package}.')
flattened_processor_list += dependencies_list
return flattened_processor_list
def set_logging_level(logging_level, verbose):
# Check verbose for easy logging control
if verbose == False:
logging_level = 'ERROR'
elif verbose == True:
logging_level = 'INFO'
if logging_level is None:
# default logging level of INFO is set in stanza.__init__
# but the user may have set it via the logging API
# it should NOT be 0, but let's check to be sure...
if logger.level == 0:
logger.setLevel('INFO')
return logger.level
# Set logging level
logging_level = logging_level.upper()
all_levels = ['DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL']
if logging_level not in all_levels:
raise ValueError(
f"Unrecognized logging level for pipeline: "
f"{logging_level}. Must be one of {', '.join(all_levels)}."
)
logger.setLevel(logging_level)
return logger.level
def process_pipeline_parameters(lang, model_dir, package, processors):
# Check parameter types and convert values to lower case
if isinstance(lang, str):
lang = lang.strip().lower()
elif lang is not None:
raise TypeError(
f"The parameter 'lang' should be str, "
f"but got {type(lang).__name__} instead."
)
if isinstance(model_dir, str):
model_dir = model_dir.strip()
elif model_dir is not None:
raise TypeError(
f"The parameter 'model_dir' should be str, "
f"but got {type(model_dir).__name__} instead."
)
if isinstance(processors, (str, list, tuple)):
# Special case: processors is str, compatible with older version
# also allow for setting alternate packages for these processors
# via the package argument
if package is None:
# each processor will be 'default' for this language
package = defaultdict(lambda: 'default')
elif isinstance(package, str):
# same, but now the named package will be the default instead
default = package
package = defaultdict(lambda: default)
elif isinstance(package, dict):
# the dictionary of packages will be used to build the processors dict
# any processor not specified in package will be 'default'
package = defaultdict(lambda: 'default', package)
else:
raise TypeError(
f"The parameter 'package' should be None, str, or dict, "
f"but got {type(package).__name__} instead."
)
if isinstance(processors, str):
processors = [x.strip().lower() for x in processors.split(",")]
processors = {
processor: package[processor] for processor in processors
}
package = None
elif isinstance(processors, dict):
processors = {
k.strip().lower(): ([v_i.strip().lower() for v_i in v] if isinstance(v, (tuple, list)) else v.strip().lower())
for k, v in processors.items()
}
elif processors is not None:
raise TypeError(
f"The parameter 'processors' should be dict or str, "
f"but got {type(processors).__name__} instead."
)
if isinstance(package, str):
package = package.strip().lower()
elif package is not None:
raise TypeError(
f"The parameter 'package' should be str, or a dict if 'processors' is a str, "
f"but got {type(package).__name__} instead."
)
return lang, model_dir, package, processors
def download_resources_json(model_dir=DEFAULT_MODEL_DIR,
resources_url=DEFAULT_RESOURCES_URL,
resources_branch=None,
resources_version=DEFAULT_RESOURCES_VERSION,
resources_filepath=None,
proxies=None):
"""
Downloads resources.json to obtain latest packages.
"""
if resources_url == DEFAULT_RESOURCES_URL and resources_branch is not None:
resources_url = STANZA_RESOURCES_GITHUB + resources_branch
# handle short name for resources urls; otherwise treat it as url
if resources_url.lower() in ('stanford', 'stanfordnlp'):
resources_url = STANFORDNLP_RESOURCES_URL
resources_url = f'{resources_url}/resources_{resources_version}.json'
logger.debug('Downloading resource file from %s', resources_url)
if resources_filepath is None:
resources_filepath = os.path.join(model_dir, 'resources.json')
# make request
request_file(
resources_url,
resources_filepath,
proxies,
raise_for_status=True
)
def load_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_filepath=None):
"""
Unpack the resources json file from the given model_dir
"""
if resources_filepath is None:
resources_filepath = os.path.join(model_dir, 'resources.json')
if not os.path.exists(resources_filepath):
raise ResourcesFileNotFoundError(resources_filepath)
with open(resources_filepath, encoding="utf-8") as fin:
resources = json.load(fin)
return resources
def get_language_resources(resources, lang):
"""
Get the resources for a lang from an already loaded resources json, following 'alias' if needed
"""
if lang not in resources:
return None
lang_resources = resources[lang]
while 'alias' in lang_resources:
lang = lang_resources['alias']
lang_resources = resources[lang]
return lang_resources
def list_available_languages(model_dir=DEFAULT_MODEL_DIR,
resources_url=DEFAULT_RESOURCES_URL,
resources_branch=None,
resources_version=DEFAULT_RESOURCES_VERSION,
proxies=None):
"""
List the non-alias languages in the resources file
"""
download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
resources = load_resources_json(model_dir)
# isinstance(str) is because of fields such as "url"
# 'alias' is because we want to skip German, alias of de, for example
languages = [lang for lang in resources
if not isinstance(resources[lang], str) and 'alias' not in resources[lang]]
languages = sorted(languages)
return languages
def expand_model_url(resources, model_url):
"""
Returns the url in the resources dict if model_url is default, or returns the model_url
"""
return resources['url'] if model_url.lower() == 'default' else model_url
def download_models(download_list,
resources,
lang,
model_dir=DEFAULT_MODEL_DIR,
resources_version=DEFAULT_RESOURCES_VERSION,
model_url=DEFAULT_MODEL_URL,
proxies=None,
log_info=True):
lang_name = resources.get(lang, {}).get('lang_name', lang)
download_table = make_table(['Processor', 'Package'], download_list)
if log_info:
log_msg = logger.info
else:
log_msg = logger.debug
log_msg(
f'Downloading these customized packages for language: '
f'{lang} ({lang_name})...\n{download_table}'
)
url = expand_model_url(resources, model_url)
# Download packages
for key, value in download_list:
try:
request_file(
url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"),
os.path.join(model_dir, lang, key, f'{value}.pt'),
proxies,
md5=resources[lang][key][value]['md5'],
log_info=log_info,
alternate_md5=resources[lang][key][value].get('alternate_md5', None)
)
except KeyError as e:
raise ValueError(
f'Cannot find the following processor and model name combination: '
f'{key}, {value}. Please check if you have provided the correct model name.'
) from e
# main download function
def download(
lang='en',
model_dir=DEFAULT_MODEL_DIR,
package='default',
processors={},
logging_level=None,
verbose=None,
resources_url=DEFAULT_RESOURCES_URL,
resources_branch=None,
resources_version=DEFAULT_RESOURCES_VERSION,
model_url=DEFAULT_MODEL_URL,
proxies=None,
download_json=True
):
# set global logging level
set_logging_level(logging_level, verbose)
# process different pipeline parameters
lang, model_dir, package, processors = process_pipeline_parameters(
lang, model_dir, package, processors
)
if download_json or not os.path.exists(os.path.join(model_dir, 'resources.json')):
if not download_json:
logger.warning("Asked to skip downloading resources.json, but the file does not exist. Downloading anyway")
download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
resources = load_resources_json(model_dir)
if lang not in resources:
raise UnknownLanguageError(lang)
if 'alias' in resources[lang]:
logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
lang = resources[lang]['alias']
lang_name = resources.get(lang, {}).get('lang_name', lang)
url = expand_model_url(resources, model_url)
# Default: download zipfile and unzip
if package == 'default' and (processors is None or len(processors) == 0):
logger.info(
f'Downloading default packages for language: {lang} ({lang_name}) ...'
)
# want the URL to become, for example:
# https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip
# so we hopefully start from
# https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}
request_file(
url.format(resources_version=resources_version, lang=lang, filename="default.zip"),
os.path.join(model_dir, lang, f'default.zip'),
proxies,
md5=resources[lang]['default_md5'],
)
unzip(os.path.join(model_dir, lang), 'default.zip')
download_list = [['zip', 'default.zip']]
# Customize: maintain download list
else:
download_list = maintain_processor_list(resources, lang, package, processors, allow_pretrain=True)
download_list = add_dependencies(resources, lang, download_list)
download_list = flatten_processor_list(download_list)
download_models(download_list=download_list,
resources=resources,
lang=lang,
model_dir=model_dir,
resources_version=resources_version,
model_url=model_url,
proxies=proxies,
log_info=True)
logger.info(f'Finished downloading models and saved to {model_dir}')
return download_list
================================================
FILE: stanza/resources/default_packages.py
================================================
"""
Constants for default packages, default pretrains, charlms, etc
Separated from prepare_resources.py so that other modules can use the
same lists / maps without importing the resources script and possibly
causing a circular import
"""
import copy
# all languages will have a map which represents the available packages
PACKAGES = "packages"
# default treebank for languages
default_treebanks = {
"ab": "abnc",
"af": "afribooms",
# currently not publicly released! sent to us from the group developing this resource
"ang": "nerthus",
"ar": "padt",
"be": "hse",
"bg": "btb",
"bxr": "bdt",
"ca": "ancora",
"cop": "scriptorium",
"cs": "pdt",
"cu": "proiel",
"cy": "ccg",
"da": "ddt",
"de": "combined",
"el": "gdt",
"en": "combined",
"es": "combined",
"et": "edt",
"eu": "bdt",
"fa": "perdt",
"fi": "tdt",
"fo": "farpahc",
"fr": "combined",
"fro": "profiterole",
"ga": "idt",
"gd": "arcosg",
"gl": "ctg",
"got": "proiel",
"grc": "perseus",
"gv": "cadhan",
"hbo": "ptnk",
"he": "combined",
"hi": "hdtb",
"hr": "set",
"hsb": "ufal",
"hu": "szeged",
"hy": "armtdp",
"hyw": "armtdp",
"id": "gsd",
"is": "icepahc",
"it": "combined",
"ja": "combined",
"ka": "glc",
"kk": "ktb",
"kmr": "mg",
"ko": "kaist",
"kpv": "lattice",
"ky": "ktmu",
"la": "ittb",
"lij": "glt",
"lt": "alksnis",
"lv": "lvtb",
"lzh": "kyoto",
"mr": "ufal",
"mt": "mudt",
"my": "ucsy",
"myv": "jr",
"nb": "bokmaal",
"nds": "lsdc",
"nl": "alpino",
"nn": "nynorsk",
"olo": "kkpp",
"orv": "torot",
"ota": "boun",
"pcm": "nsc",
"pl": "pdb",
"pt": "bosque",
"qaf": "arabizi",
"qpm": "philotis",
"qtd": "sagt",
"ro": "rrt",
"ru": "syntagrus",
"sa": "vedic",
"sd": "isra",
"sk": "snk",
"sl": "ssj",
"sme": "giella",
"sq": "combined",
"sr": "set",
"sv": "talbanken",
"swl": "sslc",
"ta": "ttb",
"te": "mtg",
"th": "tud",
"tr": "imst",
"ug": "udt",
"uk": "iu",
"ur": "udtb",
"vi": "vtb",
"wo": "wtb",
"xcl": "caval",
"zh-hans": "gsdsimp",
"zh-hant": "gsd",
"multilingual": "ud"
}
no_pretrain_languages = set([
"cop",
"olo",
"orv",
"pcm",
"qaf", # the QAF treebank is code switched and Romanized, so not easy to reuse existing resources
"qpm", # have talked about deriving this from a language neighborinig to Pomak, but that hasn't happened yet
"qtd",
"swl",
"multilingual", # special case so that all languages with a default treebank are represented somewhere
])
# in some cases, we give the pretrain a name other than the original
# name for the UD dataset
# we will eventually do this for all of the pretrains
specific_default_pretrains = {
"ab": "fasttextwiki",
"af": "fasttextwiki",
"ang": "nerthus",
"ar": "conll17",
"be": "fasttextwiki",
"bg": "conll17",
"bxr": "fasttextwiki",
"ca": "conll17",
"cs": "conll17",
"cu": "conll17",
"cy": "fasttext157",
"da": "conll17",
"de": "conll17",
"el": "conll17",
"en": "conll17",
"es": "conll17",
"et": "conll17",
"eu": "conll17",
"fa": "conll17",
"fi": "conll17",
"fo": "fasttextwiki",
"fr": "conll17",
"fro": "conll17",
"ga": "conll17",
"gd": "fasttextwiki",
"gl": "conll17",
"got": "fasttextwiki",
"grc": "conll17",
"gv": "fasttext157",
"hbo": "utah",
"he": "conll17",
"hi": "conll17",
"hr": "conll17",
"hsb": "fasttextwiki",
"hu": "conll17",
"hy": "isprasglove",
"hyw": "isprasglove",
"id": "conll17",
"is": "fasttext157",
"it": "conll17",
"ja": "conll17",
"ka": "fasttext157",
"kk": "fasttext157",
"kmr": "fasttextwiki",
"ko": "conll17",
"kpv": "fasttextwiki",
"ky": "fasttext157",
"la": "conll17",
"lij": "fasttextwiki",
"lt": "fasttextwiki",
"lv": "conll17",
"lzh": "fasttextwiki",
"mr": "fasttextwiki",
"mt": "fasttextwiki",
"my": "ucsy",
"myv": "mokha",
"nb": "conll17",
"nds": "fasttext157",
"nl": "conll17",
"nn": "conll17",
"or": "fasttext157",
"ota": "conll17",
"pl": "conll17",
"pt": "conll17",
"ro": "conll17",
"ru": "conll17",
"sa": "fasttext157",
"sd": "isra",
"sk": "conll17",
"sl": "conll17",
"sme": "fasttextwiki",
"sq": "fasttext157",
"sr": "fasttextwiki",
"sv": "conll17",
"ta": "fasttextwiki",
"te": "fasttextwiki",
"th": "fasttext157",
"tr": "conll17",
"ug": "conll17",
"uk": "conll17",
"ur": "conll17",
"vi": "conll17",
"wo": "fasttextwiki",
"xcl": "caval",
"zh-hans": "fasttext157",
"zh-hant": "conll17",
}
def build_default_pretrains(default_treebanks):
default_pretrains = dict(default_treebanks)
for lang in no_pretrain_languages:
default_pretrains.pop(lang, None)
for lang in specific_default_pretrains.keys():
default_pretrains[lang] = specific_default_pretrains[lang]
return default_pretrains
default_pretrains = build_default_pretrains(default_treebanks)
pos_pretrains = {
"en": {
"craft": "biomed",
"genia": "biomed",
"mimic": "mimic",
},
}
depparse_pretrains = pos_pretrains
ner_pretrains = {
"ar": {
"aqmar": "fasttextwiki",
},
"de": {
"conll03": "fasttextwiki",
# the bert version of germeval uses the smaller vector file
"germeval2014": "fasttextwiki",
},
"en": {
"anatem": "biomed",
"bc4chemd": "biomed",
"bc5cdr": "biomed",
"bionlp13cg": "biomed",
"jnlpba": "biomed",
"linnaeus": "biomed",
"ncbi_disease": "biomed",
"s800": "biomed",
"ontonotes": "fasttextcrawl",
# the stanza-train sample NER model should use the default NER pretrain
# for English, that is the same as ontonotes
"sample": "fasttextcrawl",
"conll03": "glove",
"i2b2": "mimic",
"radiology": "mimic",
},
"es": {
"ancora": "fasttextwiki",
"conll02": "fasttextwiki",
},
"nl": {
"conll02": "fasttextwiki",
"wikiner": "fasttextwiki",
},
"ru": {
"wikiner": "fasttextwiki",
},
"th": {
"lst20": "fasttext157",
},
}
# default charlms for languages
default_charlms = {
"af": "oscar",
"ang": "nerthus1024",
"ar": "ccwiki",
"bg": "conll17",
"da": "oscar",
"de": "newswiki",
"en": "1billion",
"es": "newswiki",
"fa": "conll17",
"fi": "conll17",
"fr": "newswiki",
"he": "oscar",
"hi": "oscar",
"id": "oscar2023",
"it": "conll17",
"ja": "conll17",
"kk": "oscar",
"mr": "l3cube",
"my": "oscar",
"nb": "conll17",
"nl": "ccwiki",
"pl": "oscar",
"pt": "oscar2023",
"ru": "newswiki",
"sd": "isra",
"sv": "conll17",
"te": "oscar2022",
"th": "oscar",
"tr": "conll17",
"uk": "conll17",
"vi": "conll17",
"zh-hans": "gigaword"
}
pos_charlms = {
"en": {
# none of the English charlms help with craft or genia
"craft": None,
"genia": None,
"mimic": "mimic",
},
"tr": { # no idea why, but this particular one goes down in dev score
"boun": None,
},
}
depparse_charlms = copy.deepcopy(pos_charlms)
lemma_charlms = copy.deepcopy(pos_charlms)
tokenizer_charlms = copy.deepcopy(pos_charlms)
ner_charlms = {
"en": {
"conll03": "1billion",
"ontonotes": "1billion",
"anatem": "pubmed",
"bc4chemd": "pubmed",
"bc5cdr": "pubmed",
"bionlp13cg": "pubmed",
"i2b2": "mimic",
"jnlpba": "pubmed",
"linnaeus": "pubmed",
"ncbi_disease": "pubmed",
"radiology": "mimic",
"s800": "pubmed",
},
"hu": {
"combined": None,
},
"nn": {
"norne": None,
},
}
# default ner for languages
default_ners = {
"af": "nchlt",
"ang": "oedt_charlm",
"ar": "aqmar_charlm",
"bg": "bsnlp19",
"da": "ddt",
"de": "germeval2014",
"en": "ontonotes-ww-multi_charlm",
"es": "conll02",
"fa": "arman",
"fi": "turku",
"fr": "wikinergold_charlm",
"he": "iahlt_charlm",
"hi": "ilner_charlm",
"hu": "combined",
"hy": "armtdp",
"it": "fbk",
"ja": "gsd",
"kk": "kazNERD",
"mr": "l3cube",
"my": "ucsy",
"nb": "norne",
"nl": "conll02",
"nn": "norne",
"pl": "nkjp",
"ru": "wikiner",
"sd": "siner",
"sv": "suc3shuffle",
"te": "ilner_charlm",
"th": "lst20",
"tr": "starlang",
"uk": "languk",
"ur": "ilner_nocharlm",
"vi": "vlsp",
"zh-hans": "ontonotes",
}
# a few languages have sentiment classifier models
default_sentiment = {
"en": "sstplus_charlm",
"de": "sb10k_charlm",
"es": "tass2020_charlm",
"mr": "l3cube_charlm",
"vi": "vsfc_charlm",
"zh-hans": "ren_charlm",
}
# also, a few languages (very few, currently) have constituency parser models
default_constituency = {
"da": "arboretum_charlm",
"de": "spmrl_charlm",
"en": "ptb3-revised_charlm",
"es": "combined_charlm",
"id": "icon_charlm",
"it": "vit_charlm",
"ja": "alt_charlm",
"pt": "cintil_charlm",
#"tr": "starlang_charlm",
"vi": "vlsp22_charlm",
"zh-hans": "ctb-51_charlm",
}
optional_constituency = {
"tr": "starlang_charlm",
}
# an alternate tokenizer for languages which aren't trained from a base UD source
default_tokenizer = {
"my": "alt",
}
# ideally we would have a less expensive model as the base model
#default_coref = {
# "en": "ontonotes_roberta-large_finetuned",
#}
optional_coref = {
"ca": "udcoref_xlm-roberta-lora",
"cs": "udcoref_xlm-roberta-lora",
"de": "udcoref_xlm-roberta-lora",
"en": "udcoref_xlm-roberta-lora",
"es": "udcoref_xlm-roberta-lora",
"fr": "udcoref_xlm-roberta-lora",
"he": "iahlt_xlm-roberta-lora",
"hi": "deeph_muril-large-cased-lora",
# UD Coref has both nb and nn datasets for Norwegian
"nb": "udcoref_xlm-roberta-lora",
"nn": "udcoref_xlm-roberta-lora",
"pl": "udcoref_xlm-roberta-lora",
"ru": "udcoref_xlm-roberta-lora",
"ta": "kbc_muril-large-cased-lora",
}
"""
default transformers to use for various languages
we try to document why we choose a particular model in each case
"""
TRANSFORMERS = {
# We tested three candidate AR models on POS, Depparse, and NER
#
# POS: padt dev set scores, AllTags
# depparse: padt dev set scores, LAS
# NER: dev scores on a random split of AQMAR, entity scores
#
# pos depparse ner
# none (pt & charlm only) 94.08 83.49 84.19
# asafaya/bert-base-arabic 95.10 84.96 85.98
# aubmindlab/bert-base-arabertv2 95.33 85.28 84.93
# aubmindlab/araelectra-base-discriminator 95.66 85.83 86.10
"ar": "aubmindlab/araelectra-base-discriminator",
# https://huggingface.co/Maltehb/danish-bert-botxo
# contrary to normal expectations, this hurts F1
# on a dev split by about 1 F1
# "da": "Maltehb/danish-bert-botxo",
#
# the multilingual bert is a marginal improvement for conparse
#
# December 2022 update:
# there are quite a few Danish transformers available on HuggingFace
# here are the results of training a constituency parser with adadelta/adamw
# on each of them:
#
# no bert 0.8245 0.8230
# alexanderfalk/danbert-small-cased 0.8236 0.8286
# Geotrend/distilbert-base-da-cased 0.8268 0.8306
# sarnikowski/convbert-small-da-cased 0.8322 0.8341
# bert-base-multilingual-cased 0.8341 0.8342
# vesteinn/ScandiBERT-no-faroese 0.8373 0.8408
# Maltehb/danish-bert-botxo 0.8383 0.8408
# vesteinn/ScandiBERT 0.8421 0.8475
#
# Also, two models have token windows too short for use with the
# Danish dataset:
# jonfd/electra-small-nordic
# Maltehb/aelaectra-danish-electra-small-cased
#
"da": "vesteinn/ScandiBERT",
# As of April 2022, the bert models available have a weird
# tokenizer issue where soft hyphen causes it to crash.
# We attempt to compensate for that in the dev branch
#
# NER scores
# model dev text
# xlm-roberta-large 86.56 85.23
# bert-base-german-cased 87.59 86.95
# dbmdz/bert-base-german-cased 88.27 87.47
# german-nlp-group/electra-base-german-uncased 88.60 87.09
#
# constituency scores w/ peft, March 2024 model, in-order
# model dev test
# xlm-roberta-base 95.17 93.34
# xlm-roberta-large 95.86 94.46 (!!!)
# bert-base 95.24 93.24
# dbmdz/bert 95.32 93.33
# german/electra 95.72 94.05
#
# POS scores
# model dev test
# None 88.65 87.28
# xlm-roberta-large 89.21 88.11
# bert-base 89.52 88.42
# dbmdz/bert 89.67 88.54
# german/electra 89.98 88.66
#
# depparse scores, LAS
# model dev test
# None 87.76 84.37
# xlm-roberta-large 89.00 85.79
# bert-base 88.72 85.40
# dbmdz/bert 88.70 85.14
# german/electra 89.21 86.06
"de": "german-nlp-group/electra-base-german-uncased",
# experiments on various forms of roberta & electra
# https://huggingface.co/roberta-base
# https://huggingface.co/roberta-large
# https://huggingface.co/google/electra-small-discriminator
# https://huggingface.co/google/electra-base-discriminator
# https://huggingface.co/google/electra-large-discriminator
#
# experiments using the different models for POS tagging,
# dev set, including WV and charlm, AllTags score:
# roberta-base: 95.67
# roberta-large: 95.98
# electra-small: 95.31
# electra-base: 95.90
# electra-large: 96.01
#
# depparse scores, dev set, no finetuning, with WV and charlm
# UAS LAS CLAS MLAS BLEX
# roberta-base: 93.16 91.20 89.87 89.38 89.87
# roberta-large: 93.47 91.56 90.13 89.71 90.13
# electra-small: 92.17 90.02 88.25 87.66 88.25
# electra-base: 93.42 91.44 90.10 89.67 90.10
# electra-large: 94.07 92.17 90.99 90.53 90.99
#
# conparse scores, dev & test set, with WV and charlm
# roberta_base: 96.05 95.60
# roberta_large: 95.95 95.60
# electra-small: 95.33 95.04
# electra-base: 96.09 95.98
# electra-large: 96.25 96.14
#
# conparse scores w/ finetune, dev & test set, with WV and charlm
# roberta_base: 96.07 95.81
# roberta_large: 96.37 96.41 (!!!)
# electra-small: 95.62 95.36
# electra-base: 96.21 95.94
# electra-large: 96.40 96.32
#
"en": "google/electra-large-discriminator",
# TODO need to test, possibly compare with others
"es": "bertin-project/bertin-roberta-base-spanish",
# NER scores for a couple Persian options:
# none:
# dev: 2022-04-23 01:44:53 INFO: fa_arman 79.46
# test: 2022-04-23 01:45:03 INFO: fa_arman 80.06
#
# HooshvareLab/bert-fa-zwnj-base
# dev: 2022-04-23 02:43:44 INFO: fa_arman 80.87
# test: 2022-04-23 02:44:07 INFO: fa_arman 80.81
#
# HooshvareLab/roberta-fa-zwnj-base
# dev: 2022-04-23 16:23:25 INFO: fa_arman 81.23
# test: 2022-04-23 16:23:48 INFO: fa_arman 81.11
#
# HooshvareLab/bert-base-parsbert-uncased
# dev: 2022-04-26 10:42:09 INFO: fa_arman 82.49
# test: 2022-04-26 10:42:31 INFO: fa_arman 83.16
"fa": 'HooshvareLab/bert-base-parsbert-uncased',
# NER scores for a couple options:
# none:
# dev: 2022-03-04 INFO: fi_turku 83.45
# test: 2022-03-04 INFO: fi_turku 86.25
#
# bert-base-multilingual-cased
# dev: 2022-03-04 INFO: fi_turku 85.23
# test: 2022-03-04 INFO: fi_turku 89.00
#
# TurkuNLP/bert-base-finnish-cased-v1:
# dev: 2022-03-04 INFO: fi_turku 88.41
# test: 2022-03-04 INFO: fi_turku 91.36
"fi": "TurkuNLP/bert-base-finnish-cased-v1",
# POS dev set tagging results for French:
# No bert:
# 98.60 100.00 98.55 98.04
# dbmdz/electra-base-french-europeana-cased-discriminator
# 98.70 100.00 98.69 98.24
# benjamin/roberta-base-wechsel-french
# 98.71 100.00 98.75 98.26
# camembert/camembert-large
# 98.75 100.00 98.75 98.30
# camembert-base
# 98.78 100.00 98.77 98.33
#
# GSD depparse dev set results for French:
# No bert:
# 95.83 94.52 91.34 91.10 91.34
# camembert/camembert-large
# 96.80 95.71 93.37 93.13 93.37
# TODO: the rest of the chart
"fr": "camembert/camembert-large",
# Ancient Greek has a surprising number of transformers, considering
# Model POS Depparse LAS
# None 0.8812 0.7684
# Microbert M 0.8883 0.7706
# Microbert MX 0.8910 0.7755
# Microbert MXP 0.8916 0.7742
# Pranaydeeps Bert 0.9139 0.7987
"grc": "pranaydeeps/Ancient-Greek-BERT",
# a couple possibilities to experiment with for Hebrew
# dev scores for POS and depparse
# https://huggingface.co/imvladikon/alephbertgimmel-base-512
# UPOS XPOS UFeats AllTags
# 97.25 97.25 92.84 91.81
# UAS LAS CLAS MLAS BLEX
# 94.42 92.47 89.49 88.82 89.49
#
# https://huggingface.co/onlplab/alephbert-base
# UPOS XPOS UFeats AllTags
# 97.37 97.37 92.50 91.55
# UAS LAS CLAS MLAS BLEX
# 94.06 92.12 88.80 88.13 88.80
#
# https://huggingface.co/avichr/heBERT
# UPOS XPOS UFeats AllTags
# 97.09 97.09 92.36 91.28
# UAS LAS CLAS MLAS BLEX
# 94.29 92.30 88.99 88.38 88.99
"he": "imvladikon/alephbertgimmel-base-512",
# can also experiment with xlm-roberta
# on a coref dataset from IITH, span F1:
# dev test
# xlm-roberta-large 0.63635 0.66579
# muril-large 0.65369 0.68290
"hi": "google/muril-large-cased",
# https://huggingface.co/xlm-roberta-base
# Scores by entity for armtdp NER on 18 labels:
# no bert : 86.68
# xlm-roberta-base : 89.31
"hy": "xlm-roberta-base",
# Indonesian POS experiments: dev set of GSD
# python3 stanza/utils/training/run_pos.py id_gsd --no_bert
# python3 stanza/utils/training/run_pos.py id_gsd --bert_model ...
# also ran on the ICON constituency dataset
# model POS CON
# no_bert 89.95 84.74
# flax-community/indonesian-roberta-large 89.78 (!) xxx
# flax-community/indonesian-roberta-base 90.14 xxx
# indobenchmark/indobert-base-p2 90.09
# indobenchmark/indobert-base-p1 90.14
# indobenchmark/indobert-large-p1 90.19
# indolem/indobert-base-uncased 90.21 88.60
# cahya/bert-base-indonesian-1.5G 90.32 88.15
# cahya/roberta-base-indonesian-1.5G 90.40 87.27
"id": "indolem/indobert-base-uncased",
# from https://github.com/idb-ita/GilBERTo
# annoyingly, it doesn't handle cased text
# supposedly there is an argument "do_lower_case"
# but that still leaves a lot of unk tokens
# "it": "idb-ita/gilberto-uncased-from-camembert",
#
# from https://github.com/musixmatchresearch/umberto
# on NER, this gets 88.37 dev and 91.02 test
# another option is dbmdz/bert-base-italian-cased,
# which gets 87.27 dev and 90.32 test
#
# in-order constituency parser on the VIT dev set:
# dbmdz/bert-base-italian-cased 0.8079
# dbmdz/bert-base-italian-xxl-cased: 0.8195
# Musixmatch/umberto-commoncrawl-cased-v1: 0.8256
# dbmdz/electra-base-italian-xxl-cased-discriminator: 0.8314
#
# FBK NER dev set:
# dbmdz/bert-base-italian-cased: 87.76
# Musixmatch/umberto-commoncrawl-cased-v1: 88.62
# dbmdz/bert-base-italian-xxl-cased: 88.84
# dbmdz/electra-base-italian-xxl-cased-discriminator: 89.91
#
# combined UD POS dev set: UPOS XPOS UFeats AllTags
# dbmdz/bert-base-italian-cased: 98.62 98.53 98.06 97.49
# dbmdz/bert-base-italian-xxl-cased: 98.61 98.54 98.07 97.58
# dbmdz/electra-base-italian-xxl-cased-discriminator: 98.64 98.54 98.14 97.61
# Musixmatch/umberto-commoncrawl-cased-v1: 98.56 98.45 98.13 97.62
"it": "dbmdz/electra-base-italian-xxl-cased-discriminator",
# for Japanese
# there are others that would also work,
# but they require different tokenizers instead of being
# plug & play
#
# Constitutency scores on ALT (in-order)
# no bert: 90.68 dev, 91.40 test
# rinna: 91.54 dev, 91.89 test
"ja": "rinna/japanese-roberta-base",
# could also try:
# l3cube-pune/marathi-bert-v2
# or
# https://huggingface.co/l3cube-pune/hindi-marathi-dev-roberta
# l3cube-pune/hindi-marathi-dev-roberta
#
# depparse ufal dev scores:
# no transformer 74.89 63.70 57.43 53.01 57.43
# l3cube-pune/marathi-roberta 76.48 66.21 61.20 57.60 61.20
"mr": "l3cube-pune/marathi-roberta",
"or": "google/muril-large-cased",
# https://huggingface.co/allegro/herbert-base-cased
# Scores by entity on the NKJP NER task:
# no bert (dev/test): 88.64/88.75
# herbert-base-cased (dev/test): 91.48/91.02,
# herbert-large-cased (dev/test): 92.25/91.62
# sdadas/polish-roberta-large-v2 (dev/test): 92.66/91.22
"pl": "allegro/herbert-base-cased",
# experiments on the cintil conparse dataset
# ran a variety of transformer settings
# found the following dev set scores after 400 iterations:
# Geotrend/distilbert-base-pt-cased : not plug & play
# no bert: 0.9082
# xlm-roberta-base: 0.9109
# xlm-roberta-large: 0.9254
# adalbertojunior/distilbert-portuguese-cased: 0.9300
# neuralmind/bert-base-portuguese-cased: 0.9307
# neuralmind/bert-large-portuguese-cased: 0.9343
"pt": "neuralmind/bert-large-portuguese-cased",
# hope is actually to build our own using a large text collection
"sd": "google/muril-large-cased",
# Tamil options: quite a few, need to run a bunch of experiments
# dev pos dev depparse las
# no transformer 82.82 69.12
# ai4bharat/indic-bert 82.98 70.47
# lgessler/microbert-tamil-mxp 83.21 69.28
# monsoon-nlp/tamillion 83.37 69.28
# l3cube-pune/tamil-bert 85.27 72.53
# d42kw01f/Tamil-RoBERTa 85.59 70.55
# google/muril-base-cased 85.67 72.68
# google/muril-large-cased 86.30 72.45
#
# should also consider xlm-roberta-large
# updated on UD 2.16 data: dev pos ner
# google/muril-large-cased 86.86 65.08
# xlm-roberta-large 66.28
"ta": "google/muril-large-cased",
"te": "google/muril-large-cased",
# https://huggingface.co/airesearch/wangchanberta-base-att-spm-uncased
# this is clearly better than no transformer on a couple datasets:
#
# TUD dev upos TUD dev depparse LAS
# no transformer 91.26 73.57
# wangchanberta 92.21 76.65
"th": "airesearch/wangchanberta-base-att-spm-uncased",
# https://huggingface.co/dbmdz/bert-base-turkish-128k-cased
# helps the Turkish model quite a bit
"tr": "dbmdz/bert-base-turkish-128k-cased",
"ur": "google/muril-large-cased",
# from https://github.com/VinAIResearch/PhoBERT
# "vi": "vinai/phobert-base",
# using 6 or 7 layers of phobert-large is slightly
# more effective for constituency parsing than
# using 4 layers of phobert-base
# ... going beyond 4 layers of phobert-base
# does not help the scores
"vi": "vinai/phobert-large",
# https://github.com/ymcui/Chinese-BERT-wwm
# there's also hfl/chinese-roberta-wwm-ext-large
# or hfl/chinese-electra-base-discriminator
# or hfl/chinese-electra-180g-large-discriminator,
# which works better than the below roberta on constituency
# "zh-hans": "hfl/chinese-roberta-wwm-ext",
# conparse dev scores (averaged over 5):
# google bert: 0.9422
# hfl bert: 0.9469
# hfl roberta: 0.9459
# hfl electra: 0.9515
# hfl macbert: 0.9530
# There is also a ShannonAI model, but our current codebase is
# somehow not compatible
# further comparing HFL:
# POS dev Depparse dev LAS NER dev
# HFL Electra 96.90 85.66 77.90
# HFL Macbert 96.53 84.72 78.46
# "zh-hans": "hfl/chinese-macbert-large",
"zh-hans": "hfl/chinese-electra-180g-large-discriminator",
}
TRANSFORMER_LAYERS = {
# not clear what the best number is without more experiments,
# but more than 4 is working better than just 4
"vi": 7,
}
TRANSFORMER_NICKNAMES = {
# ar
"asafaya/bert-base-arabic": "asafaya-bert",
"aubmindlab/araelectra-base-discriminator": "aubmind-electra",
"aubmindlab/bert-base-arabertv2": "aubmind-bert",
# da
"vesteinn/ScandiBERT": "scandibert",
# de
"bert-base-german-cased": "bert-base-german-cased",
"dbmdz/bert-base-german-cased": "dbmdz-bert-german-cased",
"german-nlp-group/electra-base-german-uncased": "german-nlp-electra",
# en
"bert-base-multilingual-cased": "mbert",
"xlm-roberta-large": "xlm-roberta-large",
"google/electra-large-discriminator": "electra-large",
"microsoft/deberta-v3-large": "deberta-v3-large",
"princeton-nlp/Sheared-LLaMA-1.3B": "sheared-llama-1b3",
# es
"bertin-project/bertin-roberta-base-spanish": "bertin-roberta",
# fa
"HooshvareLab/bert-base-parsbert-uncased": "parsbert",
# fi
"TurkuNLP/bert-base-finnish-cased-v1": "bert",
# fr
"benjamin/roberta-base-wechsel-french": "wechsel-roberta",
"camembert-base": "camembert-base",
"camembert/camembert-large": "camembert-large",
"dbmdz/electra-base-french-europeana-cased-discriminator": "dbmdz-electra",
# grc
"pranaydeeps/Ancient-Greek-BERT": "grc-pranaydeeps",
"lgessler/microbert-ancient-greek-m": "grc-microbert-m",
"lgessler/microbert-ancient-greek-mx": "grc-microbert-mx",
"lgessler/microbert-ancient-greek-mxp": "grc-microbert-mxp",
"altsoph/bert-base-ancientgreek-uncased": "grc-altsoph",
# he
"HeNLP/HeRo": "hero-roberta",
"imvladikon/alephbertgimmel-base-512": "alephbertgimmel",
"onlplab/alephbert-base": "alephbert",
# hy
"xlm-roberta-base": "xlm-roberta-base",
# id
"indolem/indobert-base-uncased": "indobert",
"indobenchmark/indobert-large-p1": "indobenchmark-large-p1",
"indobenchmark/indobert-base-p1": "indobenchmark-base-p1",
"indobenchmark/indobert-lite-large-p1": "indobenchmark-lite-large-p1",
"indobenchmark/indobert-lite-base-p1": "indobenchmark-lite-base-p1",
"indobenchmark/indobert-large-p2": "indobenchmark-large-p2",
"indobenchmark/indobert-base-p2": "indobenchmark-base-p2",
"indobenchmark/indobert-lite-large-p2": "indobenchmark-lite-large-p2",
"indobenchmark/indobert-lite-base-p2": "indobenchmark-lite-base-p2",
# it
"dbmdz/electra-base-italian-xxl-cased-discriminator": "electra",
# ja
"rinna/japanese-roberta-base": "rinna-roberta",
# mr
"l3cube-pune/marathi-roberta": "l3cube-marathi-roberta",
# pl
"allegro/herbert-base-cased": "herbert",
# pt
"neuralmind/bert-large-portuguese-cased": "bertimbau",
# ta: tamil
"monsoon-nlp/tamillion": "tamillion",
"lgessler/microbert-tamil-m": "ta-microbert-m",
"lgessler/microbert-tamil-mxp": "ta-microbert-mxp",
"l3cube-pune/tamil-bert": "l3cube-tamil-bert",
"d42kw01f/Tamil-RoBERTa": "ta-d42kw01f-roberta",
# th
"airesearch/wangchanberta-base-att-spm-uncased": "wangchanberta",
# tr
"dbmdz/bert-base-turkish-128k-cased": "bert",
# vi
"vinai/phobert-base": "phobert-base",
"vinai/phobert-large": "phobert-large",
# zh
"google-bert/bert-base-chinese": "google-bert-chinese",
"hfl/chinese-bert-wwm": "hfl-bert-chinese",
"hfl/chinese-macbert-large": "hfl-macbert-chinese",
"hfl/chinese-roberta-wwm-ext": "hfl-roberta-chinese",
"hfl/chinese-electra-180g-large-discriminator": "electra-large",
"ShannonAI/ChineseBERT-base": "shannonai-chinese-bert",
# multi-lingual Indic
"ai4bharat/indic-bert": "indic-bert",
"google/muril-base-cased": "muril-base-cased",
"google/muril-large-cased": "muril-large-cased",
# multi-lingual
"FacebookAI/xlm-roberta-large": "xlm-roberta-large",
}
def known_nicknames():
"""
Return a list of all the transformer nicknames
We return a list so that we can sort them in decreasing key length
"""
nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items())
# previously unspecific transformers get "transformer" as the nickname
nicknames.append("transformer")
nicknames = sorted(nicknames, key=lambda x: -len(x))
return nicknames
================================================
FILE: stanza/resources/installation.py
================================================
"""
Functions for setting up the environments.
"""
import os
import logging
import zipfile
import shutil
from stanza.resources.common import USER_CACHE_DIR, request_file, unzip, \
get_root_from_zipfile, set_logging_level
logger = logging.getLogger('stanza')
DEFAULT_CORENLP_MODEL_URL = os.getenv(
'CORENLP_MODEL_URL',
'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar'
)
BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar"
DEFAULT_CORENLP_URL = os.getenv(
'CORENLP_MODEL_URL',
'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip'
)
DEFAULT_CORENLP_DIR = os.getenv(
'CORENLP_HOME',
os.path.join(USER_CACHE_DIR, 'corenlp')
)
AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'])
def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None, force=True):
"""
A automatic way to download the CoreNLP models.
Args:
model: the name of the model, can be one of 'arabic', 'chinese', 'english',
'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'
version: the version of the model
dir: the directory to download CoreNLP model into; alternatively can be
set up with environment variable $CORENLP_HOME
url: The link to download CoreNLP models.
It will need {model} and either {version} or {tag} to properly format the URL
logging_level: logging level to use during installation
force: Download model anyway, no matter model file exists or not
"""
dir = os.path.expanduser(dir)
if not model or not version:
raise ValueError(
"Both model and model version should be specified."
)
logger.info(f"Downloading {model} models (version {version}) into directory {dir}")
model = model.strip().lower()
if model not in AVAILABLE_MODELS:
raise KeyError(
f'{model} is currently not supported. '
f'Must be one of: {list(AVAILABLE_MODELS)}.'
)
# for example:
# https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar
tag = version if version == 'main' else 'v' + version
download_url = url.format(tag=tag, model=model, version=version)
model_path = os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar')
if os.path.exists(model_path) and not force:
logger.warn(
f"Model file {model_path} already exists. "
f"Please download this model to a new directory.")
return
try:
request_file(
download_url,
model_path,
proxies
)
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
raise RuntimeError(
"Downloading CoreNLP model file failed. "
"Please try manual downloading at: https://stanfordnlp.github.io/CoreNLP/."
) from e
def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"):
"""
A fully automatic way to install and setting up the CoreNLP library
to use the client functionality.
Args:
dir: the directory to download CoreNLP model into; alternatively can be
set up with environment variable $CORENLP_HOME
url: The link to download CoreNLP models
Needs a {version} or {tag} parameter to specify the version
logging_level: logging level to use during installation
"""
dir = os.path.expanduser(dir)
set_logging_level(logging_level=logging_level, verbose=None)
if os.path.exists(dir) and len(os.listdir(dir)) > 0:
logger.warn(
f"Directory {dir} already exists. "
f"Please install CoreNLP to a new directory.")
return
logger.info(f"Installing CoreNLP package into {dir}")
# First download the URL package
logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}")
tag = version if version == 'main' else 'v' + version
url = url.format(version=version, tag=tag)
try:
request_file(url, os.path.join(dir, 'corenlp.zip'), proxies)
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
raise RuntimeError(
"Downloading CoreNLP zip file failed. "
"Please try manual installation: https://stanfordnlp.github.io/CoreNLP/."
) from e
# Unzip corenlp into dir
logger.debug("Unzipping downloaded zip file...")
unzip(dir, 'corenlp.zip')
# By default CoreNLP will be unzipped into a version-dependent folder,
# e.g., stanford-corenlp-4.0.0. We need some hack around that and move
# files back into our designated folder
logger.debug(f"Moving files into the designated folder at: {dir}")
corenlp_dirname = get_root_from_zipfile(os.path.join(dir, 'corenlp.zip'))
corenlp_dirname = os.path.join(dir, corenlp_dirname)
for f in os.listdir(corenlp_dirname):
shutil.move(os.path.join(corenlp_dirname, f), dir)
# Remove original zip and folder
logger.debug("Removing downloaded zip file...")
os.remove(os.path.join(dir, 'corenlp.zip'))
shutil.rmtree(corenlp_dirname)
# Warn user to set up env
if dir != DEFAULT_CORENLP_DIR:
logger.warning(
f"For customized installation location, please set the `CORENLP_HOME` "
f"environment variable to the location of the installation. "
f"In Unix, this is done with `export CORENLP_HOME={dir}`.")
================================================
FILE: stanza/resources/prepare_resources.py
================================================
"""
Converts a directory of models organized by type into a directory organized by language.
Also produces the resources.json file.
For example, on the cluster, you can do this:
python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0 > resources.out 2>&1
nlprun -a stanza-1.2 -q john "python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0" -o resources.out
"""
import argparse
from collections import defaultdict
import json
import os
from pathlib import Path
import hashlib
import shutil
import zipfile
from stanza import __resources_version__
from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters, extra_lang_to_lcodes
from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES
from stanza.resources.default_packages import *
from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS
from stanza.utils.get_tqdm import get_tqdm
tqdm = get_tqdm()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/models/current-models-%s" % __resources_version__, help='Input dir for various models. Defaults to the recommended home on the nlp cluster')
parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/models/%s" % __resources_version__, help='Output dir for various models.')
parser.add_argument('--packages_only', action='store_true', default=False, help='Only build the package maps instead of rebuilding everything')
parser.add_argument('--lang', type=str, default=None, help='Only process this language or a comma-separated list of languages. If left blank, will prepare all languages. To use this argument, a previous prepared resources with all of the languages is necessary.')
args = parser.parse_args()
args.input_dir = os.path.abspath(args.input_dir)
args.output_dir = os.path.abspath(args.output_dir)
if args.lang is not None:
args.lang = ",".join(args.lang.strip().split())
return args
allowed_empty_languages = [
# only tokenize and NER for Myanmar right now (soon...)
"my",
# currently only an NER, not even a tokenizer, for Oriya
"or",
]
# map processor name to file ending
# the order of this dict determines the order in which default.zip files are built
# changing it will necessitate rebuilding all of the default.zip files
# not a disaster, but it would involve a bunch of uploading
processor_to_ending = {
"tokenize": "tokenizer",
"mwt": "mwt_expander",
"lemma": "lemmatizer",
"pos": "tagger",
"depparse": "parser",
"pretrain": "pretrain",
"ner": "nertagger",
"forward_charlm": "forward_charlm",
"backward_charlm": "backward_charlm",
"sentiment": "sentiment",
"constituency": "constituency",
"coref": "coref",
"langid": "langid",
}
ending_to_processor = {j: i for i, j in processor_to_ending.items()}
PROCESSORS = list(processor_to_ending.keys())
def ensure_dir(dir):
Path(dir).mkdir(parents=True, exist_ok=True)
def copy_file(src, dst):
ensure_dir(Path(dst).parent)
shutil.copy2(src, dst)
def get_md5(path):
data = open(path, 'rb').read()
return hashlib.md5(data).hexdigest()
def split_model_name(model):
"""
Split model names by _
Takes into account packages with _ and processor types with _
"""
model = model[:-3].replace('.', '_')
# sort by key length so that nertagger is checked before tagger, for example
for processor in sorted(ending_to_processor.keys(), key=lambda x: -len(x)):
if model.endswith(processor):
model = model[:-(len(processor)+1)]
processor = ending_to_processor[processor]
break
else:
raise AssertionError(f"Could not find a processor type in {model}")
lang, package = model.split('_', 1)
return lang, package, processor
def split_package(package, default_use_charlm=True):
if package.endswith("_finetuned"):
package = package[:-10]
if package.endswith("_nopretrain"):
package = package[:-11]
return package, False, False
if package.endswith("_nocharlm"):
package = package[:-9]
return package, True, False
if package.endswith("_charlm"):
package = package[:-7]
return package, True, True
underscore = package.rfind("_")
if underscore >= 0:
# +1 to skip the underscore
nickname = package[underscore+1:]
if nickname in known_nicknames():
return package[:underscore], True, True
# guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end
# assume WV and charlm... if the language / package doesn't allow for one, that should be caught later
return package, True, default_use_charlm
def get_pretrain_package(lang, package, model_pretrains, default_pretrains):
package, uses_pretrain, _ = split_package(package)
if not uses_pretrain or lang in no_pretrain_languages:
return None
elif model_pretrains is not None and lang in model_pretrains and package in model_pretrains[lang]:
return model_pretrains[lang][package]
elif lang in default_pretrains:
return default_pretrains[lang]
raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package))
def get_charlm_package(lang, package, model_charlms, default_charlms, default_use_charlm=True):
package, _, uses_charlm = split_package(package, default_use_charlm)
if not uses_charlm:
return None
if model_charlms is not None and lang in model_charlms and package in model_charlms[lang]:
return model_charlms[lang][package]
else:
return default_charlms.get(lang, None)
def get_con_dependencies(lang, package):
# so far, this invariant is true:
# constituency models use the default pretrain and charlm for the language
# sometimes there is no charlm for a language that has constituency, though
pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
charlm_package = default_charlms.get(lang, None)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_pos_charlm_package(lang, package):
return get_charlm_package(lang, package, pos_charlms, default_charlms)
def get_pos_dependencies(lang, package):
dependencies = []
pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains)
if pretrain_package is not None:
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
charlm_package = get_pos_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_lemma_pretrain_package(lang, package):
package, uses_pretrain, uses_charlm = split_package(package)
if not uses_pretrain:
return None
if not uses_charlm:
# currently the contextual lemma classifier is only active
# for the charlm lemmatizers
return None
if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS:
return None
return get_pretrain_package(lang, package, {}, default_pretrains)
def get_lemma_charlm_package(lang, package):
return get_charlm_package(lang, package, lemma_charlms, default_charlms)
def get_lemma_dependencies(lang, package):
dependencies = []
pretrain_package = get_lemma_pretrain_package(lang, package)
if pretrain_package is not None:
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
charlm_package = get_lemma_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_tokenizer_charlm_package(lang, package):
return get_charlm_package(lang, package, tokenizer_charlms, default_charlms, default_use_charlm=False)
def get_tokenizer_dependencies(lang, package):
dependencies = []
charlm_package = get_tokenizer_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
return dependencies
def get_depparse_charlm_package(lang, package):
return get_charlm_package(lang, package, depparse_charlms, default_charlms)
def get_depparse_dependencies(lang, package):
dependencies = []
pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains)
if pretrain_package is not None:
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
charlm_package = get_depparse_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_ner_charlm_package(lang, package):
return get_charlm_package(lang, package, ner_charlms, default_charlms)
def get_ner_pretrain_package(lang, package):
return get_pretrain_package(lang, package, ner_pretrains, default_pretrains)
def get_ner_dependencies(lang, package):
dependencies = []
pretrain_package = get_ner_pretrain_package(lang, package)
if pretrain_package is not None:
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
charlm_package = get_ner_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_sentiment_dependencies(lang, package):
"""
Return a list of dependencies for the sentiment model
Generally this will be pretrain, forward & backward charlm
So far, this invariant is true:
sentiment models use the default pretrain for the language
also, they all use the default charlm for a language
"""
pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
charlm_package = default_charlms.get(lang, None)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
return dependencies
def get_dependencies(processor, lang, package):
"""
Get the dependencies for a particular lang/package based on the package name
The package can include descriptors such as _nopretrain, _nocharlm, _charlm
which inform whether or not this particular model uses charlm or pretrain
"""
if processor == 'depparse':
return get_depparse_dependencies(lang, package)
elif processor == 'lemma':
return get_lemma_dependencies(lang, package)
elif processor == 'pos':
return get_pos_dependencies(lang, package)
elif processor == 'ner':
return get_ner_dependencies(lang, package)
elif processor == 'sentiment':
return get_sentiment_dependencies(lang, package)
elif processor == 'constituency':
return get_con_dependencies(lang, package)
elif processor == 'tokenize':
return get_tokenizer_dependencies(lang, package)
return {}
def process_dirs(args):
dirs = sorted(os.listdir(args.input_dir))
resources = {}
if args.lang:
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
# this one language gets overridden
# if this is not done, and we reuse the old resources,
# any models which were deleted will still be in the resources
for lang in args.lang.split(","):
resources[lang] = {}
for model_dir in dirs:
print(f"Processing models in {model_dir}")
models = sorted(os.listdir(os.path.join(args.input_dir, model_dir)))
for model in tqdm(models):
if not model.endswith('.pt'): continue
# get processor
lang, package, processor = split_model_name(model)
if args.lang and lang not in args.lang.split(","):
continue
# copy file
input_path = os.path.join(args.input_dir, model_dir, model)
output_path = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
copy_file(input_path, output_path)
# maintain md5
md5 = get_md5(output_path)
# maintain dependencies
dependencies = get_dependencies(processor, lang, package)
# maintain resources
if lang not in resources: resources[lang] = {}
if processor not in resources[lang]: resources[lang][processor] = {}
if dependencies:
resources[lang][processor][package] = {'md5': md5, 'dependencies': dependencies}
else:
resources[lang][processor][package] = {'md5': md5}
print("Processed initial model directories. Writing preliminary resources.json")
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
def get_default_pos_package(lang, ud_package):
charlm_package = get_pos_charlm_package(lang, ud_package)
if charlm_package is not None:
return ud_package + "_charlm"
if lang in no_pretrain_languages:
return ud_package + "_nopretrain"
return ud_package + "_nocharlm"
def get_default_depparse_package(lang, ud_package):
charlm_package = get_depparse_charlm_package(lang, ud_package)
if charlm_package is not None:
return ud_package + "_charlm"
if lang in no_pretrain_languages:
return ud_package + "_nopretrain"
return ud_package + "_nocharlm"
def process_default_zips(args):
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
for lang in resources:
# check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
if lang == 'url':
continue
if 'alias' in resources[lang]:
continue
if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
continue
if lang in allowed_empty_languages and lang not in default_treebanks:
continue
if lang not in default_treebanks:
raise AssertionError(f'{lang} not in default treebanks!!!')
if args.lang and lang not in args.lang.split(","):
continue
print(f'Preparing default models for language {lang}')
models_needed = defaultdict(set)
packages = resources[lang][PACKAGES]["default"]
for processor, package in packages.items():
if processor == 'lemma' and package == 'identity':
continue
if processor == 'optional':
continue
models_needed[processor].add(package)
dependencies = get_dependencies(processor, lang, package)
for dependency in dependencies:
models_needed[dependency['model']].add(dependency['package'])
model_files = []
for processor in PROCESSORS:
if processor in models_needed:
for package in sorted(models_needed[processor]):
filename = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
if os.path.exists(filename):
print(" Model {} package {}: file {}".format(processor, package, filename))
model_files.append((filename, processor, package))
else:
raise FileNotFoundError(f"Processor {processor} package {package} needed for {lang} but cannot be found at {filename}")
with zipfile.ZipFile(os.path.join(args.output_dir, lang, 'models', 'default.zip'), 'w', zipfile.ZIP_DEFLATED) as zipf:
for filename, processor, package in model_files:
zipf.write(filename=filename, arcname=os.path.join(processor, package + '.pt'))
default_md5 = get_md5(os.path.join(args.output_dir, lang, 'models', 'default.zip'))
resources[lang]['default_md5'] = default_md5
print("Processed default model zips. Writing resources.json")
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
def get_default_processors(resources, lang):
"""
Build a default package for this language
Will add each of pos, lemma, depparse, etc if those are available
Uses the existing models scraped from the language directories into resources.json, as relevant
"""
if lang == "multilingual":
return {"langid": "ud"}
default_package = default_treebanks[lang]
default_processors = {}
if lang in default_tokenizer:
default_processors['tokenize'] = default_tokenizer[lang]
else:
tokenize_package = default_package
if tokenize_package not in resources[lang]['tokenize']:
tokenize_package = tokenize_package + "_nocharlm"
if tokenize_package not in resources[lang]['tokenize']:
raise AssertionError("Can't find a tokenizer package for %s! Tried %s and %s" % (lang, default_package, tokenize_package))
default_processors['tokenize'] = tokenize_package
if 'mwt' in resources[lang] and default_package in resources[lang]['mwt']:
# if this doesn't happen, we just skip MWT
default_processors['mwt'] = default_package
if 'lemma' in resources[lang]:
expected_lemma = default_package + "_nocharlm"
if expected_lemma in resources[lang]['lemma']:
default_processors['lemma'] = expected_lemma
else:
expected_lemma = default_package + "_charlm"
if expected_lemma in resources[lang]['lemma']:
default_processors['lemma'] = expected_lemma
print("WARNING: nocharlm lemmatizer for %s model does not exist, but %s does" % (default_package, expected_lemma))
elif lang not in allowed_empty_languages:
default_processors['lemma'] = 'identity'
if 'pos' in resources[lang]:
default_processors['pos'] = get_default_pos_package(lang, default_package)
if default_processors['pos'] not in resources[lang]['pos']:
raise AssertionError("Expected POS model not in resources: %s" % default_processors['pos'])
elif lang not in allowed_empty_languages:
raise AssertionError("Expected to find POS models for language %s" % lang)
if 'depparse' in resources[lang]:
default_processors['depparse'] = get_default_depparse_package(lang, default_package)
if default_processors['depparse'] not in resources[lang]['depparse']:
raise AssertionError("Expected depparse model not in resources: %s" % default_processors['depparse'])
elif lang not in allowed_empty_languages:
raise AssertionError("Expected to find depparse models for language %s" % lang)
if lang in default_ners:
default_processors['ner'] = default_ners[lang]
if lang in default_sentiment:
default_processors['sentiment'] = default_sentiment[lang]
if lang in default_constituency:
default_processors['constituency'] = default_constituency[lang]
optional = get_default_optional_processors(resources, lang)
if optional:
default_processors['optional'] = optional
return default_processors
def get_default_optional_processors(resources, lang):
optional_processors = {}
if lang in optional_constituency:
optional_processors['constituency'] = optional_constituency[lang]
if lang in optional_coref:
optional_processors['coref'] = optional_coref[lang]
return optional_processors
def update_processor_add_transformer(resources, lang, current_processors, processor, transformer):
if processor not in current_processors:
return
new_model = current_processors[processor].replace('_charlm', "_" + transformer).replace('_nocharlm', "_" + transformer)
if new_model in resources[lang][processor]:
current_processors[processor] = new_model
else:
print("WARNING: wanted to use %s for %s accurate %s, but that model does not exist" % (new_model, lang, processor))
def get_default_accurate(resources, lang):
"""
A package that, if available, uses charlm and transformer models for each processor
"""
default_processors = get_default_processors(resources, lang)
tokenizer_model = default_processors['tokenize']
if tokenizer_model.endswith('_nocharlm'):
tokenizer_model = tokenizer_model.replace('_nocharlm', '_charlm')
elif 'charlm' not in tokenizer_model:
tokenizer_model = tokenizer_model + '_charlm'
if tokenizer_model.endswith('_charlm') and tokenizer_model in resources[lang]['tokenize']:
default_processors['tokenize'] = tokenizer_model
print("TOKENIZE found a charlm version %s for %s default_accurate" % (tokenizer_model, lang))
if 'lemma' in default_processors and default_processors['lemma'] != 'identity':
lemma_model = default_processors['lemma']
lemma_model = lemma_model.replace('_nocharlm', '_charlm')
charlm_package = get_lemma_charlm_package(lang, lemma_model)
if charlm_package is not None:
if lemma_model in resources[lang]['lemma']:
default_processors['lemma'] = lemma_model
else:
print("WARNING: wanted to use %s for %s default_accurate lemma, but that model does not exist" % (lemma_model, lang))
transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
if transformer is not None:
for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
update_processor_add_transformer(resources, lang, default_processors, processor, transformer)
if 'ner' in default_processors and (default_processors['ner'].endswith("_charlm") or default_processors['ner'].endswith("_nocharlm")):
update_processor_add_transformer(resources, lang, default_processors, "ner", transformer)
optional = get_optional_accurate(resources, lang)
if optional:
default_processors['optional'] = optional
return default_processors
def get_optional_accurate(resources, lang):
optional_processors = get_default_optional_processors(resources, lang)
transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
if transformer is not None:
for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
update_processor_add_transformer(resources, lang, optional_processors, processor, transformer)
if lang in optional_coref:
optional_processors['coref'] = optional_coref[lang]
return optional_processors
def get_default_fast(resources, lang):
"""
Build a packages entry which only has the nocharlm models
Will make it easy for people to use the lower tier of models
We do this by building the same default package as normal,
then switching everything out for the lower tier model when possible.
We also remove constituency, as it is super slow.
Note that in the case of a language which doesn't have a charlm,
that means we wind up building the same for default and default_nocharlm
"""
default_processors = get_default_processors(resources, lang)
# this is a slow model and we don't have non-charlm versions of it yet
if 'constituency' in default_processors:
default_processors.pop('constituency')
for processor, model in default_processors.items():
if "_charlm" in model:
nocharlm = model.replace("_charlm", "_nocharlm")
if nocharlm not in resources[lang][processor]:
print("WARNING: wanted to use %s for %s default_fast processor %s, but that model does not exist" % (nocharlm, lang, processor))
else:
default_processors[processor] = nocharlm
return default_processors
def process_packages(args):
"""
Build a package for a language's default processors and all of the treebanks specifically used for that language
"""
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
for lang in resources:
# check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
if lang == 'url':
continue
if 'alias' in resources[lang]:
continue
if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
continue
if lang in allowed_empty_languages and lang not in default_treebanks:
continue
if lang not in default_treebanks:
raise AssertionError(f'{lang} not in default treebanks!!!')
if args.lang and lang not in args.lang.split(","):
continue
default_processors = get_default_processors(resources, lang)
# TODO: eventually we can remove default_processors
# For now, we want to keep this so that v1.5.1 is compatible
# with the next iteration of resources files
resources[lang]['default_processors'] = default_processors
resources[lang][PACKAGES] = {}
resources[lang][PACKAGES]['default'] = default_processors
if lang not in no_pretrain_languages and lang != "multilingual":
default_fast = get_default_fast(resources, lang)
resources[lang][PACKAGES]['default_fast'] = default_fast
default_accurate = get_default_accurate(resources, lang)
resources[lang][PACKAGES]['default_accurate'] = default_accurate
# Now we loop over each of the tokenizers for this language
# ... we use this as a proxy for the available UD treebanks
# This loop also catches things such as "craft" which are
# included treebanks that aren't UD
# We then create a package in the packages dict for each of those treebanks
if 'tokenize' in resources[lang]:
for package in resources[lang]['tokenize']:
package, _, _ = split_package(package)
if package in resources[lang][PACKAGES]:
# can happen in the case of a _nocharlm and _charlm version of the tokenizer
continue
processors = {}
# TODO: when we rebuild all the models, make all the tokenizers say _nocharlm
if package in resources[lang]['tokenize']:
processors["tokenize"] = package
elif package + "_nocharlm" in resources[lang]['tokenize']:
processors["tokenize"] = package + "_nocharlm"
else:
raise AssertionError("Should have found a tokenizer for lang %s package %s" % (lang, package))
if "mwt" in resources[lang] and package in resources[lang]["mwt"]:
processors["mwt"] = package
if "pos" in resources[lang]:
if package + "_charlm" in resources[lang]["pos"]:
processors["pos"] = package + "_charlm"
elif package + "_nocharlm" in resources[lang]["pos"]:
processors["pos"] = package + "_nocharlm"
if "lemma" in resources[lang] and "pos" in processors:
lemma_package = package + "_nocharlm"
if lemma_package in resources[lang]["lemma"]:
processors["lemma"] = lemma_package
else:
lemma_package = package + "_charlm"
if lemma_package in resources[lang]['lemma']:
processors['lemma'] = lemma_package
print("WARNING: nocharlm lemmatizer for %s model does not exist, but %s does" % (package, lemma_package))
if "depparse" in resources[lang] and "pos" in processors:
depparse_package = None
if package + "_charlm" in resources[lang]["depparse"]:
depparse_package = package + "_charlm"
elif package + "_nocharlm" in resources[lang]["depparse"]:
depparse_package = package + "_nocharlm"
# we want to set the lemma first if it's identity
# THEN set the depparse
if depparse_package is not None:
if "lemma" not in processors:
processors["lemma"] = "identity"
processors["depparse"] = depparse_package
resources[lang][PACKAGES][package] = processors
print("Processed packages. Writing resources.json")
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
def process_lcode(args):
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
resources_new = {}
resources_new["multilingual"] = resources["multilingual"]
for lang in resources:
if lang == 'multilingual':
continue
if 'alias' in resources[lang]:
continue
if lang not in lcode2lang:
print(lang + ' not found in lcode2lang!')
continue
lang_name = lcode2lang[lang]
resources[lang]['lang_name'] = lang_name
resources_new[lang.lower()] = resources[lang.lower()]
resources_new[lang_name.lower()] = {'alias': lang.lower()}
if lang.lower() in two_to_three_letters:
resources_new[two_to_three_letters[lang.lower()]] = {'alias': lang.lower()}
elif lang.lower() in three_to_two_letters:
resources_new[three_to_two_letters[lang.lower()]] = {'alias': lang.lower()}
if lang.lower() in extra_lang_to_lcodes:
alternative = extra_lang_to_lcodes[lang.lower()].lower()
if alternative not in resources_new:
resources_new[alternative] = {'alias': lang.lower()}
print("Processed lcode aliases. Writing resources.json")
json.dump(resources_new, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
def process_misc(args):
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
resources['no'] = {'alias': 'nb'}
resources['zh'] = {'alias': 'zh-hans'}
# This is intended to be unformatted. expand_model_url in common.py will fill in the raw string
# with the appropriate values in order to find the needed model file on huggingface
resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}'
print("Finalized misc attributes. Writing resources.json")
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
def main():
args = parse_args()
print("Converting models from %s to %s" % (args.input_dir, args.output_dir))
if not args.packages_only:
process_dirs(args)
process_packages(args)
if not args.packages_only:
process_default_zips(args)
process_lcode(args)
process_misc(args)
if __name__ == '__main__':
main()
================================================
FILE: stanza/resources/print_charlm_depparse.py
================================================
"""
A small utility script to output which depparse models use charlm
(It should skip en_genia, en_craft, but currently doesn't)
Not frequently useful, but seems like the kind of thing that might get used a couple times
"""
from stanza.resources.common import load_resources_json
from stanza.resources.default_packages import default_charlms, depparse_charlms
def list_depparse():
charlm_langs = list(default_charlms.keys())
resources = load_resources_json()
models = ["%s_%s" % (lang, model) for lang in charlm_langs for model in resources[lang].get("depparse", {})
if lang not in depparse_charlms or model not in depparse_charlms[lang] or depparse_charlms[lang][model] is not None]
return models
if __name__ == "__main__":
models = list_depparse()
print(" ".join(models))
================================================
FILE: stanza/server/__init__.py
================================================
from stanza.protobuf import to_text
from stanza.protobuf import Document, Sentence, Token, IndexedWord, Span
from stanza.protobuf import ParseTree, DependencyGraph, CorefChain
from stanza.protobuf import Mention, NERMention, Entity, Relation, RelationTriple, Timex
from stanza.protobuf import Quote, SpeakerInfo
from stanza.protobuf import Operator, Polarity
from stanza.protobuf import SentenceFragment, TokenLocation
from stanza.protobuf import MapStringString, MapIntString
from .client import CoreNLPClient, AnnotationException, TimeoutException, PermanentlyFailedException, StartServer
from .annotator import Annotator
================================================
FILE: stanza/server/annotator.py
================================================
"""
Defines a base class that can be used to annotate.
"""
import io
from multiprocessing import Process
from http.server import BaseHTTPRequestHandler, HTTPServer
from http import client as HTTPStatus
from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString
class Annotator(Process):
"""
This annotator base class hosts a lightweight server that accepts
annotation requests from CoreNLP.
Each annotator simply defines 3 functions: requires, provides and annotate.
This class takes care of defining appropriate endpoints to interface
with CoreNLP.
"""
@property
def name(self):
"""
Name of the annotator (used by CoreNLP)
"""
raise NotImplementedError()
@property
def requires(self):
"""
Requires has to specify all the annotations required before we
are called.
"""
raise NotImplementedError()
@property
def provides(self):
"""
The set of annotations guaranteed to be provided when we are done.
NOTE: that these annotations are either fully qualified Java
class names or refer to nested classes of
edu.stanford.nlp.ling.CoreAnnotations (as is the case below).
"""
raise NotImplementedError()
def annotate(self, ann):
"""
@ann: is a protobuf annotation object.
Actually populate @ann with tokens.
"""
raise NotImplementedError()
@property
def properties(self):
"""
Defines a Java property to define this annotator to CoreNLP.
"""
return {
"customAnnotatorClass.{}".format(self.name): "edu.stanford.nlp.pipeline.GenericWebServiceAnnotator",
"generic.endpoint": "http://{}:{}".format(self.host, self.port),
"generic.requires": ",".join(self.requires),
"generic.provides": ",".join(self.provides),
}
class _Handler(BaseHTTPRequestHandler):
annotator = None
def __init__(self, request, client_address, server):
BaseHTTPRequestHandler.__init__(self, request, client_address, server)
def do_GET(self):
"""
Handle a ping request
"""
if not self.path.endswith("/"): self.path += "/"
if self.path == "/ping/":
msg = "pong".encode("UTF-8")
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "text/application")
self.send_header("Content-Length", len(msg))
self.end_headers()
self.wfile.write(msg)
else:
self.send_response(HTTPStatus.BAD_REQUEST)
self.end_headers()
def do_POST(self):
"""
Handle an annotate request
"""
if not self.path.endswith("/"): self.path += "/"
if self.path == "/annotate/":
# Read message
length = int(self.headers.get('content-length'))
msg = self.rfile.read(length)
# Do the annotation
doc = Document()
parseFromDelimitedString(doc, msg)
self.annotator.annotate(doc)
with io.BytesIO() as stream:
writeToDelimitedString(doc, stream)
msg = stream.getvalue()
# write message
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "application/x-protobuf")
self.send_header("Content-Length", len(msg))
self.end_headers()
self.wfile.write(msg)
else:
self.send_response(HTTPStatus.BAD_REQUEST)
self.end_headers()
def __init__(self, host="", port=8432):
"""
Launches a server endpoint to communicate with CoreNLP
"""
Process.__init__(self)
self.host, self.port = host, port
self._Handler.annotator = self
def run(self):
"""
Runs the server using Python's simple HTTPServer.
TODO: make this multithreaded.
"""
httpd = HTTPServer((self.host, self.port), self._Handler)
sa = httpd.socket.getsockname()
serve_message = "Serving HTTP on {host} port {port} (http://{host}:{port}/) ..."
print(serve_message.format(host=sa[0], port=sa[1]))
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nKeyboard interrupt received, exiting.")
httpd.shutdown()
================================================
FILE: stanza/server/client.py
================================================
"""
Client for accessing Stanford CoreNLP in Python
"""
import atexit
import contextlib
import enum
import io
import os
import re
import requests
import logging
import json
import shlex
import socket
import subprocess
import time
import sys
import uuid
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString, to_text
__author__ = 'arunchaganty, kelvinguu, vzhong, wmonroe4'
logger = logging.getLogger('stanza')
# pattern tmp props file should follow
SERVER_PROPS_TMP_FILE_PATTERN = re.compile('corenlp_server-(.*).props')
# Check if str is CoreNLP supported language
CORENLP_LANGS = ['ar', 'arabic', 'chinese', 'zh', 'english', 'en', 'french', 'fr', 'de', 'german', 'hu', 'hungarian',
'it', 'italian', 'es', 'spanish']
# map shorthands to full language names
LANGUAGE_SHORTHANDS_TO_FULL = {
"ar": "arabic",
"zh": "chinese",
"en": "english",
"fr": "french",
"de": "german",
"hu": "hungarian",
"it": "italian",
"es": "spanish"
}
def is_corenlp_lang(props_str):
""" Check if a string references a CoreNLP language """
return props_str.lower() in CORENLP_LANGS
# Validate CoreNLP properties
CORENLP_OUTPUT_VALS = ["conll", "conllu", "json", "serialized", "text", "xml", "inlinexml"]
def validate_corenlp_props(properties=None, annotators=None, output_format=None):
""" Do basic checks to validate CoreNLP properties """
if output_format and output_format.lower() not in CORENLP_OUTPUT_VALS:
raise ValueError(f"{output_format} not a valid CoreNLP outputFormat value! Choose from: {CORENLP_OUTPUT_VALS}")
if type(properties) == dict:
if "outputFormat" in properties and properties["outputFormat"].lower() not in CORENLP_OUTPUT_VALS:
raise ValueError(f"{properties['outputFormat']} not a valid CoreNLP outputFormat value! Choose from: "
f"{CORENLP_OUTPUT_VALS}")
class AnnotationException(Exception):
""" Exception raised when there was an error communicating with the CoreNLP server. """
pass
class TimeoutException(AnnotationException):
""" Exception raised when the CoreNLP server timed out. """
pass
class ShouldRetryException(Exception):
""" Exception raised if the service should retry the request. """
pass
class PermanentlyFailedException(Exception):
""" Exception raised if the service should NOT retry the request. """
pass
class StartServer(enum.Enum):
DONT_START = 0
FORCE_START = 1
TRY_START = 2
def clean_props_file(props_file):
# check if there is a temp server props file to remove and remove it
if props_file:
if os.path.isfile(props_file) and SERVER_PROPS_TMP_FILE_PATTERN.match(os.path.basename(props_file)):
os.remove(props_file)
class RobustService(object):
""" Service that resuscitates itself if it is not available. """
CHECK_ALIVE_TIMEOUT = 120
def __init__(self, start_cmd, stop_cmd, endpoint, stdout=None,
stderr=None, be_quiet=False, host=None, port=None, ignore_binding_error=False):
self.start_cmd = start_cmd and shlex.split(start_cmd)
self.stop_cmd = stop_cmd and shlex.split(stop_cmd)
self.endpoint = endpoint
self.stdout = stdout
self.stderr = stderr
self.server = None
self.is_active = False
self.be_quiet = be_quiet
self.host = host
self.port = port
self.ignore_binding_error = ignore_binding_error
atexit.register(self.atexit_kill)
def is_alive(self):
try:
if not self.ignore_binding_error and self.server is not None and self.server.poll() is not None:
return False
return requests.get(self.endpoint + "/ping").ok
except requests.exceptions.ConnectionError as e:
raise ShouldRetryException(e)
def start(self):
if self.start_cmd:
if self.host and self.port:
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
try:
sock.bind((self.host, self.port))
except socket.error as e:
if self.ignore_binding_error:
logger.info(f"Connecting to existing CoreNLP server at {self.host}:{self.port}")
self.server = None
return
else:
raise PermanentlyFailedException("Error: unable to start the CoreNLP server on port %d "
"(possibly something is already running there)" % self.port) from e
if self.be_quiet:
# Issue #26: subprocess.DEVNULL isn't supported in python 2.7.
if hasattr(subprocess, 'DEVNULL'):
stderr = subprocess.DEVNULL
else:
stderr = open(os.devnull, 'w')
stdout = stderr
else:
stdout = self.stdout
stderr = self.stderr
logger.info(f"Starting server with command: {' '.join(self.start_cmd)}")
try:
self.server = subprocess.Popen(self.start_cmd,
stderr=stderr,
stdout=stdout)
except FileNotFoundError as e:
raise FileNotFoundError("When trying to run CoreNLP, a FileNotFoundError occurred, which frequently means Java was not installed or was not in the classpath.") from e
def atexit_kill(self):
# make some kind of effort to stop the service (such as a
# CoreNLP server) at the end of the program. not waiting so
# that the python script exiting isn't delayed
if self.server and self.server.poll() is None:
self.server.terminate()
def stop(self):
if self.server:
self.server.terminate()
try:
self.server.wait(5)
except subprocess.TimeoutExpired:
# Resorting to more aggressive measures...
self.server.kill()
try:
self.server.wait(5)
except subprocess.TimeoutExpired:
# oh well
pass
self.server = None
if self.stop_cmd:
subprocess.run(self.stop_cmd, check=True)
self.is_active = False
def __enter__(self):
self.start()
return self
def __exit__(self, _, __, ___):
self.stop()
def ensure_alive(self):
# Check if the service is active and alive
if self.is_active:
try:
if self.is_alive():
return
else:
self.stop()
except ShouldRetryException:
pass
# If not, try to start up the service.
if self.server is None:
self.start()
# Wait for the service to start up.
start_time = time.time()
while True:
try:
if self.is_alive():
break
except ShouldRetryException:
pass
if time.time() - start_time < self.CHECK_ALIVE_TIMEOUT:
time.sleep(1)
else:
raise PermanentlyFailedException("Timed out waiting for service to come alive.")
# At this point we are guaranteed that the service is alive.
self.is_active = True
def resolve_classpath(classpath=None):
"""
Returns the classpath to use for corenlp.
Prefers to use the given classpath parameter, if available. If
not, uses the CORENLP_HOME environment variable. Resolves $CLASSPATH
(the exact string) in either the classpath parameter or $CORENLP_HOME.
"""
if classpath == '$CLASSPATH' or (classpath is None and os.getenv("CORENLP_HOME", None) == '$CLASSPATH'):
classpath = os.getenv("CLASSPATH")
elif classpath is None:
classpath = os.getenv("CORENLP_HOME", os.path.join(str(Path.home()), 'stanza_corenlp'))
if not os.path.exists(classpath):
raise FileNotFoundError("Please install CoreNLP by running `stanza.install_corenlp()`. If you have installed it, please define "
"$CORENLP_HOME to be location of your CoreNLP distribution or pass in a classpath parameter. "
"$CORENLP_HOME={}".format(os.getenv("CORENLP_HOME")))
classpath = os.path.join(classpath, "*")
return classpath
class CoreNLPClient(RobustService):
""" A client to the Stanford CoreNLP server. """
DEFAULT_ENDPOINT = "http://localhost:9000"
DEFAULT_TIMEOUT = 60000
DEFAULT_THREADS = 5
DEFAULT_OUTPUT_FORMAT = "serialized"
DEFAULT_MEMORY = "5G"
DEFAULT_MAX_CHAR_LENGTH = 100000
def __init__(self, start_server=StartServer.FORCE_START,
endpoint=DEFAULT_ENDPOINT,
timeout=DEFAULT_TIMEOUT,
threads=DEFAULT_THREADS,
annotators=None,
pretokenized=False,
output_format=None,
properties=None,
stdout=None,
stderr=None,
memory=DEFAULT_MEMORY,
be_quiet=False,
max_char_length=DEFAULT_MAX_CHAR_LENGTH,
preload=True,
classpath=None,
**kwargs):
# whether or not server should be started by client
self.start_server = start_server
self.server_props_path = None
self.server_start_time = None
self.server_host = None
self.server_port = None
self.server_classpath = None
# validate properties
validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
# set up client defaults
self.properties = properties
self.annotators = annotators
self.pretokenized = pretokenized
self.output_format = output_format
self._setup_client_defaults()
# start the server
if isinstance(start_server, bool):
warning_msg = f"Setting 'start_server' to a boolean value when constructing {self.__class__.__name__} is deprecated and will stop" + \
" to function in a future version of stanza. Please consider switching to using a value from stanza.server.StartServer."
logger.warning(warning_msg)
start_server = StartServer.FORCE_START if start_server is True else StartServer.DONT_START
# start the server
if start_server is StartServer.FORCE_START or start_server is StartServer.TRY_START:
# record info for server start
self.server_start_time = datetime.now()
# set up default properties for server
self._setup_server_defaults()
host, port = urlparse(endpoint).netloc.split(":")
port = int(port)
assert host == "localhost", "If starting a server, endpoint must be localhost"
classpath = resolve_classpath(classpath)
start_cmd = f"java -Xmx{memory} -cp '{classpath}' edu.stanford.nlp.pipeline.StanfordCoreNLPServer " \
f"-port {port} -timeout {timeout} -threads {threads} -maxCharLength {max_char_length} " \
f"-quiet {be_quiet} "
self.server_classpath = classpath
self.server_host = host
self.server_port = port
# set up server defaults
if self.server_props_path is not None:
start_cmd += f" -serverProperties {self.server_props_path}"
# possibly set pretokenized
if self.pretokenized:
start_cmd += f" -preTokenized"
# set annotators for server default
if self.annotators is not None:
annotators_str = self.annotators if type(annotators) == str else ",".join(annotators)
start_cmd += f" -annotators {annotators_str}"
# specify what to preload, if anything
if preload:
if type(preload) == bool:
# -preload flag means to preload all default annotators
start_cmd += " -preload"
elif type(preload) == list:
# turn list into comma separated list string, only preload these annotators
start_cmd += f" -preload {','.join(preload)}"
elif type(preload) == str:
# comma separated list of annotators
start_cmd += f" -preload {preload}"
# set outputFormat for server default
# if no output format requested by user, set to serialized
start_cmd += f" -outputFormat {self.output_format}"
# additional options for server:
# - server_id
# - ssl
# - status_port
# - uriContext
# - strict
# - key
# - username
# - password
# - blockList
for kw in ['ssl', 'strict']:
if kwargs.get(kw) is not None:
start_cmd += f" -{kw}"
for kw in ['status_port', 'uriContext', 'key', 'username', 'password', 'blockList', 'server_id']:
if kwargs.get(kw) is not None:
start_cmd += f" -{kw} {kwargs.get(kw)}"
stop_cmd = None
else:
start_cmd = stop_cmd = None
host = port = None
super(CoreNLPClient, self).__init__(start_cmd, stop_cmd, endpoint,
stdout, stderr, be_quiet, host=host, port=port, ignore_binding_error=(start_server == StartServer.TRY_START))
self.timeout = timeout
def _setup_client_defaults(self):
"""
Do some processing of annotators and output_format specified for the client.
If interacting with an externally started server, these will be defaults for annotate() calls.
:return: None
"""
# normalize annotators to str
if self.annotators is not None:
self.annotators = self.annotators if type(self.annotators) == str else ",".join(self.annotators)
# handle case where no output format is specified
if self.output_format is None:
if type(self.properties) == dict and 'outputFormat' in self.properties:
self.output_format = self.properties['outputFormat']
else:
self.output_format = CoreNLPClient.DEFAULT_OUTPUT_FORMAT
def _setup_server_defaults(self):
"""
Set up the default properties for the server.
The properties argument can take on one of 3 value types
1. File path on system or in CLASSPATH (e.g. /path/to/server.props or StanfordCoreNLP-french.properties
2. Name of a Stanford CoreNLP supported language (e.g. french or fr)
3. Python dictionary (properties written to tmp file for Java server, erased at end)
In addition, an annotators list and output_format can be specified directly with arguments. These
will overwrite any settings in the specified properties.
If no properties are specified, the standard Stanford CoreNLP English server will be launched. The outputFormat
will be set to 'serialized' and use the ProtobufAnnotationSerializer.
"""
# ensure properties is str or dict
if self.properties is None or (not isinstance(self.properties, str) and not isinstance(self.properties, dict)):
if self.properties is not None:
logger.warning('properties passed invalid value (not a str or dict), setting properties = {}')
self.properties = {}
# check if properties is a string, pass on to server which can handle
if isinstance(self.properties, str):
# try to translate to Stanford CoreNLP language name, or assume properties is a path
if is_corenlp_lang(self.properties):
if self.properties.lower() in LANGUAGE_SHORTHANDS_TO_FULL:
self.properties = LANGUAGE_SHORTHANDS_TO_FULL[self.properties]
logger.info(
f"Using CoreNLP default properties for: {self.properties}. Make sure to have "
f"{self.properties} models jar (available for download here: "
f"https://stanfordnlp.github.io/CoreNLP/) in CLASSPATH")
else:
if not os.path.isfile(self.properties):
logger.warning(f"{self.properties} does not correspond to a file path. Make sure this file is in "
f"your CLASSPATH.")
self.server_props_path = self.properties
elif isinstance(self.properties, dict):
# make a copy
server_start_properties = dict(self.properties)
if self.annotators is not None:
server_start_properties['annotators'] = self.annotators
if self.output_format is not None and isinstance(self.output_format, str):
server_start_properties['outputFormat'] = self.output_format
# write desired server start properties to tmp file
# set up to erase on exit
tmp_path = write_corenlp_props(server_start_properties)
logger.info(f"Writing properties to tmp file: {tmp_path}")
atexit.register(clean_props_file, tmp_path)
self.server_props_path = tmp_path
def _request(self, buf, properties, reset_default=False, **kwargs):
"""
Send a request to the CoreNLP server.
:param (str | bytes) buf: data to be sent with the request
:param (dict) properties: properties that the server expects
:return: request result
"""
if self.start_server is not StartServer.DONT_START:
self.ensure_alive()
try:
input_format = properties.get("inputFormat", "text")
if input_format == "text":
ctype = "text/plain; charset=utf-8"
elif input_format == "serialized":
ctype = "application/x-protobuf"
else:
raise ValueError("Unrecognized inputFormat " + input_format)
# handle auth
if 'username' in kwargs and 'password' in kwargs:
kwargs['auth'] = requests.auth.HTTPBasicAuth(kwargs['username'], kwargs['password'])
kwargs.pop('username')
kwargs.pop('password')
r = requests.post(self.endpoint,
params={'properties': str(properties), 'resetDefault': str(reset_default).lower()},
data=buf, headers={'content-type': ctype},
timeout=(self.timeout*2)/1000, **kwargs)
r.raise_for_status()
return r
except requests.exceptions.Timeout as e:
raise TimeoutException("Timeout requesting to CoreNLPServer. Maybe server is unavailable or your document is too long")
except requests.exceptions.RequestException as e:
if e.response is not None and e.response.text is not None:
raise AnnotationException(e.response.text) from e
elif e.args:
raise AnnotationException(e.args[0]) from e
raise AnnotationException() from e
def annotate(self, text, annotators=None, output_format=None, properties=None, reset_default=None, **kwargs):
"""
Send a request to the CoreNLP server.
:param (str | unicode) text: raw text for the CoreNLPServer to parse
:param (list | string) annotators: list of annotators to use
:param (str) output_format: output type from server: serialized, json, text, conll, conllu, or xml
:param (dict) properties: additional request properties (written on top of defaults)
:param (bool) reset_default: don't use server defaults
Precedence for settings:
1. annotators and output_format args
2. Values from properties dict
3. Client defaults self.annotators and self.output_format (set during client construction)
4. Server defaults
Additional request parameters (apart from CoreNLP pipeline properties) such as 'username' and 'password'
can be specified with the kwargs.
:return: request result
"""
# validate request properties
validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
# set request properties
request_properties = {}
# start with client defaults
if self.annotators is not None:
request_properties['annotators'] = self.annotators
if self.output_format is not None:
request_properties['outputFormat'] = self.output_format
# add values from properties arg
# handle str case
if type(properties) == str:
if is_corenlp_lang(properties):
properties = {'pipelineLanguage': properties.lower()}
if reset_default is None:
reset_default = True
else:
raise ValueError(f"Unrecognized properties keyword {properties}")
if type(properties) == dict:
request_properties.update(properties)
# if annotators list is specified, override with that
# also can use the annotators field the object was created with
if annotators is not None and (type(annotators) == str or type(annotators) == list):
request_properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
# if output format is specified, override with that
if output_format is not None and type(output_format) == str:
request_properties['outputFormat'] = output_format
# make the request
# if not explicitly set or the case of pipelineLanguage, reset_default should be None
if reset_default is None:
reset_default = False
r = self._request(text.encode('utf-8'), request_properties, reset_default, **kwargs)
if request_properties["outputFormat"] == "json":
return r.json()
elif request_properties["outputFormat"] == "serialized":
doc = Document()
parseFromDelimitedString(doc, r.content)
return doc
elif request_properties["outputFormat"] in ["text", "conllu", "conll", "xml"]:
return r.text
else:
return r
def update(self, doc, annotators=None, properties=None):
if properties is None:
properties = {}
properties.update({
'inputFormat': 'serialized',
'outputFormat': 'serialized',
'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
})
if annotators:
properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
with io.BytesIO() as stream:
writeToDelimitedString(doc, stream)
msg = stream.getvalue()
r = self._request(msg, properties)
doc = Document()
parseFromDelimitedString(doc, r.content)
return doc
def tokensregex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
# this is required for some reason
matches = self.__regex('/tokensregex', text, pattern, filter, annotators, properties)
if to_words:
matches = regex_matches_to_indexed_words(matches)
return matches
def semgrex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
matches = self.__regex('/semgrex', text, pattern, filter, annotators, properties)
if to_words:
matches = regex_matches_to_indexed_words(matches)
return matches
def fill_tree_proto(self, tree, proto_tree):
if tree.label:
proto_tree.value = tree.label
for child in tree.children:
proto_child = proto_tree.child.add()
self.fill_tree_proto(child, proto_child)
def tregex(self, text=None, pattern=None, filter=False, annotators=None, properties=None, trees=None):
# parse is not included by default in some of the pipelines,
# so we may need to manually override the annotators
# to include parse in order for tregex to do anything
if annotators is None and self.annotators is not None:
assert isinstance(self.annotators, str)
pieces = self.annotators.split(",")
if "parse" not in pieces:
annotators = self.annotators + ",parse"
else:
annotators = "tokenize,ssplit,pos,parse"
if pattern is None:
raise ValueError("Cannot have None as a pattern for tregex")
# TODO: we could also allow for passing in a complete document,
# along with the original text, so that the spans returns are more accurate
if trees is not None:
if properties is None:
properties = {}
properties['inputFormat'] = 'serialized'
if 'serializer' not in properties:
properties['serializer'] = 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
doc = Document()
full_text = []
for tree_idx, tree in enumerate(trees):
sentence = doc.sentence.add()
sentence.sentenceIndex = tree_idx
sentence.tokenOffsetBegin = len(full_text)
leaves = tree.leaf_labels()
full_text.extend(leaves)
sentence.tokenOffsetEnd = len(full_text)
self.fill_tree_proto(tree, sentence.parseTree)
for word in leaves:
token = sentence.token.add()
# the other side uses both value and word, weirdly enough
token.value = word
token.word = word
# without the actual tokenization, at least we can
# stop the words from running together
token.after = " "
doc.text = " ".join(full_text)
with io.BytesIO() as stream:
writeToDelimitedString(doc, stream)
text = stream.getvalue()
return self.__regex('/tregex', text, pattern, filter, annotators, properties)
def __regex(self, path, text, pattern, filter, annotators=None, properties=None):
"""
Send a regex-related request to the CoreNLP server.
:param (str | unicode) path: the path for the regex endpoint
:param text: raw text for the CoreNLPServer to apply the regex
:param (str | unicode) pattern: regex pattern
:param (bool) filter: option to filter sentences that contain matches, if false returns matches
:param properties: option to filter sentences that contain matches, if false returns matches
:return: request result
"""
if self.start_server is not StartServer.DONT_START:
self.ensure_alive()
if properties is None:
properties = {}
properties.update({
'inputFormat': 'text',
'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
})
if annotators:
properties['annotators'] = ",".join(annotators) if isinstance(annotators, list) else annotators
# force output for regex requests to be json
properties['outputFormat'] = 'json'
# if the server is trying to send back character offsets, it
# should send back codepoints counts as well in case the text
# has extra wide characters
properties['tokenize.codepoint'] = 'true'
try:
# Error occurs unless put properties in params
input_format = properties.get("inputFormat", "text")
if input_format == "text":
ctype = "text/plain; charset=utf-8"
elif input_format == "serialized":
ctype = "application/x-protobuf"
else:
raise ValueError("Unrecognized inputFormat " + input_format)
# change request method from `get` to `post` as required by CoreNLP
r = requests.post(
self.endpoint + path, params={
'pattern': pattern,
'filter': filter,
'properties': str(properties)
},
data=text.encode('utf-8') if isinstance(text, str) else text,
headers={'content-type': ctype},
timeout=(self.timeout*2)/1000,
)
r.raise_for_status()
if r.encoding is None:
r.encoding = "utf-8"
return json.loads(r.text)
except requests.HTTPError as e:
if r.text.startswith("Timeout"):
raise TimeoutException(r.text)
else:
raise AnnotationException(r.text)
except json.JSONDecodeError:
raise AnnotationException(r.text)
def scenegraph(self, text, properties=None):
"""
Send a request to the server which processes the text using SceneGraph
This will require a new CoreNLP release, 4.5.5 or later
"""
# since we're using requests ourself,
# check if the server has started or not
if self.start_server is not StartServer.DONT_START:
self.ensure_alive()
if properties is None:
properties = {}
# the only thing the scenegraph knows how to use is text
properties['inputFormat'] = 'text'
ctype = "text/plain; charset=utf-8"
# the json output format is much more useful
properties['outputFormat'] = 'json'
try:
r = requests.post(
self.endpoint + "/scenegraph",
params={
'properties': str(properties)
},
data=text.encode('utf-8') if isinstance(text, str) else text,
headers={'content-type': ctype},
timeout=(self.timeout*2)/1000,
)
r.raise_for_status()
if r.encoding is None:
r.encoding = "utf-8"
return json.loads(r.text)
except requests.HTTPError as e:
if r.text.startswith("Timeout"):
raise TimeoutException(r.text)
else:
raise AnnotationException(r.text)
except json.JSONDecodeError:
raise AnnotationException(r.text)
def read_corenlp_props(props_path):
""" Read a Stanford CoreNLP properties file into a dict """
props_dict = {}
with open(props_path) as props_file:
entry_lines = [entry_line for entry_line in props_file.read().split('\n')
if entry_line.strip() and not entry_line.startswith('#')]
for entry_line in entry_lines:
k = entry_line.split('=')[0]
k_len = len(k+"=")
v = entry_line[k_len:]
props_dict[k.strip()] = v
return props_dict
def write_corenlp_props(props_dict, file_path=None):
""" Write a Stanford CoreNLP properties dict to a file """
if file_path is None:
file_path = f"corenlp_server-{uuid.uuid4().hex[:16]}.props"
# confirm tmp file path matches pattern
assert SERVER_PROPS_TMP_FILE_PATTERN.match(file_path)
with open(file_path, 'w') as props_file:
for k, v in props_dict.items():
if isinstance(v, list):
writeable_v = ",".join(v)
else:
writeable_v = v
props_file.write(f'{k} = {writeable_v}\n\n')
return file_path
def regex_matches_to_indexed_words(matches):
"""
Transforms tokensregex and semgrex matches to indexed words.
:param matches: unprocessed regex matches
:return: flat array of indexed words
"""
words = [dict(v, **dict([('sentence', i)]))
for i, s in enumerate(matches['sentences'])
for k, v in s.items() if k != 'length']
return words
__all__ = ["CoreNLPClient", "AnnotationException", "TimeoutException", "to_text"]
================================================
FILE: stanza/server/dependency_converter.py
================================================
"""
A converter from constituency trees to dependency trees using CoreNLP's UniversalEnglish converter.
ONLY works on English.
"""
import stanza
from stanza.protobuf import DependencyConverterRequest, DependencyConverterResponse
from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext
CONVERTER_JAVA = "edu.stanford.nlp.trees.ProcessDependencyConverterRequest"
def send_converter_request(request, classpath=None):
return send_request(request, DependencyConverterResponse, CONVERTER_JAVA, classpath=classpath)
def build_request(doc):
"""
Request format is simple: one tree per sentence in the document
"""
request = DependencyConverterRequest()
for sentence in doc.sentences:
request.trees.append(build_tree(sentence.constituency, None))
return request
def process_doc(doc, classpath=None):
"""
Convert the constituency trees in the document,
then attach the resulting dependencies to the sentences
"""
request = build_request(doc)
response = send_converter_request(request, classpath=classpath)
attach_dependencies(doc, response)
def attach_dependencies(doc, response):
if len(doc.sentences) != len(response.conversions):
raise ValueError("Sent %d sentences but got back %d conversions" % (len(doc.sentences), len(response.conversions)))
for sent_idx, (sentence, conversion) in enumerate(zip(doc.sentences, response.conversions)):
graph = conversion.graph
# The deterministic conversion should have an equal number of words and one fewer edge
# ... the root is represented by a word with no parent
if len(sentence.words) != len(graph.node):
raise ValueError("Sentence %d of the conversion should have %d words but got back %d nodes in the graph" % (sent_idx, len(sentence.words), len(graph.node)))
if len(sentence.words) != len(graph.edge) + 1:
raise ValueError("Sentence %d of the conversion should have %d edges (one per word, plus the root) but got back %d edges in the graph" % (sent_idx, len(sentence.words) - 1, len(graph.edge)))
expected_nodes = set(range(1, len(sentence.words) + 1))
targets = set()
for edge in graph.edge:
if edge.target in targets:
raise ValueError("Found two parents of %d in sentence %d" % (edge.target, sent_idx))
targets.add(edge.target)
# -1 since the words are 0 indexed in the sentence,
# but we count dependencies from 1
sentence.words[edge.target-1].head = edge.source
sentence.words[edge.target-1].deprel = edge.dep
roots = expected_nodes - targets
assert len(roots) == 1
for root in roots:
sentence.words[root-1].head = 0
sentence.words[root-1].deprel = "root"
sentence.build_dependencies()
class DependencyConverter(JavaProtobufContext):
"""
Context window for the dependency converter
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, classpath=None):
super(DependencyConverter, self).__init__(classpath, DependencyConverterResponse, CONVERTER_JAVA)
def process(self, doc):
"""
Converts a constituency tree to dependency trees for each of the sentences in the document
"""
request = build_request(doc)
response = self.process_request(request)
attach_dependencies(doc, response)
return doc
def main():
nlp = stanza.Pipeline('en',
processors='tokenize,pos,constituency')
doc = nlp('I like blue antennae.')
print("{:C}".format(doc))
process_doc(doc, classpath="$CLASSPATH")
print("{:C}".format(doc))
doc = nlp('And I cannot lie.')
print("{:C}".format(doc))
with DependencyConverter(classpath="$CLASSPATH") as converter:
converter.process(doc)
print("{:C}".format(doc))
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/java_protobuf_requests.py
================================================
from collections import deque
import subprocess
from stanza.models.common.utils import misc_to_space_after
from stanza.models.constituency.parse_tree import Tree
from stanza.protobuf import DependencyGraph, FlattenedParseTree
from stanza.server.client import resolve_classpath
def send_request(request, response_type, java_main, classpath=None):
"""
Use subprocess to run a Java protobuf processor on the given request
Returns the protobuf response
"""
classpath = resolve_classpath(classpath)
if classpath is None:
raise ValueError("Classpath is None, Perhaps you need to set the $CLASSPATH or $CORENLP_HOME environment variable to point to a CoreNLP install.")
pipe = subprocess.run(["java", "-cp", classpath, java_main],
input=request.SerializeToString(),
stdout=subprocess.PIPE,
check=True)
response = response_type()
response.ParseFromString(pipe.stdout)
return response
def add_tree_nodes(proto_tree, tree, score):
# add an open node
node = proto_tree.nodes.add()
node.openNode = True
if score is not None:
node.score = score
# add the content of this node
node = proto_tree.nodes.add()
node.value = tree.label
# add all children...
# leaves get just one node
# branches are called recursively
for child in tree.children:
if child.is_leaf():
node = proto_tree.nodes.add()
node.value = child.label
else:
add_tree_nodes(proto_tree, child, None)
node = proto_tree.nodes.add()
node.closeNode = True
def build_tree(tree, score):
"""
Builds a FlattenedParseTree from CoreNLP.proto
Populates the value field from tree.label and iterates through the
children via tree.children. Should work on any tree structure
which follows that layout
The score will be added to the top node (if it is not None)
Operates by recursively calling add_tree_nodes
"""
proto_tree = FlattenedParseTree()
add_tree_nodes(proto_tree, tree, score)
return proto_tree
def from_tree(proto_tree):
"""
Convert a FlattenedParseTree back into a Tree
returns Tree, score
(score might be None if it is missing)
"""
score = None
stack = deque()
for node in proto_tree.nodes:
if node.HasField("score") and score is None:
score = node.score
if node.openNode:
if len(stack) > 0 and isinstance(stack[-1], FlattenedParseTree.Node) and stack[-1].openNode:
raise ValueError("Got a proto with no label on a node: {}".format(proto_tree))
stack.append(node)
continue
if not node.closeNode:
child = Tree(label=node.value)
# TODO: do something with the score
stack.append(child)
continue
# must be a close operation...
if len(stack) <= 1:
raise ValueError("Got a proto with too many close operations: {}".format(proto_tree))
# on a close operation, pop until we hit the open
# then turn everything in that span into a new node
children = []
nextNode = stack.pop()
while not isinstance(nextNode, FlattenedParseTree.Node):
children.append(nextNode)
nextNode = stack.pop()
if len(children) == 0:
raise ValueError("Got a proto with an open immediately followed by a close: {}".format(proto_tree))
children.reverse()
label = children[0]
children = children[1:]
subtree = Tree(label=label.label, children=children)
stack.append(subtree)
if len(stack) > 1:
raise ValueError("Got a proto which does not close all of the nodes: {}".format(proto_tree))
tree = stack.pop()
if not isinstance(tree, Tree):
raise ValueError("Got a proto which was just one Open operation: {}".format(proto_tree))
return tree, score
def add_token(token_list, word, token):
"""
Add a token to a proto request.
CoreNLP tokens have components of both word and token from stanza.
We pass along "after" but not "before"
"""
if token is None and isinstance(word.id, int):
raise AssertionError("Only expected word w/o token for 'extra' words")
query_token = token_list.add()
query_token.word = word.text
query_token.value = word.text
if word.lemma is not None:
query_token.lemma = word.lemma
if word.xpos is not None:
query_token.pos = word.xpos
if word.upos is not None:
query_token.coarseTag = word.upos
if word.feats and word.feats != "_":
for feature in word.feats.split("|"):
key, value = feature.split("=", maxsplit=1)
query_token.conllUFeatures.key.append(key)
query_token.conllUFeatures.value.append(value)
if token is not None:
if token.ner is not None:
query_token.ner = token.ner
if token is not None and len(token.id) > 1:
query_token.mwtText = token.text
query_token.isMWT = True
query_token.isFirstMWT = token.id[0] == word.id
if token.id[-1] != word.id:
# if we are not the last word of an MWT token
# we are absolutely not followed by space
pass
else:
query_token.after = token.spaces_after
query_token.index = word.id
else:
# presumably empty words won't really be written this way,
# but we can still keep track of it
query_token.after = misc_to_space_after(word.misc)
query_token.index = word.id[0]
query_token.emptyIndex = word.id[1]
if word.misc and word.misc != "_":
query_token.conllUMisc = word.misc
if token is not None and token.misc and token.misc != "_":
query_token.mwtMisc = token.misc
def add_sentence(request_sentences, sentence, num_tokens):
"""
Add the tokens for this stanza sentence to a list of protobuf sentences
"""
request_sentence = request_sentences.add()
request_sentence.tokenOffsetBegin = num_tokens
request_sentence.tokenOffsetEnd = num_tokens + sum(len(token.words) for token in sentence.tokens)
for token in sentence.tokens:
for word in token.words:
add_token(request_sentence.token, word, token)
return request_sentence
def add_word_to_graph(graph, word, sent_idx):
"""
Add a node and possibly an edge for a word in a basic dependency graph.
"""
node = graph.node.add()
node.sentenceIndex = sent_idx+1
if isinstance(word.id, int):
node.index = word.id
else:
node.index = word.id[0]
node.emptyIndex = word.id[1]
if word.head != 0 and word.head is not None:
edge = graph.edge.add()
edge.source = word.head
if isinstance(word.id, int):
edge.target = word.id
else:
edge.target = word.id[0]
edge.targetEmpty = word.id[1]
if word.deprel is not None:
edge.dep = word.deprel
else:
# the receiving side doesn't like null as a dependency
edge.dep = "_"
def convert_networkx_graph(graph_proto, sentence, sent_idx):
"""
Turns a networkx graph into a DependencyGraph from the proto file
"""
for token in sentence.tokens:
for word in token.words:
add_token(graph_proto.token, word, token)
for word in sentence.empty_words:
add_token(graph_proto.token, word, None)
dependencies = sentence._enhanced_dependencies
for target in dependencies:
if target == 0:
# don't need to send the explicit root
continue
for source in dependencies.predecessors(target):
if source == 0:
# unlike with basic, we need to send over the roots,
# as the enhanced can have loops
graph_proto.rootNode.append(len(graph_proto.node))
continue
for deprel in dependencies.get_edge_data(source, target):
edge = graph_proto.edge.add()
if isinstance(source, int):
edge.source = source
else:
edge.source = source[0]
if source[1] != 0:
edge.sourceEmpty = source[1]
if isinstance(target, int):
edge.target = target
else:
edge.target = target[0]
if target[1] != 0:
edge.targetEmpty = target[1]
edge.dep = deprel
node = graph_proto.node.add()
node.sentenceIndex = sent_idx + 1
# the nodes in the networkx graph are indexed from 1, not counting the root
if isinstance(target, int):
node.index = target
else:
node.index = target[0]
if target[1] != 0:
node.emptyIndex = target[1]
return graph_proto
def features_to_string(features):
if not features:
return None
if len(features.key) == 0:
return None
return "|".join("%s=%s" % (key, value) for key, value in zip(features.key, features.value))
def misc_space_pieces(misc):
"""
Return only the space-related misc pieces
"""
if misc is None or misc == "" or misc == "_":
return misc
pieces = misc.split("|")
pieces = [x for x in pieces if x.split("=", maxsplit=1)[0] in ("SpaceAfter", "SpacesAfter", "SpacesBefore")]
if len(pieces) > 0:
return "|".join(pieces)
return None
def remove_space_misc(misc):
"""
Remove any pieces from misc which are space-related
"""
if misc is None or misc == "" or misc == "_":
return misc
pieces = misc.split("|")
pieces = [x for x in pieces if x.split("=", maxsplit=1)[0] not in ("SpaceAfter", "SpacesAfter", "SpacesBefore")]
if len(pieces) > 0:
return "|".join(pieces)
return None
def substitute_space_misc(misc, space_misc):
space_misc_pieces = space_misc.split("|") if space_misc else []
space_misc_after = None
space_misc_before = None
for piece in space_misc_pieces:
if piece.startswith("SpaceBefore"):
space_misc_before = piece
elif piece.startswith("SpaceAfter") or piece.startswith("SpacesAfter"):
space_misc_after = piece
else:
raise AssertionError("An unknown piece wound up in the misc space fields: %s" % piece)
pieces = misc.split("|")
new_pieces = []
for piece in pieces:
if piece.startswith("SpaceBefore"):
if space_misc_before:
new_pieces.append(space_misc_before)
space_misc_before = None
elif piece.startswith("SpaceAfter") or piece.startswith("SpacesAfter"):
if space_misc_after:
new_pieces.append(space_misc_after)
space_misc_after = None
else:
new_pieces.append(piece)
if space_misc_after:
new_pieces.append(space_misc_after)
if space_misc_before:
new_pieces.append(space_misc_before)
if len(new_pieces) == 0:
return None
return "|".join(new_pieces)
class JavaProtobufContext(object):
"""
A generic context for sending requests to a java program using protobufs in a subprocess
"""
def __init__(self, classpath, build_response, java_main, extra_args=None):
self.classpath = resolve_classpath(classpath)
self.build_response = build_response
self.java_main = java_main
if extra_args is None:
extra_args = []
self.extra_args = extra_args
self.pipe = None
def open_pipe(self):
self.pipe = subprocess.Popen(["java", "-cp", self.classpath, self.java_main, "-multiple"] + self.extra_args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
def close_pipe(self):
if self.pipe.poll() is None:
self.pipe.stdin.write((0).to_bytes(4, 'big'))
self.pipe.stdin.flush()
self.pipe = None
def __enter__(self):
self.open_pipe()
return self
def __exit__(self, type, value, traceback):
self.close_pipe()
def process_request(self, request):
if self.pipe is None:
raise RuntimeError("Pipe to java process is not open or was closed")
text = request.SerializeToString()
self.pipe.stdin.write(len(text).to_bytes(4, 'big'))
self.pipe.stdin.write(text)
self.pipe.stdin.flush()
response_length = self.pipe.stdout.read(4)
if len(response_length) < 4:
raise BrokenPipeError("Could not communicate with java process!")
response_length = int.from_bytes(response_length, "big")
response_text = self.pipe.stdout.read(response_length)
response = self.build_response()
response.ParseFromString(response_text)
return response
================================================
FILE: stanza/server/main.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Simple shell program to pipe in
"""
import corenlp
import json
import re
import csv
import sys
from collections import namedtuple, OrderedDict
FLOAT_RE = re.compile(r"\d*\.\d+")
INT_RE = re.compile(r"\d+")
def dictstr(arg):
"""
Parse a key=value string as a tuple (key, value) that can be provided as an argument to dict()
"""
key, value = arg.split("=")
if value.lower() == "true" or value.lower() == "false":
value = bool(value)
elif INT_RE.match(value):
value = int(value)
elif FLOAT_RE.match(value):
value = float(value)
return (key, value)
def do_annotate(args):
args.props = dict(args.props) if args.props else {}
if args.sentence_mode:
args.props["ssplit.isOneSentence"] = True
with corenlp.CoreNLPClient(annotators=args.annotators, properties=args.props, be_quiet=not args.verbose_server) as client:
for line in args.input:
if line.startswith("#"): continue
ann = client.annotate(line.strip(), output_format=args.format)
if args.format == "json":
if args.sentence_mode:
ann = ann["sentences"][0]
args.output.write(json.dumps(ann))
args.output.write("\n")
def main():
import argparse
parser = argparse.ArgumentParser(description='Annotate data')
parser.add_argument('-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="Input file to process; each line contains one document (default: stdin)")
parser.add_argument('-o', '--output', type=argparse.FileType('w'), default=sys.stdout, help="File to write annotations to (default: stdout)")
parser.add_argument('-f', '--format', choices=["json",], default="json", help="Output format")
parser.add_argument('-a', '--annotators', nargs="+", type=str, default=["tokenize ssplit lemma pos"], help="A list of annotators")
parser.add_argument('-s', '--sentence-mode', action="store_true",help="Assume each line of input is a sentence.")
parser.add_argument('-v', '--verbose-server', action="store_true",help="Server is made verbose")
parser.add_argument('-m', '--memory', type=str, default="4G", help="Memory to use for the server")
parser.add_argument('-p', '--props', nargs="+", type=dictstr, help="Properties as a list of key=value pairs")
parser.set_defaults(func=do_annotate)
ARGS = parser.parse_args()
if ARGS.func is None:
parser.print_help()
sys.exit(1)
else:
ARGS.func(ARGS)
if __name__ == "__main__":
main()
================================================
FILE: stanza/server/morphology.py
================================================
"""
Direct pipe connection to the Java CoreNLP Morphology class
Only effective for English. Must be supplied with PTB scheme xpos, not upos
"""
from stanza.protobuf import MorphologyRequest, MorphologyResponse
from stanza.server.java_protobuf_requests import send_request, JavaProtobufContext
MORPHOLOGY_JAVA = "edu.stanford.nlp.process.ProcessMorphologyRequest"
def send_morphology_request(request):
return send_request(request, MorphologyResponse, MORPHOLOGY_JAVA)
def build_request(words, xpos_tags):
"""
Turn a list of words and a list of tags into a request
tags must be xpos, not upos
"""
request = MorphologyRequest()
for word, tag in zip(words, xpos_tags):
tagged_word = request.words.add()
tagged_word.word = word
tagged_word.xpos = tag
return request
def process_text(words, xpos_tags):
"""
Get the lemmata for each word/tag pair
Currently the return is a MorphologyResponse from CoreNLP.proto
tags must be xpos, not upos
"""
request = build_request(words, xpos_tags)
return send_morphology_request(request)
class Morphology(JavaProtobufContext):
"""
Morphology context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
(much faster than calling process_text over and over)
"""
def __init__(self, classpath=None):
super(Morphology, self).__init__(classpath, MorphologyResponse, MORPHOLOGY_JAVA)
def process(self, words, xpos_tags):
"""
Get the lemmata for each word/tag pair
"""
request = build_request(words, xpos_tags)
return self.process_request(request)
def main():
# TODO: turn this into a unit test, once a new CoreNLP is released
words = ["Jennifer", "has", "the", "prettiest", "antennae"]
tags = ["NNP", "VBZ", "DT", "JJS", "NNS"]
expected = ["Jennifer", "have", "the", "pretty", "antenna"]
result = process_text(words, tags)
lemma = [x.lemma for x in result.words]
print(lemma)
assert lemma == expected
with Morphology() as morph:
result = morph.process(words, tags)
lemma = [x.lemma for x in result.words]
assert lemma == expected
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/parser_eval.py
================================================
"""
This class runs a Java process to evaluate a treebank prediction using CoreNLP
"""
from collections import namedtuple
import sys
import stanza
from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse
from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext
from stanza.models.constituency.tree_reader import read_treebank
EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser"
ParseResult = namedtuple("ParseResult", ['gold', 'predictions', 'state', 'constituents'])
ScoredTree = namedtuple("ScoredTree", ['tree', 'score'])
def build_request(treebank):
"""
treebank should be a list of pairs: [gold, predictions]
each predictions is a list of tuples (prediction, score, state)
state is ignored and can be None
Note that for now, only one tree is measured, but this may be extensible in the future
Trees should be in the form of a Tree from parse_tree.py
"""
request = EvaluateParserRequest()
for raw_result in treebank:
gold = raw_result.gold
predictions = raw_result.predictions
parse_result = request.treebank.add()
parse_result.gold.CopyFrom(build_tree(gold, None))
for pred in predictions:
if isinstance(pred, tuple):
prediction, score = pred
else:
prediction = pred
score = None
try:
parse_result.predicted.append(build_tree(prediction, score))
except Exception as e:
raise RuntimeError("Unable to build parser request from tree {}".format(pred)) from e
return request
def collate(gold_treebank, predictions_treebank):
"""
Turns a list of gold and prediction into a evaluation object
"""
treebank = []
for gold, prediction in zip(gold_treebank, predictions_treebank):
result = ParseResult(gold, [prediction], None, None)
treebank.append(result)
return treebank
class EvaluateParser(JavaProtobufContext):
"""
Parser evaluation context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, classpath=None, kbest=None, silent=False):
if kbest is not None:
extra_args = ["-evalPCFGkBest", "{}".format(kbest), "-evals", "pcfgTopK"]
else:
extra_args = []
if silent:
extra_args.extend(["-evals", "summary=False"])
super(EvaluateParser, self).__init__(classpath, EvaluateParserResponse, EVALUATE_JAVA, extra_args=extra_args)
def process(self, treebank):
request = build_request(treebank)
return self.process_request(request)
def main():
gold = read_treebank(sys.argv[1])
predictions = read_treebank(sys.argv[2])
treebank = collate(gold, predictions)
with EvaluateParser() as ep:
ep.process(treebank)
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/semgrex.py
================================================
"""Invokes the Java semgrex on a document
The server client has a method "semgrex" which sends text to Java
CoreNLP for processing with a semgrex (SEMantic GRaph regEX) query:
https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html
However, this operates on text using the CoreNLP tools, which means
the dependency graphs may not align with stanza's depparse module, and
this also limits the languages for which it can be used. This module
allows for running semgrex commands on the graphs produced by
depparse.
To use, first process text into a doc using stanza.Pipeline
Next, pass the processed doc and a list of semgrex patterns to
process_doc in this module. It will run the java semgrex module as a
subprocess and return the result in the form of a SemgrexResponse,
whose description is in the proto file included with stanza.
A minimal example is the main method of this module.
Note that launching the subprocess is potentially quite expensive
relative to the search if used many times on small documents. Ideally
larger texts would be processed, and all of the desired semgrex
patterns would be run at once. The worst thing to do would be to call
this multiple times on a large document, one invocation per semgrex
pattern, as that would serialize the document each time.
Included here is a context manager which allows for keeping the same
java process open for multiple requests. This saves on the subprocess
launching time. It is still important not to wastefully serialize the
same document over and over, though.
"""
import argparse
from collections import namedtuple
import copy
import os
import re
import stanza
from stanza.protobuf import SemgrexRequest, SemgrexResponse
from stanza.server.java_protobuf_requests import send_request, add_token, add_word_to_graph, JavaProtobufContext, convert_networkx_graph
from stanza.utils.conll import CoNLL
SEMGREX_JAVA = "edu.stanford.nlp.semgraph.semgrex.ProcessSemgrexRequest"
SemgrexQuery = namedtuple("SemgrexQuery", "pattern comments")
def send_semgrex_request(request):
return send_request(request, SemgrexResponse, SEMGREX_JAVA)
def build_request(doc, semgrex_patterns, enhanced=False):
request = SemgrexRequest()
if isinstance(semgrex_patterns, str):
semgrex_patterns = [semgrex_patterns]
semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns]
for semgrex in semgrex_patterns:
request.semgrex.append(semgrex.pattern)
for sent_idx, sentence in enumerate(doc.sentences):
query = request.query.add()
if enhanced:
# tokens will be added on to the graph object
convert_networkx_graph(query.graph, sentence, sent_idx)
else:
word_idx = 0
for token in sentence.tokens:
for word in token.words:
add_token(query.token, word, token)
add_word_to_graph(query.graph, word, sent_idx)
word_idx = word_idx + 1
return request
def process_doc(doc, *semgrex_patterns, enhanced=False):
"""
Returns the result of processing the given semgrex expression on the stanza doc.
Currently the return is a SemgrexResponse from CoreNLP.proto
"""
request = build_request(doc, semgrex_patterns, enhanced=enhanced)
return send_semgrex_request(request)
class Semgrex(JavaProtobufContext):
"""
Semgrex context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, classpath=None):
super(Semgrex, self).__init__(classpath, SemgrexResponse, SEMGREX_JAVA)
def process(self, doc, *semgrex_patterns):
"""
Apply each of the semgrex patterns to each of the dependency trees in doc
"""
request = build_request(doc, semgrex_patterns)
return self.process_request(request)
def annotate_doc(doc, semgrex_result, semgrex_patterns, matches_only, exclude_matches):
"""
Put comments on the sentences which describe the matching semgrex patterns
"""
doc = copy.deepcopy(doc)
if isinstance(semgrex_patterns, str):
semgrex_patterns = [semgrex_patterns]
semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns]
matched_ids = set()
for sentence_result in semgrex_result.result:
for pattern_result in sentence_result.result:
for match in pattern_result.match:
matched_ids.add(match.sentenceIndex)
pattern_texts = [semgrex_pattern.pattern.replace("\n", " ") for semgrex_pattern in semgrex_patterns]
matching_sentences = []
for sentence_result in semgrex_result.result:
sentence_matched = False
matched_semgrex_ids = set()
for pattern_result in sentence_result.result:
if len(pattern_result.match) == 0:
continue
highlight_tokens = []
highlight_edges = []
for match in pattern_result.match:
sentence_matched = True
sentence = doc.sentences[match.sentenceIndex]
semgrex_pattern = semgrex_patterns[match.semgrexIndex]
pattern_text = pattern_texts[match.semgrexIndex]
matched_semgrex_ids.add(match.semgrexIndex)
match_word = "%d:%s" % (match.matchIndex, sentence.words[match.matchIndex-1].text)
if len(match.node) == 0:
node_matches = ""
else:
node_matches = ["%s=%d:%s" % (node.name, node.matchIndex, sentence.words[node.matchIndex-1].text)
for node in match.node]
node_matches = " " + " ".join(node_matches)
if len(match.varstring) == 0:
var_values = ""
else:
var_values = ["%s=%s" % (v.name, v.value) for v in match.varstring]
var_values = " " + " ".join(var_values)
sentence.add_comment("# semgrex pattern |%s| matched at %s%s%s" % (pattern_text, match_word, node_matches, var_values))
for comment in semgrex_pattern.comments:
sentence.add_comment("# semgrex comment: %s" % comment)
highlight_tokens.append(match.matchIndex)
for edge in match.edge:
highlight_edges.append(edge.target)
if len(highlight_tokens) > 0:
sentence.add_comment("# highlight tokens = %s" % (" ".join("%d" % x for x in highlight_tokens)))
if len(highlight_edges) > 0:
sentence.add_comment("# highlight deprels = %s" % (" ".join("%d" % x for x in highlight_edges)))
if sentence_matched and not matches_only:
for semgrex_idx, pattern_text in enumerate(pattern_texts):
if semgrex_idx not in matched_semgrex_ids:
sentence.add_comment("# semgrex pattern |%s| did not match!" % pattern_text)
if sentence_matched:
matching_sentences.append(sentence)
nonmatching_sentences = [sentence for sentence_idx, sentence in enumerate(doc.sentences) if sentence_idx not in matched_ids]
for sentence in nonmatching_sentences:
for semgrex_idx, pattern_text in enumerate(pattern_texts):
sentence.add_comment("# semgrex pattern |%s| did not match!" % pattern_text)
if matches_only:
doc.sentences = matching_sentences
elif exclude_matches:
doc.sentences = nonmatching_sentences
return doc
def main():
"""
Runs a toy example, or can run a given semgrex expression on the given input file.
For example:
python3 -m stanza.server.semgrex --input_file demo/semgrex_sample.conllu
--matches_only to only print sentences that match the semgrex pattern
--no_print_input to not print the input
"""
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default=None, help='Process this file or directory')
parser.add_argument('--input_filter', type=str, default=None, help='Only process files that match this regex')
parser.add_argument('semgrex', type=str, nargs="*", default=["{}=source >obj=zzz {}=target"], help="Semgrex to apply to the text. The default looks for sentences with objects")
parser.add_argument('--semgrex_file', type=str, default=None, help="File to read semgrex patterns from - relevant in case the pattern you want to use doesn't work well on the command line, for example")
parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy")
parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
parser.add_argument('--matches_only', action='store_true', default=True, help="Only print the matching sentences")
parser.add_argument('--no_matches_only', dest='matches_only', action='store_false', help="Only print the matching sentences")
parser.add_argument('--exclude_matches', action='store_true', default=False, help="Only print the NON-matching sentences")
parser.add_argument('--enhanced', action='store_true', default=False, help='Use the enhanced dependencies instead of the basic')
parser.add_argument('--no_combined_doc', dest='combined_doc', action='store_false', default=True, help='By default, combine all the input docs into one big document. Allows for easier secondary processing like sorting')
args = parser.parse_args()
if args.semgrex_file:
with open(args.semgrex_file) as fin:
args.semgrex = [x.strip() for x in fin.readlines()]
semgrex_patterns = []
current_comments = []
for line in args.semgrex:
if not line:
current_comments = []
elif line.startswith("#"):
current_comments.append(line[1:].strip())
else:
semgrex_patterns.append(SemgrexQuery(line, current_comments))
current_comments = []
args.semgrex = semgrex_patterns
if args.input:
if os.path.isfile(args.input):
docs = [CoNLL.conll2doc(input_file=args.input, ignore_gapping=False)]
else:
filenames = sorted(os.listdir(args.input))
if args.input_filter:
input_filter = re.compile(args.input_filter)
filenames = [x for x in filenames if input_filter.match(x)]
filenames = [os.path.join(args.input, filename) for filename in filenames]
filenames = [filename for filename in filenames if os.path.isfile(filename)]
docs = [CoNLL.conll2doc(input_file=filename, ignore_gapping=False) for filename in filenames]
else:
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse')
docs = [nlp('Uro ruined modern. Fortunately, Wotc banned him.')]
if args.combined_doc:
sentences = [sent for doc in docs for sent in doc.sentences]
docs = [docs[0]]
docs[0].sentences = sentences
for doc in docs:
if args.print_input:
print("{:C}".format(doc))
print()
print("-" * 75)
print()
semgrex_result = process_doc(doc, *args.semgrex, enhanced=args.enhanced)
doc = annotate_doc(doc, semgrex_result, args.semgrex, args.matches_only, args.exclude_matches)
if len(doc.sentences) > 0:
print("{:C}\n".format(doc))
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/ssurgeon.py
================================================
"""Invokes the Java ssurgeon on a document
"ssurgeon" sends text to Java CoreNLP for processing with a ssurgeon
(Semantic graph SURGEON) query
The main program in this file gives a very short intro to how to use it.
"""
import argparse
from collections import namedtuple
import copy
import os
import re
import sys
from stanza.models.common.utils import misc_to_space_after, space_after_to_misc
from stanza.protobuf import SsurgeonRequest, SsurgeonResponse
from stanza.server import java_protobuf_requests
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, Word, Token, Sentence
SSURGEON_JAVA = "edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest"
SsurgeonEdit = namedtuple("SsurgeonEdit",
"semgrex_pattern ssurgeon_edits ssurgeon_id notes language",
defaults=[None, None, "UniversalEnglish"])
def parse_ssurgeon_edits(ssurgeon_text):
ssurgeon_text = ssurgeon_text.strip()
ssurgeon_blocks = re.split("\n\n+", ssurgeon_text)
ssurgeon_edits = []
for idx, block in enumerate(ssurgeon_blocks):
lines = block.split("\n")
comments = [line[1:].strip() for line in lines if line.startswith("#")]
notes = " ".join(comments)
lines = [x.strip() for x in lines if x.strip() and not x.startswith("#")]
if len(lines) == 0:
# was a block of entirely comments
continue
semgrex = lines[0]
ssurgeon = lines[1:]
ssurgeon_edits.append(SsurgeonEdit(semgrex, ssurgeon, "%d" % (idx + 1), notes))
return ssurgeon_edits
def read_ssurgeon_edits(edit_file):
with open(edit_file, encoding="utf-8") as fin:
return parse_ssurgeon_edits(fin.read())
def send_ssurgeon_request(request):
return java_protobuf_requests.send_request(request, SsurgeonResponse, SSURGEON_JAVA)
def build_request(doc, ssurgeon_edits):
request = SsurgeonRequest()
for ssurgeon in ssurgeon_edits:
ssurgeon_proto = request.ssurgeon.add()
ssurgeon_proto.semgrex = ssurgeon.semgrex_pattern
for operation in ssurgeon.ssurgeon_edits:
ssurgeon_proto.operation.append(operation)
if ssurgeon.ssurgeon_id is not None:
ssurgeon_proto.id = ssurgeon.ssurgeon_id
if ssurgeon.notes is not None:
ssurgeon_proto.notes = ssurgeon.notes
if ssurgeon.language is not None:
ssurgeon_proto.language = ssurgeon.language
try:
for sent_idx, sentence in enumerate(doc.sentences):
graph = request.graph.add()
word_idx = 0
for token in sentence.tokens:
for word in token.words:
java_protobuf_requests.add_token(graph.token, word, token)
java_protobuf_requests.add_word_to_graph(graph, word, sent_idx)
word_idx = word_idx + 1
except Exception as e:
raise RuntimeError("Failed to process sentence {}:\n{:C}".format(sent_idx, sentence)) from e
return request
def build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
ssurgeon_edit = SsurgeonEdit(semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
return build_request(doc, [ssurgeon_edit])
def process_doc(doc, ssurgeon_edits):
"""
Returns the result of processing the given semgrex expression and ssurgeon edits on the stanza doc.
Currently the return is a SsurgeonResponse from CoreNLP.proto
"""
request = build_request(doc, ssurgeon_edits)
return send_ssurgeon_request(request)
def process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
return send_ssurgeon_request(request)
def build_word_entry(word_index, graph_word):
word_entry = {
ID: word_index,
TEXT: graph_word.word if graph_word.word else None,
LEMMA: graph_word.lemma if graph_word.lemma else None,
UPOS: graph_word.coarseTag if graph_word.coarseTag else None,
XPOS: graph_word.pos if graph_word.pos else None,
FEATS: java_protobuf_requests.features_to_string(graph_word.conllUFeatures),
DEPS: None,
NER: graph_word.ner if graph_word.ner else None,
MISC: None,
START_CHAR: None, # TODO: fix this? one problem is the text positions
END_CHAR: None, # might change across all of the sentences
# presumably python will complain if this conflicts
# with one of the constants above
"is_mwt": graph_word.isMWT,
"is_first_mwt": graph_word.isFirstMWT,
"mwt_text": graph_word.mwtText,
"mwt_misc": graph_word.mwtMisc,
}
# TODO: do "before" as well
word_entry[MISC] = space_after_to_misc(graph_word.after)
if graph_word.conllUMisc:
word_entry[MISC] = java_protobuf_requests.substitute_space_misc(graph_word.conllUMisc, word_entry[MISC])
return word_entry
def convert_response_to_doc(doc, semgrex_response, add_missing_text):
doc = copy.deepcopy(doc)
try:
for sent_idx, (sentence, ssurgeon_result) in enumerate(zip(doc.sentences, semgrex_response.result)):
# EditNode is currently bugged... :/
# TODO: change this after next CoreNLP release (after 4.5.3)
#if not ssurgeon_result.changed:
# continue
ssurgeon_graph = ssurgeon_result.graph
tokens = []
token_id_to_idx = {}
for graph_node, graph_word in zip(ssurgeon_graph.node, ssurgeon_graph.token):
word_entry = build_word_entry(graph_node.index, graph_word)
token_id_to_idx[graph_node.index] = len(tokens)
tokens.append(word_entry)
for root in ssurgeon_graph.root:
tokens[token_id_to_idx[root]][HEAD] = 0
tokens[token_id_to_idx[root]][DEPREL] = "root"
for edge in ssurgeon_graph.edge:
# can't do anything about the extra dependencies for now
# TODO: put them all in .deps
if edge.isExtra:
continue
tokens[token_id_to_idx[edge.target]][HEAD] = edge.source
tokens[token_id_to_idx[edge.target]][DEPREL] = edge.dep
tokens.sort(key=lambda x: x[ID])
# for any MWT, produce a token_entry which represents the word range
mwt_tokens = []
for word_start_idx, word in enumerate(tokens):
if not word["is_first_mwt"]:
if word["is_mwt"]:
word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
mwt_tokens.append(word)
continue
word_end_idx = word_start_idx + 1
while word_end_idx < len(tokens) and tokens[word_end_idx]["is_mwt"] and not tokens[word_end_idx]["is_first_mwt"]:
word_end_idx += 1
mwt_token_entry = {
# the tokens don't fencepost the way lists do
ID: (tokens[word_start_idx][ID], tokens[word_end_idx-1][ID]),
TEXT: word["mwt_text"],
NER: word[NER],
# use the SpaceAfter=No (or not) from the last word in the token
MISC: None,
}
mwt_token_entry[MISC] = java_protobuf_requests.misc_space_pieces(tokens[word_end_idx-1][MISC])
if tokens[word_end_idx-1]["mwt_misc"]:
mwt_token_entry[MISC] = java_protobuf_requests.substitute_space_misc(tokens[word_end_idx-1]["mwt_misc"], mwt_token_entry[MISC])
word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
mwt_tokens.append(mwt_token_entry)
mwt_tokens.append(word)
old_comments = list(sentence.comments)
sentence = Sentence(mwt_tokens, doc)
token_text = []
for token_idx, token in enumerate(sentence.tokens):
token_text.append(token.text)
if token_idx == len(sentence.tokens) - 1:
break
token_text.append(token.spaces_after)
sentence_text = "".join(token_text)
found_text = False
for comment in old_comments:
if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
sentence.add_comment("# text = " + sentence_text)
found_text = True
else:
sentence.add_comment(comment)
if not found_text and add_missing_text:
sentence.add_comment("# text = " + sentence_text)
doc.sentences[sent_idx] = sentence
sentence.rebuild_dependencies()
except Exception as e:
raise RuntimeError("Ssurgeon could not process sentence {}\nSsurgeon result:\n{}\nOriginal sentence:\n{:C}".format(sent_idx, ssurgeon_result, sentence)) from e
return doc
class Ssurgeon(java_protobuf_requests.JavaProtobufContext):
"""
Ssurgeon context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, classpath=None):
super(Ssurgeon, self).__init__(classpath, SsurgeonResponse, SSURGEON_JAVA)
def process(self, doc, ssurgeon_edits):
"""
Apply each of the ssurgeon patterns to each of the dependency trees in doc
"""
request = build_request(doc, ssurgeon_edits)
return self.process_request(request)
def process_one_operation(self, doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
"""
Convenience method - build one operation, then apply it
"""
request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
return self.process_request(request)
SAMPLE_DOC = """
# sent_id = 271
# text = Hers is easy to clean.
# previous = What did the dealer like about Alex's car?
# comment = extraction/raising via "tough extraction" and clausal subject
1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
3 easy easy ADJ JJ Degree=Pos 0 root _ _
4 to to PART TO _ 5 mark _ _
5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No
6 . . PUNCT . _ 5 punct _ _
"""
def main():
# for Windows, so that we aren't randomly printing garbage (or just failing to print)
try:
sys.stdout.reconfigure(encoding='utf-8')
except AttributeError:
# TODO: deprecate 3.6 support after the next release
pass
# The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word.
# The default ssurgeon transforms the unwanted csubj to advcl
# See https://github.com/UniversalDependencies/docs/issues/923
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default=None, help="Input file / directory to process (otherwise will process a sample text)")
parser.add_argument('--output', type=str, default=None, help="Output location (otherwise will write back to the input directory)")
parser.add_argument('--stdout', action='store_true', default=False, help='Output to stdout')
parser.add_argument('--input_filter', type=str, default=".*[.]conllu", help="If processing a directory, only process files from --input that match this filter - regex, not shell filter. Default: %(default)s")
parser.add_argument('--no_input_filter', action='store_const', const=None, dest="input_filter", help="Remove the default input filename filter")
parser.add_argument('--edit_file', type=str, default=None, help="File to get semgrex and ssurgeon rules from")
parser.add_argument('--semgrex', type=str, default="{}=source >nsubj {} >csubj=bad {}", help="Semgrex to apply to the text. A default detects words which have both an nsubj and a csubj. Default: %(default)s")
parser.add_argument('ssurgeon', type=str, default=["relabelNamedEdge -edge bad -reln advcl"], nargs="*", help="Ssurgeon edits to apply based on the Semgrex. Can have multiple edits in a row. A default exists to transform csubj into advcl. Default: %(default)s")
parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy. Default: %(default)s")
parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
parser.add_argument('--no_add_missing_text', dest='add_missing_text', action='store_false', help="By default, the tool will add a #text comment if one does not exist. This leaves that blank")
args = parser.parse_args()
if args.edit_file:
ssurgeon_edits = read_ssurgeon_edits(args.edit_file)
else:
ssurgeon_edits = [SsurgeonEdit(args.semgrex, args.ssurgeon)]
if args.input:
if os.path.isfile(args.input):
docs = [CoNLL.conll2doc(input_file=args.input)]
if args.output is None:
outputs = [args.input]
else:
# TODO: could check if --output is a directory
outputs = [args.output]
input_output = zip(docs, outputs)
else:
if not args.output:
args.output = args.input
if not os.path.exists(args.output):
os.makedirs(args.output)
def read_docs():
for doc_filename in os.listdir(args.input):
if args.input_filter:
if not re.match(args.input_filter, doc_filename):
continue
doc_path = os.path.join(args.input, doc_filename)
if not os.path.isfile(doc_path):
continue
output_path = os.path.join(args.output, doc_filename)
print("Processing %s to %s" % (doc_path, output_path))
yield CoNLL.conll2doc(input_file=doc_path), output_path
input_output = read_docs()
else:
docs = [CoNLL.conll2doc(input_str=SAMPLE_DOC)]
outputs = [None]
input_output = zip(docs, outputs)
args.stdout = True
for doc, output in input_output:
if args.print_input:
print("{:C}".format(doc))
ssurgeon_request = build_request(doc, ssurgeon_edits)
ssurgeon_response = send_ssurgeon_request(ssurgeon_request)
updated_doc = convert_response_to_doc(doc, ssurgeon_response, args.add_missing_text)
if output is not None:
with open(output, "w", encoding="utf-8") as fout:
fout.write("{:C}\n\n".format(updated_doc))
if args.stdout:
print("{:C}\n".format(updated_doc))
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/tokensregex.py
================================================
"""Invokes the Java tokensregex on a document
This operates tokensregex on docs processed with stanza models.
https://nlp.stanford.edu/software/tokensregex.html
A minimal example is the main method of this module.
"""
import stanza
from stanza.protobuf import TokensRegexRequest, TokensRegexResponse
from stanza.server.java_protobuf_requests import send_request, add_sentence
def send_tokensregex_request(request):
return send_request(request, TokensRegexResponse,
"edu.stanford.nlp.ling.tokensregex.ProcessTokensRegexRequest")
def process_doc(doc, *patterns):
request = TokensRegexRequest()
for pattern in patterns:
request.pattern.append(pattern)
request_doc = request.doc
request_doc.text = doc.text
num_tokens = 0
for sentence in doc.sentences:
add_sentence(request_doc.sentence, sentence, num_tokens)
num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens)
return send_tokensregex_request(request)
def main():
#nlp = stanza.Pipeline('en',
# processors='tokenize,pos,lemma,ner')
nlp = stanza.Pipeline('en',
processors='tokenize')
doc = nlp('Uro ruined modern. Fortunately, Wotc banned him')
print(process_doc(doc, "him", "ruined"))
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/tsurgeon.py
================================================
"""Invokes the Java tsurgeon on a list of trees
Included with CoreNLP is a mechanism for modifying trees based on
existing patterns within a tree. The patterns are found using tregex:
https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/TregexPattern.html
The modifications are then performed using tsurgeon:
https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/tsurgeon/Tsurgeon.html
This module accepts Tree objects as produced by the conparser and
returns the modified trees that result from one or more tsurgeon
operations.
"""
from stanza.models.constituency import tree_reader
from stanza.models.constituency.parse_tree import Tree
from stanza.protobuf import TsurgeonRequest, TsurgeonResponse
from stanza.server.java_protobuf_requests import send_request, build_tree, from_tree, JavaProtobufContext
TSURGEON_JAVA = "edu.stanford.nlp.trees.tregex.tsurgeon.ProcessTsurgeonRequest"
def send_tsurgeon_request(request):
return send_request(request, TsurgeonResponse, TSURGEON_JAVA)
def build_request(trees, operations):
"""
Build the TsurgeonRequest object
trees: a list of trees
operations: a list of (tregex, tsurgeon, tsurgeon, ...)
"""
if isinstance(trees, Tree):
trees = (trees,)
request = TsurgeonRequest()
for tree in trees:
request.trees.append(build_tree(tree, 0.0))
if all(isinstance(x, str) for x in operations):
operations = (operations,)
for operation in operations:
if len(operation) == 1:
raise ValueError("Expected [tregex, tsurgeon, ...] but just got a tregex")
operation_request = request.operations.add()
operation_request.tregex = operation[0]
for tsurgeon in operation[1:]:
operation_request.tsurgeon.append(tsurgeon)
return request
def process_trees(trees, *operations):
"""
Returns the result of processing the given tsurgeon operations on the given trees
Returns a list of modified trees, eg, the result is already processed
"""
request = build_request(trees, operations)
result = send_tsurgeon_request(request)
return [from_tree(t)[0] for t in result.trees]
class Tsurgeon(JavaProtobufContext):
"""
Tsurgeon context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, classpath=None):
super(Tsurgeon, self).__init__(classpath, TsurgeonResponse, TSURGEON_JAVA)
def process(self, trees, *operations):
request = build_request(trees, operations)
result = self.process_request(request)
return [from_tree(t)[0] for t in result.trees]
def main():
"""
A small demonstration of a tsurgeon operation
"""
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
tregex = "WP=wp"
tsurgeon = "relabel wp WWWPPP"
result = process_trees(trees, (tregex, tsurgeon))
print(result)
if __name__ == '__main__':
main()
================================================
FILE: stanza/server/ud_enhancer.py
================================================
import stanza
from stanza.protobuf import DependencyEnhancerRequest, Document, Language
from stanza.server.java_protobuf_requests import send_request, add_sentence, JavaProtobufContext
ENHANCER_JAVA = "edu.stanford.nlp.trees.ud.ProcessUniversalEnhancerRequest"
def build_enhancer_request(doc, language, pronouns_pattern):
if bool(language) == bool(pronouns_pattern):
raise ValueError("Should set exactly one of language and pronouns_pattern")
request = DependencyEnhancerRequest()
if pronouns_pattern:
request.setRelativePronouns(pronouns_pattern)
elif language.lower() in ("en", "english"):
request.language = Language.UniversalEnglish
elif language.lower() in ("zh", "zh-hans", "chinese"):
request.language = Language.UniversalChinese
else:
raise ValueError("Sorry, but language " + language + " is not supported yet. Either set a pronouns pattern or file an issue at https://stanfordnlp.github.io/stanza suggesting a mechanism for converting this language")
request_doc = request.document
request_doc.text = doc.text
num_tokens = 0
for sent_idx, sentence in enumerate(doc.sentences):
request_sentence = add_sentence(request_doc.sentence, sentence, num_tokens)
num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens)
graph = request_sentence.basicDependencies
nodes = []
word_index = 0
for token in sentence.tokens:
for word in token.words:
# TODO: refactor with the bit in java_protobuf_requests
word_index = word_index + 1
node = graph.node.add()
node.sentenceIndex = sent_idx
node.index = word_index
if word.head != 0:
edge = graph.edge.add()
edge.source = word.head
edge.target = word_index
edge.dep = word.deprel
return request
def process_doc(doc, language=None, pronouns_pattern=None):
request = build_enhancer_request(doc, language, pronouns_pattern)
return send_request(request, Document, ENHANCER_JAVA)
class UniversalEnhancer(JavaProtobufContext):
"""
UniversalEnhancer context window
This is a context window which keeps a process open. Should allow
for multiple requests without launching new java processes each time.
"""
def __init__(self, language=None, pronouns_pattern=None, classpath=None):
super(UniversalEnhancer, self).__init__(classpath, Document, ENHANCER_JAVA)
if bool(language) == bool(pronouns_pattern):
raise ValueError("Should set exactly one of language and pronouns_pattern")
self.language = language
self.pronouns_pattern = pronouns_pattern
def process(self, doc):
request = build_enhancer_request(doc, self.language, self.pronouns_pattern)
return self.process_request(request)
def main():
nlp = stanza.Pipeline('en',
processors='tokenize,pos,lemma,depparse')
with UniversalEnhancer(language="en") as enhancer:
doc = nlp("This is the car that I bought")
result = enhancer.process(doc)
print(result.sentence[0].enhancedDependencies)
if __name__ == '__main__':
main()
================================================
FILE: stanza/tests/__init__.py
================================================
"""
Utilities for testing
"""
import os
import re
from platformdirs import user_cache_dir
from stanza import __resources_version__
# Environment Variables
# set this to specify working directory of tests
TEST_HOME_VAR = 'STANZA_TEST_HOME'
# Global Variables
TEST_DIR_BASE_NAME = 'stanza_test'
TEST_WORKING_DIR = os.getenv(TEST_HOME_VAR, None)
if not TEST_WORKING_DIR:
TEST_WORKING_DIR = user_cache_dir(TEST_DIR_BASE_NAME, 'StanfordNLP', __resources_version__)
TEST_MODELS_DIR = f'{TEST_WORKING_DIR}/models'
TEST_CORENLP_DIR = f'{TEST_WORKING_DIR}/corenlp_dir'
# server resources
SERVER_TEST_PROPS = f'{TEST_WORKING_DIR}/scripts/external_server.properties'
# language resources
LANGUAGE_RESOURCES = {}
TOKENIZE_MODEL = 'tokenizer.pt'
MWT_MODEL = 'mwt_expander.pt'
POS_MODEL = 'tagger.pt'
POS_PRETRAIN = 'pretrain.pt'
LEMMA_MODEL = 'lemmatizer.pt'
DEPPARSE_MODEL = 'parser.pt'
DEPPARSE_PRETRAIN = 'pretrain.pt'
MODEL_FILES = [TOKENIZE_MODEL, MWT_MODEL, POS_MODEL, POS_PRETRAIN, LEMMA_MODEL, DEPPARSE_MODEL, DEPPARSE_PRETRAIN]
# English resources
EN_KEY = 'en'
EN_SHORTHAND = 'en_ewt'
# models
EN_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{EN_SHORTHAND}_models'
EN_MODEL_FILES = [f'{EN_MODELS_DIR}/{EN_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
# French resources
FR_KEY = 'fr'
FR_SHORTHAND = 'fr_gsd'
# regression file paths
FR_TEST_IN = f'{TEST_WORKING_DIR}/in/fr_gsd.test.txt'
FR_TEST_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out'
FR_TEST_GOLD_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out.gold'
# models
FR_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{FR_SHORTHAND}_models'
FR_MODEL_FILES = [f'{FR_MODELS_DIR}/{FR_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
# Other language resources
AR_SHORTHAND = 'ar_padt'
DE_SHORTHAND = 'de_gsd'
KK_SHORTHAND = 'kk_ktb'
KO_SHORTHAND = 'ko_gsd'
# utils for clean up
# only allow removal of dirs/files in this approved list
REMOVABLE_PATHS = ['en_ewt_models', 'en_ewt_tokenizer.pt', 'en_ewt_mwt_expander.pt', 'en_ewt_tagger.pt',
'en_ewt.pretrain.pt', 'en_ewt_lemmatizer.pt', 'en_ewt_parser.pt', 'fr_gsd_models',
'fr_gsd_tokenizer.pt', 'fr_gsd_mwt_expander.pt', 'fr_gsd_tagger.pt', 'fr_gsd.pretrain.pt',
'fr_gsd_lemmatizer.pt', 'fr_gsd_parser.pt', 'ar_padt_models', 'ar_padt_tokenizer.pt',
'ar_padt_mwt_expander.pt', 'ar_padt_tagger.pt', 'ar_padt.pretrain.pt', 'ar_padt_lemmatizer.pt',
'ar_padt_parser.pt', 'de_gsd_models', 'de_gsd_tokenizer.pt', 'de_gsd_mwt_expander.pt',
'de_gsd_tagger.pt', 'de_gsd.pretrain.pt', 'de_gsd_lemmatizer.pt', 'de_gsd_parser.pt',
'kk_ktb_models', 'kk_ktb_tokenizer.pt', 'kk_ktb_mwt_expander.pt', 'kk_ktb_tagger.pt',
'kk_ktb.pretrain.pt', 'kk_ktb_lemmatizer.pt', 'kk_ktb_parser.pt', 'ko_gsd_models',
'ko_gsd_tokenizer.pt', 'ko_gsd_mwt_expander.pt', 'ko_gsd_tagger.pt', 'ko_gsd.pretrain.pt',
'ko_gsd_lemmatizer.pt', 'ko_gsd_parser.pt']
def safe_rm(path_to_rm):
"""
Safely remove a directory of files or a file
1.) check path exists, files are files, dirs are dirs
2.) only remove things on approved list REMOVABLE_PATHS
3.) assert no longer exists
"""
# just return if path doesn't exist
if not os.path.exists(path_to_rm):
return
# handle directory
if os.path.isdir(path_to_rm):
files_to_rm = [f'{path_to_rm}/{fname}' for fname in os.listdir(path_to_rm)]
dir_to_rm = path_to_rm
else:
files_to_rm = [path_to_rm]
dir_to_rm = None
# clear out files
for file_to_rm in files_to_rm:
if os.path.isfile(file_to_rm) and os.path.basename(file_to_rm) in REMOVABLE_PATHS:
os.remove(file_to_rm)
assert not os.path.exists(file_to_rm), f'Error removing: {file_to_rm}'
# clear out directory
if dir_to_rm is not None and os.path.isdir(dir_to_rm):
os.rmdir(dir_to_rm)
assert not os.path.exists(dir_to_rm), f'Error removing: {dir_to_rm}'
def compare_ignoring_whitespace(predicted, expected):
predicted = re.sub('[ \t]+', ' ', predicted.strip())
predicted = re.sub('\r\n', '\n', predicted)
expected = re.sub('[ \t]+', ' ', expected.strip())
expected = re.sub('\r\n', '\n', expected)
assert predicted == expected
================================================
FILE: stanza/tests/classifiers/__init__.py
================================================
================================================
FILE: stanza/tests/classifiers/test_classifier.py
================================================
import glob
import os
import pytest
import numpy as np
import torch
import stanza
import stanza.models.classifier as classifier
import stanza.models.classifiers.data as data
from stanza.models.classifiers.trainer import Trainer
from stanza.models.common import pretrain
from stanza.models.common import utils
from stanza.tests import TEST_MODELS_DIR
from stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
EMB_DIM = 5
@pytest.fixture(scope="module")
def fake_embeddings(tmp_path_factory):
"""
will return a path to a fake embeddings file with the words in SENTENCES
"""
# could set np random seed here
words = sorted(set([x.lower() for y in SENTENCES for x in y]))
words = words[:-1]
embedding_dir = tmp_path_factory.mktemp("data")
embedding_txt = embedding_dir / "embedding.txt"
embedding_pt = embedding_dir / "embedding.pt"
embedding = np.random.random((len(words), EMB_DIM))
with open(embedding_txt, "w", encoding="utf-8") as fout:
for word, emb in zip(words, embedding):
fout.write(word)
fout.write("\t")
fout.write("\t".join(str(x) for x in emb))
fout.write("\n")
pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt))
pt.load()
assert os.path.exists(embedding_pt)
return embedding_pt
class TestClassifier:
def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
"""
Build a model to be used by one of the later tests
"""
save_dir = str(tmp_path / "classifier")
save_name = "model.pt"
args = ["--save_dir", save_dir,
"--save_name", save_name,
"--wordvec_pretrain_file", str(fake_embeddings),
"--filter_channels", "20",
"--fc_shapes", "20,10",
"--train_file", str(train_file),
"--dev_file", str(dev_file),
"--max_epochs", "2",
"--batch_size", "60"]
if extra_args is not None:
args = args + extra_args
args = classifier.parse_args(args)
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
if checkpoint_file:
trainer = Trainer.load(checkpoint_file, args, load_optimizer=True)
else:
trainer = Trainer.build_new_model(args, train_set)
return trainer, train_set, args
def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
"""
Iterate a couple times over a model
"""
trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file)
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
labels = data.dataset_labels(train_set)
save_filename = os.path.join(args.save_dir, args.save_name)
if checkpoint_file is None:
checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
return trainer, save_filename, checkpoint_file
def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test that building a basic model works
"""
self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test that a basic model can save & load
"""
trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
save_filename = os.path.join(args.save_dir, args.save_name)
trainer.save(save_filename)
args.load_name = args.save_name
trainer = Trainer.load(args.load_name, args)
args.load_name = save_filename
trainer = Trainer.load(args.load_name, args)
def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file):
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test w/ and w/o bilstm variations of the classifier
"""
args = ["--bilstm", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
args = ["--no_bilstm"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test various maxpool widths
Also sets --filter_channels to a multiple of 2 but not of 3 for
the test to make sure the math is done correctly on a non-divisible width
"""
args = ["--maxpool_width", "1", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
args = ["--maxpool_width", "2", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
args = ["--maxpool_width", "3", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file):
args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file):
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--no_bilstm"]
trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
assert trainer.model.fc_input_size == 40
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "15,20", "--no_bilstm"]
trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
# 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20
assert trainer.model.fc_input_size == 50
def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory
"""
bert_model = "hf-internal-testing/tiny-bert"
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model])
assert os.path.exists(save_filename)
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
# check that the bert model wasn't saved as part of the classifier
assert not saved_model['params']['config']['force_bert_saved']
assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory
"""
bert_model = "hf-internal-testing/tiny-bert"
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune"])
assert os.path.exists(save_filename)
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
# after finetuning the bert model, make sure that the save file DOES contain parts of the transformer
assert saved_model['params']['config']['force_bert_saved']
assert any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file):
"""Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers
As an added bonus (or eager test), load the finished model and continue
training from there. Then check that the initial model and
the middle model are different, then that the middle model and
final model are different
"""
bert_model = "hf-internal-testing/tiny-bert"
trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models"])
assert os.path.exists(save_filename)
save_path = os.path.split(save_filename)[0]
initial_model = glob.glob(os.path.join(save_path, "*E0000*"))
assert len(initial_model) == 1
initial_model = initial_model[0]
initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True)
second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
assert len(second_model_file) == 1
second_model_file = second_model_file[0]
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
for layer_idx in range(2):
bert_names = [x for x in second_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
assert len(bert_names) > 0
assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
# put some random marker in the file to look for later,
# check the continued training didn't clobber the expected file
assert "asdf" not in second_model
second_model["asdf"] = 1234
torch.save(second_model, second_model_file)
trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
second_model_file_redo = glob.glob(os.path.join(save_path, "*E0002*"))
assert len(second_model_file_redo) == 1
assert second_model_file == second_model_file_redo[0]
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
assert "asdf" in second_model
fifth_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
assert len(fifth_model_file) == 1
final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True)
for layer_idx in range(2):
bert_names = [x for x in final_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
assert len(bert_names) > 0
assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test on a tiny Bert with PEFT finetuning
"""
bert_model = "hf-internal-testing/tiny-bert"
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler"])
assert os.path.exists(save_filename)
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
# after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form
assert saved_model['params']['config']['bert_model'] == bert_model
assert saved_model['params']['config']['force_bert_saved']
assert saved_model['params']['config']['use_peft']
assert not saved_model['params']['config']['has_charlm_forward']
assert not saved_model['params']['config']['has_charlm_backward']
assert len(saved_model['params']['bert_lora']) > 0
assert any(x.find(".pooler.") >= 0 for x in saved_model['params']['bert_lora'])
assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
# The Pipeline should load and run a PEFT trained model,
# although obviously we don't expect the results to do
# anything correct
pipeline = stanza.Pipeline("en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize,sentiment", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings))
doc = pipeline("This is a test")
def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file):
"""
Test that if we restart training on a peft model, the peft weights change
"""
bert_model = "hf-internal-testing/tiny-bert"
trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models"])
assert os.path.exists(save_file)
saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True)
assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
save_path = os.path.split(save_file)[0]
initial_model_file = glob.glob(os.path.join(save_path, "*E0000*"))
assert len(initial_model_file) == 1
initial_model_file = initial_model_file[0]
initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True)
second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
assert len(second_model_file) == 1
second_model_file = second_model_file[0]
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
final_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
assert len(final_model_file) == 1
final_model_file = final_model_file[0]
final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True)
# params in initial_model & second_model start with "base_model.model."
# whereas params in final_model start directly with "encoder" or "pooler"
initial_lora = initial_model['params']['bert_lora']
second_lora = second_model['params']['bert_lora']
final_lora = final_model['params']['bert_lora']
for side in ("_A.", "_B."):
for layer in (".0.", ".1."):
initial_params = sorted([x for x in initial_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
second_params = sorted([x for x in second_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
final_params = sorted([x for x in final_lora if x.startswith("encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
assert len(initial_params) > 0
assert len(initial_params) == len(second_params)
assert len(initial_params) == len(final_params)
for x, y in zip(second_params, final_params):
assert x.endswith(y)
if side != "_A.": # the A tensors don't move very much, if at all
assert not torch.allclose(initial_lora.get(x), second_lora.get(x))
assert not torch.allclose(second_lora.get(x), final_lora.get(y))
================================================
FILE: stanza/tests/classifiers/test_constituency_classifier.py
================================================
import os
import pytest
import stanza
import stanza.models.classifier as classifier
import stanza.models.classifiers.data as data
from stanza.models.classifiers.trainer import Trainer
from stanza.tests import TEST_MODELS_DIR
from stanza.tests.classifiers.test_classifier import fake_embeddings
from stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees
from stanza.models.common import utils
from stanza.tests.constituency.test_trainer import build_trainer, TREEBANK
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
class TestConstituencyClassifier:
@pytest.fixture(scope="class")
def constituency_model(self, fake_embeddings, tmp_path_factory):
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK)
trainer_pt = str(tmp_path_factory.mktemp("constituency") / "constituency.pt")
trainer.save(trainer_pt, save_optimizer=False)
return trainer_pt
def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
"""
Build a Constituency Classifier model to be used by one of the later tests
"""
save_dir = str(tmp_path / "classifier")
save_name = "model.pt"
args = ["--save_dir", save_dir,
"--save_name", save_name,
"--model_type", "constituency",
"--constituency_model", constituency_model,
"--wordvec_pretrain_file", str(fake_embeddings),
"--fc_shapes", "20,10",
"--train_file", str(train_file_with_trees),
"--dev_file", str(dev_file_with_trees),
"--max_epochs", "2",
"--batch_size", "60"]
if extra_args is not None:
args = args + extra_args
args = classifier.parse_args(args)
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
trainer = Trainer.build_new_model(args, train_set)
return trainer, train_set, args
def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
"""
Iterate a couple times over a model
"""
trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args)
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
labels = data.dataset_labels(train_set)
save_filename = os.path.join(args.save_dir, args.save_name)
checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
return trainer, train_set, args
def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that building a basic constituency-based model works
"""
self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that a constituency model can save & load
"""
trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
save_filename = os.path.join(args.save_dir, args.save_name)
trainer.save(save_filename)
args.load_name = args.save_name
trainer = Trainer.load(args.load_name, args)
def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that writing out a temp model, then loading it in the pipeline is a thing that works
"""
trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
save_filename = os.path.join(args.save_dir, args.save_name)
assert os.path.exists(save_filename)
assert os.path.exists(args.constituency_model)
pipeline_args = {"lang": "en",
"download_method": None,
"model_dir": TEST_MODELS_DIR,
"processors": "tokenize,pos,constituency,sentiment",
"tokenize_pretokenized": True,
"constituency_model_path": args.constituency_model,
"constituency_pretrain_path": args.wordvec_pretrain_file,
"constituency_backward_charlm_path": None,
"constituency_forward_charlm_path": None,
"sentiment_model_path": save_filename,
"sentiment_pretrain_path": args.wordvec_pretrain_file,
"sentiment_backward_charlm_path": None,
"sentiment_forward_charlm_path": None}
pipeline = stanza.Pipeline(**pipeline_args)
doc = pipeline("This is a test")
# since the model is random, we have no expectations for what the result actually is
assert doc.sentences[0].sentiment is not None
def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words'])
def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer'])
def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn'])
================================================
FILE: stanza/tests/classifiers/test_data.py
================================================
import json
import pytest
import stanza.models.classifiers.data as data
from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, UNK
from stanza.models.constituency.parse_tree import Tree
SENTENCES = [
["I", "hate", "the", "Opal", "banning"],
["Tell", "my", "wife", "hello"], # obviously this is the neutral result
["I", "like", "Sh'reyan", "'s", "antennae"],
]
DATASET = [
{"sentiment": "0", "text": SENTENCES[0]},
{"sentiment": "1", "text": SENTENCES[1]},
{"sentiment": "2", "text": SENTENCES[2]},
]
TREES = [
"(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))",
"(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))",
"(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))",
]
DATASET_WITH_TREES = [
{"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]},
{"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]},
{"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]},
]
@pytest.fixture(scope="module")
def train_file(tmp_path_factory):
train_set = DATASET * 20
train_filename = tmp_path_factory.mktemp("data") / "train.json"
with open(train_filename, "w", encoding="utf-8") as fout:
json.dump(train_set, fout, ensure_ascii=False)
return train_filename
@pytest.fixture(scope="module")
def dev_file(tmp_path_factory):
dev_set = DATASET * 2
dev_filename = tmp_path_factory.mktemp("data") / "dev.json"
with open(dev_filename, "w", encoding="utf-8") as fout:
json.dump(dev_set, fout, ensure_ascii=False)
return dev_filename
@pytest.fixture(scope="module")
def test_file(tmp_path_factory):
test_set = DATASET
test_filename = tmp_path_factory.mktemp("data") / "test.json"
with open(test_filename, "w", encoding="utf-8") as fout:
json.dump(test_set, fout, ensure_ascii=False)
return test_filename
@pytest.fixture(scope="module")
def train_file_with_trees(tmp_path_factory):
train_set = DATASET_WITH_TREES * 20
train_filename = tmp_path_factory.mktemp("data") / "train_trees.json"
with open(train_filename, "w", encoding="utf-8") as fout:
json.dump(train_set, fout, ensure_ascii=False)
return train_filename
@pytest.fixture(scope="module")
def dev_file_with_trees(tmp_path_factory):
dev_set = DATASET_WITH_TREES * 2
dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json"
with open(dev_filename, "w", encoding="utf-8") as fout:
json.dump(dev_set, fout, ensure_ascii=False)
return dev_filename
class TestClassifierData:
def test_read_data(self, train_file):
"""
Test reading of the json format
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
assert len(train_set) == 60
def test_read_data_with_trees(self, train_file, train_file_with_trees):
"""
Test reading of the json format
"""
train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)
assert len(train_trees_set) == 60
for idx, x in enumerate(train_trees_set):
assert isinstance(x.constituency, Tree)
assert str(x.constituency) == TREES[idx % len(TREES)]
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
def test_dataset_vocab(self, train_file):
"""
Converting a dataset to vocab should have a specific set of words along with PAD and UNK
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
vocab = data.dataset_vocab(train_set)
expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
assert set(vocab) == expected
def test_dataset_labels(self, train_file):
"""
Test the extraction of labels from a dataset
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
labels = data.dataset_labels(train_set)
assert labels == ["0", "1", "2"]
def test_sort_by_length(self, train_file):
"""
There are two unique lengths in the toy dataset
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
sorted_dataset = data.sort_dataset_by_len(train_set)
assert list(sorted_dataset.keys()) == [4, 5]
assert len(sorted_dataset[4]) == len(train_set) // 3
assert len(sorted_dataset[5]) == 2 * len(train_set) // 3
def test_check_labels(self, train_file):
"""
Check that an exception is thrown for an unknown label
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
labels = sorted(set([x["sentiment"] for x in DATASET]))
assert len(labels) > 1
data.check_labels(labels, train_set)
with pytest.raises(RuntimeError):
data.check_labels(labels[:1], train_set)
================================================
FILE: stanza/tests/classifiers/test_process_utils.py
================================================
"""
A few tests of the utils module for the sentiment datasets
"""
import os
import pytest
import stanza
from stanza.models.classifiers import data
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import WVType
from stanza.utils.datasets.sentiment import process_utils
from stanza.tests import TEST_MODELS_DIR
from stanza.tests.classifiers.test_data import train_file, dev_file, test_file
def test_write_list(tmp_path, train_file):
"""
Test that writing a single list of items to an output file works
"""
train_set = data.read_dataset(train_file, WVType.OTHER, 1)
dataset_file = tmp_path / "foo.json"
process_utils.write_list(dataset_file, train_set)
train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1)
assert train_copy == train_set
def test_write_dataset(tmp_path, train_file, dev_file, test_file):
"""
Test that writing all three parts of a dataset works
"""
dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)]
process_utils.write_dataset(dataset, tmp_path, "en_test")
expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json']
dataset_files = os.listdir(tmp_path)
assert sorted(dataset_files) == sorted(expected_files)
for filename, expected in zip(expected_files, dataset):
written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1)
assert written == expected
def test_read_snippets(tmp_path):
"""
Test the basic operation of the read_snippets function
"""
filename = tmp_path / "foo.csv"
with open(filename, "w", encoding="utf-8") as fout:
fout.write("FOO\tThis is a test\thappy\n")
fout.write("FOO\tThis is a second sentence\tsad\n")
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
mapping = {"happy": 0, "sad": 1}
snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp)
assert len(snippets) == 2
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])]
def test_read_snippets_two_columns(tmp_path):
"""
Test what happens when multiple columns are combined for the sentiment value
"""
filename = tmp_path / "foo.csv"
with open(filename, "w", encoding="utf-8") as fout:
fout.write("FOO\tThis is a test\thappy\tfoo\n")
fout.write("FOO\tThis is a second sentence\tsad\tbar\n")
fout.write("FOO\tThis is a third sentence\tsad\tfoo\n")
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2}
snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp)
assert len(snippets) == 3
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']),
SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])]
================================================
FILE: stanza/tests/common/__init__.py
================================================
================================================
FILE: stanza/tests/common/test_bert_embedding.py
================================================
import pytest
import torch
from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
BERT_MODEL = "hf-internal-testing/tiny-bert"
@pytest.fixture(scope="module")
def tiny_bert():
m, t = load_bert(BERT_MODEL)
return m, t
def test_load_bert(tiny_bert):
"""
Empty method that just tests loading the bert
"""
m, t = tiny_bert
def test_run_bert(tiny_bert):
m, t = tiny_bert
device = next(m.parameters()).device
extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True)
def test_run_bert_empty_word(tiny_bert):
m, t = tiny_bert
device = next(m.parameters()).device
foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True)
bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True)
assert len(foo) == 1
assert torch.allclose(foo[0], bar[0])
================================================
FILE: stanza/tests/common/test_char_model.py
================================================
"""
Currently tests a few configurations of files for creating a charlm vocab
Also has a skeleton test of loading & saving a charlm
"""
from collections import Counter
import glob
import lzma
import os
import tempfile
import pytest
from stanza.models import charlm
from stanza.models.common import char_model
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
fake_text_1 = """
Unban mox opal!
I hate watching Peppa Pig
"""
fake_text_2 = """
This is plastic cheese
"""
class TestCharModel:
def test_single_file_vocab(self):
with tempfile.TemporaryDirectory() as tempdir:
sample_file = os.path.join(tempdir, "text.txt")
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(fake_text_1)
vocab = char_model.build_charlm_vocab(sample_file)
for i in fake_text_1:
assert i in vocab
assert "Q" not in vocab
def test_single_file_xz_vocab(self):
with tempfile.TemporaryDirectory() as tempdir:
sample_file = os.path.join(tempdir, "text.txt.xz")
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
fout.write(fake_text_1)
vocab = char_model.build_charlm_vocab(sample_file)
for i in fake_text_1:
assert i in vocab
assert "Q" not in vocab
def test_single_file_dir_vocab(self):
with tempfile.TemporaryDirectory() as tempdir:
sample_file = os.path.join(tempdir, "text.txt")
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(fake_text_1)
vocab = char_model.build_charlm_vocab(tempdir)
for i in fake_text_1:
assert i in vocab
assert "Q" not in vocab
def test_multiple_files_vocab(self):
with tempfile.TemporaryDirectory() as tempdir:
sample_file = os.path.join(tempdir, "t1.txt")
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(fake_text_1)
sample_file = os.path.join(tempdir, "t2.txt.xz")
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
fout.write(fake_text_2)
vocab = char_model.build_charlm_vocab(tempdir)
for i in fake_text_1:
assert i in vocab
for i in fake_text_2:
assert i in vocab
assert "Q" not in vocab
def test_cutoff_vocab(self):
with tempfile.TemporaryDirectory() as tempdir:
sample_file = os.path.join(tempdir, "t1.txt")
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(fake_text_1)
sample_file = os.path.join(tempdir, "t2.txt.xz")
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
fout.write(fake_text_2)
vocab = char_model.build_charlm_vocab(tempdir, cutoff=2)
counts = Counter(fake_text_1) + Counter(fake_text_2)
for letter, count in counts.most_common():
if count < 2:
assert letter not in vocab
else:
assert letter in vocab
def test_build_model(self):
"""
Test the whole thing on a small dataset for an iteration or two
"""
with tempfile.TemporaryDirectory() as tempdir:
eval_file = os.path.join(tempdir, "en_test.dev.txt")
with open(eval_file, "w", encoding="utf-8") as fout:
fout.write(fake_text_1)
train_file = os.path.join(tempdir, "en_test.train.txt")
with open(train_file, "w", encoding="utf-8") as fout:
for i in range(1000):
fout.write(fake_text_1)
fout.write("\n")
fout.write(fake_text_2)
fout.write("\n")
save_name = 'en_test.forward.pt'
vocab_save_name = 'en_text.vocab.pt'
checkpoint_save_name = 'en_text.checkpoint.pt'
args = ['--train_file', train_file,
'--eval_file', eval_file,
'--eval_steps', '0', # eval once per opoch
'--epochs', '2',
'--cutoff', '1',
'--batch_size', '%d' % len(fake_text_1),
'--shorthand', 'en_test',
'--save_dir', tempdir,
'--save_name', save_name,
'--vocab_save_name', vocab_save_name,
'--checkpoint_save_name', checkpoint_save_name]
args = charlm.parse_args(args)
charlm.train(args)
assert os.path.exists(os.path.join(tempdir, vocab_save_name))
# test that saving & loading of the model worked
assert os.path.exists(os.path.join(tempdir, save_name))
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))
# test that saving & loading of the checkpoint worked
assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))
trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))
assert trainer.global_step > 0
assert trainer.epoch == 2
# quick test to verify this method works with a trained model
charlm.get_current_lr(trainer, args)
# test loading a vocab built by the training method...
vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name))
trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab)
# ... and test the get_current_lr for an untrained model as well
# this test is super "eager"
assert charlm.get_current_lr(trainer, args) == args['lr0']
@pytest.fixture(scope="class")
def english_forward(self):
# eg, stanza_test/models/en/forward_charlm/1billion.pt
models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*")
models = glob.glob(models_path)
# we expect at least one English model downloaded for the tests
assert len(models) >= 1
model_file = models[0]
return char_model.CharacterLanguageModel.load(model_file)
@pytest.fixture(scope="class")
def english_backward(self):
# eg, stanza_test/models/en/forward_charlm/1billion.pt
models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
models = glob.glob(models_path)
# we expect at least one English model downloaded for the tests
assert len(models) >= 1
model_file = models[0]
return char_model.CharacterLanguageModel.load(model_file)
def test_load_model(self, english_forward, english_backward):
"""
Check that basic loading functions work
"""
assert english_forward.is_forward_lm
assert not english_backward.is_forward_lm
def test_save_load_model(self, english_forward, english_backward):
"""
Load, save, and load again
"""
with tempfile.TemporaryDirectory() as tempdir:
for model in (english_forward, english_backward):
save_file = os.path.join(tempdir, "resaved", "charlm.pt")
model.save(save_file)
reloaded = char_model.CharacterLanguageModel.load(save_file)
assert model.is_forward_lm == reloaded.is_forward_lm
================================================
FILE: stanza/tests/common/test_chuliu_edmonds.py
================================================
"""
Test some use cases of the chuliu_edmonds algorithm
(currently just the tarjan implementation)
"""
import numpy as np
import pytest
from stanza.models.common.chuliu_edmonds import tarjan
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_tarjan_basic():
simple = np.array([0, 4, 4, 4, 0])
result = tarjan(simple)
assert result == []
simple = np.array([0, 2, 0, 4, 2, 2])
result = tarjan(simple)
assert result == []
def test_tarjan_cycle():
cycle_graph = np.array([0, 3, 1, 2])
result = tarjan(cycle_graph)
expected = np.array([False, True, True, True])
assert len(result) == 1
np.testing.assert_array_equal(result[0], expected)
cycle_graph = np.array([0, 3, 1, 2, 5, 6, 4])
result = tarjan(cycle_graph)
assert len(result) == 2
expected = [np.array([False, True, True, True, False, False, False]),
np.array([False, False, False, False, True, True, True])]
for r, e in zip(result, expected):
np.testing.assert_array_equal(r, e)
================================================
FILE: stanza/tests/common/test_common_data.py
================================================
import pytest
import stanza
from stanza.tests import *
from stanza.models.common.data import get_augment_ratio, augment_punct
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_augment_ratio():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
should_augment = lambda x: x >= 3
can_augment = lambda x: x >= 4
# check that zero is returned if no augmentation is needed
# which will be the case since 2 are already satisfactory
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0
# this should throw an error
with pytest.raises(AssertionError):
get_augment_ratio(data, can_augment, should_augment)
# with a desired ratio of 0.4,
# there are already 2 that don't need augmenting
# and 7 that are eligible to be augmented
# so 2/7 will need to be augmented
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7)
def test_augment_punct():
data = [["Simple", "test", "."]]
should_augment = lambda x: x[-1] == "."
can_augment = should_augment
new_data = augment_punct(data, 1.0, should_augment, can_augment)
assert new_data == [["Simple", "test"]]
================================================
FILE: stanza/tests/common/test_confusion.py
================================================
"""
Test a couple simple confusion matrices and output formats
"""
from collections import defaultdict
import pytest
from stanza.utils.confusion import format_confusion, confusion_to_f1, confusion_to_macro_f1, confusion_to_weighted_f1
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
@pytest.fixture
def simple_confusion():
confusion = defaultdict(lambda: defaultdict(int))
confusion["B-ORG"]["B-ORG"] = 1
confusion["B-ORG"]["B-PER"] = 1
confusion["E-ORG"]["E-ORG"] = 1
confusion["E-ORG"]["E-PER"] = 1
confusion["O"]["O"] = 4
return confusion
@pytest.fixture
def short_confusion():
"""
Same thing, but with a short name. This should not be sorted by entity type
"""
confusion = defaultdict(lambda: defaultdict(int))
confusion["A"]["B-ORG"] = 1
confusion["B-ORG"]["B-PER"] = 1
confusion["E-ORG"]["E-ORG"] = 1
confusion["E-ORG"]["E-PER"] = 1
confusion["O"]["O"] = 4
return confusion
EXPECTED_SIMPLE_OUTPUT = """
t\\p O B-ORG E-ORG B-PER E-PER
O 4 0 0 0 0
B-ORG 0 1 0 1 0
E-ORG 0 0 1 0 1
B-PER 0 0 0 0 0
E-PER 0 0 0 0 0
"""[1:-1] # don't want to strip
EXPECTED_SHORT_OUTPUT = """
t\\p O A B-ORG B-PER E-ORG E-PER
O 4 0 0 0 0 0
A 0 0 1 0 0 0
B-ORG 0 0 0 1 0 0
B-PER 0 0 0 0 0 0
E-ORG 0 0 0 0 1 1
E-PER 0 0 0 0 0 0
"""[1:-1]
EXPECTED_HIDE_BLANK_SHORT_OUTPUT = """
t\\p O B-ORG E-ORG B-PER E-PER
O 4 0 0 0 0
A 0 1 0 0 0
B-ORG 0 0 0 1 0
E-ORG 0 0 1 0 1
"""[1:-1]
def test_simple_output(simple_confusion):
assert EXPECTED_SIMPLE_OUTPUT == format_confusion(simple_confusion)
def test_short_output(short_confusion):
assert EXPECTED_SHORT_OUTPUT == format_confusion(short_confusion)
def test_hide_blank_short_output(short_confusion):
assert EXPECTED_HIDE_BLANK_SHORT_OUTPUT == format_confusion(short_confusion, hide_blank=True)
def test_macro_f1(simple_confusion, short_confusion):
assert confusion_to_macro_f1(simple_confusion) == pytest.approx(0.466666666666)
assert confusion_to_macro_f1(short_confusion) == pytest.approx(0.277777777777)
def test_weighted_f1(simple_confusion, short_confusion):
assert confusion_to_weighted_f1(simple_confusion) == pytest.approx(0.83333333)
assert confusion_to_weighted_f1(short_confusion) == pytest.approx(0.66666666)
assert confusion_to_weighted_f1(simple_confusion, exclude=["O"]) == pytest.approx(0.66666666)
assert confusion_to_weighted_f1(short_confusion, exclude=["O"]) == pytest.approx(0.33333333)
================================================
FILE: stanza/tests/common/test_constant.py
================================================
"""
Test the conversion to lcodes and splitting of dataset names
"""
import tempfile
import pytest
import stanza
from stanza.models.common.constant import treebank_to_short_name, lang_to_langcode, is_right_to_left, two_to_three_letters, langlower2lcode
from stanza.tests import *
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_treebank():
"""
Test the entire treebank name conversion
"""
# conversion of a UD_ name
assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB")
# conversion of names without UD
assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013")
assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013")
assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013")
# already short names are generally preserved
assert "hi_fire2013" == treebank_to_short_name("hi-fire2013")
assert "hi_fire2013" == treebank_to_short_name("hi_fire2013")
# a special case
assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD")
# a special case already converted once
assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud")
assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud")
assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp")
assert "wo_masakhane" == treebank_to_short_name("wo_masakhane")
assert "wo_masakhane" == treebank_to_short_name("wol_masakhane")
assert "wo_masakhane" == treebank_to_short_name("Wol_masakhane")
assert "wo_masakhane" == treebank_to_short_name("wolof_masakhane")
assert "wo_masakhane" == treebank_to_short_name("Wolof_masakhane")
def test_lang_to_langcode():
assert "hi" == lang_to_langcode("Hindi")
assert "hi" == lang_to_langcode("HINDI")
assert "hi" == lang_to_langcode("hindi")
assert "hi" == lang_to_langcode("HI")
assert "hi" == lang_to_langcode("hi")
def test_right_to_left():
assert is_right_to_left("ar")
assert is_right_to_left("Arabic")
assert not is_right_to_left("en")
assert not is_right_to_left("English")
def test_two_to_three():
assert lang_to_langcode("Wolof") == "wo"
assert lang_to_langcode("wol") == "wo"
assert "wo" in two_to_three_letters
assert two_to_three_letters["wo"] == "wol"
def test_langlower():
assert lang_to_langcode("WOLOF") == "wo"
assert lang_to_langcode("nOrWeGiAn") == "nb"
assert "soj" == langlower2lcode["soi"]
assert "soj" == langlower2lcode["sohi"]
================================================
FILE: stanza/tests/common/test_data_conversion.py
================================================
"""
Basic tests of the data conversion
"""
import io
import pytest
import tempfile
from zipfile import ZipFile
import stanza
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import Document
from stanza.tests import *
pytestmark = pytest.mark.pipeline
# data for testing
CONLL = [[['1', 'Nous', 'il', 'PRON', '_', 'Number=Plur|Person=1|PronType=Prs', '3', 'nsubj', '_', 'start_char=0|end_char=4'],
['2', 'avons', 'avoir', 'AUX', '_', 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', '3', 'aux:tense', '_', 'start_char=5|end_char=10'],
['3', 'atteint', 'atteindre', 'VERB', '_', 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', '0', 'root', '_', 'start_char=11|end_char=18'],
['4', 'la', 'le', 'DET', '_', 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', '5', 'det', '_', 'start_char=19|end_char=21'],
['5', 'fin', 'fin', 'NOUN', '_', 'Gender=Fem|Number=Sing', '3', 'obj', '_', 'start_char=22|end_char=25'],
['6-7', 'du', '_', '_', '_', '_', '_', '_', '_', 'start_char=26|end_char=28'],
['6', 'de', 'de', 'ADP', '_', '_', '8', 'case', '_', '_'],
['7', 'le', 'le', 'DET', '_', 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', '8', 'det', '_', '_'],
['8', 'sentier', 'sentier', 'NOUN', '_', 'Gender=Masc|Number=Sing', '5', 'nmod', '_', 'start_char=29|end_char=36'],
['9', '.', '.', 'PUNCT', '_', '_', '3', 'punct', '_', 'start_char=36|end_char=37']]]
DICT = [[{'id': (1,), 'text': 'Nous', 'lemma': 'il', 'upos': 'PRON', 'feats': 'Number=Plur|Person=1|PronType=Prs', 'head': 3, 'deprel': 'nsubj', 'misc': 'start_char=0|end_char=4'},
{'id': (2,), 'text': 'avons', 'lemma': 'avoir', 'upos': 'AUX', 'feats': 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', 'head': 3, 'deprel': 'aux:tense', 'misc': 'start_char=5|end_char=10'},
{'id': (3,), 'text': 'atteint', 'lemma': 'atteindre', 'upos': 'VERB', 'feats': 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', 'head': 0, 'deprel': 'root', 'misc': 'start_char=11|end_char=18'},
{'id': (4,), 'text': 'la', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', 'head': 5, 'deprel': 'det', 'misc': 'start_char=19|end_char=21'},
{'id': (5,), 'text': 'fin', 'lemma': 'fin', 'upos': 'NOUN', 'feats': 'Gender=Fem|Number=Sing', 'head': 3, 'deprel': 'obj', 'misc': 'start_char=22|end_char=25'},
{'id': (6, 7), 'text': 'du', 'misc': 'start_char=26|end_char=28'},
{'id': (6,), 'text': 'de', 'lemma': 'de', 'upos': 'ADP', 'head': 8, 'deprel': 'case'},
{'id': (7,), 'text': 'le', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', 'head': 8, 'deprel': 'det'},
{'id': (8,), 'text': 'sentier', 'lemma': 'sentier', 'upos': 'NOUN', 'feats': 'Gender=Masc|Number=Sing', 'head': 5, 'deprel': 'nmod', 'misc': 'start_char=29|end_char=36'},
{'id': (9,), 'text': '.', 'lemma': '.', 'upos': 'PUNCT', 'head': 3, 'deprel': 'punct', 'misc': 'start_char=36|end_char=37'}]]
def test_conll_to_dict():
dicts, empty = CoNLL.convert_conll(CONLL)
assert dicts == DICT
assert len(dicts) == len(empty)
assert all(len(x) == 0 for x in empty)
def test_dict_to_conll():
document = Document(DICT)
# :c = no comments
conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
assert conll == CONLL
def test_dict_to_doc_and_doc_to_dict():
"""
Test the conversion from raw dict to Document and back
This code path will first turn start_char|end_char into start_char & end_char fields in the Document
That version to a dict will have separate fields for each of those
Finally, the conversion from that dict to a list of conll entries should convert that back to misc
"""
document = Document(DICT)
dicts = document.to_dict()
document = Document(dicts)
conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
assert conll == CONLL
# sample is two sentences long so that the tests check multiple sentences
RUSSIAN_SAMPLE="""
# sent_id = yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253
# genre = review
# text = Как- то слишком мало цветов получают актёры после спектакля.
1 Как как-то ADV _ Degree=Pos|PronType=Ind 7 advmod _ SpaceAfter=No
2 - - PUNCT _ _ 3 punct _ _
3 то то PART _ _ 1 list _ deprel=list:goeswith
4 слишком слишком ADV _ Degree=Pos 5 advmod _ _
5 мало мало ADV _ Degree=Pos 6 advmod _ _
6 цветов цветок NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Plur 7 obj _ _
7 получают получать VERB _ Aspect=Imp|Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ _
8 актёры актер NOUN _ Animacy=Anim|Case=Nom|Gender=Masc|Number=Plur 7 nsubj _ _
9 после после ADP _ _ 10 case _ _
10 спектакля спектакль NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Sing 7 obl _ SpaceAfter=No
11 . . PUNCT _ _ 7 punct _ _
# sent_id = 4
# genre = social
# text = В женщине важна верность, а не красота.
1 В в ADP _ _ 2 case _ _
2 женщине женщина NOUN _ Animacy=Anim|Case=Loc|Gender=Fem|Number=Sing 3 obl _ _
3 важна важный ADJ _ Degree=Pos|Gender=Fem|Number=Sing|Variant=Short 0 root _ _
4 верность верность NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 3 nsubj _ SpaceAfter=No
5 , , PUNCT _ _ 8 punct _ _
6 а а CCONJ _ _ 8 cc _ _
7 не не PART _ Polarity=Neg 8 advmod _ _
8 красота красота NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 4 conj _ SpaceAfter=No
9 . . PUNCT _ _ 3 punct _ _
""".strip()
RUSSIAN_TEXT = ["Как- то слишком мало цветов получают актёры после спектакля.", "В женщине важна верность, а не красота."]
RUSSIAN_IDS = ["yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253", "4"]
def check_russian_doc(doc):
"""
Refactored the test for the Russian doc so we can use it to test various file methods
"""
lines = RUSSIAN_SAMPLE.split("\n")
assert len(doc.sentences) == 2
assert lines[0] == doc.sentences[0].comments[0]
assert lines[1] == doc.sentences[0].comments[1]
assert lines[2] == doc.sentences[0].comments[2]
for sent_idx, (expected_text, expected_id, sentence) in enumerate(zip(RUSSIAN_TEXT, RUSSIAN_IDS, doc.sentences)):
assert expected_text == sentence.text
assert expected_id == sentence.sent_id
assert sent_idx == sentence.index
assert len(sentence.comments) == 3
assert not sentence.has_enhanced_dependencies()
sentences = "{:C}".format(doc)
sentences = sentences.split("\n\n")
assert len(sentences) == 2
sentence = sentences[0].split("\n")
assert len(sentence) == 14
assert lines[0] == sentence[0]
assert lines[1] == sentence[1]
assert lines[2] == sentence[2]
# assert that the weird deprel=list:goeswith was properly handled
assert doc.sentences[0].words[2].head == 1
assert doc.sentences[0].words[2].deprel == "list:goeswith"
def test_write_russian_doc(tmp_path):
"""
Specifically test the write_doc2conll method
"""
filename = tmp_path / "russian.conll"
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
check_russian_doc(doc)
CoNLL.write_doc2conll(doc, filename)
with open(filename, encoding="utf-8") as fin:
text = fin.read()
# the conll docs have to end with \n\n
assert text.endswith("\n\n")
# but to compare against the original, strip off the whitespace
text = text.strip()
# we skip the first sentence because the "deprel=list:goeswith" is weird
# note that the deprel itself is checked in check_russian_doc
text = text[text.find("# sent_id = 4"):]
sample = RUSSIAN_SAMPLE[RUSSIAN_SAMPLE.find("# sent_id = 4"):]
assert text == sample
doc2 = CoNLL.conll2doc(filename)
check_russian_doc(doc2)
# random sentence from EN_Pronouns
ENGLISH_SAMPLE = """
# newdoc
# sent_id = 1
# text = It is hers.
# previous = Which person owns this?
# comment = copular subject
1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 3 nsubj _ _
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
3 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
4 . . PUNCT . _ 3 punct _ _
""".strip()
def test_write_to_io():
doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
output = io.StringIO()
CoNLL.write_doc2conll(doc, output)
output_value = output.getvalue()
assert output_value.endswith("\n\n")
assert output_value.strip() == ENGLISH_SAMPLE
def test_write_doc2conll_append(tmp_path):
doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
filename = tmp_path / "english.conll"
CoNLL.write_doc2conll(doc, filename)
CoNLL.write_doc2conll(doc, filename, mode="a")
with open(filename) as fin:
text = fin.read()
expected = ENGLISH_SAMPLE + "\n\n" + ENGLISH_SAMPLE + "\n\n"
assert text == expected
def test_doc_with_comments():
"""
Test that a doc with comments gets converted back with comments
"""
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
check_russian_doc(doc)
def test_unusual_misc():
"""
The above RUSSIAN_SAMPLE resulted in a blank misc field in one particular implementation of the conll code
(the below test would fail)
"""
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
sentences = "{:C}".format(doc).split("\n\n")
assert len(sentences) == 2
sentence = sentences[0].split("\n")
assert len(sentence) == 14
for word in sentence:
pieces = word.split("\t")
assert len(pieces) == 1 or len(pieces) == 10
if len(pieces) == 10:
assert all(piece for piece in pieces)
def test_file():
"""
Test loading a doc from a file
"""
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "russian.conll")
with open(filename, "w", encoding="utf-8") as fout:
fout.write(RUSSIAN_SAMPLE)
doc = CoNLL.conll2doc(input_file=filename)
check_russian_doc(doc)
def test_zip_file():
"""
Test loading a doc from a zip file
"""
with tempfile.TemporaryDirectory() as tempdir:
zip_file = os.path.join(tempdir, "russian.zip")
filename = "russian.conll"
with ZipFile(zip_file, "w") as zout:
with zout.open(filename, "w") as fout:
fout.write(RUSSIAN_SAMPLE.encode())
doc = CoNLL.conll2doc(input_file=filename, zip_file=zip_file)
check_russian_doc(doc)
SIMPLE_NER = """
# text = Teferi's best friend is Karn
# sent_id = 0
1 Teferi _ _ _ _ 0 _ _ start_char=0|end_char=6|ner=S-PERSON
2 's _ _ _ _ 1 _ _ start_char=6|end_char=8|ner=O
3 best _ _ _ _ 2 _ _ start_char=9|end_char=13|ner=O
4 friend _ _ _ _ 3 _ _ start_char=14|end_char=20|ner=O
5 is _ _ _ _ 4 _ _ start_char=21|end_char=23|ner=O
6 Karn _ _ _ _ 5 _ _ start_char=24|end_char=28|ner=S-PERSON
""".strip()
def test_simple_ner_conversion():
"""
Test that tokens get properly created with NER tags
"""
doc = CoNLL.conll2doc(input_str=SIMPLE_NER)
assert len(doc.sentences) == 1
sentence = doc.sentences[0]
assert len(sentence.tokens) == 6
EXPECTED_NER = ["S-PERSON", "O", "O", "O", "O", "S-PERSON"]
for token, ner in zip(sentence.tokens, EXPECTED_NER):
assert token.ner == ner
# check that the ner, start_char, end_char fields were not put on the token's misc
# those should all be set as specific fields on the token
assert not token.misc
assert len(token.words) == 1
# they should also not reach the word's misc field
assert not token.words[0].misc
conll = "{:C}".format(doc)
assert conll == SIMPLE_NER
MWT_NER = """
# text = This makes John's headache worse
# sent_id = 0
1 This _ _ _ _ 0 _ _ start_char=0|end_char=4|ner=O
2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10|ner=O
3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17|ner=S-PERSON
3 John _ _ _ _ 2 _ _ _
4 's _ _ _ _ 3 _ _ _
5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26|ner=O
6 worse _ _ _ _ 5 _ _ start_char=27|end_char=32|ner=O
""".strip()
def test_mwt_ner_conversion():
"""
Test that tokens including MWT get properly created with NER tags
Note that this kind of thing happens with the EWT tokenizer for English, for example
"""
doc = CoNLL.conll2doc(input_str=MWT_NER)
assert len(doc.sentences) == 1
sentence = doc.sentences[0]
assert len(sentence.tokens) == 5
assert not sentence.has_enhanced_dependencies()
EXPECTED_NER = ["O", "O", "S-PERSON", "O", "O"]
EXPECTED_WORDS = [1, 1, 2, 1, 1]
for token, ner, expected_words in zip(sentence.tokens, EXPECTED_NER, EXPECTED_WORDS):
assert token.ner == ner
# check that the ner, start_char, end_char fields were not put on the token's misc
# those should all be set as specific fields on the token
assert not token.misc
assert len(token.words) == expected_words
# they should also not reach the word's misc field
assert not token.words[0].misc
conll = "{:C}".format(doc)
assert conll == MWT_NER
ALL_OFFSETS_CONLLU = """
# text = This makes John's headache worse
# sent_id = 0
1 This _ _ _ _ 0 _ _ start_char=0|end_char=4
2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10
3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17
3 John _ _ _ _ 2 _ _ start_char=11|end_char=15
4 's _ _ _ _ 3 _ _ start_char=15|end_char=17
5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26
6 worse _ _ _ _ 5 _ _ SpaceAfter=No|start_char=27|end_char=32
""".strip()
NO_OFFSETS_CONLLU = """
# text = This makes John's headache worse
# sent_id = 0
1 This _ _ _ _ 0 _ _ _
2 makes _ _ _ _ 1 _ _ _
3-4 John's _ _ _ _ _ _ _ _
3 John _ _ _ _ 2 _ _ _
4 's _ _ _ _ 3 _ _ _
5 headache _ _ _ _ 4 _ _ _
6 worse _ _ _ _ 5 _ _ SpaceAfter=No
""".strip()
NO_COMMENTS_NO_OFFSETS_CONLLU = """
1 This _ _ _ _ 0 _ _ _
2 makes _ _ _ _ 1 _ _ _
3-4 John's _ _ _ _ _ _ _ _
3 John _ _ _ _ 2 _ _ _
4 's _ _ _ _ 3 _ _ _
5 headache _ _ _ _ 4 _ _ _
6 worse _ _ _ _ 5 _ _ SpaceAfter=No
""".strip()
def test_no_offsets_output():
doc = CoNLL.conll2doc(input_str=ALL_OFFSETS_CONLLU)
assert len(doc.sentences) == 1
sentence = doc.sentences[0]
assert len(sentence.tokens) == 5
conll = "{:C}".format(doc)
assert conll == ALL_OFFSETS_CONLLU
conll = "{:C-o}".format(doc)
assert conll == NO_OFFSETS_CONLLU
conll = "{:c-o}".format(doc)
assert conll == NO_COMMENTS_NO_OFFSETS_CONLLU
# A random sentence from et_ewt-ud-train.conllu
# which we use to test the deps conversion for multiple deps
ESTONIAN_DEPS = """
# newpar
# sent_id = aia_foorum_37
# text = Sestpeale ei mõistagi neid, kes koduaias sortidega tegelevad.
1 Sestpeale sest_peale ADV D _ 3 advmod 3:advmod _
2 ei ei AUX V Polarity=Neg 3 aux 3:aux _
3 mõistagi mõistma VERB V Connegative=Yes|Mood=Ind|Tense=Pres|VerbForm=Fin|Voice=Act 0 root 0:root _
4 neid tema PRON P Case=Par|Number=Plur|Person=3|PronType=Prs 3 obj 3:obj|9:nsubj SpaceAfter=No
5 , , PUNCT Z _ 9 punct 9:punct _
6 kes kes PRON P Case=Nom|Number=Plur|PronType=Int,Rel 9 nsubj 4:ref _
7 koduaias kodu_aed NOUN S Case=Ine|Number=Sing 9 obl 9:obl _
8 sortidega sort NOUN S Case=Com|Number=Plur 9 obl 9:obl _
9 tegelevad tegelema VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 4 acl:relcl 4:acl SpaceAfter=No
10 . . PUNCT Z _ 3 punct 3:punct _
""".strip()
def test_deps_conversion():
doc = CoNLL.conll2doc(input_str=ESTONIAN_DEPS)
assert len(doc.sentences) == 1
sentence = doc.sentences[0]
assert len(sentence.tokens) == 10
assert sentence.has_enhanced_dependencies()
word = doc.sentences[0].words[3]
assert word.deps == "3:obj|9:nsubj"
conll = "{:C}".format(doc)
assert conll == ESTONIAN_DEPS
ESTONIAN_EMPTY_DEPS = """
# sent_id = ewtb2_000035_15
# text = Ja paari aasta pärast rôômalt maasikatele ...
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
7 ... ... PUNCT Z _ 3 punct 5.1:punct _
""".strip()
ESTONIAN_EMPTY_END_DEPS = """
# sent_id = ewtb2_000035_15
# text = Ja paari aasta pärast rôômalt maasikatele ...
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
""".strip()
def test_empty_deps_conversion():
"""
Check that we can read and then output a sentence with empty dependencies
"""
check_empty_deps_conversion(ESTONIAN_EMPTY_DEPS, 7)
def test_empty_deps_at_end_conversion():
"""
The empty deps conversion should also work if the empty dep is at the end
"""
check_empty_deps_conversion(ESTONIAN_EMPTY_END_DEPS, 5)
def check_empty_deps_conversion(input_str, expected_words):
doc = CoNLL.conll2doc(input_str=input_str, ignore_gapping=False)
assert len(doc.sentences) == 1
assert len(doc.sentences[0].tokens) == expected_words
assert len(doc.sentences[0].words) == expected_words
assert len(doc.sentences[0].empty_words) == 1
sentence = doc.sentences[0]
conll = "{:C}".format(doc)
assert conll == input_str
sentence_dict = doc.sentences[0].to_dict()
assert len(sentence_dict) == expected_words + 1
# currently this is true for both of the examples we run
assert sentence_dict[5]['id'] == (5, 1)
# redo the above checks to make sure
# there are no weird bugs in the accessors
assert len(doc.sentences) == 1
assert len(doc.sentences[0].tokens) == expected_words
assert len(doc.sentences[0].words) == expected_words
assert len(doc.sentences[0].empty_words) == 1
ESTONIAN_DOC_ID = """
# doc_id = this_is_a_doc
# sent_id = ewtb2_000035_15
# text = Ja paari aasta pärast rôômalt maasikatele ...
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
7 ... ... PUNCT Z _ 3 punct 5.1:punct _
""".strip()
def test_read_doc_id():
doc = CoNLL.conll2doc(input_str=ESTONIAN_DOC_ID, ignore_gapping=False)
assert "{:C}".format(doc) == ESTONIAN_DOC_ID
assert doc.sentences[0].doc_id == 'this_is_a_doc'
SIMPLE_DEPENDENCY_INDEX_ERROR = """
# text = Teferi's best friend is Karn
# sent_id = 0
# notes = this sentence has a dependency index outside the sentence. it should throw an IndexError
1 Teferi _ _ _ _ 0 root _ start_char=0|end_char=6|ner=S-PERSON
2 's _ _ _ _ 1 dep _ start_char=6|end_char=8|ner=O
3 best _ _ _ _ 2 dep _ start_char=9|end_char=13|ner=O
4 friend _ _ _ _ 3 dep _ start_char=14|end_char=20|ner=O
5 is _ _ _ _ 4 dep _ start_char=21|end_char=23|ner=O
6 Karn _ _ _ _ 8 dep _ start_char=24|end_char=28|ner=S-PERSON
""".strip()
def test_read_dependency_errors():
with pytest.raises(IndexError):
doc = CoNLL.conll2doc(input_str=SIMPLE_DEPENDENCY_INDEX_ERROR)
MULTIPLE_DOC_IDS = """
# doc_id = doc_1
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0020
# text = His mother was also killed in the attack.
1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 2 nmod:poss 2:nmod:poss _
2 mother mother NOUN NN Number=Sing 5 nsubj:pass 5:nsubj:pass _
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 5 aux:pass 5:aux:pass _
4 also also ADV RB _ 5 advmod 5:advmod _
5 killed kill VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
6 in in ADP IN _ 8 case 8:case _
7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
8 attack attack NOUN NN Number=Sing 5 obl 5:obl:in SpaceAfter=No
9 . . PUNCT . _ 5 punct 5:punct _
# doc_id = doc_1
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0028
# text = This item is a small one and easily missed.
1 This this DET DT Number=Sing|PronType=Dem 2 det 2:det _
2 item item NOUN NN Number=Sing 6 nsubj 6:nsubj|9:nsubj:pass _
3 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
5 small small ADJ JJ Degree=Pos 6 amod 6:amod _
6 one one NOUN NN Number=Sing 0 root 0:root _
7 and and CCONJ CC _ 9 cc 9:cc _
8 easily easily ADV RB _ 9 advmod 9:advmod _
9 missed miss VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 6 conj 6:conj:and SpaceAfter=No
10 . . PUNCT . _ 6 punct 6:punct _
# doc_id = doc_2
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0029
# text = But in my view it is highly significant.
1 But but CCONJ CC _ 8 cc 8:cc _
2 in in ADP IN _ 4 case 4:case _
3 my my PRON PRP$ Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs 4 nmod:poss 4:nmod:poss _
4 view view NOUN NN Number=Sing 8 obl 8:obl:in _
5 it it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 8 nsubj 8:nsubj _
6 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 8 cop 8:cop _
7 highly highly ADV RB _ 8 advmod 8:advmod _
8 significant significant ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
9 . . PUNCT . _ 8 punct 8:punct _
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0040
# text = The trial begins again Nov.28.
1 The the DET DT Definite=Def|PronType=Art 2 det 2:det _
2 trial trial NOUN NN Number=Sing 3 nsubj 3:nsubj _
3 begins begin VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
4 again again ADV RB _ 3 advmod 3:advmod _
5 Nov. November PROPN NNP Abbr=Yes|Number=Sing 3 obl:tmod 3:obl:tmod SpaceAfter=No
6 28 28 NUM CD NumForm=Digit|NumType=Card 5 nummod 5:nummod SpaceAfter=No
7 . . PUNCT . _ 3 punct 3:punct _
""".lstrip()
def test_read_multiple_doc_ids():
docs = CoNLL.conll2multi_docs(input_str=MULTIPLE_DOC_IDS)
assert len(docs) == 2
assert len(docs[0].sentences) == 2
assert len(docs[1].sentences) == 2
# remove the first doc_id comment
text = "\n".join(MULTIPLE_DOC_IDS.split("\n")[1:])
docs = CoNLL.conll2multi_docs(input_str=text)
assert len(docs) == 3
assert len(docs[0].sentences) == 1
assert len(docs[1].sentences) == 1
assert len(docs[2].sentences) == 2
ENGLISH_TEST_SENTENCE = """
# text = This is a test
# sent_id = 0
1 This this PRON DT Number=Sing|PronType=Dem 4 nsubj _ start_char=0|end_char=4
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ start_char=5|end_char=7
3 a a DET DT Definite=Ind|PronType=Art 4 det _ start_char=8|end_char=9
4 test test NOUN NN Number=Sing 0 root _ SpaceAfter=No|start_char=10|end_char=14
""".lstrip()
def test_convert_dict():
doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE)
converted = CoNLL.convert_dict(doc.to_dict())
expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],
['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],
['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],
['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]
assert converted == expected
def test_line_numbers():
doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE, keep_line_numbers=True)
# currently the line numbers are not output in the conllu format
doc_conllu = "{:C}\n".format(doc)
assert doc_conllu == ENGLISH_TEST_SENTENCE
# currently the line numbers are not output in the dict format
converted = CoNLL.convert_dict(doc.to_dict())
expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],
['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],
['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],
['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]
assert converted == expected
for word_idx, word in enumerate(doc.sentences[0].words):
# the test sentence has two comments in it
assert word.line_number == word_idx + 2
SPEAKER_EXAMPLE = """
# sent_id = GUM_fiction_pag-57
# speaker = Siri
# addressee = Pag
# text = "Sorry."
1 " " PUNCT `` _ 2 punct 2:punct Discourse=joint-sequence_m:130->128:1:_|SpaceAfter=No
2 Sorry sorry ADJ JJ Degree=Pos 0 root 0:root MSeg=Sorr-y|SpaceAfter=No
3 . . PUNCT . _ 2 punct 2:punct SpaceAfter=No
4 " " PUNCT '' _ 2 punct 2:punct _
""".lstrip()
def test_speaker():
doc = CoNLL.conll2doc(input_str=SPEAKER_EXAMPLE)
assert len(doc.sentences) == 1
assert doc.sentences[0].speaker == 'Siri'
assert "# speaker = Siri" in doc.sentences[0].comments
doc.sentences[0].speaker = "foo"
assert doc.sentences[0].speaker == 'foo'
assert any(comment.startswith("# speaker") for comment in doc.sentences[0].comments)
assert "# speaker = foo" in doc.sentences[0].comments
doc.sentences[0].speaker = None
assert not any(comment.startswith("# speaker") for comment in doc.sentences[0].comments)
assert doc.sentences[0].speaker is None
doc.sentences[0].speaker = "Siri"
assert doc.sentences[0].speaker == 'Siri'
assert "# speaker = Siri" in doc.sentences[0].comments
================================================
FILE: stanza/tests/common/test_data_objects.py
================================================
"""
Basic tests of the stanza data objects, especially the setter/getter routines
"""
import pytest
import stanza
from stanza.models.common.doc import Document, Sentence, Word
from stanza.tests import *
pytestmark = pytest.mark.pipeline
# data for testing
EN_DOC = "This is a test document. Pretty cool!"
EN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.'))
EN_DOC2 = "Chris Manning wrote a sentence. Then another."
@pytest.fixture(scope="module")
def nlp_pipeline():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en')
return nlp
def test_readonly(nlp_pipeline):
Document.add_property('some_property', 123)
doc = nlp_pipeline(EN_DOC)
assert doc.some_property == 123
with pytest.raises(ValueError):
doc.some_property = 456
def test_getter(nlp_pipeline):
Word.add_property('upos_xpos', getter=lambda self: f"{self.upos}_{self.xpos}")
doc = nlp_pipeline(EN_DOC)
assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences)
def test_setter_getter(nlp_pipeline):
int2str = {0: 'ok', 1: 'good', 2: 'bad'}
str2int = {'ok': 0, 'good': 1, 'bad': 2}
def setter(self, value):
self._classname = str2int[value]
Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter)
doc = nlp_pipeline(EN_DOC)
sentence = doc.sentences[0]
sentence.classname = 'good'
assert sentence._classname == 1
# don't try this at home
sentence._classname = 2
assert sentence.classname == 'bad'
def test_backpointer(nlp_pipeline):
doc = nlp_pipeline(EN_DOC2)
ent = doc.ents[0]
assert ent.sent is doc.sentences[0]
assert list(doc.iter_words())[0].sent is doc.sentences[0]
assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1]
================================================
FILE: stanza/tests/common/test_doc.py
================================================
import pytest
import stanza
from stanza.tests import *
from stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
@pytest.fixture
def sentences_dict():
return [[{ID: 1, TEXT: "unban"},
{ID: 2, TEXT: "mox"},
{ID: 3, TEXT: "opal"}],
[{ID: 4, TEXT: "ban"},
{ID: 5, TEXT: "Lurrus"}]]
@pytest.fixture
def doc(sentences_dict):
doc = Document(sentences_dict)
return doc
def test_basic_values(doc, sentences_dict):
"""
Test that sentences & token text are properly set when constructing a doc
"""
assert len(doc.sentences) == len(sentences_dict)
for sentence, raw_sentence in zip(doc.sentences, sentences_dict):
assert sentence.doc == doc
assert len(sentence.tokens) == len(raw_sentence)
for token, raw_token in zip(sentence.tokens, raw_sentence):
assert token.text == raw_token[TEXT]
def test_set_sentence(doc):
"""
Test setting a field on the sentences themselves
"""
doc.set(fields="sentiment",
contents=["4", "0"],
to_sentence=True)
assert doc.sentences[0].sentiment == "4"
assert doc.sentences[1].sentiment == "0"
def test_set_tokens(doc):
"""
Test setting values on tokens
"""
ner_contents = ["O", "ARTIFACT", "ARTIFACT", "O", "CAT"]
doc.set(fields=NER,
contents=ner_contents,
to_token=True)
result = doc.get(NER, from_token=True)
assert result == ner_contents
def test_constituency_comment(doc):
"""
Test that setting the constituency tree on a doc sets the constituency comment
"""
for sentence in doc.sentences:
assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 0
# currently nothing is checking that the items are actually trees
trees = ["asdf", "zzzz"]
doc.set(fields=CONSTITUENCY,
contents=trees,
to_sentence=True)
for sentence, expected in zip(doc.sentences, trees):
constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
assert len(constituency_comments) == 1
assert constituency_comments[0].endswith(expected)
# Test that if we replace the trees with an updated tree, the comment is also replaced
trees = ["zzzz", "asdf"]
doc.set(fields=CONSTITUENCY,
contents=trees,
to_sentence=True)
for sentence, expected in zip(doc.sentences, trees):
constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
assert len(constituency_comments) == 1
assert constituency_comments[0].endswith(expected)
def test_sentiment_comment(doc):
"""
Test that setting the sentiment on a doc sets the sentiment comment
"""
for sentence in doc.sentences:
assert len([x for x in sentence.comments if x.startswith("# sentiment")]) == 0
# currently nothing is checking that the items are actually trees
sentiments = ["1", "2"]
doc.set(fields=SENTIMENT,
contents=sentiments,
to_sentence=True)
for sentence, expected in zip(doc.sentences, sentiments):
sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
assert len(sentiment_comments) == 1
assert sentiment_comments[0].endswith(expected)
# Test that if we replace the trees with an updated tree, the comment is also replaced
sentiments = ["3", "4"]
doc.set(fields=SENTIMENT,
contents=sentiments,
to_sentence=True)
for sentence, expected in zip(doc.sentences, sentiments):
sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
assert len(sentiment_comments) == 1
assert sentiment_comments[0].endswith(expected)
def test_sent_id_comment(doc):
"""
Test that setting the sent_id on a sentence sets the sentiment comment
"""
for sent_idx, sentence in enumerate(doc.sentences):
assert len([x for x in sentence.comments if x.startswith("# sent_id")]) == 1
assert sentence.sent_id == "%d" % sent_idx
doc.sentences[0].sent_id = "foo"
assert doc.sentences[0].sent_id == "foo"
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
assert "# sent_id = foo" in doc.sentences[0].comments
doc.reindex_sentences(10)
for sent_idx, sentence in enumerate(doc.sentences):
assert sentence.sent_id == "%d" % (sent_idx + 10)
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
assert "# sent_id = %d" % (sent_idx + 10) in sentence.comments
doc.sentences[0].add_comment("# sent_id = bar")
assert doc.sentences[0].sent_id == "bar"
assert "# sent_id = bar" in doc.sentences[0].comments
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
def test_doc_id_comment(doc):
"""
Test that setting the doc_id on a sentence sets the document comment
"""
assert doc.sentences[0].doc_id is None
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 0
doc.sentences[0].doc_id = "foo"
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
assert "# doc_id = foo" in doc.sentences[0].comments
assert doc.sentences[0].doc_id == "foo"
doc.sentences[0].add_comment("# doc_id = bar")
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
assert doc.sentences[0].doc_id == "bar"
@pytest.fixture(scope="module")
def pipeline():
return stanza.Pipeline(dir=TEST_MODELS_DIR)
def test_serialized(pipeline):
"""
Brief test of the serialized format
Checks that NER entities are correctly set.
Also checks that constituency & sentiment are set on the sentences.
"""
text = "John Bauer works at Stanford"
doc = pipeline(text)
assert len(doc.ents) == 2
serialized = doc.to_serialized()
doc2 = Document.from_serialized(serialized)
assert len(doc2.sentences) == 1
assert len(doc2.ents) == 2
assert doc.sentences[0].constituency == doc2.sentences[0].constituency
assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment
================================================
FILE: stanza/tests/common/test_dropout.py
================================================
import pytest
import torch
import stanza
from stanza.models.common.dropout import WordDropout
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_word_dropout():
"""
Test that word_dropout is randomly dropping out the entire final dimension of a tensor
Doing 600 small rows should be super fast, but it leaves us with
something like a 1 in 10^180 chance of the test failing. Not very
common, in other words
"""
wd = WordDropout(0.5)
batch = torch.randn(600, 4)
dropped = wd(batch)
# the one time any of this happens, it's going to be really confusing
assert not torch.allclose(batch, dropped)
num_zeros = 0
for i in range(batch.shape[0]):
assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0
if torch.sum(dropped[i]) == 0.0:
num_zeros += 1
assert num_zeros > 0 and num_zeros < batch.shape[0]
================================================
FILE: stanza/tests/common/test_foundation_cache.py
================================================
import glob
import os
import shutil
import tempfile
import pytest
import stanza
from stanza.models.common.foundation_cache import FoundationCache, load_charlm
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_charlm_cache():
models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
models = glob.glob(models_path)
# we expect at least one English model downloaded for the tests
assert len(models) >= 1
model_file = models[0]
cache = FoundationCache()
with tempfile.TemporaryDirectory(dir=".") as test_dir:
temp_file = os.path.join(test_dir, "charlm.pt")
shutil.copy2(model_file, temp_file)
# this will work
model = load_charlm(temp_file)
# this will save the model
model = cache.load_charlm(temp_file)
# this should no longer work
with pytest.raises(FileNotFoundError):
model = load_charlm(temp_file)
# it should remember the cached version
model = cache.load_charlm(temp_file)
================================================
FILE: stanza/tests/common/test_pretrain.py
================================================
import os
import tempfile
import pytest
import numpy as np
import torch
from stanza.models.common import pretrain
from stanza.models.common.vocab import UNK_ID
from stanza.tests import *
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def check_vocab(vocab):
# 4 base vectors, plus the 3 vectors actually present in the file
assert len(vocab) == 7
assert 'unban' in vocab
assert 'mox' in vocab
assert 'opal' in vocab
def check_embedding(emb, unk=False):
expected = np.array([[ 0., 0., 0., 0.,],
[ 0., 0., 0., 0.,],
[ 0., 0., 0., 0.,],
[ 0., 0., 0., 0.,],
[ 1., 2., 3., 4.,],
[ 5., 6., 7., 8.,],
[ 9., 10., 11., 12.,]])
if unk:
expected[UNK_ID] = -1
np.testing.assert_allclose(emb, expected)
def check_pretrain(pt):
check_vocab(pt.vocab)
check_embedding(pt.emb)
def test_text_pretrain():
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.txt', save_to_file=False)
check_pretrain(pt)
def test_xz_pretrain():
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)
check_pretrain(pt)
def test_gz_pretrain():
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.gz', save_to_file=False)
check_pretrain(pt)
def test_zip_pretrain():
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.zip', save_to_file=False)
check_pretrain(pt)
def test_csv_pretrain():
pt = pretrain.Pretrain(csv_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.csv', save_to_file=False)
check_pretrain(pt)
def test_resave_pretrain():
"""
Test saving a pretrain and then loading from the existing file
"""
test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".pt", delete=False)
try:
test_pt_file.close()
# note that this tests the ability to save a pretrain and the
# ability to fall back when the existing pretrain isn't working
pt = pretrain.Pretrain(filename=test_pt_file.name,
vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz')
check_pretrain(pt)
pt2 = pretrain.Pretrain(filename=test_pt_file.name,
vec_filename=f'unban_mox_opal')
check_pretrain(pt2)
pt3 = torch.load(test_pt_file.name, weights_only=True)
check_embedding(pt3['emb'])
finally:
os.unlink(test_pt_file.name)
SPACE_PRETRAIN="""
3 4
unban mox 1 2 3 4
opal 5 6 7 8
foo 9 10 11 12
""".strip()
def test_whitespace():
"""
Test reading a pretrain with an ascii space in it
The vocab word with a space in it should have the correct number
of dimensions read, with the space converted to nbsp
"""
test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".txt", delete=False)
try:
test_txt_file.write(SPACE_PRETRAIN.encode())
test_txt_file.close()
pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False)
check_embedding(pt.emb)
assert "unban\xa0mox" in pt.vocab
# this one also works because of the normalize_unit in vocab.py
assert "unban mox" in pt.vocab
finally:
os.unlink(test_txt_file.name)
NO_HEADER_PRETRAIN="""
unban 1 2 3 4
mox 5 6 7 8
opal 9 10 11 12
""".strip()
def test_no_header():
"""
Check loading a pretrain with no rows,cols header
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
filename = os.path.join(tmpdir, "tiny.txt")
with open(filename, "w", encoding="utf-8") as fout:
fout.write(NO_HEADER_PRETRAIN)
pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
check_embedding(pt.emb)
UNK_PRETRAIN="""
unban 1 2 3 4
mox 5 6 7 8
opal 9 10 11 12
-1 -1 -1 -1
""".strip()
def test_no_header():
"""
Check loading a pretrain with at the end, like GloVe does
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
filename = os.path.join(tmpdir, "tiny.txt")
with open(filename, "w", encoding="utf-8") as fout:
fout.write(UNK_PRETRAIN)
pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
check_embedding(pt.emb, unk=True)
================================================
FILE: stanza/tests/common/test_relative_attn.py
================================================
import pytest
import torch
from stanza.models.common.relative_attn import RelativeAttention
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_attn():
foo = RelativeAttention(d_model=100, num_heads=2, window=8, dropout=0.0)
bar = torch.randn(10, 13, 100)
result = foo(bar)
assert result.shape == bar.shape
value = foo.value(bar)
if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06):
raise ValueError(result[:, -1, :] - value[:, -1, :])
assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06)
assert not torch.allclose(result[:, 0, :], value[:, 0, :])
def test_shorter_sequence():
# originally this was failing because the batch was smaller than the window
foo = RelativeAttention(d_model=20, num_heads=2, window=5, dropout=0.0)
bar = torch.randn(10, 3, 20)
result = foo(bar)
assert result.shape == bar.shape
value = foo.value(bar)
if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06):
raise ValueError(result[:, -1, :] - value[:, -1, :])
assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06)
assert not torch.allclose(result[:, 0, :], value[:, 0, :])
def test_reverse():
foo = RelativeAttention(d_model=100, num_heads=2, window=8, reverse=True, dropout=0.0)
bar = torch.randn(10, 13, 100)
result = foo(bar)
assert result.shape == bar.shape
value = foo.value(bar)
if not torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06):
raise ValueError(result[:, 0, :] - value[:, 0, :])
assert torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06)
assert not torch.allclose(result[:, -1, :], value[:, -1, :])
================================================
FILE: stanza/tests/common/test_short_name_to_treebank.py
================================================
import pytest
import stanza
from stanza.models.common import short_name_to_treebank
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_short_name():
assert short_name_to_treebank.short_name_to_treebank("en_ewt") == "UD_English-EWT"
def test_canonical_name():
assert short_name_to_treebank.canonical_treebank_name("UD_URDU-UDTB") == "UD_Urdu-UDTB"
assert short_name_to_treebank.canonical_treebank_name("ur_udtb") == "UD_Urdu-UDTB"
assert short_name_to_treebank.canonical_treebank_name("Unban_Mox_Opal") == "Unban_Mox_Opal"
================================================
FILE: stanza/tests/common/test_utils.py
================================================
import lzma
import os
import tempfile
import pytest
import stanza
import stanza.models.common.utils as utils
from stanza.tests import *
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_wordvec_not_found():
"""
get_wordvec_file should fail if neither word2vec nor fasttext exists
"""
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
with pytest.raises(FileNotFoundError):
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
def test_word2vec_xz():
"""
Test searching for word2vec and xz files
"""
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
# make a fake directory for English word vectors
word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English')
os.makedirs(word2vec_dir)
# make a fake English word vector file
fake_file = os.path.join(word2vec_dir, 'en.vectors.xz')
fout = open(fake_file, 'w')
fout.close()
# get_wordvec_file should now find this fake file
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
assert filename == fake_file
def test_fasttext_txt():
"""
Test searching for fasttext and txt files
"""
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
# make a fake directory for English word vectors
fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English')
os.makedirs(fasttext_dir)
# make a fake English word vector file
fake_file = os.path.join(fasttext_dir, 'en.vectors.txt')
fout = open(fake_file, 'w')
fout.close()
# get_wordvec_file should now find this fake file
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
assert filename == fake_file
def test_wordvec_type():
"""
If we supply our own wordvec type, get_wordvec_file should find that
"""
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
# make a fake directory for English word vectors
google_dir = os.path.join(temp_dir, 'google', 'English')
os.makedirs(google_dir)
# make a fake English word vector file
fake_file = os.path.join(google_dir, 'en.vectors.txt')
fout = open(fake_file, 'w')
fout.close()
# get_wordvec_file should now find this fake file
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google')
assert filename == fake_file
# this file won't be found using the normal defaults
with pytest.raises(FileNotFoundError):
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
def test_sort_with_indices():
data = [[1, 2, 3], [4, 5], [6]]
ordered, orig_idx = utils.sort_with_indices(data, key=len)
assert ordered == ([6], [4, 5], [1, 2, 3])
assert orig_idx == (2, 1, 0)
unsorted = utils.unsort(ordered, orig_idx)
assert data == unsorted
def test_empty_sort_with_indices():
ordered, orig_idx = utils.sort_with_indices([])
assert len(ordered) == 0
assert len(orig_idx) == 0
unsorted = utils.unsort(ordered, orig_idx)
assert [] == unsorted
def test_split_into_batches():
data = []
for i in range(5):
data.append(["Unban", "mox", "opal", str(i)])
data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"])
data.append(["Ban", "Ragavan"])
# small batches will put one element in each interval
batches = utils.split_into_batches(data, 5)
assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
# this one has a batch interrupted in the middle by a large element
batches = utils.split_into_batches(data, 8)
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)]
# this one has the large element at the start of its own batch
batches = utils.split_into_batches(data[1:], 8)
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)]
# overloading the test! assert that the key & reverse is working
ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True)
assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2]
# this has the large element at the start
batches = utils.split_into_batches(ordered, 8)
assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)]
# double check that unsort is working as expected
assert data == utils.unsort(ordered, orig_idx)
def test_find_missing_tags():
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == []
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG']
assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG']
def test_open_read_text():
"""
test that we can read either .xz or regular txt
"""
TEXT = "this is a test"
with tempfile.TemporaryDirectory() as tempdir:
# test text file
filename = os.path.join(tempdir, "foo.txt")
with open(filename, "w") as fout:
fout.write(TEXT)
with utils.open_read_text(filename) as fin:
in_text = fin.read()
assert TEXT == in_text
assert fin.closed
# the context should close the file when we throw an exception!
try:
with utils.open_read_text(filename) as finex:
assert not finex.closed
raise ValueError("unban mox opal!")
except ValueError:
pass
assert finex.closed
# test xz file
filename = os.path.join(tempdir, "foo.txt.xz")
with lzma.open(filename, "wt") as fout:
fout.write(TEXT)
with utils.open_read_text(filename) as finxz:
in_text = finxz.read()
assert TEXT == in_text
assert finxz.closed
# the context should close the file when we throw an exception!
try:
with utils.open_read_text(filename) as finexxz:
assert not finexxz.closed
raise ValueError("unban mox opal!")
except ValueError:
pass
assert finexxz.closed
def test_checkpoint_name():
"""
Test some expected results for the checkpoint names
"""
# use os.path.split so that the test is agnostic of file separator on Linux or Windows
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm.pt", None)
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint.pt")
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", None)
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint")
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", "othername.pt")
assert os.path.split(checkpoint) == ("saved_models", "othername.pt")
def test_punct_simplification():
"""
Test a punctuation simplification that should make it so unexpected
question/exclamation marks types are processed into ? and !
"""
test = [[["!!!!"],
["‼‼‼‼"],
["????"],
["?!?!"],
["??︖"],
["?foo"],
["bar!"]]]
test = utils.simplify_punct(test)
expected = [[['!'], ['!'], ['?'], ['?'], ['?'], ['?foo'], ['bar!']]]
assert test == expected
================================================
FILE: stanza/tests/constituency/__init__.py
================================================
================================================
FILE: stanza/tests/constituency/test_convert_arboretum.py
================================================
"""
Test a couple different classes of trees to check the output of the Arboretum conversion
Note that the text has been removed
"""
import os
import tempfile
import pytest
from stanza.server import tsurgeon
from stanza.tests import TEST_WORKING_DIR
from stanza.utils.datasets.constituency import convert_arboretum
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
PROJ_EXAMPLE="""
"""
NOT_FIX_NONPROJ_EXAMPLE="""
"""
NONPROJ_EXAMPLE="""
"""
def test_projective_example():
"""
Test reading a basic tree, along with some further manipulations from the conversion program
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
test_name = os.path.join(tempdir, "proj.xml")
with open(test_name, "w", encoding="utf-8") as fout:
fout.write(PROJ_EXAMPLE)
sentences = convert_arboretum.read_xml_file(test_name)
assert len(sentences) == 1
tree, words = convert_arboretum.process_tree(sentences[0])
expected_tree = "(s (fcl (prop s2_1) (v-fin s2_2) (pron-pers s2_3) (adjp (adj s2_4) (pp (prp s2_5) (np (art s2_6) (adj s2_7) (n s2_8)))) (pu s2_9)))"
assert str(tree) == expected_tree
assert [w.word for w in words.values()] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', '.']
assert not convert_arboretum.word_sequence_missing_words(tree)
with tsurgeon.Tsurgeon() as tsurgeon_processor:
assert tree == convert_arboretum.check_words(tree, tsurgeon_processor)
# check that the words can be replaced as expected
replaced_tree = convert_arboretum.replace_words(tree, words)
expected_tree = "(s (fcl (prop A) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
assert str(replaced_tree) == expected_tree
assert convert_arboretum.split_underscores(replaced_tree) == replaced_tree
# fake a word which should be split
words['s2_1'] = words['s2_1']._replace(word='foo_bar')
replaced_tree = convert_arboretum.replace_words(tree, words)
expected_tree = "(s (fcl (prop foo_bar) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
assert str(replaced_tree) == expected_tree
expected_tree = "(s (fcl (np (prop foo) (prop bar)) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
assert str(convert_arboretum.split_underscores(replaced_tree)) == expected_tree
def test_not_fix_example():
"""
Test that a non-projective tree which we don't have a heuristic for quietly fails
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
test_name = os.path.join(tempdir, "nofix.xml")
with open(test_name, "w", encoding="utf-8") as fout:
fout.write(NOT_FIX_NONPROJ_EXAMPLE)
sentences = convert_arboretum.read_xml_file(test_name)
assert len(sentences) == 1
tree, words = convert_arboretum.process_tree(sentences[0])
assert not convert_arboretum.word_sequence_missing_words(tree)
with tsurgeon.Tsurgeon() as tsurgeon_processor:
assert convert_arboretum.check_words(tree, tsurgeon_processor) is None
def test_fix_proj_example():
"""
Test that a non-projective tree can be rearranged as expected
Note that there are several other classes of non-proj tree we could test as well...
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
test_name = os.path.join(tempdir, "fix.xml")
with open(test_name, "w", encoding="utf-8") as fout:
fout.write(NONPROJ_EXAMPLE)
sentences = convert_arboretum.read_xml_file(test_name)
assert len(sentences) == 1
tree, words = convert_arboretum.process_tree(sentences[0])
assert not convert_arboretum.word_sequence_missing_words(tree)
# the 4 and 5 are moved inside the 3-6 node
expected_orig = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (v-pcp2 s9_6)) (prop s9_4) (adv s9_5) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
expected_proj = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (prop s9_4) (adv s9_5) (v-pcp2 s9_6)) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
assert str(tree) == expected_orig
with tsurgeon.Tsurgeon() as tsurgeon_processor:
assert str(convert_arboretum.check_words(tree, tsurgeon_processor)) == expected_proj
================================================
FILE: stanza/tests/constituency/test_convert_it_vit.py
================================================
"""
Test a couple different classes of trees to check the output of the VIT conversion
A couple representative trees are included, but hopefully not enough
to be a problem in terms of our license.
One of the tests is currently disabled as it relies on tregex & tsurgeon features
not yet released
"""
import io
import os
import tempfile
import pytest
from stanza.server import tsurgeon
from stanza.utils.conll import CoNLL
from stanza.utils.datasets.constituency import convert_it_vit
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# just a sample! don't sue us please
CON_SAMPLE = """
#ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]
#ID=sent_00318 dirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]]
#ID=sent_00589 f-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.]
"""
UD_SAMPLE = """
# sent_id = VIT-2
# text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano.
1-2 Negli _ _ _ _ _ _ _ _
1 In in ADP E _ 4 case _ _
2 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 4 det _ _
3 ultimi ultimo ADJ A Gender=Masc|Number=Plur 4 amod _ _
4 anni anno NOUN S Gender=Masc|Number=Plur 16 obl _ _
5 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _
6 dinamica dinamica NOUN S Gender=Fem|Number=Sing 16 nsubj:pass _ _
7-8 dei _ _ _ _ _ _ _ _
7 di di ADP E _ 9 case _ _
8 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 9 det _ _
9 polo polo NOUN S Gender=Masc|Number=Sing 6 nmod _ _
10 di di ADP E _ 11 case _ _
11 attrazione attrazione NOUN S Gender=Fem|Number=Sing 9 nmod _ _
12 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux _ _
13 stata essere AUX VA Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 16 aux:pass _ _
14 sempre sempre ADV B _ 15 advmod _ _
15 più più ADV B _ 16 advmod _ _
16 caratterizzata caratterizzare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 0 root _ _
17-18 dall' _ _ _ _ _ _ _ SpaceAfter=No
17 da da ADP E _ 19 case _ _
18 l' il DET RD Definite=Def|Number=Sing|PronType=Art 19 det _ _
19 emergere emergere NOUN S Gender=Masc|Number=Sing 16 obl _ _
20 di di ADP E _ 23 case _ _
21 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 23 det _ _
22 crescente crescente ADJ A Number=Sing 23 amod _ _
23 concorrenza concorrenza NOUN S Gender=Fem|Number=Sing 19 nmod _ _
24 che che PRON PR PronType=Rel 28 nsubj _ _
25 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 28 expl _ _
26 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 28 aux _ _
27 progressivamente progressivamente ADV B _ 28 advmod _ _
28 spostata spostare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 23 acl:relcl _ _
29-30 dalle _ _ _ _ _ _ _ _
29 da da ADP E _ 32 case _ _
30 le il DET RD Definite=Def|Gender=Fem|Number=Plur|PronType=Art 32 det _ _
31 singole singolo ADJ A Gender=Fem|Number=Plur 32 amod _ _
32 imprese impresa NOUN S Gender=Fem|Number=Plur 28 obl _ _
33-34 ai _ _ _ _ _ _ _ _
33 a a ADP E _ 35 case _ _
34 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 35 det _ _
35 sistemi sistema NOUN S Gender=Masc|Number=Plur 28 obl _ _
36 economici economico ADJ A Gender=Masc|Number=Plur 35 amod _ _
37 e e CCONJ CC _ 38 cc _ _
38 territoriali territoriale ADJ A Number=Plur 36 conj _ SpaceAfter=No
39 , , PUNCT FF _ 28 punct _ _
40 determinando determinare VERB V VerbForm=Ger 28 advcl _ _
41 l' il DET RD Definite=Def|Number=Sing|PronType=Art 42 det _ SpaceAfter=No
42 esigenza esigenza NOUN S Gender=Fem|Number=Sing 40 obj _ _
43 di di ADP E _ 45 case _ _
44 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 45 det _ _
45 riconsiderazione riconsiderazione NOUN S Gender=Fem|Number=Sing 42 nmod _ _
46-47 dei _ _ _ _ _ _ _ _
46 di di ADP E _ 48 case _ _
47 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 48 det _ _
48 rapporti rapporto NOUN S Gender=Masc|Number=Plur 45 nmod _ _
49 esistenti esistente VERB V Number=Plur 48 acl _ _
50 tra tra ADP E _ 51 case _ _
51 soggetti soggetto NOUN S Gender=Masc|Number=Plur 49 obl _ _
52 produttivi produttivo ADJ A Gender=Masc|Number=Plur 51 amod _ _
53 e e CCONJ CC _ 54 cc _ _
54 ambiente ambiente NOUN S Gender=Masc|Number=Sing 51 conj _ _
55 in in ADP E _ 56 case _ _
56 cui cui PRON PR PronType=Rel 58 obl _ _
57 questi questo PRON PD Gender=Masc|Number=Plur|PronType=Dem 58 nsubj _ _
58 operano operare VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 54 acl:relcl _ SpaceAfter=No
59 . . PUNCT FS _ 16 punct _ _
# sent_id = VIT-318
# text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi.
1 Tuttavia tuttavia CCONJ CC _ 5 cc _ _
2 qualche qualche DET DI Number=Sing|PronType=Ind 3 det _ _
3 problema problema NOUN S Gender=Masc|Number=Sing 5 nsubj _ _
4 potrebbe potere AUX VA Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux _ _
5-6 esserci _ _ _ _ _ _ _ _
5 esser essere VERB V VerbForm=Inf 0 root _ _
6 ci ci PRON PC Clitic=Yes|Number=Plur|Person=1|PronType=Prs 5 expl _ _
7 per per ADP E _ 9 case _ _
8 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 9 det _ _
9 commissione commissione NOUN S Gender=Fem|Number=Sing 5 obl _ _
10 esteri estero ADJ A Gender=Masc|Number=Plur 9 amod _ _
11-12 alla _ _ _ _ _ _ _ _
11 a a ADP E _ 14 case _ _
12 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 14 det _ _
13 cui cui DET DR PronType=Rel 14 det:poss _ _
14 presidenza presidenza NOUN S Gender=Fem|Number=Sing 16 obl _ _
15 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux:pass _ _
16 candidato candidare VERB V Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 9 acl:relcl _ _
17 l' il DET RD Definite=Def|Number=Sing|PronType=Art 18 det _ SpaceAfter=No
18 esponente esponente NOUN S Number=Sing 16 nsubj:pass _ _
19 di di ADP E _ 20 case _ _
20 Alleanza Alleanza PROPN SP _ 18 nmod _ _
21 Nazionale Nazionale PROPN SP _ 20 flat:name _ _
22 Mirko Mirko PROPN SP _ 18 nmod _ _
23 Tremaglia Tremaglia PROPN SP _ 22 flat:name _ SpaceAfter=No
24 : : PUNCT FC _ 22 punct _ _
25 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 26 det _ _
26 candidatura candidatura NOUN S Gender=Fem|Number=Sing 22 appos _ _
27 più più ADV B _ 28 advmod _ _
28 subìta subire VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 26 advcl _ _
29 che che CCONJ CC _ 30 cc _ _
30 gradita gradito ADJ A Gender=Fem|Number=Sing 28 amod _ _
31-32 dalla _ _ _ _ _ _ _ _
31 da da ADP E _ 33 case _ _
32 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 33 det _ _
33 Lega Lega PROPN SP _ 28 obl:agent _ _
34 Nord Nord PROPN SP _ 33 flat:name _ SpaceAfter=No
35 , , PUNCT FC _ 33 punct _ _
36 che che PRON PR PronType=Rel 39 nsubj _ _
37 tuttavia tuttavia CCONJ CC _ 39 cc _ _
38 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 39 aux _ _
39 rispettare rispettare VERB V VerbForm=Inf 33 acl:relcl _ _
40 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 41 det _ _
41 accordi accordio NOUN S Gender=Masc|Number=Plur 39 obj _ SpaceAfter=No
42 . . PUNCT FS _ 5 punct _ _
# sent_id = VIT-591
# text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro.
1 L' il DET RD Definite=Def|Number=Sing|PronType=Art 2 det _ SpaceAfter=No
2 ottimismo ottimismo NOUN S Gender=Masc|Number=Sing 7 nsubj _ _
3 di di ADP E _ 4 case _ _
4 Kantor Kantor PROPN SP _ 2 nmod _ _
5 potrebbe potere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux _ _
6 però però ADV B _ 7 advmod _ _
7-8 rivelarsi _ _ _ _ _ _ _ _
7 rivelar rivelare VERB V VerbForm=Inf 0 root _ _
8 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 7 expl _ _
9 ancora ancora ADV B _ 7 advmod _ _
10 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 11 det _ _
11 volta volta NOUN S Gender=Fem|Number=Sing 7 obl _ _
12 prematuro prematuro ADJ A Gender=Masc|Number=Sing 7 xcomp _ SpaceAfter=No
13 . . PUNCT FS _ 7 punct _ _
"""
def test_process_mwts():
# dei appears multiple times
# the verb/pron esserci will be ignored
expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), "dall'": ('da', "l'"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')}
ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
mwts = convert_it_vit.get_mwt(ud_train_data)
assert expected_mwts == mwts
def test_raw_tree():
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
expected_ids = ["#ID=sent_00002", "#ID=sent_00318", "#ID=sent_00589"]
expected_trees = ["(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
assert len(con_sentences) == 3
for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees):
assert sentence[0] == expected_id
tree = convert_it_vit.raw_tree(sentence[1])
assert str(tree) == expected_tree
def test_update_mwts():
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
mwt_map = convert_it_vit.get_mwt(ud_train_data)
expected_trees=["(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
with tsurgeon.Tsurgeon() as tsurgeon_processor:
for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees):
con_tree = convert_it_vit.raw_tree(con_sentence[1])
updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor)
assert str(updated_tree) == expected_tree
CON_PERCENT_SAMPLE = """
ID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.]
ID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]]
ID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]]
ID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.]
"""
CON_PERCENT_LEAVES = [
['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'],
['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'],
['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', "dell'", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'],
# the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset
['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.']
]
def test_read_percent():
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE))
assert len(con_sentences) == len(CON_PERCENT_LEAVES)
for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES):
tree = convert_it_vit.raw_tree(raw_tree)
words = tree.leaf_labels()
if expected_leaves is None:
print(words)
else:
assert words == expected_leaves
================================================
FILE: stanza/tests/constituency/test_convert_starlang.py
================================================
"""
Test a couple different classes of trees to check the output of the Starlang conversion
"""
import os
import tempfile
import pytest
from stanza.utils.datasets.constituency import convert_starlang
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
def test_read_tree():
"""
Test a basic tree read
"""
tree = convert_starlang.read_tree(TREE)
assert "(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))" == str(tree)
def test_missing_word():
"""
Test that an error is thrown if the word is missing
"""
tree_text = TREE.replace("turkish=", "foo=")
with pytest.raises(ValueError):
tree = convert_starlang.read_tree(tree_text)
def test_bad_label():
"""
Test that an unexpected label results in an error
"""
tree_text = TREE.replace("(S", "(s")
with pytest.raises(ValueError):
tree = convert_starlang.read_tree(tree_text)
================================================
FILE: stanza/tests/constituency/test_ensemble.py
================================================
"""
Add a simple test of the Ensemble's inference path
This just reuses one model several times - that should still check the main loop, at least
"""
import pytest
from stanza import Pipeline
from stanza.models.constituency import text_processing
from stanza.models.constituency import tree_reader
from stanza.models.constituency.ensemble import Ensemble, EnsembleTrainer
from stanza.models.constituency.text_processing import parse_tokenized_sentences
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def pipeline():
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
@pytest.fixture(scope="module")
def saved_ensemble(tmp_path_factory, pipeline):
tmp_path = tmp_path_factory.mktemp("ensemble")
# test the ensemble by reusing the same parser multiple times
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = dict(model.args)
foundation_cache = pipeline.foundation_cache
model_path = con_processor._config['model_path']
# reuse the same model 3 times just to make sure the code paths are working
filenames = [model_path, model_path, model_path]
ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
save_path = tmp_path / "ensemble.pt"
ensemble.save(save_path)
return ensemble, save_path, args, foundation_cache
def check_basic_predictions(trees):
predictions = [x.predictions for x in trees]
assert len(predictions) == 2
assert all(len(x) == 1 for x in predictions)
trees = [x[0].tree for x in predictions]
result = ["{}".format(tree) for tree in trees]
expected = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
assert result == expected
def test_ensemble_inference(pipeline):
# test the ensemble by reusing the same parser multiple times
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = dict(model.args)
foundation_cache = pipeline.foundation_cache
model_path = con_processor._config['model_path']
# reuse the same model 3 times just to make sure the code paths are working
filenames = [model_path, model_path, model_path]
ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
ensemble = ensemble.model
sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
trees = parse_tokenized_sentences(args, ensemble, [pipeline], sentences)
check_basic_predictions(trees)
def test_ensemble_save(saved_ensemble):
"""
Depending on the saved_ensemble fixture should be enough to ensure
that the ensemble was correctly saved
(loading is tested separately)
"""
def test_ensemble_save_load(pipeline, saved_ensemble):
_, save_path, args, foundation_cache = saved_ensemble
ensemble = EnsembleTrainer.load(save_path, args, foundation_cache=foundation_cache)
sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
trees = parse_tokenized_sentences(args, ensemble.model, [pipeline], sentences)
check_basic_predictions(trees)
def test_parse_text(tmp_path, pipeline, saved_ensemble):
_, model_path, args, foundation_cache = saved_ensemble
raw_file = str(tmp_path / "test_input.txt")
with open(raw_file, "w") as fout:
fout.write("This is a test\nThis is another test\n")
output_file = str(tmp_path / "test_output.txt")
args = dict(args)
args['tokenized_file'] = raw_file
args['predict_file'] = output_file
text_processing.load_model_parse_text(args, model_path, [pipeline])
trees = tree_reader.read_treebank(output_file)
trees = ["{}".format(x) for x in trees]
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
assert trees == expected_trees
def test_pipeline(saved_ensemble):
_, model_path, _, foundation_cache = saved_ensemble
nlp = Pipeline("en", processors="tokenize,pos,constituency", constituency_model_path=str(model_path), foundation_cache=foundation_cache, download_method=None)
doc = nlp("This is a test")
tree = "{}".format(doc.sentences[0].constituency)
assert tree == "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
================================================
FILE: stanza/tests/constituency/test_in_order_compound_oracle.py
================================================
import pytest
from stanza.models.constituency import in_order_compound_oracle
from stanza.models.constituency import tree_reader
from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
from stanza.models.constituency.transition_sequence import build_treebank
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# A sample tree from PTB with a triple unary transition (at a location other than root)
# Here we test the incorrect closing of various brackets
TRIPLE_UNARY_START_TREE = """
( (S
(PRN
(S
(NP-SBJ (-NONE- *) )
(VP (VB See) )))
(, ,)
(NP-SBJ
(NP (DT the) (JJ other) (NN rule) )
(PP (IN of)
(NP (NN thumb) ))
(PP (IN about)
(NP (NN ballooning) )))))
"""
TREES = [TRIPLE_UNARY_START_TREE]
TREEBANK = "\n".join(TREES)
ROOT_LABELS = ["ROOT"]
@pytest.fixture(scope="module")
def trees():
trees = tree_reader.read_trees(TREEBANK)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == len(TREES)
return trees
@pytest.fixture(scope="module")
def gold_sequences(trees):
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
return gold_sequences
def get_repairs(gold_sequence, wrong_transition, repair_fn):
"""
Use the repair function and the wrong transition to iterate over the gold sequence
Returns a list of possible repairs, one for each position in the sequence
Repairs are tuples, (idx, seq)
"""
repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
for idx, gold_transition in enumerate(gold_sequence)]
repairs = [x for x in repairs if x[1] is not None]
return repairs
def test_fix_shift_close():
trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == 1
tree = trees[0]
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
# there are three places in this tree where a long bracket (more than 2 subtrees)
# could theoretically be closed and then reopened
repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_shift_close_error)
assert len(repairs) == 3
expected_trees = ["(ROOT (S (S (PRN (S (VP (VB See)))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other)) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)))) (PP (IN about) (NP (NN ballooning))))))"]
for repair, expected in zip(repairs, expected_trees):
repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
assert str(repaired_tree) == expected
def test_fix_open_close():
trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == 1
tree = trees[0]
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_open_close_error)
print("------------------")
for repair in repairs:
print(repair)
repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
print("{:P}".format(repaired_tree))
================================================
FILE: stanza/tests/constituency/test_in_order_oracle.py
================================================
import itertools
import pytest
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import tree_reader
from stanza.models.constituency.base_model import SimpleModel
from stanza.models.constituency.in_order_oracle import *
from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
from stanza.models.constituency.transition_sequence import build_treebank
from stanza.tests import *
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# A sample tree from PTB with a single unary transition (at a location other than root)
SINGLE_UNARY_TREE = """
( (S
(NP-SBJ-1 (DT A) (NN record) (NN date) )
(VP (VBZ has) (RB n't)
(VP (VBN been)
(VP (VBN set)
(NP (-NONE- *-1) ))))
(. .) ))
"""
# [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent]
# A sample tree from PTB with a double unary transition (at a location other than root)
DOUBLE_UNARY_TREE = """
( (S
(NP-SBJ
(NP (RB Not) (PDT all) (DT those) )
(SBAR
(WHNP-3 (WP who) )
(S
(NP-SBJ (-NONE- *T*-3) )
(VP (VBD wrote) ))))
(VP (VBP oppose)
(NP (DT the) (NNS changes) ))
(. .) ))
"""
# A sample tree from PTB with a triple unary transition (at a location other than root)
# The triple unary is at the START of the next bracket, which affects how the
# dynamic oracle repairs the transition sequence
TRIPLE_UNARY_START_TREE = """
( (S
(PRN
(S
(NP-SBJ (-NONE- *) )
(VP (VB See) )))
(, ,)
(NP-SBJ
(NP (DT the) (JJ other) (NN rule) )
(PP (IN of)
(NP (NN thumb) ))
(PP (IN about)
(NP (NN ballooning) )))))
"""
# A sample tree from PTB with a triple unary transition (at a location other than root)
# The triple unary is at the END of the next bracket, which affects how the
# dynamic oracle repairs the transition sequence
TRIPLE_UNARY_END_TREE = """
( (S
(NP (NNS optimists) )
(VP (VBP expect)
(S
(NP-SBJ-4 (NNP Hong) (NNP Kong) )
(VP (TO to)
(VP (VB hum)
(ADVP-CLR (RB along) )
(SBAR-MNR (RB as)
(S
(NP-SBJ (-NONE- *-4) )
(VP (-NONE- *?*)
(ADVP-TMP (IN before) ))))))))))
"""
TREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE]
TREEBANK = "\n".join(TREES)
NOUN_PHRASE_TREE = """
( (NP
(NP (NNP Chicago) (POS 's))
(NNP Goodman)
(NNP Theatre)))
"""
WIDE_NP_TREE = """
( (S
(NP-SBJ (DT These) (NNS studies))
(VP (VBP demonstrate)
(SBAR (IN that)
(S
(NP-SBJ (NNS mice))
(VP (VBP are)
(NP-PRD
(NP (DT a)
(ADJP (JJ practical)
(CC and)
(JJ powerful))
(JJ experimental) (NN system))
(SBAR
(WHADVP-2 (-NONE- *0*))
(S
(NP-SBJ (-NONE- *PRO*))
(VP (TO to)
(VP (VB study)
(NP (DT the) (NN genetics)))))))))))))
"""
WIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE]
WIDE_TREEBANK = "\n".join(WIDE_TREES)
ROOT_LABELS = ["ROOT"]
def get_repairs(gold_sequence, wrong_transition, repair_fn):
"""
Use the repair function and the wrong transition to iterate over the gold sequence
Returns a list of possible repairs, one for each position in the sequence
Repairs are tuples, (idx, seq)
"""
repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
for idx, gold_transition in enumerate(gold_sequence)]
repairs = [x for x in repairs if x[1] is not None]
return repairs
@pytest.fixture(scope="module")
def unary_trees():
trees = tree_reader.read_trees(TREEBANK)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == len(TREES)
return trees
@pytest.fixture(scope="module")
def gold_sequences(unary_trees):
gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER)
return gold_sequences
@pytest.fixture(scope="module")
def wide_trees():
trees = tree_reader.read_trees(WIDE_TREEBANK)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == len(WIDE_TREES)
return trees
def test_wrong_open_root(gold_sequences):
"""
Test the results of the dynamic oracle on a few trees if the ROOT is mishandled.
"""
wrong_transition = OpenConstituent("S")
gold_transition = OpenConstituent("ROOT")
close_transition = CloseConstituent()
for gold_sequence in gold_sequences:
# each of the sequences should be ended with ROOT, Close
assert gold_sequence[-2] == gold_transition
repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error)
# there is only spot in the sequence with a ROOT, so there should
# be exactly one location which affords a S/ROOT replacement
assert len(repairs) == 1
repair = repairs[0]
# the repair should occur at the -2 position, which is where ROOT is
assert repair[0] == len(gold_sequence) - 2
# and the resulting list should have the wrong transition followed by a Close
# to give the model another chance to close the tree
expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:]
assert repair[1] == expected
def test_missed_unary(gold_sequences):
"""
Test the repairs of an open/open error if it is effectively a skipped unary transition
"""
wrong_transition = OpenConstituent("S")
repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain)
assert len(repairs) == 0
# here we are simulating picking NT-S instead of NT-VP
# the DOUBLE_UNARY tree has one location where this is relevant, index 11
repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain)
assert len(repairs) == 1
assert repairs[0][0] == 11
assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:]
# the TRIPLE_UNARY_START tree has two locations where this is relevant
# at index 1, the pattern goes (S (VP ...))
# so choosing S instead of VP means you can skip the VP and only miss that one bracket
# at index 5, the pattern goes (S (PRN (S (VP ...))) (...))
# note that this is capturing a unary transition into a larger constituent
# skipping the PRN is satisfactory
repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain)
assert len(repairs) == 2
assert repairs[0][0] == 1
assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:]
assert repairs[1][0] == 5
assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:]
# The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations
# where the repair might happen
# Surprisingly the unary transition at the very start can only be
# repaired by skipping it and using the outer S transition instead
# The second repair overall (first repair in the second location)
# should have a double skip to reach the S node
repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain)
assert len(repairs) == 3
assert repairs[0][0] == 1
assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:]
assert repairs[1][0] == 21
assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:]
assert repairs[2][0] == 23
assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:]
def test_open_with_stuff(unary_trees, gold_sequences):
wrong_transition = OpenConstituent("S")
expected_trees = [
"(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))",
"(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))",
None,
"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
]
for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary)
if expected is None:
assert len(repairs) == 0
else:
assert len(repairs) == 1
result = reconstruct_tree(tree, repairs[0][1])
assert str(result) == expected
def test_general_open(gold_sequences):
wrong_transition = OpenConstituent("SBARQ")
for sequence in gold_sequences:
repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general)
assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1
for repair in repairs:
assert len(repair[1]) == len(sequence)
assert repair[1][repair[0]] == wrong_transition
assert repair[1][:repair[0]] == sequence[:repair[0]]
assert repair[1][repair[0]+1:] == sequence[repair[0]+1:]
def test_missed_unary(unary_trees, gold_sequences):
shift_transition = Shift()
close_transition = CloseConstituent()
expected_close_results = [
[(12, 2)],
[(11, 4), (13, 2)],
# (NP NN thumb) and (NP NN ballooning) are both candidates for this repair
[(18, 2), (24, 2)],
[(21, 6), (23, 4), (25, 2)],
]
expected_shift_results = [
(),
(),
(),
# (ADVP-CLR (RB along)) is followed by a shift
[(16, 2)],
]
for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results):
repairs = get_repairs(sequence, close_transition, fix_missed_unary)
assert len(repairs) == len(expected_close)
for repair, (expected_idx, expected_len) in zip(repairs, expected_close):
assert repair[0] == expected_idx
assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
repairs = get_repairs(sequence, shift_transition, fix_missed_unary)
assert len(repairs) == len(expected_shift)
for repair, (expected_idx, expected_len) in zip(repairs, expected_shift):
assert repair[0] == expected_idx
assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
def test_open_shift(unary_trees, gold_sequences):
shift_transition = Shift()
expected_repairs = [
[(7, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
(10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))")],
[(7, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
(9, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
(19, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))"),
(21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))")],
[(14, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
(16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
(22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))")],
[(5, "(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(10, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))")]
]
for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
repairs = get_repairs(sequence, shift_transition, fix_open_shift)
assert len(repairs) == len(expected)
for repair, (idx, expected_tree) in zip(repairs, expected):
assert repair[0] == idx
result_tree = reconstruct_tree(tree, repair[1])
assert str(result_tree) == expected_tree
def test_open_close(unary_trees, gold_sequences):
close_transition = CloseConstituent()
expected_repairs = [
[(7, "(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
(10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))")],
# missed the WHNP. The surrounding SBAR cannot be created, either
[(7, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
# missed the SBAR
(9, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
# missed the VP around "oppose the changes"
(19, "(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))"),
# missed the NP in "the changes", looks pretty bad tbh
(21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))")],
[(14, "(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
(16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
(22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))")],
[(5, "(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(10, "(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
(19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))")]
]
for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
repairs = get_repairs(sequence, close_transition, fix_open_close)
assert len(repairs) == len(expected)
for repair, (idx, expected_tree) in zip(repairs, expected):
assert repair[0] == idx
result_tree = reconstruct_tree(tree, repair[1])
assert str(result_tree) == expected_tree
def test_shift_close(unary_trees, gold_sequences):
"""
Test the fix for a shift -> close
These errors can occur pretty much everywhere, and the fix is quite simple,
so we only test a few cases.
"""
close_transition = CloseConstituent()
expected_tree = "(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))"
repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close)
assert len(repairs) == 7
result_tree = reconstruct_tree(unary_trees[0], repairs[0][1])
assert str(result_tree) == expected_tree
repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close)
assert len(repairs) == 8
repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close)
assert len(repairs) == 8
repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close)
assert len(repairs) == 9
for rep in repairs:
if rep[0] == 16:
# This one is special because it occurs as part of a unary
# in other words, it should go unary, shift
# and instead we are making it close where the unary should be
# ... the unary would create "(ADVP (RB along))"
expected_tree = "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
result_tree = reconstruct_tree(unary_trees[3], rep[1])
assert str(result_tree) == expected_tree
break
else:
raise AssertionError("Did not find an expected repair location")
def test_close_open_shift_nested(unary_trees, gold_sequences):
shift_transition = Shift()
expected_trees = [{},
{4: "(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
{4: "(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
13: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
{}]
for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested)
assert len(repairs) == len(expected)
if len(expected) >= 1:
for repair in repairs:
assert repair[0] in expected.keys()
result_tree = reconstruct_tree(tree, repair[1])
assert str(result_tree) == expected[repair[0]]
def check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn):
for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)):
repairs = get_repairs(gold_sequence, transition, repair_fn)
if expected is not None:
assert len(repairs) == len(expected)
for repair in repairs:
assert repair[0] in expected
result_tree = reconstruct_tree(gold_tree, repair[1])
assert str(result_tree) == expected[repair[0]]
else:
print("---------------------")
print("{:P}".format(gold_tree))
print(gold_sequence)
#print(repairs)
for repair in repairs:
print("---------------------")
print(gold_sequence)
print(repair[1])
result_tree = reconstruct_tree(gold_tree, repair[1])
print("{:P}".format(gold_tree))
print("{:P}".format(result_tree))
print(tree_idx)
print(repair[0])
print(result_tree)
def test_close_open_shift_unambiguous(unary_trees, gold_sequences):
shift_transition = Shift()
expected_trees = [{},
{8: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
{},
{2: "(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))",
9: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}]
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket)
def test_close_open_shift_ambiguous_early(unary_trees, gold_sequences):
shift_transition = Shift()
expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))"},
{16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))"},
{2: "(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
6: "(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
{}]
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early)
def test_close_open_shift_ambiguous_late(unary_trees, gold_sequences):
shift_transition = Shift()
expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))"},
{16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))"},
{2: "(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))",
6: "(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))"},
{}]
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late)
def test_close_shift_shift(unary_trees, wide_trees):
"""
Test that close -> shift works when there is a single block shifted after
Includes a test specifically that there is no oracle action when there are two blocks after the missed close
"""
shift_transition = Shift()
expected_trees = [{15: "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))"},
{24: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))"},
{20: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))"},
{17: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"},
{},
{}]
test_trees = unary_trees + wide_trees
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous)
def test_close_shift_shift_early(unary_trees, wide_trees):
"""
Test that close -> shift works when there are multiple blocks shifted after
Also checks that the single block case is skipped, so as to keep them separate when testing
A tree with the expected property was specifically added for this test
"""
shift_transition = Shift()
test_trees = unary_trees + wide_trees
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
expected_trees = [{},
{},
{},
{},
{},
{21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early)
def test_close_shift_shift_late(unary_trees, wide_trees):
"""
Test that close -> shift works when there are multiple blocks shifted after
Also checks that the single block case is skipped, so as to keep them separate when testing
A tree with the expected property was specifically added for this test
"""
shift_transition = Shift()
test_trees = unary_trees + wide_trees
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
expected_trees = [{},
{},
{},
{},
{},
{21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late)
================================================
FILE: stanza/tests/constituency/test_lstm_model.py
================================================
import os
import pytest
import torch
from stanza.models.common import pretrain
from stanza.models.common.utils import set_random_seed
from stanza.models.constituency import parse_transitions
from stanza.tests import *
from stanza.tests.constituency import test_parse_transitions
from stanza.tests.constituency.test_trainer import build_trainer
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def pretrain_file():
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
def build_model(pretrain_file, *args):
# By default, we turn off multistage, since that can turn off various other structures in the initial training
args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args)
trainer = build_trainer(pretrain_file, *args)
return trainer.model
@pytest.fixture(scope="module")
def unary_model(pretrain_file):
return build_model(pretrain_file, "--transition_scheme", "TOP_DOWN_UNARY")
def test_initial_state(unary_model):
test_parse_transitions.test_initial_state(unary_model)
def test_shift(pretrain_file):
# TODO: might be good to include some tests specifically for shift
# in the context of a model with unaries
model = build_model(pretrain_file)
test_parse_transitions.test_shift(model)
def test_unary(unary_model):
test_parse_transitions.test_unary(unary_model)
def test_unary_requires_root(unary_model):
test_parse_transitions.test_unary_requires_root(unary_model)
def test_open(unary_model):
test_parse_transitions.test_open(unary_model)
def test_compound_open(pretrain_file):
model = build_model(pretrain_file, '--transition_scheme', "TOP_DOWN_COMPOUND")
test_parse_transitions.test_compound_open(model)
def test_in_order_open(pretrain_file):
model = build_model(pretrain_file, '--transition_scheme', "IN_ORDER")
test_parse_transitions.test_in_order_open(model)
def test_close(unary_model):
test_parse_transitions.test_close(unary_model)
def run_forward_checks(model, num_states=1):
"""
Run a couple small transitions and a forward pass on the given model
Results are not checked in any way. This function allows for
testing that building models with various options results in a
functional model.
"""
states = test_parse_transitions.build_initial_state(model, num_states)
model(states)
shift = parse_transitions.Shift()
shifts = [shift for _ in range(num_states)]
states = model.bulk_apply(states, shifts)
model(states)
open_transition = parse_transitions.OpenConstituent("NP")
open_transitions = [open_transition for _ in range(num_states)]
assert open_transition.is_legal(states[0], model)
states = model.bulk_apply(states, open_transitions)
assert states[0].num_opens == 1
model(states)
states = model.bulk_apply(states, shifts)
model(states)
states = model.bulk_apply(states, shifts)
model(states)
assert states[0].num_opens == 1
# now should have "mox", "opal" on the constituents
close_transition = parse_transitions.CloseConstituent()
close_transitions = [close_transition for _ in range(num_states)]
assert close_transition.is_legal(states[0], model)
states = model.bulk_apply(states, close_transitions)
assert states[0].num_opens == 0
model(states)
def test_unary_forward(unary_model):
"""
Checks that the forward pass doesn't crash when run after various operations
Doesn't check the forward pass for making reasonable answers
"""
run_forward_checks(unary_model)
def test_lstm_forward(pretrain_file):
model = build_model(pretrain_file)
run_forward_checks(model, num_states=1)
run_forward_checks(model, num_states=2)
def test_lstm_layers(pretrain_file):
model = build_model(pretrain_file, '--num_lstm_layers', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_lstm_layers', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_lstm_layers', '3')
run_forward_checks(model)
def test_multiple_output_forward(pretrain_file):
"""
Test a couple different sizes of output layers
"""
model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2')
run_forward_checks(model)
def test_no_tag_embedding_forward(pretrain_file):
"""
Test that the model continues to work if the tag embedding is turned on or off
"""
model = build_model(pretrain_file, '--tag_embedding_dim', '20')
run_forward_checks(model)
model = build_model(pretrain_file, '--tag_embedding_dim', '0')
run_forward_checks(model)
def test_forward_combined_dummy(pretrain_file):
"""
Tests combined dummy and open node embeddings
"""
model = build_model(pretrain_file, '--combined_dummy_embedding')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_combined_dummy_embedding')
run_forward_checks(model)
def test_nonlinearity_init(pretrain_file):
"""
Tests that different initialization methods of the nonlinearities result in valid tensors
"""
model = build_model(pretrain_file, '--nonlinearity', 'relu')
run_forward_checks(model)
model = build_model(pretrain_file, '--nonlinearity', 'tanh')
run_forward_checks(model)
model = build_model(pretrain_file, '--nonlinearity', 'silu')
run_forward_checks(model)
def test_forward_charlm(pretrain_file):
"""
Tests loading and running a charlm
Note that this doesn't test the results of the charlm itself,
just that the model is shaped correctly
"""
forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt")
backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt")
assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)"
assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)"
model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none')
run_forward_checks(model)
model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words')
run_forward_checks(model)
def test_forward_bert(pretrain_file):
"""
Test on a tiny Bert, which hopefully does not take up too much disk space or memory
"""
bert_model = "hf-internal-testing/tiny-bert"
model = build_model(pretrain_file, '--bert_model', bert_model)
run_forward_checks(model)
def test_forward_xlnet(pretrain_file):
"""
Test on a tiny xlnet, which hopefully does not take up too much disk space or memory
"""
bert_model = "hf-internal-testing/tiny-random-xlnet"
model = build_model(pretrain_file, '--bert_model', bert_model)
run_forward_checks(model)
def test_forward_sentence_boundaries(pretrain_file):
"""
Test start & stop boundary vectors
"""
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything')
run_forward_checks(model)
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words')
run_forward_checks(model)
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none')
run_forward_checks(model)
def test_forward_constituency_composition(pretrain_file):
"""
Test different constituency composition functions
"""
model = build_model(pretrain_file, '--constituency_composition', 'bilstm')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'max')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'key')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'untied_key')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'untied_max')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'bigram')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'attn')
run_forward_checks(model, num_states=2)
def test_forward_key_position(pretrain_file):
"""
Test KEY and UNTIED_KEY either with or without reduce_position
"""
model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0')
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32')
run_forward_checks(model, num_states=2)
def test_forward_attn_hidden_size(pretrain_file):
"""
Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size
"""
model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129')
assert model.hidden_size >= 129
assert model.hidden_size % model.reduce_heads == 0
run_forward_checks(model, num_states=2)
model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10')
assert model.hidden_size == 130
assert model.reduce_heads == 10
def test_forward_partitioned_attention(pretrain_file):
"""
Test with & without partitioned attention layers
"""
model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8')
run_forward_checks(model)
model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0')
run_forward_checks(model)
def test_forward_labeled_attention(pretrain_file):
"""
Test with & without labeled attention layers
"""
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16')
run_forward_checks(model)
model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0')
run_forward_checks(model)
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input')
run_forward_checks(model)
def test_lattn_partitioned(pretrain_file):
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned')
run_forward_checks(model)
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned')
run_forward_checks(model)
def test_lattn_projection(pretrain_file):
"""
Test with & without labeled attention layers
"""
with pytest.raises(ValueError):
# this is too small
model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned')
run_forward_checks(model)
model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256')
run_forward_checks(model)
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768')
run_forward_checks(model)
# check that it works if we turn off the projection,
# in case having it on beccomes the default
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0')
run_forward_checks(model)
def test_forward_timing_choices(pretrain_file):
"""
Test different timing / position encodings
"""
model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin')
run_forward_checks(model)
model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned')
run_forward_checks(model)
def test_transition_stack(pretrain_file):
"""
Test different transition stack types: lstm & attention
"""
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_stack', 'attn', '--transition_heads', '1')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_stack', 'attn', '--transition_heads', '4')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_stack', 'lstm')
run_forward_checks(model)
def test_constituent_stack(pretrain_file):
"""
Test different constituent stack types: lstm & attention
"""
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--constituent_stack', 'attn', '--constituent_heads', '1')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--constituent_stack', 'attn', '--constituent_heads', '4')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--constituent_stack', 'lstm')
run_forward_checks(model)
def test_different_transition_sizes(pretrain_file):
"""
If the transition hidden size and embedding size are different, the model should still work
"""
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '10', '--transition_hidden_size', '10',
'--sentence_boundary_vectors', 'everything')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '20', '--transition_hidden_size', '10',
'--sentence_boundary_vectors', 'everything')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '10', '--transition_hidden_size', '20',
'--sentence_boundary_vectors', 'everything')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '10', '--transition_hidden_size', '10',
'--sentence_boundary_vectors', 'none')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '20', '--transition_hidden_size', '10',
'--sentence_boundary_vectors', 'none')
run_forward_checks(model)
model = build_model(pretrain_file,
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
'--transition_embedding_dim', '10', '--transition_hidden_size', '20',
'--sentence_boundary_vectors', 'none')
run_forward_checks(model)
def test_relative_attention(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat')
run_forward_checks(model)
def test_relative_attention_cat(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat')
run_forward_checks(model)
cat_size = model.word_input_size
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat')
run_forward_checks(model)
no_cat_size = model.word_input_size
assert cat_size > no_cat_size
def test_relative_attention_directional(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_forward', '--no_rattn_cat')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_reverse', '--no_rattn_cat')
run_forward_checks(model)
def test_relative_attention_sinks(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '2')
run_forward_checks(model)
def test_relative_attention_cat_sinks(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '2')
run_forward_checks(model)
def test_relative_attention_endpoint_sinks(pretrain_file):
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '1')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '2')
run_forward_checks(model)
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '2')
run_forward_checks(model)
def test_lstm_tree_forward(pretrain_file):
"""
Test the LSTM_TREE forward pass
"""
model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm')
run_forward_checks(model)
def test_lstm_tree_cx_forward(pretrain_file):
"""
Test the LSTM_TREE_CX forward pass
"""
model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx')
run_forward_checks(model)
model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx')
run_forward_checks(model)
def test_maxout(pretrain_file):
"""
Test with and without maxout layers for output
"""
model = build_model(pretrain_file, '--maxout_k', '0')
run_forward_checks(model)
# check the output size & implicitly check the type
# to check for a particularly silly bug
assert model.output_layers[-1].weight.shape[0] == len(model.transitions)
model = build_model(pretrain_file, '--maxout_k', '2')
run_forward_checks(model)
assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2
model = build_model(pretrain_file, '--maxout_k', '3')
run_forward_checks(model)
assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3
def check_structure_test(pretrain_file, args1, args2):
"""
Test that the "copy" method copies the parameters from one model to another
Also check that the copied models produce the same results
"""
set_random_seed(1000)
other = build_model(pretrain_file, *args1)
other.eval()
set_random_seed(1001)
model = build_model(pretrain_file, *args2)
model.eval()
assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
model.copy_with_new_structure(other)
assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
# the norms will be the same, as the non-zero values are all the same
assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))
# now, check that applying one transition to an initial state
# results in the same values in the output states for both models
# as the pattn layer inputs are 0, the output values should be equal
shift = [parse_transitions.Shift()]
model_states = test_parse_transitions.build_initial_state(model, 1)
model_states = model.bulk_apply(model_states, shift)
other_states = test_parse_transitions.build_initial_state(other, 1)
other_states = other.bulk_apply(other_states, shift)
for i, j in zip(other_states[0].word_queue, model_states[0].word_queue):
assert torch.allclose(i.hx, j.hx, atol=1e-07)
for i, j in zip(other_states[0].transitions, model_states[0].transitions):
assert torch.allclose(i.lstm_hx, j.lstm_hx)
assert torch.allclose(i.lstm_cx, j.lstm_cx)
for i, j in zip(other_states[0].constituents, model_states[0].constituents):
assert (i.value is None) == (j.value is None)
if i.value is not None:
assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07)
assert torch.allclose(i.lstm_hx, j.lstm_hx)
assert torch.allclose(i.lstm_cx, j.lstm_cx)
def test_copy_with_new_structure_same(pretrain_file):
"""
Test that copying the structure with no changes works as expected
"""
check_structure_test(pretrain_file,
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'])
def test_copy_with_new_structure_untied(pretrain_file):
"""
Test that copying the structure with no changes works as expected
"""
check_structure_test(pretrain_file,
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'],
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX'])
def test_copy_with_new_structure_pattn(pretrain_file):
check_structure_test(pretrain_file,
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
def test_copy_with_new_structure_both(pretrain_file):
check_structure_test(pretrain_file,
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
def test_copy_with_new_structure_lattn(pretrain_file):
check_structure_test(pretrain_file,
['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'],
['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
def test_parse_tagged_words(pretrain_file):
"""
Small test which doesn't check results, just execution
"""
model = build_model(pretrain_file)
sentence = [("I", "PRP"), ("am", "VBZ"), ("Luffa", "NNP")]
# we don't expect a useful tree out of a random model
# so we don't check the result
# just check that it works without crashing
result = model.parse_tagged_words([sentence], 10)
assert len(result) == 1
pts = [x for x in result[0].yield_preterminals()]
for word, pt in zip(sentence, pts):
assert pt.children[0].label == word[0]
assert pt.label == word[1]
================================================
FILE: stanza/tests/constituency/test_parse_transitions.py
================================================
import pytest
from stanza.models.constituency import parse_transitions
from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def build_initial_state(model, num_states=1):
words = ["Unban", "Mox", "Opal"]
tags = ["VB", "NNP", "NNP"]
sentences = [list(zip(words, tags)) for _ in range(num_states)]
states = model.initial_state_from_words(sentences)
assert len(states) == num_states
assert all(state.num_transitions == 0 for state in states)
return states
def test_initial_state(model=None):
if model is None:
model = SimpleModel()
states = build_initial_state(model)
assert len(states) == 1
state = states[0]
assert state.sentence_length == 3
assert state.num_opens == 0
# each stack has a sentinel value at the end
assert len(state.word_queue) == 5
assert len(state.constituents) == 1
assert len(state.transitions) == 1
assert state.word_position == 0
def test_shift(model=None):
if model is None:
model = SimpleModel()
state = build_initial_state(model)[0]
open_transition = parse_transitions.OpenConstituent("ROOT")
state = open_transition.apply(state, model)
open_transition = parse_transitions.OpenConstituent("S")
state = open_transition.apply(state, model)
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
assert len(state.word_queue) == 5
assert state.word_position == 0
state = shift.apply(state, model)
assert len(state.word_queue) == 5
# 4 because of the dummy created by the opens
assert len(state.constituents) == 4
assert len(state.transitions) == 4
assert shift.is_legal(state, model)
assert state.word_position == 1
assert not state.empty_word_queue()
state = shift.apply(state, model)
assert len(state.word_queue) == 5
assert len(state.constituents) == 5
assert len(state.transitions) == 5
assert shift.is_legal(state, model)
assert state.word_position == 2
assert not state.empty_word_queue()
state = shift.apply(state, model)
assert len(state.word_queue) == 5
assert len(state.constituents) == 6
assert len(state.transitions) == 6
assert not shift.is_legal(state, model)
assert state.word_position == 3
assert state.empty_word_queue()
constituents = state.constituents
assert model.get_top_constituent(constituents).children[0].label == 'Opal'
constituents = constituents.pop()
assert model.get_top_constituent(constituents).children[0].label == 'Mox'
constituents = constituents.pop()
assert model.get_top_constituent(constituents).children[0].label == 'Unban'
def test_initial_unary(model=None):
# it doesn't make sense to start with a CompoundUnary
if model is None:
model = SimpleModel()
state = build_initial_state(model)[0]
unary = parse_transitions.CompoundUnary('ROOT', 'VP')
assert unary.label == ('ROOT', 'VP',)
assert not unary.is_legal(state, model)
unary = parse_transitions.CompoundUnary('VP')
assert unary.label == ('VP',)
assert not unary.is_legal(state, model)
def test_unary(model=None):
if model is None:
model = SimpleModel()
state = build_initial_state(model)[0]
shift = parse_transitions.Shift()
state = shift.apply(state, model)
# this is technically the wrong parse but we're being lazy
unary = parse_transitions.CompoundUnary('S', 'VP')
assert unary.is_legal(state, model)
state = unary.apply(state, model)
assert not unary.is_legal(state, model)
tree = model.get_top_constituent(state.constituents)
assert tree.label == 'S'
assert len(tree.children) == 1
tree = tree.children[0]
assert tree.label == 'VP'
assert len(tree.children) == 1
tree = tree.children[0]
assert tree.label == 'VB'
assert tree.is_preterminal()
def test_unary_requires_root(model=None):
if model is None:
model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
state = build_initial_state(model)[0]
open_transition = parse_transitions.OpenConstituent("S")
assert open_transition.is_legal(state, model)
state = open_transition.apply(state, model)
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert not shift.is_legal(state, model)
close_transition = parse_transitions.CloseConstituent()
assert close_transition.is_legal(state, model)
state = close_transition.apply(state, model)
assert not open_transition.is_legal(state, model)
assert not close_transition.is_legal(state, model)
np_unary = parse_transitions.CompoundUnary("NP")
assert not np_unary.is_legal(state, model)
root_unary = parse_transitions.CompoundUnary("ROOT")
assert root_unary.is_legal(state, model)
assert not state.finished(model)
state = root_unary.apply(state, model)
assert not root_unary.is_legal(state, model)
assert state.finished(model)
def test_open(model=None):
if model is None:
model = SimpleModel()
state = build_initial_state(model)[0]
shift = parse_transitions.Shift()
state = shift.apply(state, model)
state = shift.apply(state, model)
assert state.num_opens == 0
open_transition = parse_transitions.OpenConstituent("VP")
assert open_transition.is_legal(state, model)
state = open_transition.apply(state, model)
assert open_transition.is_legal(state, model)
assert state.num_opens == 1
# check that it is illegal if there are too many opens already
for i in range(20):
state = open_transition.apply(state, model)
assert not open_transition.is_legal(state, model)
assert state.num_opens == 21
# check that it is illegal if the state is out of words
state = build_initial_state(model)[0]
state = shift.apply(state, model)
state = shift.apply(state, model)
state = shift.apply(state, model)
assert not open_transition.is_legal(state, model)
def test_compound_open(model=None):
if model is None:
model = SimpleModel()
state = build_initial_state(model)[0]
open_transition = parse_transitions.OpenConstituent("ROOT", "S")
assert open_transition.is_legal(state, model)
shift = parse_transitions.Shift()
close_transition = parse_transitions.CloseConstituent()
state = open_transition.apply(state, model)
state = shift.apply(state, model)
state = shift.apply(state, model)
state = shift.apply(state, model)
state = close_transition.apply(state, model)
tree = model.get_top_constituent(state.constituents)
assert tree.label == 'ROOT'
assert len(tree.children) == 1
tree = tree.children[0]
assert tree.label == 'S'
assert len(tree.children) == 3
assert tree.children[0].children[0].label == 'Unban'
assert tree.children[1].children[0].label == 'Mox'
assert tree.children[2].children[0].label == 'Opal'
def test_in_order_open(model=None):
if model is None:
model = SimpleModel(TransitionScheme.IN_ORDER)
state = build_initial_state(model)[0]
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert not shift.is_legal(state, model)
open_vp = parse_transitions.OpenConstituent("VP")
assert open_vp.is_legal(state, model)
state = open_vp.apply(state, model)
assert not open_vp.is_legal(state, model)
close_trans = parse_transitions.CloseConstituent()
assert close_trans.is_legal(state, model)
state = close_trans.apply(state, model)
open_s = parse_transitions.OpenConstituent("S")
assert open_s.is_legal(state, model)
state = open_s.apply(state, model)
assert not open_vp.is_legal(state, model)
# check that root transitions won't happen in the middle of a parse
open_root = parse_transitions.OpenConstituent("ROOT")
assert not open_root.is_legal(state, model)
# build (NP (NNP Mox) (NNP Opal))
open_np = parse_transitions.OpenConstituent("NP")
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert open_np.is_legal(state, model)
# make sure root can't happen in places where an arbitrary open is legal
assert not open_root.is_legal(state, model)
state = open_np.apply(state, model)
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert close_trans.is_legal(state, model)
state = close_trans.apply(state, model)
assert close_trans.is_legal(state, model)
state = close_trans.apply(state, model)
assert open_root.is_legal(state, model)
state = open_root.apply(state, model)
def test_too_many_unaries_close():
"""
This tests rejecting Close at the start of a sequence after too many unary transitions
The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
"""
model = SimpleModel(TransitionScheme.IN_ORDER)
state = build_initial_state(model)[0]
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
state = shift.apply(state, model)
open_np = parse_transitions.OpenConstituent("NP")
close_trans = parse_transitions.CloseConstituent()
for _ in range(UNARY_LIMIT):
assert open_np.is_legal(state, model)
state = open_np.apply(state, model)
assert close_trans.is_legal(state, model)
state = close_trans.apply(state, model)
assert open_np.is_legal(state, model)
state = open_np.apply(state, model)
assert not close_trans.is_legal(state, model)
def test_too_many_unaries_open():
"""
This tests rejecting Open in the middle of a sequence after too many unary transitions
The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
"""
model = SimpleModel(TransitionScheme.IN_ORDER)
state = build_initial_state(model)[0]
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
state = shift.apply(state, model)
open_np = parse_transitions.OpenConstituent("NP")
close_trans = parse_transitions.CloseConstituent()
assert open_np.is_legal(state, model)
state = open_np.apply(state, model)
assert not open_np.is_legal(state, model)
assert shift.is_legal(state, model)
state = shift.apply(state, model)
for _ in range(UNARY_LIMIT):
assert open_np.is_legal(state, model)
state = open_np.apply(state, model)
assert close_trans.is_legal(state, model)
state = close_trans.apply(state, model)
assert not open_np.is_legal(state, model)
def test_close(model=None):
if model is None:
model = SimpleModel()
# this one actually tests an entire subtree building
state = build_initial_state(model)[0]
open_transition_vp = parse_transitions.OpenConstituent("VP")
assert open_transition_vp.is_legal(state, model)
state = open_transition_vp.apply(state, model)
assert state.num_opens == 1
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
state = shift.apply(state, model)
open_transition_np = parse_transitions.OpenConstituent("NP")
assert open_transition_np.is_legal(state, model)
state = open_transition_np.apply(state, model)
assert state.num_opens == 2
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert shift.is_legal(state, model)
state = shift.apply(state, model)
assert not shift.is_legal(state, model)
assert state.num_opens == 2
# now should have "mox", "opal" on the constituents
close_transition = parse_transitions.CloseConstituent()
assert close_transition.is_legal(state, model)
state = close_transition.apply(state, model)
assert state.num_opens == 1
assert close_transition.is_legal(state, model)
state = close_transition.apply(state, model)
assert state.num_opens == 0
assert not close_transition.is_legal(state, model)
tree = model.get_top_constituent(state.constituents)
assert tree.label == 'VP'
assert len(tree.children) == 2
tree = tree.children[1]
assert tree.label == 'NP'
assert len(tree.children) == 2
assert tree.children[0].is_preterminal()
assert tree.children[1].is_preterminal()
assert tree.children[0].children[0].label == 'Mox'
assert tree.children[1].children[0].label == 'Opal'
# extra one for None at the start of the TreeStack
assert len(state.constituents) == 2
assert state.all_transitions(model) == [open_transition_vp, shift, open_transition_np, shift, shift, close_transition, close_transition]
def test_in_order_compound_finalize(model=None):
"""
Test the Finalize transition is only legal at the end of a sequence
"""
if model is None:
model = SimpleModel(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
state = build_initial_state(model)[0]
finalize = parse_transitions.Finalize("ROOT")
shift = parse_transitions.Shift()
assert shift.is_legal(state, model)
assert not finalize.is_legal(state, model)
state = shift.apply(state, model)
open_transition = parse_transitions.OpenConstituent("NP")
assert open_transition.is_legal(state, model)
assert not finalize.is_legal(state, model)
state = open_transition.apply(state, model)
assert state.num_opens == 1
assert shift.is_legal(state, model)
assert not finalize.is_legal(state, model)
state = shift.apply(state, model)
assert shift.is_legal(state, model)
assert not finalize.is_legal(state, model)
state = shift.apply(state, model)
close_transition = parse_transitions.CloseConstituent()
assert close_transition.is_legal(state, model)
state = close_transition.apply(state, model)
assert state.num_opens == 0
assert not close_transition.is_legal(state, model)
assert finalize.is_legal(state, model)
state = finalize.apply(state, model)
assert not finalize.is_legal(state, model)
tree = model.get_top_constituent(state.constituents)
assert tree.label == 'ROOT'
def test_hashes():
transitions = set()
shift = parse_transitions.Shift()
assert shift not in transitions
transitions.add(shift)
assert shift in transitions
shift = parse_transitions.Shift()
assert shift in transitions
for i in range(5):
transitions.add(shift)
assert len(transitions) == 1
unary = parse_transitions.CompoundUnary("asdf")
assert unary not in transitions
transitions.add(unary)
assert unary in transitions
unary = parse_transitions.CompoundUnary("asdf", "zzzz")
assert unary not in transitions
transitions.add(unary)
transitions.add(unary)
transitions.add(unary)
unary = parse_transitions.CompoundUnary("asdf", "zzzz")
assert unary in transitions
oc = parse_transitions.OpenConstituent("asdf")
assert oc not in transitions
transitions.add(oc)
assert oc in transitions
transitions.add(oc)
transitions.add(oc)
assert len(transitions) == 4
assert parse_transitions.OpenConstituent("asdf") in transitions
cc = parse_transitions.CloseConstituent()
assert cc not in transitions
transitions.add(cc)
transitions.add(cc)
transitions.add(cc)
assert cc in transitions
cc = parse_transitions.CloseConstituent()
assert cc in transitions
assert len(transitions) == 5
def test_sort():
expected = []
expected.append(parse_transitions.Shift())
expected.append(parse_transitions.CloseConstituent())
expected.append(parse_transitions.CompoundUnary("NP"))
expected.append(parse_transitions.CompoundUnary("NP", "VP"))
expected.append(parse_transitions.OpenConstituent("mox"))
expected.append(parse_transitions.OpenConstituent("opal"))
expected.append(parse_transitions.OpenConstituent("unban"))
transitions = set(expected)
transitions = sorted(transitions)
assert transitions == expected
def test_check_transitions():
"""
Test that check_transitions passes or fails a couple simple, small test cases
"""
transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
parse_transitions.check_transitions(transitions, other, "test")
# This will get a pass because it is a unary made out of existing unaries
other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")}
parse_transitions.check_transitions(transitions, other, "test")
# This should fail
with pytest.raises(RuntimeError):
other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")}
parse_transitions.check_transitions(transitions, other, "test")
================================================
FILE: stanza/tests/constituency/test_parse_tree.py
================================================
import pytest
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency import tree_reader
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_leaf_preterminal():
foo = Tree(label="foo")
assert foo.is_leaf()
assert not foo.is_preterminal()
assert len(foo.children) == 0
assert str(foo) == 'foo'
bar = Tree(label="bar", children=foo)
assert not bar.is_leaf()
assert bar.is_preterminal()
assert len(bar.children) == 1
assert str(bar) == "(bar foo)"
baz = Tree(label="baz", children=[bar])
assert not baz.is_leaf()
assert not baz.is_preterminal()
assert len(baz.children) == 1
assert str(baz) == "(baz (bar foo))"
def test_yield_preterminals():
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
trees = tree_reader.read_trees(text)
preterminals = list(trees[0].yield_preterminals())
assert len(preterminals) == 3
assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]"
def test_depth():
text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
trees = tree_reader.read_trees(text)
assert trees[0].depth() == 0
assert trees[1].depth() == 4
def test_unique_labels():
"""
Test getting the unique labels from a tree
Assumes tree_reader works, which should be fine since it is tested elsewhere
"""
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
labels = Tree.get_unique_constituent_labels(trees)
expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']
assert labels == expected
def test_unique_tags():
"""
Test getting the unique tags from a tree
"""
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
tags = Tree.get_unique_tags(trees)
expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']
assert tags == expected
def test_unique_words():
"""
Test getting the unique words from a tree
"""
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
words = Tree.get_unique_words(trees)
expected = ['?', 'Who', 'in', 'seat', 'sits', 'this']
assert words == expected
def test_rare_words():
"""
Test getting the unique words from a tree
"""
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
trees = tree_reader.read_trees(text)
words = Tree.get_rare_words(trees, 0.5)
expected = ['Who', 'in', 'sits']
assert words == expected
def test_common_words():
"""
Test getting the unique words from a tree
"""
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
trees = tree_reader.read_trees(text)
words = Tree.get_common_words(trees, 3)
expected = ['?', 'seat', 'this']
assert words == expected
def test_root_labels():
text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
assert ["ROOT"] == Tree.get_root_labels(trees)
text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))")
trees = tree_reader.read_trees(text)
assert ["ROOT"] == Tree.get_root_labels(trees)
text="(FOO) (BAR)"
trees = tree_reader.read_trees(text)
assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
def test_prune_none():
text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node
"((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes
"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead
expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))",
"(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))",
"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"]
for t, e in zip(text, expected):
trees = tree_reader.read_trees(t)
assert len(trees) == 1
tree = trees[0].prune_none()
assert e == str(tree)
def test_simplify_labels():
text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
trees = tree_reader.read_trees(text)
trees = [t.simplify_labels() for t in trees]
assert len(trees) == 1
assert expected == str(trees[0])
def test_remap_constituent_labels():
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
label_map = { "SBARQ": "FOO" }
trees = tree_reader.read_trees(text)
trees = [t.remap_constituent_labels(label_map) for t in trees]
assert len(trees) == 1
assert expected == str(trees[0])
def test_remap_constituent_words():
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
word_map = { "Who": "unban", "sits": "mox", "in": "opal" }
trees = tree_reader.read_trees(text)
trees = [t.remap_words(word_map) for t in trees]
assert len(trees) == 1
assert expected == str(trees[0])
def test_replace_words():
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
new_words = ["unban", "mox", "opal", "?"]
trees = tree_reader.read_trees(text)
assert len(trees) == 1
tree = trees[0]
new_tree = tree.replace_words(new_words)
assert expected == str(new_tree)
def test_compound_constituents():
# TODO: add skinny trees like this to the various transition tests
text="((VP (VB Unban)))"
trees = tree_reader.read_trees(text)
assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
trees = tree_reader.read_trees(text)
assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)]
text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
trees = tree_reader.read_trees(text)
assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)]
def test_equals():
"""
Check one tree from the actual dataset for ==
when built with compound Open, this didn't work because of a silly bug
"""
text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
tree = trees[0]
assert tree == tree
trees2 = tree_reader.read_trees(text)
tree2 = trees2[0]
assert tree is not tree2
assert tree == tree2
# This tree was causing the model to barf on CTB7,
# although it turns out the problem was just the
# depth of the unary, not the list
CHINESE_LONG_LIST_TREE = """
(ROOT
(IP
(NP (NNP 证券法))
(VP
(PP
(IN 对)
(NP
(DNP
(NP
(NP (NNP 中国))
(NP
(NN 证券)
(NN 市场)))
(DEC 的))
(NP (NN 运作))))
(, ,)
(PP
(PP
(IN 从)
(NP
(NP (NN 股票))
(NP (VV 发行) (EC 、) (VV 交易))))
(, ,)
(PP
(VV 到)
(NP
(NP (NN 上市) (NN 公司) (NN 收购))
(EC 、)
(NP (NN 证券) (NN 交易所))
(EC 、)
(NP (NN 证券) (NN 公司))
(EC 、)
(NP (NN 登记) (NN 结算) (NN 机构))
(EC 、)
(NP (NN 交易) (NN 服务) (NN 机构))
(EC 、)
(NP (NN 证券业) (NN 协会))
(EC 、)
(NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))
(CC 和)
(NP
(DNP
(NP (CP (CP (IP (VP (VV 违法))))))
(DEC 的))
(NP (NN 法律) (NN 责任))))))
(ADVP (RB 都))
(VP
(VV 作)
(AS 了)
(NP
(ADJP (JJ 详细))
(NP (NN 规定)))))
(. 。)))
"""
WEIRD_UNARY = """
(DNP
(NP (CP (CP (IP (VP (ASDF
(NP (NN 上市) (NN 公司) (NN 收购))
(EC 、)
(NP (NN 证券) (NN 交易所))
(EC 、)
(NP (NN 证券) (NN 公司))
(EC 、)
(NP (NN 登记) (NN 结算) (NN 机构))
(EC 、)
(NP (NN 交易) (NN 服务) (NN 机构))
(EC 、)
(NP (NN 证券业) (NN 协会))
(EC 、)
(NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))))))))
(DEC 的))
"""
def test_count_unaries():
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
assert len(trees) == 1
assert trees[0].count_unary_depth() == 5
trees = tree_reader.read_trees(WEIRD_UNARY)
assert len(trees) == 1
assert trees[0].count_unary_depth() == 5
def test_str_bracket_labels():
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
expected = "(_ROOT (_S (_VP (_VB Unban )_VB )_VP (_NP (_NNP Mox )_NNP (_NNP Opal )_NNP )_NP )_S )_ROOT"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert "{:L}".format(trees[0]) == expected
def test_all_leaves_are_preterminals():
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert trees[0].all_leaves_are_preterminals()
text = "((S (VP (VB Unban)) (NP (Mox) (NNP Opal))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert not trees[0].all_leaves_are_preterminals()
def test_latex():
"""
Test the latex format for trees
"""
expected = "\\Tree [.S [.NP Jennifer ] [.VP has [.NP nice antennae ] ] ]"
tree = "(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ nice) (NNS antennae)))))"
tree = tree_reader.read_trees(tree)[0]
text = "{:T}".format(tree)
assert text == expected
def test_pretty_print():
"""
Pretty print a couple trees - newlines & indentation
"""
text = "(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) (ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric)))))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 2
expected = """(ROOT
(S
(VP (VB Unban))
(NP (NNP Mox) (NNP Opal))))
"""
assert "{:P}".format(trees[0]) == expected
expected = """(ROOT
(S
(NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission))
(VP
(VBD authorized)
(NP
(NP
(DT an)
(ADJP (CD 11.5))
(NN %)
(NN rate)
(NN increase))
(PP
(IN at)
(NP (NNP Tucson) (NNP Electric)))))))
"""
assert "{:P}".format(trees[1]) == expected
assert text == "{:O} {:O}".format(*trees)
def test_reverse():
text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
reversed_tree = trees[0].reverse()
assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))"
================================================
FILE: stanza/tests/constituency/test_positional_encoding.py
================================================
import pytest
import torch
from stanza import Pipeline
from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_positional_encoding():
encoding = SinusoidalEncoding(model_dim=10, max_len=6)
foo = encoding(torch.tensor([5]))
assert foo.shape == (1, 10)
# TODO: check the values
def test_resize():
encoding = SinusoidalEncoding(model_dim=10, max_len=3)
foo = encoding(torch.tensor([5]))
assert foo.shape == (1, 10)
def test_arange():
encoding = SinusoidalEncoding(model_dim=10, max_len=2)
foo = encoding(torch.arange(4))
assert foo.shape == (4, 10)
assert encoding.max_len() == 4
def test_add():
encoding = AddSinusoidalEncoding(d_model=10, max_len=4)
x = torch.zeros(1, 4, 10)
y = encoding(x)
r = torch.randn(1, 4, 10)
r2 = encoding(r)
assert torch.allclose(r2 - r, y, atol=1e-07)
r = torch.randn(2, 4, 10)
r2 = encoding(r)
assert torch.allclose(r2[0] - r[0], y, atol=1e-07)
assert torch.allclose(r2[1] - r[1], y, atol=1e-07)
================================================
FILE: stanza/tests/constituency/test_selftrain_vi_quad.py
================================================
"""
Test some of the methods in the vi_quad dataset
Uses a small section of the dataset as a test
"""
import pytest
from stanza.utils.datasets.constituency import selftrain_vi_quad
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
SAMPLE_TEXT = """
{"version": "1.1", "data": [{"title": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng", "paragraphs": [{"qas": [{"question": "T\u00ean g\u1ecdi n\u00e0o \u0111\u01b0\u1ee3c Ph\u1ea1m V\u0103n \u0110\u1ed3ng s\u1eed d\u1ee5ng khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m?", "answers": [{"answer_start": 507, "text": "L\u00e2m B\u00e1 Ki\u1ec7t"}], "id": "uit_01__05272_0_1"}, {"question": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec trong b\u1ed9 m\u00e1y Nh\u00e0 n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam?", "answers": [{"answer_start": 60, "text": "Th\u1ee7 t\u01b0\u1edbng"}], "id": "uit_01__05272_0_2"}, {"question": "Giai \u0111o\u1ea1n n\u0103m 1955-1976, Ph\u1ea1m V\u0103n \u0110\u1ed3ng n\u1eafm gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec?", "answers": [{"answer_start": 245, "text": "Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a"}], "id": "uit_01__05272_0_3"}], "context": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng (1 th\u00e1ng 3 n\u0103m 1906 \u2013 29 th\u00e1ng 4 n\u0103m 2000) l\u00e0 Th\u1ee7 t\u01b0\u1edbng \u0111\u1ea7u ti\u00ean c\u1ee7a n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam t\u1eeb n\u0103m 1976 (t\u1eeb n\u0103m 1981 g\u1ecdi l\u00e0 Ch\u1ee7 t\u1ecbch H\u1ed9i \u0111\u1ed3ng B\u1ed9 tr\u01b0\u1edfng) cho \u0111\u1ebfn khi ngh\u1ec9 h\u01b0u n\u0103m 1987. Tr\u01b0\u1edbc \u0111\u00f3 \u00f4ng t\u1eebng gi\u1eef ch\u1ee9c v\u1ee5 Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a t\u1eeb n\u0103m 1955 \u0111\u1ebfn n\u0103m 1976. \u00d4ng l\u00e0 v\u1ecb Th\u1ee7 t\u01b0\u1edbng Vi\u1ec7t Nam t\u1ea1i v\u1ecb l\u00e2u nh\u1ea5t (1955\u20131987). \u00d4ng l\u00e0 h\u1ecdc tr\u00f2, c\u1ed9ng s\u1ef1 c\u1ee7a Ch\u1ee7 t\u1ecbch H\u1ed3 Ch\u00ed Minh. \u00d4ng c\u00f3 t\u00ean g\u1ecdi th\u00e2n m\u1eadt l\u00e0 T\u00f4, \u0111\u00e2y t\u1eebng l\u00e0 b\u00ed danh c\u1ee7a \u00f4ng. \u00d4ng c\u00f2n c\u00f3 t\u00ean g\u1ecdi l\u00e0 L\u00e2m B\u00e1 Ki\u1ec7t khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m (Ch\u1ee7 nhi\u1ec7m l\u00e0 H\u1ed3 H\u1ecdc L\u00e3m)."}, {"qas": [{"question": "S\u1ef1 ki\u1ec7n quan tr\u1ecdng n\u00e0o \u0111\u00e3 di\u1ec5n ra v\u00e0o ng\u00e0y 20/7/1954?", "answers": [{"answer_start": 364, "text": "b\u1ea3n Hi\u1ec7p \u0111\u1ecbnh \u0111\u00ecnh ch\u1ec9 chi\u1ebfn s\u1ef1 \u1edf Vi\u1ec7t Nam, Campuchia v\u00e0 L\u00e0o \u0111\u00e3 \u0111\u01b0\u1ee3c k\u00fd k\u1ebft th\u1eeba nh\u1eadn t\u00f4n tr\u1ecdng \u0111\u1ed9c l\u1eadp, ch\u1ee7 quy\u1ec1n, c\u1ee7a n\u01b0\u1edbc Vi\u1ec7t Nam, L\u00e0o v\u00e0 Campuchia"}], "id": "uit_01__05272_1_1"}, {"question": "Ch\u1ee9c v\u1ee5 m\u00e0 Ph\u1ea1m V\u0103n \u0110\u1ed3ng \u0111\u1ea3m nhi\u1ec7m t\u1ea1i H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng?", "answers": [{"answer_start": 33, "text": "Tr\u01b0\u1edfng ph\u00e1i \u0111o\u00e0n Ch\u00ednh ph\u1ee7"}], "id": "uit_01__05272_1_2"}, {"question": "H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng c\u00f3 t\u00ednh ch\u1ea5t nh\u01b0 th\u1ebf n\u00e0o?", "answers": [{"answer_start": 262, "text": "r\u1ea5t c\u0103ng th\u1eb3ng v\u00e0 ph\u1ee9c t\u1ea1p"}], "id": "uit_01__05272_1_3"}]}]}]}
"""
EXPECTED = ['Tên gọi nào được Phạm Văn Đồng sử dụng khi làm Phó chủ nhiệm cơ quan Biện sự xứ tại Quế Lâm?', 'Phạm Văn Đồng giữ chức vụ gì trong bộ máy Nhà nước Cộng hòa Xã hội chủ nghĩa Việt Nam?', 'Giai đoạn năm 1955-1976, Phạm Văn Đồng nắm giữ chức vụ gì?', 'Sự kiện quan trọng nào đã diễn ra vào ngày 20/7/1954?', 'Chức vụ mà Phạm Văn Đồng đảm nhiệm tại Hội nghị Genève về Đông Dương?', 'Hội nghị Genève về Đông Dương có tính chất như thế nào?']
def test_read_file():
results = selftrain_vi_quad.parse_quad(SAMPLE_TEXT)
assert results == EXPECTED
================================================
FILE: stanza/tests/constituency/test_text_processing.py
================================================
"""
Run through the various text processing methods for using the parser on text files / directories
Uses a simple tree where the parser should always get it right, but things could potentially go wrong
"""
import glob
import os
import pytest
from stanza import Pipeline
from stanza.models.constituency import text_processing
from stanza.models.constituency import tree_reader
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def pipeline():
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
def test_read_tokenized_file(tmp_path):
filename = str(tmp_path / "test_input.txt")
with open(filename, "w") as fout:
# test that the underscore token comes back with spaces
fout.write("This is a_small test\nLine two\n")
text, ids = text_processing.read_tokenized_file(filename)
assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']]
assert ids == [None, None]
def test_parse_tokenized_sentences(pipeline):
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = model.args
sentences = [["This", "is", "a", "test"]]
trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences)
predictions = [x.predictions for x in trees]
assert len(predictions) == 1
scored_trees = predictions[0]
assert len(scored_trees) == 1
result = "{}".format(scored_trees[0].tree)
expected = "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
assert result == expected
def test_parse_text(tmp_path, pipeline):
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = model.args
raw_file = str(tmp_path / "test_input.txt")
with open(raw_file, "w") as fout:
fout.write("This is a test\nThis is another test\n")
output_file = str(tmp_path / "test_output.txt")
text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file)
trees = tree_reader.read_treebank(output_file)
trees = ["{}".format(x) for x in trees]
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
assert trees == expected_trees
def test_parse_dir(tmp_path, pipeline):
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = model.args
raw_dir = str(tmp_path / "input")
os.makedirs(raw_dir)
raw_f1 = str(tmp_path / "input" / "f1.txt")
raw_f2 = str(tmp_path / "input" / "f2.txt")
output_dir = str(tmp_path / "output")
with open(raw_f1, "w") as fout:
fout.write("This is a test")
with open(raw_f2, "w") as fout:
fout.write("This is another test")
text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir)
output_files = sorted(glob.glob(os.path.join(output_dir, "*")))
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
for output_file, expected_tree in zip(output_files, expected_trees):
trees = tree_reader.read_treebank(output_file)
assert len(trees) == 1
assert "{}".format(trees[0]) == expected_tree
def test_parse_text(tmp_path, pipeline):
con_processor = pipeline.processors["constituency"]
model = con_processor._model
args = dict(model.args)
model_path = con_processor._config['model_path']
raw_file = str(tmp_path / "test_input.txt")
with open(raw_file, "w") as fout:
fout.write("This is a test\nThis is another test\n")
output_file = str(tmp_path / "test_output.txt")
args['tokenized_file'] = raw_file
args['predict_file'] = output_file
text_processing.load_model_parse_text(args, model_path, [pipeline])
trees = tree_reader.read_treebank(output_file)
trees = ["{}".format(x) for x in trees]
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
assert trees == expected_trees
================================================
FILE: stanza/tests/constituency/test_top_down_oracle.py
================================================
import pytest
from stanza.models.constituency.base_model import SimpleModel
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme
from stanza.models.constituency.top_down_oracle import *
from stanza.models.constituency.transition_sequence import build_sequence
from stanza.models.constituency.tree_reader import read_trees
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
OPEN_SHIFT_EXAMPLE_TREE = """
( (S
(NP (NNP Jennifer) (NNP Sh\'reyan))
(VP (VBZ has)
(NP (RB nice) (NNS antennae)))))
"""
OPEN_SHIFT_PROBLEM_TREE = """
(ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .)))
"""
ROOT_LABELS = ["ROOT"]
def get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs):
return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs)
def build_state(model, tree, num_transitions):
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
states = model.initial_state_from_gold_trees([tree], [transitions])
for idx, t in enumerate(transitions[:num_transitions]):
assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
states = model.bulk_apply(states, [t])
state = states[0]
return state
def test_fix_open_shift():
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == EXPECTED_ORIG
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
assert new_transitions == EXPECTED_FIX_EARLY
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8)
assert new_transitions == EXPECTED_FIX_LATE
def test_fix_open_shift_observed_error():
"""
Ran into an error on this tree, need to fix it
The problem is the multiple Open in a row all need to be removed when a Shift happens
"""
trees = read_trees(OPEN_SHIFT_PROBLEM_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
assert new_transitions is None
new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2)
# Can break the expected transitions down like this:
# [OpenConstituent(('ROOT',)), OpenConstituent(('S',)),
# all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)),
# Shift, Shift, Shift, Shift, Shift, Shift,
# gone: CloseConstituent,
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent,
# gone: CloseConstituent,
# Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift,
# gone: CloseConstituent,
# and then the rest:
# OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)),
# Shift, Shift, Shift, Shift, CloseConstituent,
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
# OpenConstituent(('NP',)), Shift, Shift, Shift, Shift,
# CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent,
# CloseConstituent, OpenConstituent(('SBAR',)), Shift,
# OpenConstituent(('S',)), OpenConstituent(('NP',)),
# OpenConstituent(('NP',)), Shift, Shift, CloseConstituent,
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
# Shift, Shift, CloseConstituent, CloseConstituent,
# CloseConstituent, OpenConstituent(('VP',)), Shift,
# OpenConstituent(('VP',)), Shift, CloseConstituent,
# CloseConstituent, CloseConstituent, CloseConstituent,
# CloseConstituent, Shift, CloseConstituent, CloseConstituent]
expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
assert new_transitions == expected_transitions
def test_open_open_ambiguous_unary_fix():
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == EXPECTED_ORIG
new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2)
assert new_transitions == EXPECTED_FIX
def test_open_open_ambiguous_later_fix():
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == EXPECTED_ORIG
new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2)
assert new_transitions == EXPECTED_FIX
CLOSE_SHIFT_EXAMPLE_TREE = """
( (NP (DT a)
(ADJP (NN stock) (HYPH -) (VBG picking))
(NN tool)))
"""
# not intended to be a correct tree
CLOSE_SHIFT_DEEP_EXAMPLE_TREE = """
( (NP (DT a)
(VP (ADJP (NN stock) (HYPH -) (VBG picking)))
(NN tool)))
"""
# not intended to be a correct tree
CLOSE_SHIFT_OPEN_EXAMPLE_TREE = """
( (NP (DT a)
(ADJP (NN stock) (HYPH -) (VBG picking))
(NP (NN tool))))
"""
CLOSE_SHIFT_AMBIGUOUS_TREE = """
( (NP (DT a)
(ADJP (NN stock) (HYPH -) (VBG picking))
(NN tool)
(NN foo)))
"""
def test_fix_close_shift_ambiguous_immediate():
"""
Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately
"""
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_sequence == expected_update
def test_fix_close_shift_ambiguous_later():
# test that the one with two shifts, which is ambiguous, gets rejected
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_sequence == expected_update
def test_oracle_with_optional_level():
tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0]
gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
assert transitions == gold_sequence
oracle = TopDownOracle(ROOT_LABELS, 1, "", "")
model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS)
state = build_state(model, tree, 7)
fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
model=model,
state=state)
assert fix is RepairType.OTHER_CLOSE_SHIFT
assert new_sequence is None
oracle = TopDownOracle(ROOT_LABELS, 1, "CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR", "")
fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
model=model,
state=state)
assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR
assert new_sequence == expected_update
def test_fix_close_shift():
"""
Test a tree of the kind we expect the close/shift to be able to get right
"""
trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_sequence == expected_update
# test that the one with two shifts, which is ambiguous, gets rejected
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
assert new_sequence is None
def test_fix_close_shift_deeper_tree():
"""
Test a tree of the kind we expect the close/shift to be able to get right
"""
trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
for count_opens in [True, False]:
new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_sequence == expected_update
def test_fix_close_shift_open_tree():
"""
We would like the close/shift to get this case right as well
"""
trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False)
assert new_sequence is None
new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_sequence == expected_update
CLOSE_OPEN_EXAMPLE_TREE = """
( (VP (VBZ eat)
(NP (NN spaghetti))
(PP (IN with) (DT a) (NN fork))))
"""
CLOSE_OPEN_DIFFERENT_LABEL_TREE = """
( (VP (VBZ eat)
(NP (NN spaghetti))
(NP (DT a) (NN fork))))
"""
CLOSE_OPEN_TWO_LABELS_TREE = """
( (VP (VBZ eat)
(NP (NN spaghetti))
(PP (IN with) (DT a) (NN fork))
(PP (IN in) (DT a) (NN restaurant))))
"""
def test_fix_close_open():
trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
assert isinstance(transitions[5], CloseConstituent)
assert transitions[6] == OpenConstituent("PP")
new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
expected_update = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
assert new_transitions == expected_update
def test_fix_close_open_invalid():
for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE):
trees = read_trees(TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
assert isinstance(transitions[5], CloseConstituent)
assert isinstance(transitions[6], OpenConstituent)
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5)
assert new_transitions is None
def test_fix_close_open_ambiguous_immediate():
"""
Test that a fix for an ambiguous close/open works as expected
"""
trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
assert isinstance(transitions[5], CloseConstituent)
assert isinstance(transitions[6], OpenConstituent)
reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
assert tree == reconstructed
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5, check_close=False)
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
expected = """
( (VP (VBZ eat)
(NP (NN spaghetti)
(PP (IN with) (DT a) (NN fork)))
(PP (IN in) (DT a) (NN restaurant))))
"""
expected = read_trees(expected)[0]
assert reconstructed == expected
def test_fix_close_open_ambiguous_later():
"""
Test that a fix for an ambiguous close/open works as expected
"""
trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
assert isinstance(transitions[5], CloseConstituent)
assert isinstance(transitions[6], OpenConstituent)
reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
assert tree == reconstructed
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False)
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
expected = """
( (VP (VBZ eat)
(NP (NN spaghetti)
(PP (IN with) (DT a) (NN fork))
(PP (IN in) (DT a) (NN restaurant)))))
"""
expected = read_trees(expected)[0]
assert reconstructed == expected
SHIFT_CLOSE_EXAMPLES = [
("((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))", "((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))", 8),
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
"((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", 6),
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
"((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))", 8),
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))", 13),
]
def test_shift_close():
for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES):
trees = read_trees(orig_tree)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
if shift_position is None:
print(transitions)
continue
assert isinstance(transitions[shift_position], Shift)
new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position)
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
if expected_tree is None:
print(transitions)
print(new_transitions)
print("{:P}".format(reconstructed))
else:
expected_tree = read_trees(expected_tree)
assert len(expected_tree) == 1
expected_tree = expected_tree[0]
assert reconstructed == expected_tree
def test_shift_open_ambiguous_unary():
"""
Test what happens if a Shift is turned into an Open in an ambiguous manner
"""
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_unary, 4)
expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
assert new_sequence == expected_updated
def test_shift_open_ambiguous_later():
"""
Test what happens if a Shift is turned into an Open in an ambiguous manner
"""
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
assert len(trees) == 1
tree = trees[0]
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
assert transitions == expected_original
new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_later, 4)
expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
assert new_sequence == expected_updated
================================================
FILE: stanza/tests/constituency/test_trainer.py
================================================
from collections import defaultdict
import logging
import pathlib
import tempfile
import pytest
import torch
from torch import nn
from torch import optim
from stanza import Pipeline
from stanza.models import constituency_parser
from stanza.models.common import pretrain
from stanza.models.common.bert_embedding import load_bert, load_tokenizer
from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.common.utils import set_random_seed
from stanza.models.constituency import lstm_model
from stanza.models.constituency.parse_transitions import Transition
from stanza.models.constituency import parser_training
from stanza.models.constituency import trainer
from stanza.models.constituency import tree_reader
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
logger = logging.getLogger('stanza.constituency.trainer')
logger.setLevel(logging.WARNING)
TREEBANK = """
( (S
(VP (VBG Enjoying)
(NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition)))
(. .)))
( (NP
(VP (VBG Sitting)
(PP (IN in)
(NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station)))
(VP (VBG waiting)
(PP (IN for)
(NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train)))))
(. .)))
( (S
(NP (PRP I))
(VP
(ADVP (RB really))
(VBP hate)
(NP (DT the) (NNP @MBTA)))))
( (S
(S (VP (VB Seek)))
(CC and)
(S (NP (PRP ye))
(VP (MD shall)
(VP (VB find))))
(. .)))
"""
def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):
# TODO: build a fake embedding some other way?
train_trees = tree_reader.read_trees(treebank)
dev_trees = train_trees[-1:]
silver_trees = []
args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args)
args = constituency_parser.parse_args(args)
foundation_cache = FoundationCache()
# might be None, unless we're testing loading an existing model
model_load_name = args['load_name']
model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name)
assert isinstance(model.model, lstm_model.LSTMModel)
return model
class TestTrainer:
@pytest.fixture(scope="class")
def wordvec_pretrain_file(self):
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
@pytest.fixture(scope="class")
def tiny_random_xlnet(self, tmp_path_factory):
"""
Download the tiny-random-xlnet model and make a concrete copy of it
The issue here is that the "random" nature of the original
makes it difficult or impossible to test that the values in
the transformer don't change during certain operations.
Saving a concrete instantiation of those random numbers makes
it so we can test there is no difference when training only a
subset of the layers, for example
"""
xlnet_name = 'hf-internal-testing/tiny-random-xlnet'
xlnet_model, xlnet_tokenizer = load_bert(xlnet_name)
path = str(tmp_path_factory.mktemp('tiny-random-xlnet'))
xlnet_model.save_pretrained(path)
xlnet_tokenizer.save_pretrained(path)
return path
@pytest.fixture(scope="class")
def tiny_random_bart(self, tmp_path_factory):
"""
Download the tiny-random-bart model and make a concrete copy of it
Issue is the same as with tiny_random_xlnet
"""
bart_name = 'hf-internal-testing/tiny-random-bart'
bart_model, bart_tokenizer = load_bert(bart_name)
path = str(tmp_path_factory.mktemp('tiny-random-bart'))
bart_model.save_pretrained(path)
bart_tokenizer.save_pretrained(path)
return path
def test_initial_model(self, wordvec_pretrain_file):
"""
does nothing, just tests that the construction went okay
"""
args = ['wordvec_pretrain_file', wordvec_pretrain_file]
build_trainer(wordvec_pretrain_file)
def test_save_load_model(self, wordvec_pretrain_file):
"""
Just tests that saving and loading works without crashs.
Currently no test of the values themselves
(checks some fields to make sure they are regenerated correctly)
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
tr = build_trainer(wordvec_pretrain_file)
transitions = tr.model.transitions
# attempt saving
filename = os.path.join(tmpdirname, "parser.pt")
tr.save(filename)
assert os.path.exists(filename)
# load it back in
tr2 = tr.load(filename)
trans2 = tr2.model.transitions
assert(transitions == trans2)
assert all(isinstance(x, Transition) for x in trans2)
def test_relearn_structure(self, wordvec_pretrain_file):
"""
Test that starting a trainer with --relearn_structure copies the old model
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
set_random_seed(1000)
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
tr = build_trainer(wordvec_pretrain_file, *args)
# attempt saving
filename = os.path.join(tmpdirname, "parser.pt")
tr.save(filename)
set_random_seed(1001)
args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename]
tr2 = build_trainer(wordvec_pretrain_file, *args)
assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight)
assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight)
# the norms will be the same, as the non-zero values are all the same
assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0))
def write_treebanks(self, tmpdirname):
train_treebank_file = os.path.join(tmpdirname, "train.mrg")
with open(train_treebank_file, 'w', encoding='utf-8') as fout:
fout.write(TREEBANK)
fout.write(TREEBANK)
eval_treebank_file = os.path.join(tmpdirname, "eval.mrg")
with open(eval_treebank_file, 'w', encoding='utf-8') as fout:
fout.write(TREEBANK)
return train_treebank_file, eval_treebank_file
def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args):
# let's not make the model huge...
args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10',
'--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname,
'--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'),
'--train_file', train_treebank_file, '--eval_file', eval_treebank_file,
'--epoch_size', '6', '--train_batch_size', '3',
'--shorthand', 'en_test']
args = args + list(additional_args)
args = constituency_parser.parse_args(args)
# just in case we change the defaults in the future
args['wandb'] = None
return args
def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None):
"""
Runs a test of the trainer for a few iterations.
Checks some basic properties of the saved model, but doesn't
check for the accuracy of the results
"""
if extra_args is None:
extra_args = []
extra_args += ['--epochs', '%d' % num_epochs]
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
if use_silver:
extra_args += ['--silver_file', str(eval_treebank_file)]
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
each_name = args['save_each_name']
if not exists_ok:
assert not os.path.exists(args['save_name'])
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache, download_method=None)
trained_model = parser_training.train(args, None, [retag_pipeline])
# check that hooks are in the model if expected
for p in trained_model.model.parameters():
if p.requires_grad:
if args['grad_clipping'] is not None:
assert len(p._backward_hooks) == 1
else:
assert p._backward_hooks is None
# check that the model can be loaded back
assert os.path.exists(args['save_name'])
peft_name = trained_model.model.peft_name
tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained >= 1
for p in tr.model.parameters():
if p.requires_grad:
assert p._backward_hooks is None
tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained == num_epochs
for i in range(1, num_epochs+1):
model_name = each_name % i
assert os.path.exists(model_name)
tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
assert tr.epochs_trained == i
assert tr.batches_trained == (4 * i if use_silver else 2 * i)
return args, trained_model
def test_train(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations on the fake data
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname)
def test_early_dropout(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations on the fake data
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args = ['--early_dropout', '3']
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
model = model.model
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
for name, module in dropouts:
assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout"
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
# test that when turned off, early_dropout doesn't happen
args = ['--early_dropout', '-1']
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
model = model.model
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
if all(module.p == 0.0 for _, module in dropouts):
raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1")
def test_train_silver(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations on the fake data
This tests that it works if you give it a silver file
The check for the use of the silver data is that the
number of batches trained should go up
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)
def test_train_checkpoint(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations, then restart
This tests that the 5th iteration save file is not rewritten
and that the iterations continue to 10
TODO: could make it more robust by verifying that only 5 more
epochs are trained. Perhaps a "most recent epochs" could be
saved in the trainer
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False)
save_5 = args['save_each_name'] % 5
save_10 = args['save_each_name'] % 10
assert os.path.exists(save_5)
assert not os.path.exists(save_10)
save_5_stat = pathlib.Path(save_5).stat()
self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True)
assert os.path.exists(save_5)
assert os.path.exists(save_10)
assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime
def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = ['--multistage', '--pattn_num_layers', '1']
if use_lattn:
args += ['--lattn_d_proj', '16']
if extra_args:
args += extra_args
args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
word_input_sizes = defaultdict(list)
for i in range(1, 9):
model_name = each_name % i
assert os.path.exists(model_name)
tr = trainer.Trainer.load(model_name, load_optimizer=True)
assert tr.epochs_trained == i
word_input_sizes[tr.model.word_input_size].append(i)
if use_lattn:
# there should be three stages: no attn, pattn, pattn+lattn
assert len(word_input_sizes) == 3
word_input_keys = sorted(word_input_sizes.keys())
assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
assert word_input_sizes[word_input_keys[1]] == [4, 5]
assert word_input_sizes[word_input_keys[2]] == [6, 7, 8]
else:
# with no lattn, there are two stages: no attn, pattn
assert len(word_input_sizes) == 2
word_input_keys = sorted(word_input_sizes.keys())
assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8]
def test_multistage_lattn(self, wordvec_pretrain_file):
"""
Test a multistage training for a few iterations on the fake data
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)
def test_multistage_no_lattn(self, wordvec_pretrain_file):
"""
Test a multistage training for a few iterations on the fake data
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)
def test_multistage_optimizer(self, wordvec_pretrain_file):
"""
Test that the correct optimizers are built for a multistage training process
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
extra_args = ['--optim', 'adamw']
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)
# check that the optimizers which get rebuilt when loading
# the models are adadelta for the first half of the
# multistage, then adamw
each_name = os.path.join(tmpdirname, 'each_%02d.pt')
for i in range(1, 3):
model_name = each_name % i
tr = trainer.Trainer.load(model_name, load_optimizer=True)
assert tr.epochs_trained == i
assert isinstance(tr.optimizer, optim.Adadelta)
# double check that this is actually a valid test
assert not isinstance(tr.optimizer, optim.AdamW)
for i in range(4, 8):
model_name = each_name % i
tr = trainer.Trainer.load(model_name, load_optimizer=True)
assert tr.epochs_trained == i
assert isinstance(tr.optimizer, optim.AdamW)
def test_grad_clip_hooks(self, wordvec_pretrain_file):
"""
Verify that grad clipping is not saved with the model, but is attached at training time
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args = ['--grad_clipping', '25']
self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
def test_analyze_trees(self, wordvec_pretrain_file):
test_str = "(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))))) (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))"
test_tree = tree_reader.read_trees(test_str)
assert len(test_tree) == 2
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
tr = build_trainer(wordvec_pretrain_file, *args)
results = tr.model.analyze_trees(test_tree)
assert len(results) == 2
assert len(results[0].predictions) == 1
assert results[0].predictions[0].tree == test_tree[0]
assert results[0].state is not None
assert isinstance(results[0].state.score, torch.Tensor)
assert results[0].state.score.shape == torch.Size([])
assert len(results[0].constituents) == 9
assert results[0].constituents[-1].value == test_tree[0]
# the way the results are built, the next-to-last entry
# should be the thing just below the root
assert results[0].constituents[-2].value == test_tree[0].children[0]
assert len(results[1].predictions) == 1
assert results[1].predictions[0].tree == test_tree[1]
assert results[1].state is not None
assert isinstance(results[1].state.score, torch.Tensor)
assert results[1].state.score.shape == torch.Size([])
assert len(results[1].constituents) == 4
assert results[1].constituents[-1].value == test_tree[1]
assert results[1].constituents[-2].value == test_tree[1].children[0]
def bert_weights_allclose(self, bert_model, parser_model):
"""
Return True if all bert weights are close, False otherwise
"""
for name, parameter in bert_model.named_parameters():
other_name = "bert_model." + name
other_parameter = parser_model.model.get_parameter(other_name)
if not torch.allclose(parameter.cpu(), other_parameter.cpu()):
return False
return True
def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name):
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
foundation_cache = FoundationCache()
args = ['--bert_model', transformer_name]
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache)
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
assert self.bert_weights_allclose(bert_model, trained_model)
checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True)
params = checkpoint['params']
# check that the bert model wasn't saved in the model
assert all(not x.startswith("bert_model.") for x in params['model'].keys())
# make sure we're looking at the right thing
assert any(x.startswith("output_layers.") for x in params['model'].keys())
# check that the cached model is used as expected when loading a bert model
trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache)
assert trained_model.model.bert_model is bert_model
def test_bert_frozen(self, wordvec_pretrain_file):
"""
Check that the parameters of the bert model don't change when training a basic model
"""
self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet):
"""
Check that the parameters of an xlnet model don't change when training a basic model
"""
self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart):
"""
Check that the parameters of an xlnet model don't change when training a basic model
"""
self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart)
def test_bert_finetune_one_epoch(self, wordvec_pretrain_file):
"""
Check that the parameters the bert model DO change over a single training step
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
transformer_name = 'hf-internal-testing/tiny-bert'
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta']
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args)
# check that the weights are different
foundation_cache = FoundationCache()
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
assert not self.bert_weights_allclose(bert_model, trained_model)
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
model_name = args['save_name']
assert os.path.exists(model_name)
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
assert tr.model.bert_model is not bert_model
assert not self.bert_weights_allclose(bert_model, tr)
assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
new_save_name = os.path.join(tmpdirname, "test_resave_bert.pt")
assert not os.path.exists(new_save_name)
tr.save(new_save_name, save_optimizer=False)
tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache)
# check that the resaved model included its finetuned bert weights
assert tr2.model.bert_model is not bert_model
# the finetuned bert weights should also be scheduled for saving the next time as well
assert not tr2.model.is_unsaved_module("bert_model")
def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
"""
Check that the parameters of the transformer DO change when using bert_finetune
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw']
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
# check that the weights are different
foundation_cache = FoundationCache()
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
assert not self.bert_weights_allclose(bert_model, trained_model)
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache)
assert not trained_model.model.args['bert_finetune']
assert not trained_model.model.args['stage1_bert_finetune']
assert trained_model.model.bert_model is not bert_model
def test_bert_finetune(self, wordvec_pretrain_file):
"""
Check that the parameters of a bert model DO change when using bert_finetune
"""
self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet):
"""
Check that the parameters of an xlnet model DO change when using bert_finetune
"""
self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
def test_stage1_bert_finetune(self, wordvec_pretrain_file):
"""
Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
bert_model_name = 'hf-internal-testing/tiny-bert'
args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw']
# need to use num_epochs==6 so that epochs 1 and 2 are saved to be different
# a test of 5 or less means that sometimes it will reload the params
# at step 2 to get ready for the following iterations with adamw
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
# check that the weights are different
foundation_cache = FoundationCache()
bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name)
assert not self.bert_weights_allclose(bert_model, trained_model)
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', bert_model_name, '--optim', 'adamw')
num_epochs = trained_model.model.args['epochs']
each_name = os.path.join(tmpdirname, 'each_%02d.pt')
for i in range(1, num_epochs+1):
model_name = each_name % i
assert os.path.exists(model_name)
tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
assert tr.model.bert_model is not bert_model
assert not self.bert_weights_allclose(bert_model, tr)
if i >= num_epochs // 2:
assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
# verify that models 1 and 2 are saved to be different
model_name_1 = each_name % 1
model_name_2 = each_name % 2
tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache)
tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache)
assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2)
def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
"""
Check that the parameters the bert model DO change when using bert_finetune
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1']
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
# check that the weights of the last layer are different,
# but the weights of the earlier layers and
# non-transformer-layers are the same
foundation_cache = FoundationCache()
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
assert bert_model.config.num_hidden_layers > 1
layer_name = "layer.%d." % (bert_model.config.num_hidden_layers - 1)
for name, parameter in bert_model.named_parameters():
other_name = "bert_model." + name
other_parameter = trained_model.model.get_parameter(other_name)
if layer_name in name:
if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name:
# not sure why this happens for xlnet, just roll with it
continue
assert not torch.allclose(parameter.cpu(), other_parameter.cpu())
else:
assert torch.allclose(parameter.cpu(), other_parameter.cpu())
def test_bert_finetune_one_layer(self, wordvec_pretrain_file):
self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet):
self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
def test_peft_finetune(self, tmp_path, wordvec_pretrain_file):
transformer_name = 'hf-internal-testing/tiny-bert'
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft']
args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args)
def test_peft_twostage_finetune(self, wordvec_pretrain_file):
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
num_epochs = 6
transformer_name = 'hf-internal-testing/tiny-bert'
args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft']
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args)
for epoch in range(num_epochs):
filename_prev = args['save_each_name'] % epoch
filename_next = args['save_each_name'] % (epoch+1)
trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False)
trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False)
lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find("lora") >= 0]
if epoch < 2:
assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
trainer_next.model.bert_model.get_parameter(name).cpu())
for name in lora_names)
elif epoch > 2:
assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
trainer_next.model.bert_model.get_parameter(name).cpu())
for name in lora_names)
================================================
FILE: stanza/tests/constituency/test_transformer_tree_stack.py
================================================
import pytest
import torch
from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_initial_state():
"""
Test that the initial state has the expected shapes
"""
ts = TransformerTreeStack(3, 5, 0.0)
initial = ts.initial_state()
assert len(initial) == 1
assert initial.value.output.shape == torch.Size([5])
assert initial.value.key_stack.shape == torch.Size([1, 5])
assert initial.value.value_stack.shape == torch.Size([1, 5])
def test_output():
"""
Test that you can get an expected output shape from the TTS
"""
ts = TransformerTreeStack(3, 5, 0.0)
initial = ts.initial_state()
out = ts.output(initial)
assert out.shape == torch.Size([5])
assert torch.allclose(initial.value.output, out)
def test_push_state_single():
"""
Test that stacks are being updated correctly when using a single stack
Values of the attention are not verified, though
"""
ts = TransformerTreeStack(3, 5, 0.0)
initial = ts.initial_state()
rand_input = torch.randn(1, 3)
stacks = ts.push_states([initial], ["A"], rand_input)
stacks = ts.push_states(stacks, ["B"], rand_input)
assert len(stacks) == 1
assert len(stacks[0]) == 3
assert stacks[0].value.value == "B"
assert stacks[0].pop().value.value == "A"
assert stacks[0].pop().pop().value.value is None
def test_push_state_same_length():
"""
Test that stacks are being updated correctly when using 3 stacks of the same length
Values of the attention are not verified, though
"""
ts = TransformerTreeStack(3, 5, 0.0)
initial = ts.initial_state()
rand_input = torch.randn(3, 3)
stacks = ts.push_states([initial, initial, initial], ["A", "A", "A"], rand_input)
stacks = ts.push_states(stacks, ["B", "B", "B"], rand_input)
stacks = ts.push_states(stacks, ["C", "C", "C"], rand_input)
assert len(stacks) == 3
for s in stacks:
assert len(s) == 4
assert s.value.key_stack.shape == torch.Size([4, 5])
assert s.value.value_stack.shape == torch.Size([4, 5])
assert s.value.value == "C"
assert s.pop().value.value == "B"
assert s.pop().pop().value.value == "A"
assert s.pop().pop().pop().value.value is None
def test_push_state_different_length():
"""
Test what happens if stacks of different lengths are passed in
"""
ts = TransformerTreeStack(3, 5, 0.0)
initial = ts.initial_state()
rand_input = torch.randn(2, 3)
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
stacks = [one_step, initial]
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
assert len(stacks) == 2
assert len(stacks[0]) == 3
assert len(stacks[1]) == 2
assert stacks[0].pop().value.value == 'A'
assert stacks[0].value.value == 'B'
assert stacks[1].value.value == 'C'
assert stacks[0].value.key_stack.shape == torch.Size([3, 5])
assert stacks[1].value.key_stack.shape == torch.Size([2, 5])
def test_mask():
"""
Test that a mask prevents the softmax from picking up unwanted values
"""
ts = TransformerTreeStack(3, 5, 0.0)
random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]])
double_v = random_v * 2
value = torch.cat([random_v, double_v], axis=1)
random_k = torch.randn(1, 1, 5)
key = torch.cat([random_k, random_k], axis=1)
query = torch.randn(1, 5)
output = ts.attention(key, query, value)
# when the two keys are equal, we expect the attention to be 50/50
expected_output = (random_v + double_v) / 2
assert torch.allclose(output, expected_output)
# If the first entry is masked out, the second one should be the
# only one represented
mask = torch.zeros(1, 2, dtype=torch.bool)
mask[0][0] = True
output = ts.attention(key, query, value, mask)
assert torch.allclose(output, double_v)
# If the second entry is masked out, the first one should be the
# only one represented
mask = torch.zeros(1, 2, dtype=torch.bool)
mask[0][1] = True
output = ts.attention(key, query, value, mask)
assert torch.allclose(output, random_v)
def test_position():
"""
Test that nothing goes horribly wrong when position encodings are used
Does not actually test the results of the encodings
"""
ts = TransformerTreeStack(4, 5, 0.0, use_position=True)
initial = ts.initial_state()
assert len(initial) == 1
assert initial.value.output.shape == torch.Size([5])
assert initial.value.key_stack.shape == torch.Size([1, 5])
assert initial.value.value_stack.shape == torch.Size([1, 5])
rand_input = torch.randn(2, 4)
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
stacks = [one_step, initial]
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
def test_length_limit():
"""
Test that the length limit drops nodes as the length limit is exceeded
"""
ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2)
initial = ts.initial_state()
assert len(initial) == 1
assert initial.value.output.shape == torch.Size([5])
assert initial.value.key_stack.shape == torch.Size([1, 5])
assert initial.value.value_stack.shape == torch.Size([1, 5])
data = torch.tensor([[0.1, 0.2, 0.3, 0.4]])
stacks = ts.push_states([initial], ["A"], data)
stacks = ts.push_states(stacks, ["B"], data)
assert len(stacks) == 1
assert len(stacks[0]) == 3
assert stacks[0].value.key_stack.shape[0] == 3
assert stacks[0].value.value_stack.shape[0] == 3
stacks = ts.push_states(stacks, ["C"], data)
assert len(stacks) == 1
assert len(stacks[0]) == 4
assert stacks[0].value.key_stack.shape[0] == 3
assert stacks[0].value.value_stack.shape[0] == 3
stacks = ts.push_states(stacks, ["D"], data)
assert len(stacks) == 1
assert len(stacks[0]) == 5
assert stacks[0].value.key_stack.shape[0] == 3
assert stacks[0].value.value_stack.shape[0] == 3
def test_two_heads():
"""
Test that the length limit drops nodes as the length limit is exceeded
"""
ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2)
initial = ts.initial_state()
assert len(initial) == 1
assert initial.value.output.shape == torch.Size([6])
assert initial.value.key_stack.shape == torch.Size([1, 6])
assert initial.value.value_stack.shape == torch.Size([1, 6])
rand_input = torch.randn(2, 4)
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
stacks = [one_step, initial]
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
assert len(stacks) == 2
assert len(stacks[0]) == 3
assert len(stacks[1]) == 2
assert stacks[0].pop().value.value == 'A'
assert stacks[0].value.value == 'B'
assert stacks[1].value.value == 'C'
assert stacks[0].value.key_stack.shape == torch.Size([3, 6])
assert stacks[1].value.key_stack.shape == torch.Size([2, 6])
================================================
FILE: stanza/tests/constituency/test_transition_sequence.py
================================================
import pytest
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import transition_sequence
from stanza.models.constituency import tree_reader
from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
from stanza.models.constituency.parse_transitions import *
from stanza.tests import *
from stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False):
"""
Starting from a tree and a list of transitions, build the tree caused by the transitions
"""
model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse)
states = model.initial_state_from_gold_trees([tree])
assert(len(states)) == 1
assert states[0].num_transitions == 0
# TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py)
for idx, t in enumerate(sequence):
assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
states = model.bulk_apply(states, [t])
result_tree = states[0].constituents.value
if reverse:
result_tree = result_tree.reverse()
return result_tree
def check_reproduce_tree(transition_scheme):
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
model = SimpleModel(transition_scheme)
transitions = transition_sequence.build_sequence(trees[0], transition_scheme)
states = model.initial_state_from_gold_trees(trees)
assert(len(states)) == 1
state = states[0]
assert state.num_transitions == 0
for t in transitions:
assert t.is_legal(state, model)
state = t.apply(state, model)
# one item for the final tree
# one item for the sentinel at the end
assert len(state.constituents) == 2
# the transition sequence should put all of the words
# from the buffer onto the tree
# one spot left for the sentinel value
assert len(state.word_queue) == 8
assert state.sentence_length == 6
assert state.word_position == state.sentence_length
assert len(state.transitions) == len(transitions) + 1
result_tree = state.constituents.value
assert result_tree == trees[0]
def test_top_down_unary():
check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
def test_top_down_no_unary():
check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN)
def test_in_order():
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER)
def test_in_order_compound():
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
def test_in_order_unary():
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY)
def test_all_transitions():
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
model = SimpleModel()
transitions = transition_sequence.build_treebank(trees)
expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")]
assert transition_sequence.all_transitions(transitions) == expected
def test_all_transitions_no_unary():
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
model = SimpleModel()
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")]
assert transition_sequence.all_transitions(transitions) == expected
def test_top_down_compound_unary():
text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
model = SimpleModel()
transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND)
states = model.initial_state_from_gold_trees(trees)
assert len(states) == 1
state = states[0]
for t in transitions:
assert t.is_legal(state, model)
state = t.apply(state, model)
result = model.get_top_constituent(state.constituents)
assert trees[0] == result
def test_chinese_tree():
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
assert redone == trees[0]
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER)
with pytest.raises(AssertionError):
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER)
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
assert redone == trees[0]
def test_chinese_tree_reversed():
"""
test that the reversed transitions also work
"""
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
assert redone == trees[0]
with pytest.raises(AssertionError):
# turn off reverse - it should fail to rebuild the tree
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
assert redone == trees[0]
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
with pytest.raises(AssertionError):
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True)
assert redone == trees[0]
with pytest.raises(AssertionError):
# turn off reverse - it should fail to rebuild the tree
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
assert redone == trees[0]
================================================
FILE: stanza/tests/constituency/test_tree_reader.py
================================================
import pytest
from stanza.models.constituency import tree_reader
from stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_simple():
"""
Tests reading two simple trees from the same text
"""
text = "(VB Unban) (NNP Opal)"
trees = tree_reader.read_trees(text)
assert len(trees) == 2
assert trees[0].is_preterminal()
assert trees[0].label == 'VB'
assert trees[0].children[0].label == 'Unban'
assert trees[1].is_preterminal()
assert trees[1].label == 'NNP'
assert trees[1].children[0].label == 'Opal'
def test_newlines():
"""
The same test should work if there are newlines
"""
text = "(VB Unban)\n\n(NNP Opal)"
trees = tree_reader.read_trees(text)
assert len(trees) == 2
def test_parens():
"""
Parens should be escaped in the tree files and escaped when written
"""
text = "(-LRB- -LRB-) (-RRB- -RRB-)"
trees = tree_reader.read_trees(text)
assert len(trees) == 2
assert trees[0].label == '-LRB-'
assert trees[0].children[0].label == '('
assert "{}".format(trees[0]) == '(-LRB- -LRB-)'
assert trees[1].label == '-RRB-'
assert trees[1].children[0].label == ')'
assert "{}".format(trees[1]) == '(-RRB- -RRB-)'
def test_complicated():
"""
A more complicated tree that should successfully read
"""
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
tree = trees[0]
assert not tree.is_leaf()
assert not tree.is_preterminal()
assert tree.label == 'ROOT'
assert len(tree.children) == 1
assert tree.children[0].label == 'SBARQ'
assert len(tree.children[0].children) == 3
assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.']
# etc etc
def test_one_word():
"""
Check that one node trees are correctly read
probably not super relevant for the parsing use case
"""
text="(FOO) (BAR)"
trees = tree_reader.read_trees(text)
assert len(trees) == 2
assert trees[0].is_leaf()
assert trees[0].label == 'FOO'
assert trees[1].is_leaf()
assert trees[1].label == 'BAR'
def test_missing_close_parens():
"""
Test the unclosed error condition
"""
text = "(Foo) \n (Bar \n zzz"
try:
trees = tree_reader.read_trees(text)
raise AssertionError("Expected an exception")
except UnclosedTreeError as e:
assert e.line_num == 1
def test_mixed_tree():
"""
Test the mixed error condition
"""
text = "(Foo) \n (Bar) \n (Unban (Mox) Opal)"
try:
trees = tree_reader.read_trees(text)
raise AssertionError("Expected an exception")
except MixedTreeError as e:
assert e.line_num == 2
trees = tree_reader.read_trees(text, broken_ok=True)
assert len(trees) == 3
def test_unlabeled_tree():
"""
Test the unlabeled error condition
"""
text = "(ROOT ((Foo) (Bar)))"
try:
trees = tree_reader.read_trees(text)
raise AssertionError("Expected an exception")
except UnlabeledTreeError as e:
assert e.line_num == 0
trees = tree_reader.read_trees(text, broken_ok=True)
assert len(trees) == 1
================================================
FILE: stanza/tests/constituency/test_tree_stack.py
================================================
import pytest
from stanza.models.constituency.tree_stack import TreeStack
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_simple():
stack = TreeStack(value=5, parent=None, length=1)
stack = stack.push(3)
stack = stack.push(1)
expected_values = [1, 3, 5]
for value in expected_values:
assert stack.value == value
stack = stack.pop()
assert stack is None
def test_iter():
stack = TreeStack(value=5, parent=None, length=1)
stack = stack.push(3)
stack = stack.push(1)
stack_list = list(stack)
assert list(stack) == [1, 3, 5]
def test_str():
stack = TreeStack(value=5, parent=None, length=1)
stack = stack.push(3)
stack = stack.push(1)
assert str(stack) == "TreeStack(1, 3, 5)"
def test_len():
stack = TreeStack(value=5, parent=None, length=1)
assert len(stack) == 1
stack = stack.push(3)
stack = stack.push(1)
assert len(stack) == 3
def test_long_len():
"""
Original stack had a bug where this took exponential time...
"""
stack = TreeStack(value=0, parent=None, length=1)
for i in range(1, 40):
stack = stack.push(i)
assert len(stack) == 40
================================================
FILE: stanza/tests/constituency/test_utils.py
================================================
import pytest
from stanza import Pipeline
from stanza.models.constituency import tree_reader
from stanza.models.constituency import utils
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def pipeline():
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
def test_xpos_retag(pipeline):
"""
Test using the English tagger that trees will be correctly retagged by read_trees using xpos
"""
text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))"
trees = tree_reader.read_trees(text)
new_trees = utils.retag_trees(trees, [pipeline], xpos=True)
assert new_trees == tree_reader.read_trees(expected)
def test_upos_retag(pipeline):
"""
Test using the English tagger that trees will be correctly retagged by read_trees using upos
"""
text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))"
trees = tree_reader.read_trees(text)
new_trees = utils.retag_trees(trees, [pipeline], xpos=False)
assert new_trees == tree_reader.read_trees(expected)
def test_replace_tags():
"""
Test the underlying replace_tags method
Also tests that the method throws exceptions when it is supposed to
"""
text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))"
expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))"
trees = tree_reader.read_trees(text)
new_tags = ["A", "B", "C"]
new_tree = trees[0].replace_tags(new_tags)
assert new_tree == tree_reader.read_trees(expected)[0]
with pytest.raises(ValueError):
new_tags = ["A", "B"]
new_tree = trees[0].replace_tags(new_tags)
with pytest.raises(ValueError):
new_tags = ["A", "B", "C", "D"]
new_tree = trees[0].replace_tags(new_tags)
================================================
FILE: stanza/tests/constituency/test_vietnamese.py
================================================
"""
A few tests for Vietnamese parsing, which has some difficulties related to spaces in words
Technically some other languages can have this, too, like that one French token
"""
import os
import tempfile
import pytest
from stanza.models.common import pretrain
from stanza.models.constituency import tree_reader
from stanza.tests.constituency.test_trainer import build_trainer
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
VI_TREEBANK = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài Loan) (" ") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
VI_TREEBANK_UNDERSCORE = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài_Loan) (" ") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))'
VI_TREEBANK_SIMPLE = '(ROOT (S (NP (" ") (N Đảo) (Np Đài Loan) (" ") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))'
VI_TREEBANK_PAREN = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
VI_TREEBANK_VLSP = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n '
VI_TREEBANK_VLSP_50 = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n '
VI_TREEBANK_VLSP_100 = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n '
EXPECTED_LABELED_BRACKETS = '(_ROOT (_S (_NP (_" " )_" (_N Đảo )_N (_Np Đài_Loan )_Np (_" " )_" (_PP (_E ở )_E (_NP (_N đồng_bằng )_N (_NP (_N sông )_N (_Np Cửu_Long )_Np )_NP )_NP )_PP )_NP (_. . )_. )_S )_ROOT'
def test_read_vi_tree():
"""
Test that an individual tree with spaces in the leaves is being processed as we expect
"""
text = VI_TREEBANK.split("\n")[0]
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert str(trees[0]) == text
# this is the first NP
# the third node of that NP, eg (Np Đài Loan)
node = trees[0].children[0].children[0].children[2]
assert node.is_preterminal()
assert node.children[0].label == "Đài Loan"
VI_EMBEDDING = """
4 4
Đảo 0.11 0.21 0.31 0.41
Đài Loan 0.12 0.22 0.32 0.42
đồng bằng 0.13 0.23 0.33 0.43
sông 0.14 0.24 0.34 0.44
""".strip()
def test_vi_embedding():
"""
Test that a VI embedding's words are correctly found when processing trees
"""
text = VI_TREEBANK.split("\n")[0]
trees = tree_reader.read_trees(text)
words = set(trees[0].leaf_labels())
with tempfile.TemporaryDirectory() as tempdir:
emb_filename = os.path.join(tempdir, "emb.txt")
pt_filename = os.path.join(tempdir, "emb.pt")
with open(emb_filename, "w", encoding="utf-8") as fout:
fout.write(VI_EMBEDDING)
pt = pretrain.Pretrain(filename=pt_filename, vec_filename=emb_filename, save_to_file=True)
pt.load()
trainer = build_trainer(pt_filename)
model = trainer.model
assert model.num_words_known(words) == 4
def test_space_formatting():
"""
By default, spaces are left as spaces, but there is a format option to change spaces
"""
text = VI_TREEBANK.split("\n")[0]
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert str(trees[0]) == text
assert "{}".format(trees[0]) == VI_TREEBANK
assert "{:_O}".format(trees[0]) == VI_TREEBANK_UNDERSCORE
def test_vlsp_formatting():
text = VI_TREEBANK_PAREN.split("\n")[0]
trees = tree_reader.read_trees(text)
assert len(trees) == 1
assert str(trees[0]) == text
assert "{:_V}".format(trees[0]) == VI_TREEBANK_VLSP
trees[0].tree_id = 50
assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_50
trees[0].tree_id = 100
assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_100
empty = tree_reader.read_trees("(ROOT)")[0]
with pytest.raises(ValueError):
"{:V}".format(empty)
branches = tree_reader.read_trees("(ROOT (1) (2) (3))")[0]
with pytest.raises(ValueError):
"{:V}".format(branches)
def test_language_formatting():
"""
Test turning the parse tree into a 'language' for GPT
"""
text = VI_TREEBANK.split("\n")[0]
trees = tree_reader.read_trees(text)
trees = [t.prune_none().simplify_labels() for t in trees]
assert len(trees) == 1
assert str(trees[0]) == VI_TREEBANK_SIMPLE
text = "{:L}".format(trees[0])
assert text == EXPECTED_LABELED_BRACKETS
================================================
FILE: stanza/tests/datasets/__init__.py
================================================
================================================
FILE: stanza/tests/datasets/coref/__init__.py
================================================
================================================
FILE: stanza/tests/datasets/coref/test_hebrew_iahlt.py
================================================
import pytest
from stanza import Pipeline
from stanza.tests import TEST_MODELS_DIR
from stanza.utils.datasets.coref.convert_hebrew_iahlt import extract_doc
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
@pytest.fixture(scope="module")
def tokenizer():
pipe = Pipeline(lang="he", processors="tokenize", dir=TEST_MODELS_DIR, download_method=None)
return pipe
TEXT = """
מבולבלים? גם אנחנו: למסעדנים והמלצרים יש עוד סימני שאלה על הטיפים
הפער בין פסיקת בית הדין לעבודה לבין פסיקה קודמת של בג"ץ, משאיר את הענף בחוסר וודאות, וה -1 בינואר כבר מעבר לפינה . "מבחינתי , הייתי מוסיף לתפריט תוספת שירות של 17% ", אמר בעלים של מסעדה בשדרות
ברשות המיסים מסתפקים במסר עמום באשר לכוונותיהם לאור פסק דין הטיפים שצפוי להיכנס לתוקפו ב-1 בינואר . על פי פרשנותם המקצועית , הבהירו, יש מקום לחייב את כספי הטיפים במע"מ , "עם זאת, הרשות עדין בוחנת את הסוגיה וטרם התקבלה החלטה אופרטיבית בעניין ". ואיך אמורים המסעדנים להיערך בינתיים ליישום הפסיקה ולמחזור השנה הבאה ? ביום חמישי יפגשו אנשי ארגון 'מסעדנים חזקים ביחד' עם מנהל רשות המיסים ערן יעקב, וידרשו תשובות ברורות.
"אני עדיין לא מדבר עם העובדים שלי , ואני גם לא יודע איך להיערך החל מעוד שבועיים", אמר ל'דבר ראשון' ניר שוחט, הבעלים של מסעדת סושי מוטו בשדרות ומוסיף כי יהיה קשה להתאים את הפסיקה למציאות בשטח . "אף אחד לא יודע. יש המון סתירות – עורך הדין אומר דבר אחד ורואה החשבון דבר אחר. עדיין לא הצליחו להבין את החוק לאשורו ".
"מבחינתי , הייתי מוסיף לתפריט תוספת שירות של 17% . זה יגלם גם את המע"מ והטיפים ומזה אני אשלם למלצרים . די כבר עם הטיפים האלה , מספיק."
"""
CLUSTER = {'metadata': {'name': 'המסעדנים', 'entity': 'person'}, 'mentions': [[28, 35, {}], [572, 581, {}]]}
def test_extract_doc(tokenizer):
doc = {'text': TEXT,
'clusters': [CLUSTER],
'metadata': {
'doc_id': 'test'
}
}
extracted = extract_doc(tokenizer, [doc])
assert len(extracted) == 1
assert len(extracted[0].coref_spans) == 2
assert extracted[0].coref_spans[1] == [(0, 4, 4)]
assert extracted[0].coref_spans[6] == [(0, 3, 4)]
================================================
FILE: stanza/tests/datasets/ner/__init__.py
================================================
================================================
FILE: stanza/tests/datasets/ner/test_prepare_ner_file.py
================================================
"""
Test some simple conversions of NER bio files
"""
import pytest
import json
from stanza.models.common.doc import Document
from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
BIO_1 = """
Jennifer B-PERSON
Sh'reyan I-PERSON
has O
lovely O
antennae O
""".strip()
BIO_2 = """
but O
I O
don't O
like O
the O
way O
Jennifer B-PERSON
treated O
Beckett B-PERSON
on O
the O
Cerritos B-LOCATION
""".strip()
def check_json_file(doc, raw_text, expected_sentences, expected_tokens):
raw_sentences = raw_text.strip().split("\n\n")
assert len(raw_sentences) == expected_sentences
if isinstance(expected_tokens, int):
expected_tokens = [expected_tokens]
for raw_sentence, expected_len in zip(raw_sentences, expected_tokens):
assert len(raw_sentence.strip().split("\n")) == expected_len
assert len(doc.sentences) == expected_sentences
for sentence, expected_len in zip(doc.sentences, expected_tokens):
assert len(sentence.tokens) == expected_len
for sentence, raw_sentence in zip(doc.sentences, raw_sentences):
for token, line in zip(sentence.tokens, raw_sentence.strip().split("\n")):
word, tag = line.strip().split()
assert token.text == word
assert token.ner == tag
def write_and_convert(tmp_path, raw_text):
bio_file = tmp_path / "test.bio"
with open(bio_file, "w", encoding="utf-8") as fout:
fout.write(raw_text)
json_file = tmp_path / "json.bio"
process_dataset(bio_file, json_file)
with open(json_file) as fin:
doc = Document(json.load(fin))
return doc
def run_test(tmp_path, raw_text, expected_sentences, expected_tokens):
doc = write_and_convert(tmp_path, raw_text)
check_json_file(doc, raw_text, expected_sentences, expected_tokens)
def test_simple(tmp_path):
run_test(tmp_path, BIO_1, 1, 5)
def test_ner_at_end(tmp_path):
run_test(tmp_path, BIO_2, 1, 12)
def test_two_sentences(tmp_path):
raw_text = BIO_1 + "\n\n" + BIO_2
run_test(tmp_path, raw_text, 2, [5, 12])
================================================
FILE: stanza/tests/datasets/ner/test_utils.py
================================================
"""
Test the utils file of the NER dataset processing
"""
import pytest
from stanza.utils.datasets.ner.utils import list_doc_entities
from stanza.tests.datasets.ner.test_prepare_ner_file import BIO_1, BIO_2, write_and_convert
def test_list_doc_entities(tmp_path):
"""
Test the function which lists all of the entities in a doc
"""
doc = write_and_convert(tmp_path, BIO_1)
entities = list_doc_entities(doc)
expected = [(('Jennifer', "Sh'reyan"), 'PERSON')]
assert expected == entities
doc = write_and_convert(tmp_path, BIO_2)
entities = list_doc_entities(doc)
expected = [(('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
assert expected == entities
doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_2]))
entities = list_doc_entities(doc)
expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
assert expected == entities
doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_1, BIO_2]))
entities = list_doc_entities(doc)
expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
assert expected == entities
================================================
FILE: stanza/tests/datasets/test_common.py
================================================
"""
Test conllu manipulating routines in stanza/utils/dataset/common.py
"""
import pytest
from stanza.utils.datasets.common import maybe_add_fake_dependencies
# from stanza.tests import *
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
DEPS_EXAMPLE="""
# text = Sh'reyan's antennae are hella thicc
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 3 nmod:poss 3:nmod:poss SpaceAfter=No
2 's 's PART POS _ 1 case 1:case _
3 antennae antenna NOUN NNS Number=Plur 6 nsubj 6:nsubj _
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
5 hella hella ADV RB _ 6 advmod 6:advmod _
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
""".strip().split("\n")
ONLY_ROOT_EXAMPLE="""
# text = Sh'reyan's antennae are hella thicc
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
2 's 's PART POS _ _ _ _ _
3 antennae antenna NOUN NNS Number=Plur _ _ _ _
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
5 hella hella ADV RB _ _ _ _ _
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
""".strip().split("\n")
ONLY_ROOT_EXPECTED="""
# text = Sh'reyan's antennae are hella thicc
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 6 dep _ SpaceAfter=No
2 's 's PART POS _ 1 dep _ _
3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
5 hella hella ADV RB _ 1 dep _ _
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
""".strip().split("\n")
NO_DEPS_EXAMPLE="""
# text = Sh'reyan's antennae are hella thicc
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
2 's 's PART POS _ _ _ _ _
3 antennae antenna NOUN NNS Number=Plur _ _ _ _
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
5 hella hella ADV RB _ _ _ _ _
6 thicc thicc ADJ JJ Degree=Pos _ _ _ _
""".strip().split("\n")
NO_DEPS_EXPECTED="""
# text = Sh'reyan's antennae are hella thicc
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 0 root _ SpaceAfter=No
2 's 's PART POS _ 1 dep _ _
3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
5 hella hella ADV RB _ 1 dep _ _
6 thicc thicc ADJ JJ Degree=Pos 1 dep _ _
""".strip().split("\n")
def test_fake_deps_no_change():
result = maybe_add_fake_dependencies(DEPS_EXAMPLE)
assert result == DEPS_EXAMPLE
def test_fake_deps_all_tokens():
result = maybe_add_fake_dependencies(NO_DEPS_EXAMPLE)
assert result == NO_DEPS_EXPECTED
def test_fake_deps_only_root():
result = maybe_add_fake_dependencies(ONLY_ROOT_EXAMPLE)
assert result == ONLY_ROOT_EXPECTED
================================================
FILE: stanza/tests/datasets/test_vietnamese_renormalization.py
================================================
import pytest
import os
from stanza.utils.datasets.vietnamese import renormalize
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_replace_all():
text = "SỌAmple tụy test file"
expected = "SOẠmple tuỵ test file"
assert renormalize.replace_all(text) == expected
def test_replace_file(tmp_path):
text = "SỌAmple tụy test file"
expected = "SOẠmple tuỵ test file"
orig = tmp_path / "orig.txt"
converted = tmp_path / "converted.txt"
with open(orig, "w", encoding="utf-8") as fout:
for i in range(10):
fout.write(text)
fout.write("\n")
renormalize.convert_file(orig, converted)
assert os.path.exists(converted)
with open(converted, encoding="utf-8") as fin:
lines = fin.readlines()
assert len(lines) == 10
for i in lines:
assert i.strip() == expected
================================================
FILE: stanza/tests/depparse/__init__.py
================================================
================================================
FILE: stanza/tests/depparse/test_depparse_data.py
================================================
"""
Test some pieces of the depparse dataloader
"""
import pytest
from stanza.models import parser
from stanza.models.depparse.data import data_to_batches, DataLoader
from stanza.utils.conll import CoNLL
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def make_fake_data(*lengths):
data = []
for i, length in enumerate(lengths):
word = chr(ord('A') + i)
chunk = [[word] * length]
data.append(chunk)
return data
def check_batches(batched_data, expected_sizes, expected_order):
for chunk, size in zip(batched_data, expected_sizes):
assert sum(len(x[0]) for x in chunk) == size
word_order = []
for chunk in batched_data:
for sentence in chunk:
word_order.append(sentence[0][0])
assert word_order == expected_order
def test_data_to_batches_eval_mode():
"""
Tests the chunking of batches in eval_mode
A few options are tested, such as whether or not to sort and the maximum sentence size
"""
data = make_fake_data(1, 2, 3)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 1], ['C', 'B', 'A'])
data = make_fake_data(1, 2, 6)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [6, 3], ['C', 'B', 'A'])
data = make_fake_data(3, 2, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 1], ['A', 'B', 'C'])
data = make_fake_data(3, 5, 2)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 5], ['B', 'A', 'C'])
data = make_fake_data(3, 5, 2)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C'])
data = make_fake_data(4, 1, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [4, 2], ['A', 'B', 'C'])
data = make_fake_data(1, 4, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C'])
EWT_PUNCT_SAMPLE = """
# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048
# text = Bush asked for permission to go to Alabama to work on a Senate campaign.
1 Bush Bush PROPN NNP Number=Sing 2 nsubj 2:nsubj _
2 asked ask VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
3 for for ADP IN _ 4 case 4:case _
4 permission permission NOUN NN Number=Sing 2 obl 2:obl:for _
5 to to PART TO _ 6 mark 6:mark _
6 go go VERB VB VerbForm=Inf 4 acl 4:acl:to _
7 to to ADP IN _ 8 case 8:case _
8 Alabama Alabama PROPN NNP Number=Sing 6 obl 6:obl:to _
9 to to PART TO _ 10 mark 10:mark _
10 work work VERB VB VerbForm=Inf 6 advcl 6:advcl:to _
11 on on ADP IN _ 14 case 14:case _
12 a a DET DT Definite=Ind|PronType=Art 14 det 14:det _
13 Senate Senate PROPN NNP Number=Sing 14 compound 14:compound _
14 campaign campaign NOUN NN Number=Sing 10 obl 10:obl:on SpaceAfter=No
15 !!!!! ! PUNCT . _ 2 punct 2:punct _
# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049
# text = His superior officers said OK.
1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nmod:poss 3:nmod:poss _
2 superior superior ADJ JJ Degree=Pos 3 amod 3:amod _
3 officers officer NOUN NNS Number=Plur 4 nsubj 4:nsubj _
4 said say VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
5 OK ok INTJ UH _ 4 obj 4:obj SpaceAfter=No
6 ????? ? PUNCT . _ 4 punct 4:punct _
"""
def test_punct_simplification():
"""
Test a punctuation simplification that should make it so unexpected
question/exclamation marks types are processed into ? and !
"""
sample = CoNLL.conll2doc(input_str=EWT_PUNCT_SAMPLE)
args = parser.parse_args(args=["--batch_size", "1000", "--shorthand", "en_test"])
data = DataLoader(sample, 5000, args, None)
batches = [batch for batch in data]
assert batches[0][-1] == [['Bush', 'asked', 'for', 'permission', 'to', 'go', 'to', 'Alabama', 'to', 'work', 'on', 'a', 'Senate', 'campaign', '!'],
['His', 'superior', 'officers', 'said', 'OK', '?']]
if __name__ == '__main__':
test_data_to_batches()
================================================
FILE: stanza/tests/depparse/test_parser.py
================================================
"""
Run the tagger for a couple iterations on some fake data
Uses a couple sentences of UD_English-EWT as training/dev data
"""
import os
import pytest
import zipfile
import torch
from stanza.models import parser
from stanza.models.common import pretrain
from stanza.models.depparse.trainer import Trainer
from stanza.tests import TEST_WORKING_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
TRAIN_DATA = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
2 : : PUNCT : _ 1 punct 1:punct _
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
6 that that SCONJ IN _ 9 mark 9:mark _
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
15 in in ADP IN _ 16 case 16:case _
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
17 . . PUNCT . _ 1 punct 1:punct _
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
2 of of ADP IN _ 3 case 3:case _
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
7 by by ADP IN _ 9 case 9:case _
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
10 of of ADP IN _ 12 case 12:case _
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
13 of of ADP IN _ 15 case 15:case _
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
16 ! ! PUNCT . _ 6 punct 6:punct _
""".lstrip()
DEV_DATA = """
1 From from ADP IN _ 3 case 3:case _
2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
7 : : PUNCT : _ 4 punct 4:punct _
""".lstrip()
class TestParser:
@pytest.fixture(scope="class")
def wordvec_pretrain_file(self):
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None, zip_train_data=False):
"""
Run the training for a few iterations, load & return the model
"""
train_file = str(tmp_path / "train.zip") if zip_train_data else str(tmp_path / "train.conllu")
dev_file = str(tmp_path / "dev.conllu")
pred_file = str(tmp_path / "pred.conllu")
save_name = "test_parser.pt"
save_file = str(tmp_path / save_name)
if zip_train_data:
with zipfile.ZipFile(train_file, "w") as zout:
with zout.open('train.conllu', 'w') as fout:
fout.write(train_text.encode())
else:
with open(train_file, "w", encoding="utf-8") as fout:
fout.write(train_text)
with open(dev_file, "w", encoding="utf-8") as fout:
fout.write(dev_text)
args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
"--train_file", train_file,
"--eval_file", dev_file,
"--output_file", pred_file,
"--log_step", "10",
"--eval_interval", "20",
"--max_steps", "100",
"--shorthand", "en_test",
"--save_dir", str(tmp_path),
"--save_name", save_name,
# in case we are doing a bert test
"--bert_start_finetuning", "10",
"--bert_warmup_steps", "10",
"--lang", "en"]
if not augment_nopunct:
args.extend(["--augment_nopunct", "0.0"])
if extra_args is not None:
args = args + extra_args
trainer, _ = parser.main(args)
assert os.path.exists(save_file)
pt = pretrain.Pretrain(wordvec_pretrain_file)
# test loading the saved model
saved_model = Trainer(pretrain=pt, model_file=save_file)
return trainer
def test_train(self, tmp_path, wordvec_pretrain_file):
"""
Simple test of a few 'epochs' of tagger training
"""
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)
def test_arc_embedding(self, tmp_path, wordvec_pretrain_file):
"""
Simple test w/ and w/o arc embedding
"""
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--use_arc_embedding'])
def test_no_arc_embedding(self, tmp_path, wordvec_pretrain_file):
"""
Simple test w/ and w/o arc embedding
"""
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--no_use_arc_embedding'])
def test_zipfile_train(self, tmp_path, wordvec_pretrain_file):
"""
Simple test of a few 'epochs' of tagger training with a zipfile
"""
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, zip_train_data=True)
def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file):
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
assert 'bert_optimizer' in trainer.optimizer.keys()
assert 'bert_scheduler' in trainer.scheduler.keys()
def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
"""
Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost
"""
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
assert 'bert_optimizer' in trainer.optimizer.keys()
assert 'bert_scheduler' in trainer.scheduler.keys()
save_name = trainer.args['save_name']
filename = tmp_path / save_name
assert os.path.exists(filename)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
# Test loading the saved model, saving it, and still having bert in it
# even if we have set bert_finetune to False for this incarnation
pt = pretrain.Pretrain(wordvec_pretrain_file)
args = {"bert_finetune": False}
saved_model = Trainer(pretrain=pt, model_file=filename, args=args)
saved_model.save(filename)
# This is the part that would fail if the force_bert_saved option did not exist
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
def test_with_peft(self, tmp_path, wordvec_pretrain_file):
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft'])
assert 'bert_optimizer' in trainer.optimizer.keys()
assert 'bert_scheduler' in trainer.scheduler.keys()
def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])
save_dir = trainer.args['save_dir']
save_name = trainer.args['save_name']
checkpoint_name = trainer.args["checkpoint_save_name"]
assert os.path.exists(os.path.join(save_dir, save_name))
assert checkpoint_name is not None
assert os.path.exists(checkpoint_name)
assert len(trainer.optimizer) == 1
for opt in trainer.optimizer.values():
assert isinstance(opt, torch.optim.Adam)
pt = pretrain.Pretrain(wordvec_pretrain_file)
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
assert checkpoint.optimizer is not None
assert len(checkpoint.optimizer) == 1
for opt in checkpoint.optimizer.values():
assert isinstance(opt, torch.optim.Adam)
def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])
save_dir = trainer.args['save_dir']
save_name = trainer.args['save_name']
checkpoint_name = trainer.args["checkpoint_save_name"]
assert os.path.exists(os.path.join(save_dir, save_name))
assert checkpoint_name is not None
assert os.path.exists(checkpoint_name)
assert len(trainer.optimizer) == 1
for opt in trainer.optimizer.values():
assert isinstance(opt, torch.optim.SGD)
pt = pretrain.Pretrain(wordvec_pretrain_file)
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
assert checkpoint.optimizer is not None
assert len(checkpoint.optimizer) == 1
for opt in trainer.optimizer.values():
assert isinstance(opt, torch.optim.SGD)
================================================
FILE: stanza/tests/langid/__init__.py
================================================
================================================
FILE: stanza/tests/langid/test_langid.py
================================================
"""
Basic tests of langid module
"""
import pytest
from stanza.models.common.doc import Document
from stanza.pipeline.core import Pipeline
from stanza.pipeline.langid_processor import LangIDProcessor
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
#pytestmark = pytest.mark.skip
@pytest.fixture(scope="module")
def basic_multilingual():
return Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid")
@pytest.fixture(scope="module")
def enfr_multilingual():
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en", "fr"])
@pytest.fixture(scope="module")
def en_multilingual():
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
@pytest.fixture(scope="module")
def clean_multilingual():
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True)
def test_langid(basic_multilingual):
"""
Basic test of language identification
"""
english_text = "This is an English sentence."
french_text = "C'est une phrase française."
docs = [english_text, french_text]
docs = [Document([], text=text) for text in docs]
basic_multilingual(docs)
predictions = [doc.lang for doc in docs]
assert predictions == ["en", "fr"]
def test_langid_benchmark(basic_multilingual):
"""
Run lang id model on 500 examples, confirm reasonable accuracy.
"""
examples = [
{"text": "contingentiam in naturalibus causis.", "label": "la"},
{"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"},
{"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"},
{"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"},
{"text": "contradictionem implicat.", "label": "la"},
{"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"},
{"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"},
{"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"},
{"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"},
{"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"},
{"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"},
{"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"},
{"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"},
{"text": "relación coas pretensións do demandante e que, nos", "label": "gl"},
{"text": "med petdeset in sedemdeset", "label": "sl"},
{"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"},
{"text": "kunnen worden.", "label": "nl"},
{"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"},
{"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"},
{"text": "olurmuş...", "label": "tr"},
{"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"},
{"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"},
{"text": "Norddal kommune :", "label": "no"},
{"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"},
{"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"},
{"text": "We do not believe the suspect has ties to this", "label": "en"},
{"text": "groziņu pīšanu.", "label": "lv"},
{"text": "Senior Vice-President David M. Thomas möchte", "label": "de"},
{"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"},
{"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"},
{"text": "outubro, alcançando R $ bilhões em .", "label": "pt"},
{"text": "(Porte, ), as it does other disciplines", "label": "en"},
{"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"},
{"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"},
{"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"},
{"text": "и рѣста еи братья", "label": "orv"},
{"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"},
{"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"},
{"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"},
{"text": "средней полосе и севернее в Ярославской,", "label": "ru"},
{"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"},
{"text": "The Global Beauties on internetilehekülg, mida", "label": "et"},
{"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"},
{"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"},
{"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"},
{"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"},
{"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"},
{"text": "цей є автентичним.", "label": "uk"},
{"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"},
{"text": "sobra personal cualificado.", "label": "es"},
{"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"},
{"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"},
{"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"},
{"text": "Li forestier du mont avale", "label": "fro"},
{"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"},
{"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"},
{"text": "balance and weight distribution but not really for", "label": "en"},
{"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"},
{"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"},
{"text": "puteare fac aceastea.", "label": "ro"},
{"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"},
{"text": "juhtimisega, tulid sealt.", "label": "et"},
{"text": "Veränderungen .", "label": "de"},
{"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"},
{"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"},
{"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"},
{"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"},
{"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"},
{"text": "직접적인 관련이 있다 .", "label": "ko"},
{"text": "가까운 듯하면서도 멀다 .", "label": "ko"},
{"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"},
{"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"},
{"text": "την τρομοκρατία.", "label": "el"},
{"text": "hurbiltzen diren neurrian.", "label": "eu"},
{"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"},
{"text": "На первом этапе (-) прошла так называемая", "label": "ru"},
{"text": "of games are on the market.", "label": "en"},
{"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"},
{"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"},
{"text": "Дошла очередь и до Гималаев.", "label": "ru"},
{"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"},
{"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"},
{"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"},
{"text": "Jeg har bodd her i år .", "label": "no"},
{"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"},
{"text": "tinham conseguido.", "label": "pt"},
{"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"},
{"text": "permanece em aberto.", "label": "pt"},
{"text": "questi possono promettere rendimenti fino a un", "label": "it"},
{"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"},
{"text": "Поведение внешне простой игрушки оказалось", "label": "ru"},
{"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"},
{"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"},
{"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"},
{"text": "essas gordurinhas.", "label": "pt"},
{"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"},
{"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"},
{"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"},
{"text": "@user nella sfortuna sei fortunata ..", "label": "it"},
{"text": "математических школ в виде грозовых туч.", "label": "ru"},
{"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"},
{"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"},
{"text": "en todo el mundo, mientras que en España consiguió", "label": "es"},
{"text": "политики считать надежное обеспечение военной", "label": "ru"},
{"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"},
{"text": "Бычий глаз.", "label": "ru"},
{"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"},
{"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"},
{"text": "오경을 가르쳤다 .", "label": "ko"},
{"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"},
{"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"},
{"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"},
{"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"},
{"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"},
{"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"},
{"text": "lugar á formación de xuizos máis complexos.", "label": "gl"},
{"text": "cheaper, in the end?", "label": "en"},
{"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"},
{"text": "tärkeintä elämässäni .", "label": "fi"},
{"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"},
{"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"},
{"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"},
{"text": "продолжение работы.", "label": "ru"},
{"text": "Metro AG plant Online-Offensive", "label": "de"},
{"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"},
{"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"},
{"text": "de turism care încasează contravaloarea", "label": "ro"},
{"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"},
{"text": "com a centre de formació en Tecnologies de la", "label": "ca"},
{"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"},
{"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"},
{"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"},
{"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"},
{"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"},
{"text": "Jos dajan nubbái: Mana!", "label": "sme"},
{"text": "toți carii ascultară de el să răsipiră.", "label": "ro"},
{"text": "odpowiedzialności, które w systemie własności", "label": "pl"},
{"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"},
{"text": "d'entre les agències internacionals.", "label": "ca"},
{"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"},
{"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"},
{"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"},
{"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"},
{"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"},
{"text": "zaměstnanosti a investičních strategií.", "label": "cs"},
{"text": "Tatínku, udělej den", "label": "cs"},
{"text": "frecuencia con Mary.", "label": "es"},
{"text": "Свеаборге.", "label": "ru"},
{"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"},
{"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"},
{"text": "Die Demonstration sollte nach Darstellung der", "label": "de"},
{"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"},
{"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"},
{"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"},
{"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"},
{"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"},
{"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"},
{"text": "depends on the data type of the field.", "label": "en"},
{"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"},
{"text": "niin heti nukun .", "label": "fi"},
{"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"},
{"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"},
{"text": "Di chi è figlia Martine Aubry?", "label": "it"},
{"text": "avec le reste du monde.", "label": "fr"},
{"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"},
{"text": "și în cazul destrămării cenaclului.", "label": "ro"},
{"text": "befriedigen kann , und ohne die auftretenden", "label": "de"},
{"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"},
{"text": "færdiguddannede.", "label": "da"},
{"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"},
{"text": "и вдаша попадь ѡпрати", "label": "orv"},
{"text": "cine nu știe învățătură”.", "label": "ro"},
{"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"},
{"text": "Der har saka også ei lengre forhistorie.", "label": "nn"},
{"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"},
{"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"},
{"text": "gesucht habe, vielen Dank nochmals!", "label": "de"},
{"text": "инструментальных сталей, повышения", "label": "ru"},
{"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"},
{"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"},
{"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"},
{"text": "za SETimes.", "label": "sr"},
{"text": "de Irak, a un piloto italiano que había violado el", "label": "es"},
{"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"},
{"text": "Прохорова.", "label": "ru"},
{"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"},
{"text": "entre ambas instituciones, confirmó al medio que", "label": "es"},
{"text": "Austlandet, vart det funne om lag førti", "label": "nn"},
{"text": "уровнями власти.", "label": "ru"},
{"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"},
{"text": "brillantes del acto, al llevar a cabo en el", "label": "es"},
{"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"},
{"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"},
{"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"},
{"text": "Даже на полсантиметра.", "label": "ru"},
{"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"},
{"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"},
{"text": "sed ad divitias congregandas, vel superfluum", "label": "la"},
{"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"},
{"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"},
{"text": "будут выступать на одинаковых снарядах.", "label": "ru"},
{"text": "Orang-orang terbunuh di sana.", "label": "id"},
{"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"},
{"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"},
{"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"},
{"text": "В средине дня, когда солнце светило в нашу", "label": "ru"},
{"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"},
{"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"},
{"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"},
{"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"},
{"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"},
{"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"},
{"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"},
{"text": "ezinbestekoa dela esan zuen.", "label": "eu"},
{"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"},
{"text": "perfectius alio videat, quamvis uterque videat", "label": "la"},
{"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"},
{"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"},
{"text": "han representerer radikal islamisme .", "label": "no"},
{"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"},
{"text": "previsto para também ser desconstruido.", "label": "pt"},
{"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"},
{"text": "para jovens de a anos nos Cieps.", "label": "pt"},
{"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"},
{"text": "objeví i u nás.", "label": "cs"},
{"text": "kvitteringer.", "label": "da"},
{"text": "This report is no exception.", "label": "en"},
{"text": "Разлепват доносниците до избирателните списъци", "label": "bg"},
{"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"},
{"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"},
{"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"},
{"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"},
{"text": "Teles-Einstieg in ADSL-Markt", "label": "de"},
{"text": "ettekäändeks liiga suure osamaksu.", "label": "et"},
{"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"},
{"text": "quod quidem aperte consequitur ponentes", "label": "la"},
{"text": "de negociación para el próximo de junio.", "label": "es"},
{"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"},
{"text": "največjega uspeha doslej.", "label": "sl"},
{"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"},
{"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"},
{"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"},
{"text": "otros clubes y del Barça B saldrán varios", "label": "es"},
{"text": "Jaskula (Pol.) -", "label": "cs"},
{"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"},
{"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"},
{"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"},
{"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"},
{"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"},
{"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"},
{"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"},
{"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"},
{"text": "plângere în termen de zile de la comunicarea", "label": "ro"},
{"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"},
{"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"},
{"text": "aunque se mostró contrario a establecer un", "label": "es"},
{"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"},
{"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"},
{"text": "herunterlädt .", "label": "de"},
{"text": "стрессовую ситуацию для организма, каковой", "label": "ru"},
{"text": "Státního shromáždění (parlamentu).", "label": "cs"},
{"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"},
{"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"},
{"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"},
{"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"},
{"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"},
{"text": "не желаят запазването на статуквото.", "label": "bg"},
{"text": "Offenburg gewesen .", "label": "de"},
{"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"},
{"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"},
{"text": "historischen Gänselieselbrunnens.", "label": "de"},
{"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"},
{"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"},
{"text": "movementos migratorios.", "label": "gl"},
{"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"},
{"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"},
{"text": "H.G. Bücknera.", "label": "pl"},
{"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"},
{"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"},
{"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"},
{"text": "poidan experimentar as percepcións do interesado,", "label": "gl"},
{"text": "Miał przecież w kieszeni nóż.", "label": "pl"},
{"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"},
{"text": "esim. helpottamalla luottoja muiden", "label": "fi"},
{"text": "Podle předběžných výsledků zvítězila v", "label": "cs"},
{"text": "Nicht nur das Web-Frontend , auch die", "label": "de"},
{"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"},
{"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"},
{"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"},
{"text": "лѧхове же не идоша", "label": "orv"},
{"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"},
{"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"},
{"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"},
{"text": "činí kyperských liber, to je asi USD.", "label": "cs"},
{"text": "Studienplätze getauscht werden .", "label": "de"},
{"text": "учёных, орнитологов признают вид.", "label": "ru"},
{"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"},
{"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"},
{"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"},
{"text": "offentligheten av unge , sinte menn som har", "label": "no"},
{"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"},
{"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"},
{"text": "El Apostolado de la prensa contribuye en modo", "label": "es"},
{"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"},
{"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"},
{"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"},
{"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"},
{"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"},
{"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"},
{"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"},
{"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"},
{"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"},
{"text": "Communicator bekannt wurde .", "label": "de"},
{"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"},
{"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"},
{"text": "دستهجمعی در درخشندگی ماه سیمگون زمزمه ستاینده و", "label": "fa"},
{"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"},
{"text": "putares canem leporem persequi.", "label": "la"},
{"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"},
{"text": "offizielles Verfahren gegen die Telekom", "label": "de"},
{"text": "podrían haber sido habitantes de la Península", "label": "es"},
{"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"},
{"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"},
{"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"},
{"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"},
{"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"},
{"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"},
{"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"},
{"text": "Een zin van heb ik jou daar", "label": "nl"},
{"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"},
{"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"},
{"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"},
{"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"},
{"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"},
{"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"},
{"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"},
{"text": "výpis o počtu akcií.", "label": "cs"},
{"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"},
{"text": "Tentu saja, tren yang berhubungandengan", "label": "id"},
{"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"},
{"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"},
{"text": "The islands are in their own time zone, minutes", "label": "en"},
{"text": "Auswahl debütierte er am .", "label": "de"},
{"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"},
{"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"},
{"text": "Time scything the hours, but at the top, over the", "label": "en"},
{"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"},
{"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"},
{"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"},
{"text": "непроницаемым, как для СССР.", "label": "ru"},
{"text": "Ma non mancano le polemiche.", "label": "it"},
{"text": "Internet als Ort politischer Diskussion und auch", "label": "de"},
{"text": "incomplets.", "label": "ca"},
{"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"},
{"text": "informazione.", "label": "it"},
{"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"},
{"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"},
{"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"},
{"text": "[speaker:laughter]", "label": "sl"},
{"text": "Dog med langt mindre utstyr med seg.", "label": "nn"},
{"text": "dass es nicht schon mit der anfänglichen", "label": "de"},
{"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"},
{"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"},
{"text": "dell'Assemblea Costituente che posseggono i", "label": "it"},
{"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"},
{"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"},
{"text": "законодательных норм, принимаемых существующей", "label": "ru"},
{"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"},
{"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"},
{"text": "unterschiedlicher Meinung .", "label": "de"},
{"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"},
{"text": "Añadió que, en el futuro se harán otros", "label": "es"},
{"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"},
{"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"},
{"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"},
{"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"},
{"text": "Norsk politikk frå til kan dermed, i", "label": "nn"},
{"text": "Głosowało posłów.", "label": "pl"},
{"text": "Danny Jones -- smithjones@ev.net", "label": "en"},
{"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"},
{"text": "относительно спокойный сон: тому гарантия", "label": "ru"},
{"text": "A halte voiz prist li pedra a crïer", "label": "fro"},
{"text": "آنها امیدوارند این واکسن بهزودی در دسترس بیماران", "label": "fa"},
{"text": "vlastní důstojnou vousatou tváří.", "label": "cs"},
{"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"},
{"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"},
{"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"},
{"text": "entwickeln .", "label": "de"},
{"text": "incolumità pubblica.", "label": "it"},
{"text": "lehtija televisiomainonta", "label": "fi"},
{"text": "joistakin kohdista eri mieltä.", "label": "fi"},
{"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"},
{"text": "pásmech od do bodů bodové stupnice.", "label": "cs"},
{"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"},
{"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"},
{"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"},
{"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"},
{"text": "Beides sollte gelöscht werden!", "label": "de"},
{"text": "modifiqués la seva petició inicial de anys de", "label": "ca"},
{"text": "В день открытия симпозиума состоялась закладка", "label": "ru"},
{"text": "tõestatud.", "label": "et"},
{"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"},
{"text": "bisher nie enttäuscht!", "label": "de"},
{"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"},
{"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"},
{"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"},
{"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"},
{"text": "potentia Schöpfer.", "label": "de"},
{"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"},
{"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"},
{"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"},
{"text": "Настоящие герои, те, чьи истории потом", "label": "ru"},
{"text": "praesumptio:", "label": "la"},
{"text": "Olin justkui nende vastutusel.", "label": "et"},
{"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"},
{"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"},
{"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"},
{"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"},
{"text": "Windows CE zu integrieren .", "label": "de"},
{"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"},
{"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"},
{"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"},
{"text": "Насчитал, что с начала года всего три дня были", "label": "ru"},
{"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"},
{"text": "la presenza di ben veleni diversi: . chili di", "label": "it"},
{"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"},
{"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"},
{"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"},
{"text": "investigating since March.", "label": "en"},
{"text": "Tonböden (Mullböden).", "label": "de"},
{"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"},
{"text": "červnu předložené smlouvy.", "label": "cs"},
{"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"},
{"text": ".%의 신장세를 보였다.", "label": "ko"},
{"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"},
{"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"},
{"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"},
{"text": "As informações são da Dow Jones.", "label": "pt"},
{"text": "Milliarde DM ausgestattet sein .", "label": "de"},
{"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"},
{"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"},
{"text": "De ahí que en este mercado puedan negociarse", "label": "es"},
{"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"},
{"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"},
{"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"},
{"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"},
{"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"},
{"text": "durat motus, aliquid fit et non est;", "label": "la"},
{"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"},
{"text": "beantwortet .", "label": "de"},
{"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"},
{"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"},
{"text": "chợt tan biến.", "label": "vi"},
{"text": "avtomobil ločuje od drugih.", "label": "sl"},
{"text": "Congress has proven itself ineffective as a body.", "label": "en"},
{"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"},
{"text": "No minimum order amount.", "label": "en"},
{"text": "Convertassa .", "label": "fi"},
{"text": "Как это можно сделать?", "label": "ru"},
{"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"},
{"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"},
{"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"},
{"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"},
{"text": "La comida y el servicio aprueban.", "label": "es"},
{"text": "again, connected not with each other but to the", "label": "en"},
{"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"},
{"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"},
{"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"},
{"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"},
{"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"},
{"text": "Bronzen medaille in de Europese marathon.", "label": "nl"},
{"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"},
{"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"},
{"text": "cuartos de final, su clasificación para la final a", "label": "es"},
{"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"},
{"text": "Way to go!", "label": "en"},
{"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"},
{"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"},
{"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"},
{"text": "Dist Clarïen", "label": "fro"},
{"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"},
{"text": "Software von NetObjects ist nach Angaben des", "label": "de"},
{"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"},
{"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"},
{"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"},
{"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"},
{"text": "pemmican på hundrede forskellige måder.", "label": "da"},
{"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"},
{"text": "I highly recommend his shop.", "label": "en"},
{"text": "verità, giovani fedeli prostratevi #amen", "label": "it"},
{"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"},
{"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"},
{"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"},
{"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"},
{"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"},
{"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"},
{"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"},
{"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"},
{"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"},
{"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}]
docs = [Document([], text=example["text"]) for example in examples]
gold_labels = [example["label"] for example in examples]
basic_multilingual(docs)
accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs)
assert accuracy >= 0.98
def test_text_cleaning(basic_multilingual, clean_multilingual):
"""
Basic test of cleaning text
"""
docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
"Bonjour le monde! https://t.co/U0Zjp3tusD"]
docs = [Document([], text=text) for text in docs]
basic_multilingual(docs)
assert [doc.lang for doc in docs] == ["it", "it"]
assert clean_multilingual.processors["langid"]._clean_text
clean_multilingual(docs)
assert [doc.lang for doc in docs] == ["fr", "fr"]
def test_emoji_cleaning():
TEXT = ["Sh'reyan has nice antennae :thumbs_up:",
"This is🐱 a cat"]
EXPECTED = ["Sh'reyan has nice antennae",
"This is a cat"]
for text, expected in zip(TEXT, EXPECTED):
assert LangIDProcessor.clean_text(text) == expected
def test_lang_subset(basic_multilingual, enfr_multilingual, en_multilingual):
"""
Basic test of restricting output to subset of languages
"""
docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
"Bonjour le monde! https://t.co/U0Zjp3tusD"]
docs = [Document([], text=text) for text in docs]
basic_multilingual(docs)
assert [doc.lang for doc in docs] == ["it", "it"]
assert enfr_multilingual.processors["langid"]._model.lang_subset == ["en", "fr"]
enfr_multilingual(docs)
assert [doc.lang for doc in docs] == ["fr", "fr"]
assert en_multilingual.processors["langid"]._model.lang_subset == ["en"]
en_multilingual(docs)
assert [doc.lang for doc in docs] == ["en", "en"]
def test_lang_subset_unlikely_language(en_multilingual):
"""
Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely
"""
sentences = ["你好" * 200]
docs = [Document([], text=text) for text in sentences]
en_multilingual(docs)
assert [doc.lang for doc in docs] == ["en"]
processor = en_multilingual.processors['langid']
model = processor._model
text_tensor = processor._text_to_tensor(sentences)
en_idx = model.tag_to_idx['en']
predictions = model(text_tensor)
assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input"
================================================
FILE: stanza/tests/langid/test_multilingual.py
================================================
"""
Tests specifically for the MultilingualPipeline
"""
from collections import defaultdict
import pytest
from stanza.pipeline.multilingual import MultilingualPipeline
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=True, **kwargs):
english_text = "This is an English sentence."
english_words = ["This", "is", "an", "English", "sentence", "."]
english_deps_gold = "\n".join((
"('This', 5, 'nsubj')",
"('is', 5, 'cop')",
"('an', 5, 'det')",
"('English', 5, 'amod')",
"('sentence', 0, 'root')",
"('.', 5, 'punct')"
))
if not en_has_dependencies:
english_deps_gold = ""
french_text = "C'est une phrase française."
french_words = ["C'", "est", "une", "phrase", "française", "."]
french_deps_gold = "\n".join((
"(\"C'\", 4, 'nsubj')",
"('est', 4, 'cop')",
"('une', 4, 'det')",
"('phrase', 0, 'root')",
"('française', 4, 'amod')",
"('.', 4, 'punct')"
))
if not fr_has_dependencies:
french_deps_gold = ""
if 'lang_configs' in kwargs:
nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, **kwargs)
else:
lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse"},
"fr": {"processors": "tokenize,pos,lemma,depparse"}}
nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, lang_configs=lang_configs, **kwargs)
docs = [english_text, french_text]
docs = nlp(docs)
assert docs[0].lang == "en"
assert len(docs[0].sentences) == 1
assert [x.text for x in docs[0].sentences[0].words] == english_words
assert docs[0].sentences[0].dependencies_string() == english_deps_gold
assert len(docs[1].sentences) == 1
assert docs[1].lang == "fr"
assert [x.text for x in docs[1].sentences[0].words] == french_words
assert docs[1].sentences[0].dependencies_string() == french_deps_gold
def test_multilingual_pipeline():
"""
Basic test of multilingual pipeline
"""
run_multilingual_pipeline()
def test_multilingual_pipeline_small_cache():
"""
Test with the cache size 1
"""
run_multilingual_pipeline(max_cache_size=1)
def test_multilingual_config():
"""
Test with only tokenize for the EN pipeline
"""
lang_configs = {
"en": {"processors": "tokenize"}
}
run_multilingual_pipeline(en_has_dependencies=False, lang_configs=lang_configs)
def test_multilingual_processors_limited():
"""
Test loading an available subset of processors
"""
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize")
run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs={"en": {"processors": "tokenize,pos,lemma,depparse"}}, processors="tokenize")
# this should not fail, as it will drop the zzzzzzzzzz processor for the languages which don't have it
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize,zzzzzzzzzz")
def test_defaultdict_config():
"""
Test that you can pass in a defaultdict for the lang_configs argument
"""
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs=lang_configs)
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
lang_configs["en"] = {"processors": "tokenize,pos,lemma,depparse"}
run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs=lang_configs)
================================================
FILE: stanza/tests/lemma/__init__.py
================================================
================================================
FILE: stanza/tests/lemma/test_data.py
================================================
"""
Test a couple basic data functions, such as processing a doc for its lemmas
"""
import pytest
from stanza.models.common.doc import Document
from stanza.models.lemma.data import DataLoader
from stanza.utils.conll import CoNLL
TRAIN_DATA = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
2 : : PUNCT : _ 1 punct 1:punct _
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
6 that that SCONJ IN _ 9 mark 9:mark _
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
15 in in ADP IN _ 16 case 16:case _
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
17 . . PUNCT . _ 1 punct 1:punct _
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
2 of of ADP IN _ 3 case 3:case _
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
7 by by ADP IN _ 9 case 9:case _
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
10 of of ADP IN _ 12 case 12:case _
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
13 of of ADP IN _ 15 case 15:case _
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
16 ! ! PUNCT . _ 6 punct 6:punct _
""".lstrip()
GOESWITH_DATA = """
# sent_id = email-enronsent27_01-0041
# newpar id = email-enronsent27_01-p0005
# text = Ken Rice@ENRON COMMUNICATIONS
1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _
2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _
3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _
""".lstrip()
CORRECT_FORM_DATA = """
# sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019
# text = They are targetting ambulances
1 They they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 3 nsubj 3:nsubj _
2 are be AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
3 targetting target VERB VBG Tense=Pres|Typo=Yes|VerbForm=Part 0 root 0:root CorrectForm=targeting
4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No
"""
BLANKS_DATA = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0018
# text = Guerrillas killed an engineer, Asi Ali, from Tikrit.
1 Guerrillas _ NOUN NNS Number=Plur 2 nsubj 2:nsubj _
2 killed _ VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
3 an a DET DT Definite=Ind|PronType=Art 4 det 4:det _
4 engineer _ NOUN NN Number=Sing 2 obj 2:obj SpaceAfter=No
""".lstrip()
def test_load_document():
train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
assert len(data) == 33 # meticulously counted by hand
assert all(len(x) == 3 for x in data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
assert len(data) == 33
assert all(len(x) == 3 for x in data)
def test_load_goeswith():
raw_data = TRAIN_DATA + GOESWITH_DATA
train_doc = CoNLL.conll2doc(input_str=raw_data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
assert len(data) == 36 # will be the same as in test_load_document with three additional words
assert all(len(x) == 3 for x in data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed
assert all(len(x) == 3 for x in data)
def test_correct_form():
raw_data = TRAIN_DATA + CORRECT_FORM_DATA
train_doc = CoNLL.conll2doc(input_str=raw_data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
assert len(data) == 37
# the 'targeting' correction should not be applied if evaluation=True
# when evaluation=False, then the CorrectForms will be applied
assert not any(x[0] == 'targeting' for x in data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'
assert any(x[0] == 'targeting' for x in data)
assert any(x[0] == 'targetting' for x in data)
def test_load_blank():
raw_data = TRAIN_DATA + BLANKS_DATA
train_doc = CoNLL.conll2doc(input_str=raw_data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
assert len(data) == 37 # will be the same as in test_load_document with FOUR additional words
assert all(len(x) == 3 for x in data)
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=True, evaluation=False)
assert len(data) == 34 # will be the same as in test_load_document, but one extra word is added. others were blank
assert all(len(x) == 3 for x in data)
================================================
FILE: stanza/tests/lemma/test_lemma_trainer.py
================================================
"""
Test a couple basic functions - load & save an existing model
"""
import pytest
import glob
import os
import tempfile
import torch
from stanza.models import lemmatizer
from stanza.models.lemma import trainer
from stanza.tests import *
from stanza.utils.training.common import choose_lemma_charlm, build_charlm_args
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def english_model():
models_path = os.path.join(TEST_MODELS_DIR, "en", "lemma", "*")
models = glob.glob(models_path)
# we expect at least one English model downloaded for the tests
assert len(models) >= 1, "No English lemma models downloaded during setup! Please make sure to run the setup script."
for model_file in models:
if "nocharlm" in model_file:
return trainer.Trainer(model_file=model_file)
raise FileNotFoundError("Should have downloaded the nocharlm English lemmatizer during setup. Please rerun the setup script.")
def test_load_model(english_model):
"""
Does nothing, just tests that loading works
"""
def test_save_load_model(english_model):
"""
Load, save, and load again
"""
with tempfile.TemporaryDirectory() as tempdir:
save_file = os.path.join(tempdir, "resaved", "lemma.pt")
english_model.save(save_file)
reloaded = trainer.Trainer(model_file=save_file)
TRAIN_DATA = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
2 : : PUNCT : _ 1 punct 1:punct _
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
6 that that SCONJ IN _ 9 mark 9:mark _
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
15 in in ADP IN _ 16 case 16:case _
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
17 . . PUNCT . _ 1 punct 1:punct _
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
2 of of ADP IN _ 3 case 3:case _
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
7 by by ADP IN _ 9 case 9:case _
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
10 of of ADP IN _ 12 case 12:case _
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
13 of of ADP IN _ 15 case 15:case _
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
16 ! ! PUNCT . _ 6 punct 6:punct _
""".lstrip()
DEV_DATA = """
1 From from ADP IN _ 3 case 3:case _
2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
7 : : PUNCT : _ 4 punct 4:punct _
""".lstrip()
class TestLemmatizer:
@pytest.fixture(scope="class")
def charlm_args(self):
charlm = choose_lemma_charlm("en", "test", "default")
charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
return charlm_args
def run_training(self, tmp_path, train_text, dev_text, extra_args=None):
"""
Run the training for a few iterations, load & return the model
"""
pred_file = str(tmp_path / "pred.conllu")
save_name = "test_tagger.pt"
save_file = str(tmp_path / save_name)
train_file = str(tmp_path / "train.conllu")
with open(train_file, "w", encoding="utf-8") as fout:
fout.write(train_text)
dev_file = str(tmp_path / "dev.conllu")
with open(dev_file, "w", encoding="utf-8") as fout:
fout.write(dev_text)
args = ["--train_file", train_file,
"--eval_file", dev_file,
"--output_file", pred_file,
"--num_epoch", "2",
"--log_step", "10",
"--save_dir", str(tmp_path),
"--save_name", save_name,
"--shorthand", "en_test"]
if extra_args is not None:
args = args + extra_args
lemmatizer.main(args)
assert os.path.exists(save_file)
saved_model = trainer.Trainer(model_file=save_file)
return saved_model
def test_basic_train(self, tmp_path):
"""
Simple test of a few 'epochs' of lemmatizer training
"""
self.run_training(tmp_path, TRAIN_DATA, DEV_DATA)
def test_charlm_train(self, tmp_path, charlm_args):
"""
Simple test of a few 'epochs' of lemmatizer training
"""
saved_model = self.run_training(tmp_path, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)
# check that the charlm wasn't saved in here
args = saved_model.args
save_name = os.path.join(args['save_dir'], args['save_name'])
checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)
assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys())
================================================
FILE: stanza/tests/lemma/test_lowercase.py
================================================
import pytest
from stanza.models.lemmatizer import all_lowercase
from stanza.utils.conll import CoNLL
LATIN_CONLLU = """
# sent_id = train-s1
# text = unde et philosophus dicit felicitatem esse operationem perfectam.
# reference = ittb-scg-s4203
1 unde unde ADV O4 AdvType=Loc|PronType=Rel 4 advmod:lmod _ _
2 et et CCONJ O4 _ 3 advmod:emph _ _
3 philosophus philosophus NOUN B1|grn1|casA|gen1 Case=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing 4 nsubj _ _
4 dicit dico VERB N3|modA|tem1|gen6 Aspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
5 felicitatem felicitas NOUN C1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 7 nsubj _ _
6 esse sum AUX N3|modH|tem1 Aspect=Imp|Tense=Pres|VerbForm=Inf 7 cop _ _
7 operationem operatio NOUN C1|grn1|casD|gen2|vgr1 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 4 ccomp _ _
8 perfectam perfectus ADJ A1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing 7 amod _ SpaceAfter=No
9 . . PUNCT Punc _ 4 punct _ _
# sent_id = train-s2
# text = perfectio autem operationis dependet ex quatuor.
# reference = ittb-scg-s4204
1 perfectio perfectio NOUN C1|grn1|casA|gen2 Case=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing 4 nsubj _ _
2 autem autem PART O4 _ 4 discourse _ _
3 operationis operatio NOUN C1|grn1|casB|gen2|vgr1 Case=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing 1 nmod _ _
4 dependet dependeo VERB K3|modA|tem1|gen6 Aspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
5 ex ex ADP S4|vgr2 _ 6 case _ _
6 quatuor quattuor NUM G1|gen3|vgr1 NumForm=Word|NumType=Card 4 obl:arg _ SpaceAfter=No
7 . . PUNCT Punc _ 4 punct _ _
""".lstrip()
ENG_CONLLU = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007
# text = You wonder if he was manipulating the market with his bombing targets.
1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj 2:nsubj _
2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 0 root 0:root _
3 if if SCONJ IN _ 6 mark 6:mark _
4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj 6:nsubj _
5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp 2:ccomp _
7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
8 market market NOUN NN Number=Sing 6 obj 6:obj _
9 with with ADP IN _ 12 case 12:case _
10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss 12:nmod:poss _
11 bombing bombing NOUN NN Number=Sing 12 compound 12:compound _
12 targets target NOUN NNS Number=Plur 6 obl 6:obl:with SpaceAfter=No
13 . . PUNCT . _ 2 punct 2:punct _
""".lstrip()
def test_all_lowercase():
doc = CoNLL.conll2doc(input_str=LATIN_CONLLU)
assert all_lowercase(doc)
def test_not_all_lowercase():
doc = CoNLL.conll2doc(input_str=ENG_CONLLU)
assert not all_lowercase(doc)
================================================
FILE: stanza/tests/lemma_classifier/__init__.py
================================================
================================================
FILE: stanza/tests/lemma_classifier/test_data_preparation.py
================================================
import os
import pytest
import stanza.models.lemma_classifier.utils as utils
import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
EWT_ONE_SENTENCE = """
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
# text = Here's a Miami Herald interview
1-2 Here's _ _ _ _ _ _ _ _
1 Here here ADV RB PronType=Dem 0 root 0:root _
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
""".lstrip()
EWT_TRAIN_SENTENCES = """
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
# text = Here's a Miami Herald interview
1-2 Here's _ _ _ _ _ _ _ _
1 Here here ADV RB PronType=Dem 0 root 0:root _
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027
# text = But Posada's nearly 80 years old
1 But but CCONJ CC _ 7 cc 7:cc _
2-3 Posada's _ _ _ _ _ _ _ _
2 Posada Posada PROPN NNP Number=Sing 7 nsubj 7:nsubj _
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 cop 7:cop _
4 nearly nearly ADV RB _ 5 advmod 5:advmod _
5 80 80 NUM CD NumForm=Digit|NumType=Card 6 nummod 6:nummod _
6 years year NOUN NNS Number=Plur 7 obl:npmod 7:obl:npmod _
7 old old ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011
# text = Now that's a post I can relate to.
1 Now now ADV RB _ 5 advmod 5:advmod _
2-3 that's _ _ _ _ _ _ _ _
2 that that PRON DT Number=Sing|PronType=Dem 5 nsubj 5:nsubj _
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
5 post post NOUN NN Number=Sing 0 root 0:root _
6 I I PRON PRP Case=Nom|Number=Sing|Person=1|PronType=Prs 8 nsubj 8:nsubj _
7 can can AUX MD VerbForm=Fin 8 aux 8:aux _
8 relate relate VERB VB VerbForm=Inf 5 acl:relcl 5:acl:relcl _
9 to to ADP IN _ 8 obl 8:obl SpaceAfter=No
10 . . PUNCT . _ 5 punct 5:punct _
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012
# text = hey that's a great blog
1 hey hey INTJ UH _ 6 discourse 6:discourse _
2-3 that's _ _ _ _ _ _ _ _
2 that that PRON DT Number=Sing|PronType=Dem 6 nsubj 6:nsubj _
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
5 great great ADJ JJ Degree=Pos 6 amod 6:amod _
6 blog blog NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089
# text = And It's Not Hard To Do
1 And and CCONJ CC _ 5 cc 5:cc _
2-3 It's _ _ _ _ _ _ _ _
2 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 expl 5:expl _
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
4 Not not PART RB _ 5 advmod 5:advmod _
5 Hard hard ADJ JJ Degree=Pos 0 root 0:root _
6 To to PART TO _ 7 mark 7:mark _
7 Do do VERB VB VerbForm=Inf 5 csubj 5:csubj SpaceAfter=No
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029
# text = Meanwhile, a decision's been reached
1 Meanwhile meanwhile ADV RB _ 7 advmod 7:advmod SpaceAfter=No
2 , , PUNCT , _ 1 punct 1:punct _
3 a a DET DT Definite=Ind|PronType=Art 4 det 4:det _
4-5 decision's _ _ _ _ _ _ _ _
4 decision decision NOUN NN Number=Sing 7 nsubj:pass 7:nsubj:pass _
5 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux 7:aux _
6 been be AUX VBN Tense=Past|VerbForm=Part 7 aux:pass 7:aux:pass _
7 reached reach VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138
# text = It's become a guardian of morality
1-2 It's _ _ _ _ _ _ _ _
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|5:nsubj:xsubj _
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
3 become become VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
5 guardian guardian NOUN NN Number=Sing 3 xcomp 3:xcomp _
6 of of ADP IN _ 7 case 7:case _
7 morality morality NOUN NN Number=Sing 5 nmod 5:nmod:of _
# sent_id = email-enronsent15_01-0018
# text = It's got its own bathroom and tv
1-2 It's _ _ _ _ _ _ _ _
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|13:nsubj _
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
3 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
4 its its PRON PRP$ Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 6 nmod:poss 6:nmod:poss _
5 own own ADJ JJ Degree=Pos 6 amod 6:amod _
6 bathroom bathroom NOUN NN Number=Sing 3 obj 3:obj _
7 and and CCONJ CC _ 8 cc 8:cc _
8 tv TV NOUN NN Number=Sing 6 conj 3:obj|6:conj:and SpaceAfter=No
# sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022
# text = It's also got the website
1-2 It's _ _ _ _ _ _ _ _
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
3 also also ADV RB _ 4 advmod 4:advmod _
4 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
6 website website NOUN NN Number=Sing 4 obj 4:obj|12:obl _
""".lstrip()
# from the train set, actually
EWT_DEV_SENTENCES = """
# sent_id = answers-20111108104724AAuBUR7_ans-0044
# text = He's only exhibited weight loss and some muscle atrophy
1-2 He's _ _ _ _ _ _ _ _
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
3 only only ADV RB _ 4 advmod 4:advmod _
4 exhibited exhibit VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
5 weight weight NOUN NN Number=Sing 6 compound 6:compound _
6 loss loss NOUN NN Number=Sing 4 obj 4:obj _
7 and and CCONJ CC _ 10 cc 10:cc _
8 some some DET DT PronType=Ind 10 det 10:det _
9 muscle muscle NOUN NN Number=Sing 10 compound 10:compound _
10 atrophy atrophy NOUN NN Number=Sing 6 conj 4:obj|6:conj:and SpaceAfter=No
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097
# text = It's a good thing too.
1-2 It's _ _ _ _ _ _ _ _
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
3 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
4 good good ADJ JJ Degree=Pos 5 amod 5:amod _
5 thing thing NOUN NN Number=Sing 0 root 0:root _
6 too too ADV RB _ 5 advmod 5:advmod SpaceAfter=No
7 . . PUNCT . _ 5 punct 5:punct _
""".lstrip()
# from the train set, actually
EWT_TEST_SENTENCES = """
# sent_id = reviews-162422-0015
# text = He said he's had a long and bad day.
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 2 nsubj 2:nsubj _
2 said say VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
3-4 he's _ _ _ _ _ _ _ _
3 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
4 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux 5:aux _
5 had have VERB VBN Tense=Past|VerbForm=Part 2 ccomp 2:ccomp _
6 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _
7 long long ADJ JJ Degree=Pos 10 amod 10:amod _
8 and and CCONJ CC _ 9 cc 9:cc _
9 bad bad ADJ JJ Degree=Pos 7 conj 7:conj:and|10:amod _
10 day day NOUN NN Number=Sing 5 obj 5:obj SpaceAfter=No
11 . . PUNCT . _ 2 punct 2:punct _
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100
# text = What's a few dead soldiers
1-2 What's _ _ _ _ _ _ _ _
1 What what PRON WP PronType=Int 6 nsubj 6:nsubj _
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
4 few few ADJ JJ Degree=Pos 6 amod 6:amod _
5 dead dead ADJ JJ Degree=Pos 6 amod 6:amod _
6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _
"""
def write_test_dataset(tmp_path, texts, datasets):
ud_path = tmp_path / "ud"
input_path = ud_path / "UD_English-EWT"
output_path = tmp_path / "data" / "lemma_classifier"
os.makedirs(input_path, exist_ok=True)
for text, dataset in zip(texts, datasets):
sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset)
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(text)
paths = {"UDBASE": ud_path,
"LEMMA_CLASSIFIER_DATA_DIR": output_path}
return paths
def write_english_test_dataset(tmp_path):
texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)
datasets = prepare_lemma_classifier.SECTIONS
return write_test_dataset(tmp_path, texts, datasets)
def convert_english_dataset(tmp_path):
paths = write_english_test_dataset(tmp_path)
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have")
assert len(converted_files) == 3
return converted_files
def test_convert_one_sentence(tmp_path):
texts = [EWT_ONE_SENTENCE]
datasets = ["train"]
paths = write_test_dataset(tmp_path, texts, datasets)
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"])
assert len(converted_files) == 1
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
assert dataset.label_decoder == {'be': 0}
id_to_upos = {y: x for x, y in dataset.upos_to_id.items()}
for text_batches, _, upos_batches, _ in dataset:
assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']]
upos = [id_to_upos[x] for x in upos_batches[0]]
assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']
def test_convert_dataset(tmp_path):
converted_files = convert_english_dataset(tmp_path)
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
label_decoder = dataset.label_decoder
assert len(label_decoder) == 2
assert "be" in label_decoder
assert "have" in label_decoder
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 9
dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 2
dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 2
================================================
FILE: stanza/tests/lemma_classifier/test_training.py
================================================
import glob
import os
import pytest
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
from stanza.models.lemma_classifier import train_lstm_model
from stanza.models.lemma_classifier import train_transformer_model
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
from stanza.tests import TEST_WORKING_DIR
from stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset
@pytest.fixture(scope="module")
def pretrain_file():
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
def test_train_lstm(tmp_path, pretrain_file):
converted_files = convert_english_dataset(tmp_path)
save_name = str(tmp_path / 'lemma.pt')
train_file = converted_files[0]
eval_file = converted_files[1]
train_args = ['--wordvec_pretrain_file', pretrain_file,
'--save_name', save_name,
'--train_file', train_file,
'--eval_file', eval_file]
trainer = train_lstm_model.main(train_args)
evaluate_model(trainer.model, eval_file)
# test that loading the model works
model = LemmaClassifier.load(save_name, None)
def test_train_transformer(tmp_path, pretrain_file):
converted_files = convert_english_dataset(tmp_path)
save_name = str(tmp_path / 'lemma.pt')
train_file = converted_files[0]
eval_file = converted_files[1]
train_args = ['--bert_model', 'hf-internal-testing/tiny-bert',
'--save_name', save_name,
'--train_file', train_file,
'--eval_file', eval_file]
trainer = train_transformer_model.main(train_args)
evaluate_model(trainer.model, eval_file)
# test that loading the model works
model = LemmaClassifier.load(save_name, None)
================================================
FILE: stanza/tests/morphseg/__init__.py
================================================
================================================
FILE: stanza/tests/morphseg/conftest.py
================================================
"""
Shared pytest fixtures and configuration
"""
import pytest
from morphseg import MorphemeSegmenter
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
@pytest.fixture(scope="session")
def english_segmenter():
"""
Load English segmenter once for the entire test session
"""
return MorphemeSegmenter('en')
@pytest.fixture(scope="session")
def all_segmenters():
"""
Load all supported language segmenters
"""
segmenters = {}
for lang in MorphemeSegmenter.PRETRAINED_MODEL_LANGS:
segmenters[lang] = MorphemeSegmenter(lang)
return segmenters
def pytest_configure(config):
"""
Custom pytest configuration
"""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
config.addinivalue_line(
"markers", "multilingual: marks tests that test multiple languages"
)
================================================
FILE: stanza/tests/morphseg/test_integration.py
================================================
"""
Integration tests for morphseg
"""
import pytest
from morphseg import MorphemeSegmenter
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
class TestIntegration:
def test_full_pipeline(self):
"""Test complete segmentation pipeline"""
segmenter = MorphemeSegmenter('en')
text = "According to all known laws of aviation, there is no way a bee should be able to fly."
result = segmenter.segment(text, output_string=False)
# Should segment multiple words
assert len(result) > 10
# Each word should have at least one morpheme
for word_morphemes in result:
assert len(word_morphemes) >= 1
def test_consistency_across_modes(self):
"""Test that list and string output modes are consistent"""
segmenter = MorphemeSegmenter('en')
words = ['running', 'dogs', 'aviation']
for word in words:
list_result = segmenter.segment(word, output_string=False)
string_result = segmenter.segment(word, output_string=True, delimiter=' @@')
# String result should be reconstructable from list result
expected_string = ' @@'.join(list_result[0])
assert string_result == expected_string, \
f"List and string outputs don't match for '{word}'"
def test_unicode_handling(self):
"""Test handling of unicode characters"""
segmenter = MorphemeSegmenter('fr')
text = "café résumé"
result = segmenter.segment(text, output_string=False)
assert isinstance(result, list)
assert len(result) >= 1
def test_mixed_case(self):
"""Test handling of mixed case input"""
segmenter = MorphemeSegmenter('en')
# Should normalize to lowercase
result1 = segmenter.segment('Running', output_string=False)
result2 = segmenter.segment('RUNNING', output_string=False)
result3 = segmenter.segment('running', output_string=False)
# All should produce the same result
assert result1 == result2 == result3
================================================
FILE: stanza/tests/morphseg/test_morpheme_segmenter.py
================================================
"""
Tests for MorphemeSegmenter class
"""
import pytest
from morphseg import MorphemeSegmenter
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
class TestMorphemeSegmenter:
@pytest.fixture(scope="class")
def english_segmenter(self):
"""Load English model once for all tests"""
return MorphemeSegmenter('en')
def test_basic_segmentation(self, english_segmenter):
"""Test basic morpheme segmentation"""
result = english_segmenter.segment('running', output_string=False)
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], list)
assert len(result[0]) >= 1
def test_multiple_words(self, english_segmenter):
"""Test segmentation of multiple words"""
result = english_segmenter.segment('running quickly', output_string=False)
assert isinstance(result, list)
assert len(result) == 2
for segmentation in result:
assert isinstance(segmentation, list)
assert len(segmentation) >= 1
def test_known_segmentations(self, english_segmenter):
"""Test known morpheme segmentations"""
test_cases = {
'dogs': ['dog', 's'],
'aviation': ['aviate', 'ion'],
'known': ['know', 'n'],
}
for word, expected in test_cases.items():
result = english_segmenter.segment(word, output_string=False)
assert result[0] == expected, f"Expected {expected}, got {result[0]} for '{word}'"
def test_output_string_mode(self, english_segmenter):
"""Test string output mode"""
result = english_segmenter.segment('running quickly', output_string=True)
assert isinstance(result, str)
assert ' @@' in result # Default delimiter
def test_custom_delimiter(self, english_segmenter):
"""Test custom delimiter in output"""
result = english_segmenter.segment('running', output_string=True, delimiter='-')
assert isinstance(result, str)
assert '-' in result or result == 'running' # May be unsegmented
def test_empty_input(self, english_segmenter):
"""Test handling of empty input"""
result = english_segmenter.segment('', output_string=False)
assert result == []
result = english_segmenter.segment('', output_string=True)
assert result == ""
def test_single_character(self, english_segmenter):
"""Test single character input"""
result = english_segmenter.segment('a', output_string=False)
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == ['a']
def test_punctuation(self, english_segmenter):
"""Test handling of punctuation"""
result = english_segmenter.segment('Hello, world!', output_string=False)
assert isinstance(result, list)
# Should segment only words, not punctuation
assert len(result) > 0
class TestDeterminism:
"""
Tests to ensure predictions are deterministic
"""
def test_deterministic_predictions(self):
"""Test that same input produces same output consistently"""
segmenter = MorphemeSegmenter('en')
test_words = ['running', 'dogs', 'quickly', 'aviation']
for word in test_words:
results = []
for _ in range(5):
result = segmenter.segment(word, output_string=False)
results.append(result)
# All results should be identical
for i in range(1, len(results)):
assert results[i] == results[0], \
f"Non-deterministic results for '{word}': {results[0]} vs {results[i]}"
def test_deterministic_batch(self):
"""Test determinism with batch processing"""
segmenter = MorphemeSegmenter('en')
text = "The dogs are running quickly through the fields."
results = []
for _ in range(3):
result = segmenter.segment(text, output_string=False)
results.append(result)
# All results should be identical
for i in range(1, len(results)):
assert results[i] == results[0], \
f"Non-deterministic batch results: {results[0]} vs {results[i]}"
class TestMultilingual:
@pytest.mark.parametrize("lang", ['cs', 'en', 'es', 'fr', 'hu', 'it', 'la', 'ru'])
def test_language_loading(self, lang):
"""Test that all supported languages can be loaded"""
segmenter = MorphemeSegmenter(lang)
assert segmenter.lang == lang
assert segmenter.sequence_labeller is not None
@pytest.mark.parametrize("lang,word", [
('en', 'running'),
('es', 'corriendo'),
('fr', 'rapidement'),
('ru', 'бегущий'), # Russian instead of German
])
def test_multilingual_segmentation(self, lang, word):
"""Test segmentation across languages"""
if lang not in MorphemeSegmenter.PRETRAINED_MODEL_LANGS:
pytest.skip(f"Language {lang} not supported")
segmenter = MorphemeSegmenter(lang)
result = segmenter.segment(word, output_string=False)
assert isinstance(result, list)
assert len(result) >= 1
class TestErrorHandling:
def test_invalid_language(self):
"""Test handling of invalid language code"""
with pytest.warns(UserWarning):
segmenter = MorphemeSegmenter('invalid_lang')
assert segmenter.sequence_labeller is None
def test_invalid_input_type(self):
"""Test handling of invalid input types"""
segmenter = MorphemeSegmenter('en')
with pytest.raises(ValueError, match="Input sequence must be a string"):
segmenter.segment(123)
with pytest.raises(ValueError, match="Input sequence must be a string"):
segmenter.segment(['not', 'a', 'string'])
def test_invalid_output_string_type(self):
"""Test handling of invalid output_string parameter"""
segmenter = MorphemeSegmenter('en')
with pytest.raises(ValueError, match="output_string must be a boolean"):
segmenter.segment('test', output_string='yes')
def test_invalid_delimiter_type(self):
"""Test handling of invalid delimiter parameter"""
segmenter = MorphemeSegmenter('en')
with pytest.raises(ValueError, match="Delimiter must be a string"):
segmenter.segment('test', delimiter=123)
def test_model_not_trained(self):
"""Test error when using untrained model"""
segmenter = MorphemeSegmenter('en')
segmenter.sequence_labeller = None
with pytest.raises(RuntimeError, match="Model not trained"):
segmenter.segment('test')
class TestModelState:
"""
Tests to ensure model is in correct state
"""
def test_model_in_eval_mode(self):
"""Test that loaded model is in eval mode"""
segmenter = MorphemeSegmenter('en')
# Check that model is in eval mode
assert not segmenter.sequence_labeller.model.model.training, \
"Model should be in eval mode after loading"
def test_model_stays_in_eval_mode(self):
"""Test that model stays in eval mode after predictions"""
segmenter = MorphemeSegmenter('en')
# Make several predictions
for _ in range(3):
segmenter.segment('running', output_string=False)
# Model should still be in eval mode
assert not segmenter.sequence_labeller.model.model.training, \
"Model should remain in eval mode after predictions"
================================================
FILE: stanza/tests/morphseg/test_stanza_integration.py
================================================
"""
Integration tests for Stanza MorphSeg Processor
Tests the morpheme segmentation processor within the Stanza pipeline
"""
import pytest
import stanza
from stanza.models.common.doc import Document
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
class TestMorphSegProcessor:
"""Tests for the MorphSeg processor in Stanza pipeline"""
@pytest.fixture(scope="class")
def en_pipeline(self):
"""Create English pipeline with morphseg processor"""
return stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
model_dir=TEST_MODELS_DIR,
download_method=None
)
def test_processor_loads(self, en_pipeline):
"""Test that morphseg processor loads successfully"""
assert 'morphseg' in en_pipeline.processors
assert en_pipeline.processors['morphseg'] is not None
def test_basic_segmentation(self, en_pipeline):
"""Test basic morpheme segmentation through pipeline"""
doc = en_pipeline("running")
assert len(doc.sentences) == 1
assert len(doc.sentences[0].words) == 1
word = doc.sentences[0].words[0]
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
assert len(word.morphemes) >= 1
def test_known_segmentations(self, en_pipeline):
"""Test known morpheme segmentations"""
# Note: These are actual segmentations from the en2 model
# Some words may be unsegmented depending on the model
test_cases = {
'dogs': ['dog', 's'],
'aviation': ['aviate', 'ion'],
'known': ['know', 'n'],
}
for word_text, expected in test_cases.items():
doc = en_pipeline(word_text)
word = doc.sentences[0].words[0]
assert word.morphemes == expected, \
f"Expected {expected}, got {word.morphemes} for '{word_text}'"
def test_segmentation_consistency(self, en_pipeline):
"""Test that segmentation is consistent and produces valid output"""
words = ['running', 'quickly', 'walked', 'playing']
for word_text in words:
doc = en_pipeline(word_text)
word = doc.sentences[0].words[0]
# Should have morphemes attribute
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
assert len(word.morphemes) >= 1
# All morphemes should be strings
for morpheme in word.morphemes:
assert isinstance(morpheme, str)
assert len(morpheme) > 0
def test_multiple_words(self, en_pipeline):
"""Test segmentation of multiple words in a sentence"""
doc = en_pipeline("The dogs are running quickly.")
# Check that all words have morphemes attribute
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
assert len(word.morphemes) >= 1
def test_punctuation_handling(self, en_pipeline):
"""Test that punctuation is handled correctly"""
doc = en_pipeline("Hello, world!")
# All tokens should have morphemes, including punctuation
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
# Punctuation should be unsegmented
if word.text in [',', '!', '.']:
assert word.morphemes == [word.text]
def test_long_text(self, en_pipeline):
"""Test processing of longer text"""
text = "According to all known laws of aviation, there is no way a bee should be able to fly."
doc = en_pipeline(text)
# Should have multiple sentences or one long sentence
assert len(doc.sentences) >= 1
# Count words with morpheme segmentation
total_words = sum(len(sent.words) for sent in doc.sentences)
assert total_words > 10
# All words should have morphemes
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
def test_empty_input(self, en_pipeline):
"""Test handling of empty input"""
doc = en_pipeline("")
assert len(doc.sentences) == 0
def test_single_character(self, en_pipeline):
"""Test single character input"""
doc = en_pipeline("I")
assert len(doc.sentences) == 1
word = doc.sentences[0].words[0]
assert word.morphemes == ['i'] # Normalized to lowercase
def test_morphemes_attribute_persistence(self, en_pipeline):
"""Test that morphemes attribute persists through pipeline"""
doc = en_pipeline("running quickly")
# Store morphemes
morphemes_list = []
for sentence in doc.sentences:
for word in sentence.words:
morphemes_list.append(word.morphemes)
# Access again to ensure persistence
for i, sentence in enumerate(doc.sentences):
for j, word in enumerate(sentence.words):
assert hasattr(word, 'morphemes')
assert word.morphemes is not None
class TestMultilingualMorphSeg:
"""Test morpheme segmentation across different languages"""
@pytest.mark.parametrize("lang,text,expected_word", [
('en', 'running', 'running'),
('es', 'corriendo', 'corriendo'),
('fr', 'rapidement', 'rapidement'),
('cs', 'běžící', 'běžící'),
('it', 'correndo', 'correndo'),
])
def test_multilingual_support(self, lang, text, expected_word):
"""Test that different languages can be processed"""
try:
nlp = stanza.Pipeline(
lang=lang,
processors='tokenize,morphseg',
download_method=None
)
doc = nlp(text)
assert len(doc.sentences) >= 1
assert len(doc.sentences[0].words) >= 1
word = doc.sentences[0].words[0]
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
except Exception as e:
pytest.skip(f"Language {lang} not available: {e}")
class TestMorphSegWithOtherProcessors:
"""Test morphseg processor in combination with other processors"""
def test_with_mwt(self):
"""Test morphseg with MWT processor"""
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,mwt,morphseg',
download_method=None
)
doc = nlp("The dogs are running.")
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
def test_with_pos(self):
"""Test morphseg with POS tagging"""
try:
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,pos,morphseg',
download_method=None
)
doc = nlp("running quickly")
for sentence in doc.sentences:
for word in sentence.words:
# Should have both POS and morphemes
assert hasattr(word, 'morphemes')
assert hasattr(word, 'upos') or hasattr(word, 'xpos')
except Exception as e:
pytest.skip(f"POS processor not available: {e}")
def test_with_lemma(self):
"""Test morphseg with lemmatization"""
try:
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,pos,lemma,morphseg',
download_method=None
)
doc = nlp("The dogs were running quickly.")
for sentence in doc.sentences:
for word in sentence.words:
# Should have both lemma and morphemes
assert hasattr(word, 'morphemes')
assert hasattr(word, 'lemma')
except Exception as e:
pytest.skip(f"Lemma processor not available: {e}")
class TestMorphSegDeterminism:
"""Test that morphseg processor produces deterministic results"""
def test_deterministic_results(self):
"""Test that same input produces same output"""
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
text = "running dogs aviation"
results = []
for _ in range(3):
doc = nlp(text)
morphemes = [word.morphemes for sent in doc.sentences for word in sent.words]
results.append(morphemes)
# All results should be identical
for i in range(1, len(results)):
assert results[i] == results[0], \
f"Non-deterministic results: {results[0]} vs {results[i]}"
def test_batch_determinism(self):
"""Test determinism with batch processing"""
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
texts = [
"The dogs are running.",
"Aviation is amazing.",
"Known facts are helpful."
]
# Process multiple times
all_results = []
for _ in range(2):
batch_results = []
for text in texts:
doc = nlp(text)
morphemes = [word.morphemes for sent in doc.sentences for word in sent.words]
batch_results.append(morphemes)
all_results.append(batch_results)
# Results should be identical
assert all_results[0] == all_results[1]
class TestMorphSegEdgeCases:
"""Test edge cases and special inputs"""
@pytest.fixture(scope="class")
def en_pipeline(self):
"""Create English pipeline"""
return stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
def test_numbers(self, en_pipeline):
"""Test handling of numbers"""
doc = en_pipeline("123 456")
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
def test_mixed_case(self, en_pipeline):
"""Test mixed case handling"""
# Should normalize to same result
doc1 = en_pipeline("Running")
doc2 = en_pipeline("RUNNING")
doc3 = en_pipeline("running")
morphemes1 = doc1.sentences[0].words[0].morphemes
morphemes2 = doc2.sentences[0].words[0].morphemes
morphemes3 = doc3.sentences[0].words[0].morphemes
assert morphemes1 == morphemes2 == morphemes3
def test_unicode_characters(self, en_pipeline):
"""Test handling of unicode characters"""
doc = en_pipeline("café résumé")
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
def test_special_characters(self, en_pipeline):
"""Test handling of special characters"""
doc = en_pipeline("test@example.com $100 50%")
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
def test_very_long_word(self, en_pipeline):
"""Test handling of very long words"""
long_word = "antidisestablishmentarianism"
doc = en_pipeline(long_word)
word = doc.sentences[0].words[0]
assert hasattr(word, 'morphemes')
assert len(word.morphemes) >= 1
def test_repeated_words(self, en_pipeline):
"""Test handling of repeated words"""
doc = en_pipeline("running running running")
# All instances should have same segmentation
morphemes_list = [word.morphemes for word in doc.sentences[0].words]
assert morphemes_list[0] == morphemes_list[1] == morphemes_list[2]
def test_whitespace_handling(self, en_pipeline):
"""Test handling of various whitespace"""
doc = en_pipeline("word1 word2\tword3\nword4")
# Should properly segment all words despite whitespace
word_count = sum(len(sent.words) for sent in doc.sentences)
assert word_count >= 4
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
class TestMorphSegConfiguration:
"""Test different configurations of morphseg processor"""
def test_custom_model_path(self):
"""Test loading with custom model path configuration"""
# Test that the configuration accepts model_path parameter
# Using default behavior (no custom path)
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
doc = nlp("testing")
assert len(doc.sentences) > 0
assert hasattr(doc.sentences[0].words[0], 'morphemes')
def test_custom_model_path_with_file(self):
"""Test loading with an actual custom model file path"""
# This test would require a custom model file to exist
# Skip if no custom model is available
pytest.skip("Custom model path test requires a specific model file")
def test_processor_requirements(self):
"""Test that morphseg requires tokenize"""
# MorphSeg requires TOKENIZE processor
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
# Verify tokenize is present
assert 'tokenize' in nlp.processors or 'tokenize' in str(nlp.processors)
class TestMorphSegOutputFormat:
"""Test output format of morpheme segmentations"""
@pytest.fixture(scope="class")
def en_pipeline(self):
"""Create English pipeline"""
return stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
def test_morphemes_is_list(self, en_pipeline):
"""Test that morphemes attribute is always a list"""
doc = en_pipeline("The dogs are running quickly.")
for sentence in doc.sentences:
for word in sentence.words:
assert isinstance(word.morphemes, list)
def test_morphemes_are_strings(self, en_pipeline):
"""Test that all morphemes are strings"""
doc = en_pipeline("The dogs are running quickly.")
for sentence in doc.sentences:
for word in sentence.words:
for morpheme in word.morphemes:
assert isinstance(morpheme, str)
def test_morphemes_non_empty(self, en_pipeline):
"""Test that morphemes list is never empty"""
doc = en_pipeline("The dogs are running quickly.")
for sentence in doc.sentences:
for word in sentence.words:
assert len(word.morphemes) >= 1
def test_unsegmented_words(self, en_pipeline):
"""Test that unsegmented words have single morpheme"""
# Words like 'the', 'is', 'a' typically don't segment
doc = en_pipeline("The dog is a pet.")
for sentence in doc.sentences:
for word in sentence.words:
# Even if unsegmented, should have the word itself as morpheme
if len(word.morphemes) == 1:
# The single morpheme should match the normalized word
assert isinstance(word.morphemes[0], str)
class TestMorphSegRepeatedly:
"""Test repeated processing of multiple documents"""
def test_sequential_document_processing(self):
"""Test processing multiple documents one after another"""
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
texts = [
"The dogs are running.",
"Aviation is fascinating.",
"Programming requires patience."
]
for text in texts:
doc = nlp(text)
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
def test_multi_sentence_document(self):
"""Test processing a document with multiple sentences (internal batching)"""
nlp = stanza.Pipeline(
lang='en',
processors='tokenize,morphseg',
download_method=None
)
doc = nlp("The dogs are running. Aviation is fascinating. Programming requires patience.")
assert len(doc.sentences) == 3
for sentence in doc.sentences:
for word in sentence.words:
assert hasattr(word, 'morphemes')
assert isinstance(word.morphemes, list)
================================================
FILE: stanza/tests/mwt/__init__.py
================================================
================================================
FILE: stanza/tests/mwt/test_character_classifier.py
================================================
import os
import pytest
from stanza.models import mwt_expander
from stanza.models.mwt.character_classifier import CharacterClassifier
from stanza.models.mwt.data import DataLoader
from stanza.models.mwt.trainer import Trainer
from stanza.utils.conll import CoNLL
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
ENG_TRAIN = """
# text = Elena's motorcycle tour
1-2 Elena's _ _ _ _ _ _ _ _
1 Elena Elena PROPN NNP Number=Sing 4 nmod:poss 4:nmod:poss _
2 's 's PART POS _ 1 case 1:case _
3 motorcycle motorcycle NOUN NN Number=Sing 4 compound 4:compound _
4 tour tour NOUN NN Number=Sing 0 root 0:root _
# text = women's reproductive health
1-2 women's _ _ _ _ _ _ _ _
1 women woman NOUN NNS Number=Plur 4 nmod:poss 4:nmod:poss _
2 's 's PART POS _ 1 case 1:case _
3 reproductive reproductive ADJ JJ Degree=Pos 4 amod 4:amod _
4 health health NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
# text = The Chernobyl Children's Project
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
3-4 Children's _ _ _ _ _ _ _ _
3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
4 's 's PART POS _ 3 case 3:case _
5 Project Project PROPN NNP Number=Sing 0 root 0:root _
""".lstrip()
ENG_DEV = """
# text = The Chernobyl Children's Project
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
3-4 Children's _ _ _ _ _ _ _ _
3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
4 's 's PART POS _ 3 case 3:case _
5 Project Project PROPN NNP Number=Sing 0 root 0:root _
""".lstrip()
def test_train(tmp_path):
test_train = str(os.path.join(tmp_path, "en_test.train.conllu"))
with open(test_train, "w") as fout:
fout.write(ENG_TRAIN)
test_dev = str(os.path.join(tmp_path, "en_test.dev.conllu"))
with open(test_dev, "w") as fout:
fout.write(ENG_DEV)
test_output = str(os.path.join(tmp_path, "en_test.dev.pred.conllu"))
model_name = "en_test_mwt.pt"
args = [
"--data_dir", str(tmp_path),
"--train_file", test_train,
"--eval_file", test_dev,
"--gold_file", test_dev,
"--lang", "en",
"--shorthand", "en_test",
"--output_file", test_output,
"--save_dir", str(tmp_path),
"--save_name", model_name,
"--num_epoch", "10",
]
mwt_expander.main(args=args)
model = Trainer(model_file=os.path.join(tmp_path, model_name))
assert model.model is not None
assert isinstance(model.model, CharacterClassifier)
doc = CoNLL.conll2doc(input_str=ENG_DEV)
dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True)
preds = []
for i, batch in enumerate(dataloader.to_loader()):
assert i == 0 # there should only be one batch
preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab)
assert len(preds) == 1
# it is possible to make a version of the test where this happens almost every time
# for example, running for 100 epochs makes the model succeed 30 times in a row
# (never saw a failure)
# but the one time that failure happened, it would be really annoying
#assert preds[0] == "Children 's"
================================================
FILE: stanza/tests/mwt/test_english_corner_cases.py
================================================
"""
Test a couple English MWT corner cases which might be more widely applicable to other MWT languages
- unknown English character doesn't result in bizarre splits
- Casing or CASING doesn't get lost in the dictionary lookup
In the English UD datasets, the MWT are composed exactly of the
subwords, so the MWT model should be chopping up the input text rather
than generating new text.
Furthermore, SHE'S and She's should be split "SHE 'S" and "She 's" respectively
"""
import pytest
import stanza
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_mwt_unknown_char():
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
mwt_trainer = pipeline.processors['mwt']._trainer
assert mwt_trainer.args['force_exact_pieces']
# find a letter 'i' which isn't in the training data
# the MWT model should still recognize a possessive containing this letter
assert "i" in mwt_trainer.vocab
for letter in "ĩîíìī":
if letter not in mwt_trainer.vocab:
break
else:
raise AssertionError("Need to update the MWT test - all of the non-standard letters 'i' are now in the MWT vocab")
word = "Jenn" + letter + "fer"
possessive = word + "'s"
text = "I wanna lick " + possessive + " antennae"
doc = pipeline(text)
assert doc.sentences[0].tokens[1].text == 'wanna'
assert len(doc.sentences[0].tokens[1].words) == 2
assert "".join(x.text for x in doc.sentences[0].tokens[1].words) == 'wanna'
assert doc.sentences[0].tokens[3].text == possessive
assert len(doc.sentences[0].tokens[3].words) == 2
assert "".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive
def test_english_mwt_casing():
"""
Test that for a word where the lowercase split is known, the correct casing is still used
Once upon a time, the logic used in the MWT expander would split
SHE'S -> she 's
which is a very surprising tokenization to people expecting
the original text in the output document
"""
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
mwt_trainer = pipeline.processors['mwt']._trainer
for i in range(1, 20):
# many test cases follow this pattern for some reason,
# so we should proactively look for a test case which hasn't
# made its way into the MWT dictionary
unknown_name = "jennife" + "r" * i + "'s"
if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict:
unknown_name = unknown_name.upper()
break
else:
raise AssertionError("Need a new heuristic for the unknown word in the English MWT!")
# this SHOULD show up in the expansion dict
assert "she's" in mwt_trainer.expansion_dict, "Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case"
text = [x.text for x in pipeline("JENNIFER HAS NICE ANTENNAE").sentences[0].words]
assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE']
text = [x.text for x in pipeline(unknown_name + " GOT NICE ANTENNAE").sentences[0].words]
assert text == [unknown_name[:-2], "'S", 'GOT', 'NICE', 'ANTENNAE']
text = [x.text for x in pipeline("SHE'S GOT NICE ANTENNAE").sentences[0].words]
assert text == ['SHE', "'S", 'GOT', 'NICE', 'ANTENNAE']
text = [x.text for x in pipeline("She's GOT NICE ANTENNAE").sentences[0].words]
assert text == ['She', "'s", 'GOT', 'NICE', 'ANTENNAE']
================================================
FILE: stanza/tests/mwt/test_prepare_mwt.py
================================================
import pytest
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
from stanza.utils.datasets.prepare_mwt_treebank import check_mwt_composition
SAMPLE_GOOD_TEXT = """
# sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057
# text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region.
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
3-4 Children's _ _ _ _ _ _ _ _
3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
4 's 's PART POS _ 3 case 3:case _
5 Project Project PROPN NNP Number=Sing 9 nsubj 9:nsubj _
6 ( ( PUNCT -LRB- _ 7 punct 7:punct SpaceAfter=No
7 http://www.adiccp.org/home/default.asp http://www.adiccp.org/home/default.asp X ADD _ 5 appos 5:appos SpaceAfter=No
8 ) ) PUNCT -RRB- _ 7 punct 7:punct _
9 offers offer VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
10 several several ADJ JJ Degree=Pos 11 amod 11:amod _
11 ways way NOUN NNS Number=Plur 9 obj 9:obj _
12 to to PART TO _ 13 mark 13:mark _
13 help help VERB VB VerbForm=Inf 11 acl 11:acl:to _
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
15 children child NOUN NNS Number=Plur 13 obj 13:obj _
16 of of ADP IN _ 18 case 18:case _
17 that that DET DT Number=Sing|PronType=Dem 18 det 18:det _
18 region region NOUN NN Number=Sing 15 nmod 15:nmod:of SpaceAfter=No
19 . . PUNCT . _ 9 punct 9:punct _
""".lstrip()
SAMPLE_BAD_TEXT = """
# sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057
# text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region.
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
3-4 Children's _ _ _ _ _ _ _ _
3 Childrez Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
4 's 's PART POS _ 3 case 3:case _
5 Project Project PROPN NNP Number=Sing 9 nsubj 9:nsubj _
6 ( ( PUNCT -LRB- _ 7 punct 7:punct SpaceAfter=No
7 http://www.adiccp.org/home/default.asp http://www.adiccp.org/home/default.asp X ADD _ 5 appos 5:appos SpaceAfter=No
8 ) ) PUNCT -RRB- _ 7 punct 7:punct _
9 offers offer VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
10 several several ADJ JJ Degree=Pos 11 amod 11:amod _
11 ways way NOUN NNS Number=Plur 9 obj 9:obj _
12 to to PART TO _ 13 mark 13:mark _
13 help help VERB VB VerbForm=Inf 11 acl 11:acl:to _
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
15 children child NOUN NNS Number=Plur 13 obj 13:obj _
16 of of ADP IN _ 18 case 18:case _
17 that that DET DT Number=Sing|PronType=Dem 18 det 18:det _
18 region region NOUN NN Number=Sing 15 nmod 15:nmod:of SpaceAfter=No
19 . . PUNCT . _ 9 punct 9:punct _
""".lstrip()
def test_check_mwt_composition(tmp_path):
mwt_file = tmp_path / "good.mwt"
with open(mwt_file, "w", encoding="utf-8") as fout:
fout.write(SAMPLE_GOOD_TEXT)
check_mwt_composition(mwt_file)
mwt_file = tmp_path / "bad.mwt"
with open(mwt_file, "w", encoding="utf-8") as fout:
fout.write(SAMPLE_BAD_TEXT)
with pytest.raises(ValueError):
check_mwt_composition(mwt_file)
================================================
FILE: stanza/tests/mwt/test_utils.py
================================================
"""
Test the MWT resplitting of preexisting tokens without word splits
"""
import pytest
import stanza
from stanza.models.mwt.utils import resplit_mwt
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture(scope="module")
def pipeline():
"""
A reusable pipeline with the NER module
"""
return stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,mwt", package="gum")
def test_resplit_keep_tokens(pipeline):
"""
Test splitting with enforced token boundaries
"""
tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
doc = resplit_mwt(tokens, pipeline)
assert len(doc.sentences) == 2
assert len(doc.sentences[0].tokens) == 4
assert len(doc.sentences[0].tokens[1].words) == 2
assert doc.sentences[0].tokens[1].words[0].text == "ca"
assert doc.sentences[0].tokens[1].words[1].text == "n't"
assert len(doc.sentences[1].tokens) == 2
# updated GUM MWT splits "I can't" into three segments
# the way we want, "I - ca - n't"
# previously it would split "I - can - 't"
assert len(doc.sentences[1].tokens[0].words) == 3
assert doc.sentences[1].tokens[0].words[0].text == "I"
assert doc.sentences[1].tokens[0].words[1].text == "ca"
assert doc.sentences[1].tokens[0].words[2].text == "n't"
def test_resplit_no_keep_tokens(pipeline):
"""
Test splitting without enforced token boundaries
"""
tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
doc = resplit_mwt(tokens, pipeline, keep_tokens=False)
assert len(doc.sentences) == 2
assert len(doc.sentences[0].tokens) == 4
assert len(doc.sentences[0].tokens[1].words) == 2
assert doc.sentences[0].tokens[1].words[0].text == "ca"
assert doc.sentences[0].tokens[1].words[1].text == "n't"
assert len(doc.sentences[1].tokens) == 3
assert len(doc.sentences[1].tokens[1].words) == 2
assert doc.sentences[1].tokens[1].words[0].text == "ca"
assert doc.sentences[1].tokens[1].words[1].text == "n't"
================================================
FILE: stanza/tests/ner/__init__.py
================================================
================================================
FILE: stanza/tests/ner/test_bsf_2_beios.py
================================================
"""
Tests the conversion code for the lang_uk NER dataset
"""
import unittest
from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
import pytest
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
class TestBsf2Beios(unittest.TestCase):
def test_empty_markup(self):
res = convert_bsf('', '')
self.assertEqual('', res)
def test_1line_markup(self):
data = 'тележурналіст Василь'
bsf_markup = 'T1 PERS 14 20 Василь'
expected = '''тележурналіст O
Василь S-PERS'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
def test_1line_follow_markup(self):
data = 'тележурналіст Василь .'
bsf_markup = 'T1 PERS 14 20 Василь'
expected = '''тележурналіст O
Василь S-PERS
. O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
def test_1line_2tok_markup(self):
data = 'тележурналіст Василь Нагірний .'
bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
expected = '''тележурналіст O
Василь B-PERS
Нагірний E-PERS
. O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
def test_1line_Long_tok_markup(self):
data = 'А в музеї Гуцульщини і Покуття можна '
bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
expected = '''А O
в O
музеї B-ORG
Гуцульщини I-ORG
і I-ORG
Покуття E-ORG
можна O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
def test_2line_2tok_markup(self):
data = '''тележурналіст Василь Нагірний .
В івано-франківському видавництві «Лілея НВ» вийшла друком'''
bsf_markup = '''T1 PERS 14 29 Василь Нагірний
T2 ORG 67 75 Лілея НВ'''
expected = '''тележурналіст O
Василь B-PERS
Нагірний E-PERS
. O
В O
івано-франківському O
видавництві O
« O
Лілея B-ORG
НВ E-ORG
» O
вийшла O
друком O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
def test_real_markup(self):
data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами .
Про це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса .
Він зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян .
Абонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку .
Однак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах .
Відтак купити сімку в невеликих населених пунктах буде неможливо .
Крім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат .
- Близько 90 % українських абонентів - це абоненти передоплати .
Якщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого .
Мобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан .
'''
bsf_markup = '''T1 LOC 26 33 Україні
T2 ORG 203 218 Держспецзв'язку
T3 PERS 219 232 Віталій Кукса
T4 PERS 449 462 Віталія Кукси
T5 ORG 1201 1219 Економічній правді
T6 ORG 1267 1278 МТС-Україна
T7 PERS 1281 1295 Вікторія Рубан
'''
expected = '''Через O
напіввоєнний O
стан O
в O
Україні S-LOC
та O
збільшення O
телефонних O
терористичних O
погроз O
українці O
купуватимуть O
sim-карти O
тільки O
за O
паспортами O
. O
Про O
це O
повідомив O
начальник O
управління O
зв'язків O
зі O
ЗМІ O
адміністрації O
Держспецзв'язку S-ORG
Віталій B-PERS
Кукса E-PERS
. O
Він O
зауважив O
, O
що O
днями O
відомство O
опублікує O
проект O
змін O
до O
правил O
надання O
телекомунікаційних O
послуг O
, O
де O
будуть O
прописані O
норми O
ідентифікації O
громадян O
. O
Абонентів O
, O
які O
на O
сьогодні O
вже O
мають O
sim-карту O
, O
за O
словами O
Віталія B-PERS
Кукси E-PERS
, O
реєструватимуть O
, O
коли O
ті O
звертатимуться O
в O
службу O
підтримки O
свого O
оператора O
мобільного O
зв'язку O
. O
Однак O
мобільні O
оператори O
побоюються O
, O
що O
таке O
нововведення O
помітно O
зменшить O
продаж O
стартових O
пакетів O
, O
адже O
спеціалізовані O
магазини O
є O
лише O
у O
містах O
. O
Відтак O
купити O
сімку O
в O
невеликих O
населених O
пунктах O
буде O
неможливо O
. O
Крім O
того O
, O
нова O
процедура O
ідентифікації O
абонентів O
вимагатиме O
від O
операторів O
мобільного O
зв'язку O
додаткових O
витрат O
. O
- O
Близько O
90 O
% O
українських O
абонентів O
- O
це O
абоненти O
передоплати O
. O
Якщо O
мова O
буде O
йти O
навіть O
про O
поетапну O
їх O
ідентифікацію O
, O
зробити O
це O
буде O
складно O
, O
довго O
і O
дорого O
. O
Мобільним O
операторам O
доведеться O
йти O
на O
чималі O
витрати O
, O
пов'язані O
з O
укладанням O
і O
зберіганням O
договорів O
, O
веденням O
баз O
даних O
, O
- O
розповіла O
« O
Економічній B-ORG
правді E-ORG
» O
начальник O
відділу O
зв'язків O
з O
громадськістю O
« O
МТС-Україна S-ORG
» O
Вікторія B-PERS
Рубан E-PERS
. O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup))
class TestBsf(unittest.TestCase):
def test_empty_bsf(self):
self.assertEqual(parse_bsf(''), [])
def test_empty2_bsf(self):
self.assertEqual(parse_bsf(' \n \n'), [])
def test_1line_bsf(self):
bsf = 'T1 PERS 103 118 Василь Нагірний'
res = parse_bsf(bsf)
expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний')
self.assertEqual(len(res), 1)
self.assertEqual(res, [expected])
def test_2line_bsf(self):
bsf = '''T9 PERS 778 783 Карла
T10 MISC 814 819 міста'''
res = parse_bsf(bsf)
expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'),
BsfInfo('T10', 'MISC', 814, 819, 'міста')]
self.assertEqual(len(res), 2)
self.assertEqual(res, expected)
def test_multiline_bsf(self):
bsf = '''T3 PERS 220 235 Андрієм Кіщуком
T4 MISC 251 285 А .
Kubler .
Світло і тіні маестро
T5 PERS 363 369 Кіблер'''
res = parse_bsf(bsf)
expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'),
BsfInfo('T4', 'MISC', 251, 285, '''А .
Kubler .
Світло і тіні маестро'''),
BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')]
self.assertEqual(len(res), len(expected))
self.assertEqual(res, expected)
if __name__ == '__main__':
unittest.main()
================================================
FILE: stanza/tests/ner/test_bsf_2_iob.py
================================================
"""
Tests the conversion code for the lang_uk NER dataset
"""
import unittest
from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
import pytest
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
class TestBsf2Iob(unittest.TestCase):
def test_1line_follow_markup_iob(self):
data = 'тележурналіст Василь .'
bsf_markup = 'T1 PERS 14 20 Василь'
expected = '''тележурналіст O
Василь B-PERS
. O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
def test_1line_2tok_markup_iob(self):
data = 'тележурналіст Василь Нагірний .'
bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
expected = '''тележурналіст O
Василь B-PERS
Нагірний I-PERS
. O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
def test_1line_Long_tok_markup_iob(self):
data = 'А в музеї Гуцульщини і Покуття можна '
bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
expected = '''А O
в O
музеї B-ORG
Гуцульщини I-ORG
і I-ORG
Покуття I-ORG
можна O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
def test_2line_2tok_markup_iob(self):
data = '''тележурналіст Василь Нагірний .
В івано-франківському видавництві «Лілея НВ» вийшла друком'''
bsf_markup = '''T1 PERS 14 29 Василь Нагірний
T2 ORG 67 75 Лілея НВ'''
expected = '''тележурналіст O
Василь B-PERS
Нагірний I-PERS
. O
В O
івано-франківському O
видавництві O
« O
Лілея B-ORG
НВ I-ORG
» O
вийшла O
друком O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
def test_all_multiline_iob(self):
data = '''його книжечка «А .
Kubler .
Світло і тіні маестро» .
Причому'''
bsf_markup = '''T4 MISC 15 49 А .
Kubler .
Світло і тіні маестро
'''
expected = '''його O
книжечка O
« O
А B-MISC
. I-MISC
Kubler I-MISC
. I-MISC
Світло I-MISC
і I-MISC
тіні I-MISC
маестро I-MISC
» O
. O
Причому O'''
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
if __name__ == '__main__':
unittest.main()
================================================
FILE: stanza/tests/ner/test_combine_ner_datasets.py
================================================
import json
import os
import pytest
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
from stanza.models.common.doc import Document
from stanza.tests.ner.test_ner_training import write_temp_file, EN_TRAIN_BIO, EN_DEV_BIO
from stanza.utils.datasets.ner import combine_ner_datasets
def test_combine(tmp_path):
"""
Test that if we write two short datasets and combine them, we get back
one slightly longer dataset
To simplify matters, we just use the same input text with longer
amounts of text for each shard.
"""
SHARDS = ("train", "dev", "test")
for s_num, shard in enumerate(SHARDS):
t1_json = tmp_path / ("en_t1.%s.json" % shard)
# eg, 1x, 2x, 3x the test data from test_ner_training
write_temp_file(t1_json, "\n\n".join([EN_TRAIN_BIO] * (s_num + 1)))
t2_json = tmp_path / ("en_t2.%s.json" % shard)
write_temp_file(t2_json, "\n\n".join([EN_DEV_BIO] * (s_num + 1)))
args = ["--output_dataset", "en_c", "en_t1", "en_t2", "--input_dir", str(tmp_path), "--output_dir", str(tmp_path)]
combine_ner_datasets.main(args)
for s_num, shard in enumerate(SHARDS):
filename = tmp_path / ("en_c.%s.json" % shard)
assert os.path.exists(filename)
with open(filename, encoding="utf-8") as fin:
doc = Document(json.load(fin))
assert len(doc.sentences) == (s_num + 1) * 3
================================================
FILE: stanza/tests/ner/test_convert_amt.py
================================================
"""
Test some of the functions used for converting an AMT json to a Stanza json
"""
import os
import pytest
import stanza
from stanza.utils.datasets.ner import convert_amt
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
TEXT = "Jennifer Sh'reyan has lovely antennae."
def fake_label(label, start_char, end_char):
return {'label': label,
'startOffset': start_char,
'endOffset': end_char}
LABELS = [
fake_label('Person', 0, 8),
fake_label('Person', 9, 17),
fake_label('Person', 0, 17),
fake_label('Andorian', 0, 8),
fake_label('Appendage', 29, 37),
fake_label('Person', 1, 8),
fake_label('Person', 0, 7),
fake_label('Person', 0, 9),
fake_label('Appendage', 29, 38),
]
def fake_labels(*indices):
return [LABELS[x] for x in indices]
def fake_docs(*indices):
return [(TEXT, fake_labels(*indices))]
def test_remove_nesting():
"""
Test a few orders on nested items to make sure the desired results are coming back
"""
# this should be unchanged
result = convert_amt.remove_nesting(fake_docs(0, 1))
assert result == fake_docs(0, 1)
# this should be returned sorted
result = convert_amt.remove_nesting(fake_docs(0, 4, 1))
assert result == fake_docs(0, 1, 4)
# this should just have one copy
result = convert_amt.remove_nesting(fake_docs(0, 0))
assert result == fake_docs(0)
# outer one preferred
result = convert_amt.remove_nesting(fake_docs(0, 2))
assert result == fake_docs(2)
result = convert_amt.remove_nesting(fake_docs(1, 2))
assert result == fake_docs(2)
result = convert_amt.remove_nesting(fake_docs(5, 2))
assert result == fake_docs(2)
# order doesn't matter
result = convert_amt.remove_nesting(fake_docs(0, 4, 2))
assert result == fake_docs(2, 4)
result = convert_amt.remove_nesting(fake_docs(2, 4, 0))
assert result == fake_docs(2, 4)
# first one preferred
result = convert_amt.remove_nesting(fake_docs(0, 3))
assert result == fake_docs(0)
result = convert_amt.remove_nesting(fake_docs(3, 0))
assert result == fake_docs(3)
def test_process_doc():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
def check_results(doc, *expected):
ner = [x[1] for x in doc[0]]
assert ner == list(expected)
# test a standard case of all the values lining up
doc = convert_amt.process_doc(TEXT, fake_labels(2, 4), nlp)
check_results(doc, "B-Person", "I-Person", "O", "O", "B-Appendage", "O")
# test a slightly wrong start index
doc = convert_amt.process_doc(TEXT, fake_labels(5, 1, 4), nlp)
check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
# test a slightly wrong end index
doc = convert_amt.process_doc(TEXT, fake_labels(6, 1, 4), nlp)
check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
# test a slightly wronger end index
doc = convert_amt.process_doc(TEXT, fake_labels(7, 4), nlp)
check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
# test a period at the end of a text - should not be captured
doc = convert_amt.process_doc(TEXT, fake_labels(7, 8), nlp)
check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
================================================
FILE: stanza/tests/ner/test_convert_nkjp.py
================================================
import pytest
import io
import os
import xml.etree.ElementTree as ET
from stanza.utils.datasets.ner.convert_nkjp import MORPH_FILE, NER_FILE, extract_entities_from_subfolder, extract_entities_from_sentence, extract_unassigned_subfolder_entities
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
EXPECTED_ENTITIES = {
'1-p': {
'1.39-s': [{'ent_id': 'named_1.39-s_n1', 'index': 0, 'orth': 'Sił Zbrojnych', 'ner_type': 'orgName', 'ner_subtype': None, 'targets': ['1.37-seg', '1.38-seg']}],
'1.56-s': [],
'1.79-s': []
},
'2-p': {
'2.30-s': [],
'2.45-s': []
},
'3-p': {
'3.70-s': []
}
}
@pytest.fixture(scope="module")
def dataset(tmp_path_factory):
dataset_path = tmp_path_factory.mktemp("nkjp_dataset")
sample_path = dataset_path / "sample"
os.mkdir(sample_path)
ann_path = sample_path / NER_FILE
with open(ann_path, "w", encoding="utf-8") as fout:
fout.write(SAMPLE_ANN)
morph_path = sample_path / MORPH_FILE
with open(morph_path, "w", encoding="utf-8") as fout:
fout.write(SAMPLE_MORPHO)
return dataset_path
EXPECTED_TOKENS = [
{'seg_id': '1.1-seg', 'i': 0, 'orth': '2', 'text': '2', 'tag': '_', 'ner': 'O', 'ner_subtype': None},
{'seg_id': '1.37-seg', 'i': 36, 'orth': 'Sił', 'text': 'Sił', 'tag': '_', 'ner': 'B-orgName', 'ner_subtype': None},
{'seg_id': '1.38-seg', 'i': 37, 'orth': 'Zbrojnych', 'text': 'Zbrojnych', 'tag': '_', 'ner': 'I-orgName', 'ner_subtype': None},
]
def test_extract_entities_from_subfolder(dataset):
entities = extract_entities_from_subfolder("sample", dataset)
assert len(entities) == 1
assert len(entities['1-p']) == 1
assert len(entities['1-p']['1.39-s']) == 39
assert entities['1-p']['1.39-s']['1.1-seg'] == EXPECTED_TOKENS[0]
assert entities['1-p']['1.39-s']['1.37-seg'] == EXPECTED_TOKENS[1]
assert entities['1-p']['1.39-s']['1.38-seg'] == EXPECTED_TOKENS[2]
def test_extract_unassigned(dataset):
entities = extract_unassigned_subfolder_entities("sample", dataset)
assert entities == EXPECTED_ENTITIES
SENTENCE_SAMPLE = """
Sił Zbrojnych
Siły Zbrojne
""".strip()
EMPTY_SENTENCE = """ """
def test_extract_entities_from_sentence():
rt = ET.fromstring(SENTENCE_SAMPLE)
entities = extract_entities_from_sentence(rt)
assert entities == EXPECTED_ENTITIES['1-p']['1.39-s']
rt = ET.fromstring(EMPTY_SENTENCE)
entities = extract_entities_from_sentence(rt)
assert entities == []
# picked completely at random, one sample file for testing:
# 610-1-000248/ann_named.xml
# only the first sentence is used in the morpho file
SAMPLE_ANN = """
Sił Zbrojnych
Siły Zbrojne
""".lstrip()
SAMPLE_MORPHO = """
2
2
2:adj:sg:nom:n:pos
.
.
.:interp
Wezwanie
wezwanie
wezwać
wezwanie:subst:sg:acc:n
,
,
,:interp
o
o
o
ojciec
o:prep:loc
którym
który
który:adj:sg:loc:n:pos
mowa
mowa
mowa:subst:sg:nom:f
w
w
wiek
wielki
wiersz
wieś
wyspa
w:prep:loc:nwok
ust
usta
ustęp
ustęp:brev:pun
.
.
.:interp
1
1
1:adj:sg:loc:m3:pos
,
,
,:interp
doręcza
doręczać
doręcze
doręczać:fin:sg:ter:imperf
się
się
się:qub
na
na
na
na:prep:acc
czternaście
czternaście
czternaście:num:pl:acc:m3:rec
dni
dni
dzień
dzień:subst:pl:gen:m3
przed
przed
przed:prep:inst:nwok
terminem
termin
termin:subst:sg:inst:m3
wykonania
wykonanie
wykonać
wykonać:ger:sg:gen:n:perf:aff
świadczenia
świadczenie
świadczyć
świadczenie:subst:sg:gen:n
,
,
,:interp
z
z
z
zeszyt
z:prep:inst:nwok
wyjątkiem
wyjątek
wyjątek:subst:sg:inst:m3
przypadków
przypadek
przypadek:subst:pl:gen:m3
,
,
,:interp
w
w
wiek
wielki
wiersz
wieś
wyspa
w:prep:loc:nwok
których
który
który:adj:pl:loc:m3:pos
wykonanie
wykonanie
wykonać
wykonać:ger:sg:nom:n:perf:aff
świadczenia
świadczenie
świadczyć
świadczenie:subst:sg:gen:n
następuje
następować
następować:fin:sg:ter:imperf
w
w
wiek
wielki
wiersz
wieś
wyspa
w:prep:loc:nwok
celu
Cela
cel
cel:subst:sg:loc:m3
sprawdzenia
sprawdzić
sprawdzić:ger:sg:gen:n:perf:aff
gotowości
gotowość
gotowość:subst:sg:gen:f
mobilizacyjnej
mobilizacyjny
mobilizacyjny:adj:sg:gen:f:pos
Sił
siła
siły
siła:subst:pl:gen:f
Zbrojnych
zbrojny
zbrojny:adj:pl:gen:f:pos
.
.
.:interp
""".lstrip()
================================================
FILE: stanza/tests/ner/test_convert_starlang_ner.py
================================================
"""
Test a couple different classes of trees to check the output of the Starlang conversion for NER
"""
import os
import tempfile
import pytest
from stanza.utils.datasets.ner import convert_starlang_ner
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
def test_read_tree():
"""
Test a basic tree read
"""
sentence = convert_starlang_ner.read_tree(TREE)
expected = [('Bayan', 'PERSON'), ('Haag', 'PERSON'), ('Elianti', 'O'), ('çalar', 'O'), ('.', 'O')]
assert sentence == expected
================================================
FILE: stanza/tests/ner/test_data.py
================================================
import json
import pytest
from stanza.models import ner_tagger
from stanza.models.common.doc import Document
from stanza.models.ner.data import DataLoader
from stanza.tests import TEST_WORKING_DIR
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
ONE_SENTENCE = """
[
[
{
"text": "EU",
"ner": "B-ORG"
},
{
"text": "rejects",
"ner": "O"
},
{
"text": "German",
"ner": "B-MISC"
},
{
"text": "call",
"ner": "O"
},
{
"text": "to",
"ner": "O"
},
{
"text": "boycott",
"ner": "O"
},
{
"text": "Mox",
"ner": "B-MISC"
},
{
"text": "Opal",
"ner": "I-MISC"
},
{
"text": ".",
"ner": "O"
}
]
]
"""
@pytest.fixture(scope="module")
def pretrain_file():
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
@pytest.fixture(scope="module")
def one_sentence_json_path(tmpdir_factory):
filename = tmpdir_factory.mktemp('data').join("sentence.json")
with open(filename, 'w') as fout:
fout.write(ONE_SENTENCE)
return filename
def test_build_vocab(pretrain_file, one_sentence_json_path, tmp_path):
"""
Test that when loading a data file, we get back
"""
args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file])
pt = ner_tagger.load_pretrain(args)
with open(one_sentence_json_path) as fin:
train_doc = Document(json.load(fin))
train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])
vocab = train_batch.vocab
pt_words = list(vocab['word'])
assert pt_words == ['', '', '', '', 'unban', 'mox', 'opal']
delta_words = list(vocab['delta'])
assert delta_words == ['', '', '', '', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'mox', 'opal', '.']
tags = list(vocab['tag'])
assert tags == [[''], [''], [], [''], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]
def test_build_vocab_ignore_repeats(pretrain_file, one_sentence_json_path, tmp_path):
"""
Test that when loading a data file, we get back
"""
args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file, "--emb_finetune_known_only"])
pt = ner_tagger.load_pretrain(args)
with open(one_sentence_json_path) as fin:
train_doc = Document(json.load(fin))
train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])
vocab = train_batch.vocab
pt_words = list(vocab['word'])
assert pt_words == ['', '', '', '', 'unban', 'mox', 'opal']
delta_words = list(vocab['delta'])
assert delta_words == ['', '', '', '', 'mox', 'opal']
tags = list(vocab['tag'])
assert tags == [[''], [''], [], [''], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]
================================================
FILE: stanza/tests/ner/test_from_conllu.py
================================================
import pytest
from stanza import Pipeline
from stanza.utils.conll import CoNLL
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
def test_from_conllu():
"""
If the doc does not have the entire text available, make sure it still safely processes the text
Test case supplied from user - see issue #1428
"""
pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", download_method=None)
doc = pipe("In February, I traveled to Seattle. Dr. Pritchett gave me a new hip")
ents = [x.text for x in doc.ents]
# the default NER model ought to find these three
assert ents == ['February', 'Seattle', 'Pritchett']
doc_conllu = "{:C}\n\n".format(doc)
doc = CoNLL.conll2doc(input_str=doc_conllu)
pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", tokenize_pretokenized=True, download_method=None)
pipe(doc)
ents = [x.text for x in doc.ents]
# this should still work when processed from a CoNLLu document
# the bug previously caused a crash because the text to construct
# the entities was not available, since the Document wouldn't have
# the entire document text available
assert ents == ['February', 'Seattle', 'Pritchett']
================================================
FILE: stanza/tests/ner/test_models_ner_scorer.py
================================================
"""
Simple test of the scorer module for NER
"""
import pytest
import stanza
from stanza.tests import *
from stanza.models.ner.scorer import score_by_token, score_by_entity
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_ner_scorer():
pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],
['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]
gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],
['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]
token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences)
assert pytest.approx(token_p, abs=0.00001) == 0.625
assert pytest.approx(token_r, abs=0.00001) == 0.5
assert pytest.approx(token_f, abs=0.00001) == 0.55555
entity_p, entity_r, entity_f, entity_f1 = score_by_entity(pred_sequences, gold_sequences)
assert pytest.approx(entity_p, abs=0.00001) == 0.4
assert pytest.approx(entity_r, abs=0.00001) == 0.33333
assert pytest.approx(entity_f, abs=0.00001) == 0.36363
assert entity_f1 == {'LOC': 0.0, 'MISC': 1.0, 'ORG': 0.0, 'PER': 0.5}
================================================
FILE: stanza/tests/ner/test_ner_tagger.py
================================================
"""
Basic testing of the NER tagger.
"""
import os
import pytest
import stanza
from stanza.tests import *
from stanza.models import ner_tagger
from stanza.utils.confusion import confusion_to_macro_f1
import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
from stanza.utils.training.run_ner import build_pretrain_args
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
EN_DOC = "Chris Manning is a good man. He works in Stanford University."
EN_DOC_GOLD = """
""".strip()
EN_BIO = """
Chris B-PERSON
Manning E-PERSON
is O
a O
good O
man O
. O
He O
works O
in O
Stanford B-ORG
University E-ORG
. O
""".strip().replace(" ", "\t")
EN_EXPECTED_OUTPUT = """
Chris B-PERSON B-PERSON
Manning E-PERSON E-PERSON
is O O
a O O
good O O
man O O
. O O
He O O
works O O
in O O
Stanford B-ORG B-ORG
University E-ORG E-ORG
. O O
""".strip().replace(" ", "\t")
def test_ner():
nlp = stanza.Pipeline(**{'processors': 'tokenize,ner', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'logging_level': 'error'})
doc = nlp(EN_DOC)
assert EN_DOC_GOLD == '\n'.join([ent.pretty_print() for ent in doc.ents])
def test_evaluate(tmp_path):
"""
This simple example should have a 1.0 f1 for the ontonote model
"""
package = "ontonotes-ww-multi_charlm"
model_path = os.path.join(TEST_MODELS_DIR, "en", "ner", package + ".pt")
assert os.path.exists(model_path), "The {} model should be downloaded as part of setup.py".format(package)
os.makedirs(tmp_path, exist_ok=True)
test_bio_filename = tmp_path / "test.bio"
test_json_filename = tmp_path / "test.json"
test_output_filename = tmp_path / "output.bio"
with open(test_bio_filename, "w", encoding="utf-8") as fout:
fout.write(EN_BIO)
prepare_ner_file.process_dataset(test_bio_filename, test_json_filename)
args = ["--save_name", str(model_path),
"--eval_file", str(test_json_filename),
"--eval_output_file", str(test_output_filename),
"--mode", "predict"]
args = args + build_pretrain_args("en", package, model_dir=TEST_MODELS_DIR, extra_args=[])
args = ner_tagger.parse_args(args=args)
confusion = ner_tagger.evaluate(args)
assert confusion_to_macro_f1(confusion) == pytest.approx(1.0)
with open(test_output_filename, encoding="utf-8") as fin:
results = fin.read().strip()
assert results == EN_EXPECTED_OUTPUT
================================================
FILE: stanza/tests/ner/test_ner_trainer.py
================================================
import pytest
from stanza.tests import *
from stanza.models.ner import trainer
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_fix_singleton_tags():
TESTS = [
(["O"], ["O"]),
(["B-PER"], ["S-PER"]),
(["B-PER", "I-PER"], ["B-PER", "E-PER"]),
(["B-PER", "O", "B-PER"], ["S-PER", "O", "S-PER"]),
(["B-PER", "B-PER", "I-PER"], ["S-PER", "B-PER", "E-PER"]),
(["B-PER", "I-PER", "O", "B-PER"], ["B-PER", "E-PER", "O", "S-PER"]),
(["B-PER", "B-PER", "I-PER", "B-PER"], ["S-PER", "B-PER", "E-PER", "S-PER"]),
(["B-PER", "I-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
(["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
(["S-PER", "B-PER", "E-PER"], ["S-PER", "B-PER", "E-PER"]),
(["E-PER"], ["S-PER"]),
(["E-PER", "O", "E-PER"], ["S-PER", "O", "S-PER"]),
(["B-PER", "E-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
(["I-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
(["B-PER", "I-PER", "I-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
(["B-PER", "I-PER", "E-PER", "O", "I-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
(["B-PER", "I-PER", "E-PER", "O", "B-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
(["I-PER", "I-PER", "I-PER", "O", "I-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
]
for unfixed, expected in TESTS:
assert trainer.fix_singleton_tags(unfixed) == expected, "Error converting {} to {}".format(unfixed, expected)
================================================
FILE: stanza/tests/ner/test_ner_training.py
================================================
import json
import logging
import os
import warnings
import pytest
import torch
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
from stanza.models import ner_tagger
from stanza.models.ner.trainer import Trainer
from stanza.tests import TEST_WORKING_DIR
from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
logger = logging.getLogger('stanza')
EN_TRAIN_BIO = """
Chris B-PERSON
Manning E-PERSON
is O
a O
good O
man O
. O
He O
works O
in O
Stanford B-ORG
University E-ORG
. O
""".lstrip().replace(" ", "\t")
EN_DEV_BIO = """
Chris B-PERSON
Manning E-PERSON
is O
part O
of O
Computer B-ORG
Science E-ORG
""".lstrip().replace(" ", "\t")
EN_TRAIN_2TAG = """
Chris B-PERSON B-PER
Manning E-PERSON E-PER
is O O
a O O
good O O
man O O
. O O
He O O
works O O
in O O
Stanford B-ORG B-ORG
University E-ORG B-ORG
. O O
""".strip().replace(" ", "\t")
EN_TRAIN_2TAG_EMPTY2 = """
Chris B-PERSON -
Manning E-PERSON -
is O -
a O -
good O -
man O -
. O -
He O -
works O -
in O -
Stanford B-ORG -
University E-ORG -
. O -
""".strip().replace(" ", "\t")
EN_DEV_2TAG = """
Chris B-PERSON B-PER
Manning E-PERSON E-PER
is O O
part O O
of O O
Computer B-ORG B-ORG
Science E-ORG E-ORG
""".strip().replace(" ", "\t")
@pytest.fixture(scope="module")
def pretrain_file():
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
def write_temp_file(filename, bio_data):
bio_filename = os.path.splitext(filename)[0] + ".bio"
with open(bio_filename, "w", encoding="utf-8") as fout:
fout.write(bio_data)
process_dataset(bio_filename, filename)
def write_temp_2tag(filename, bio_data):
doc = []
sentences = bio_data.split("\n\n")
for sentence in sentences:
doc.append([])
for word in sentence.split("\n"):
text, tags = word.split("\t", maxsplit=1)
doc[-1].append({
"text": text,
"multi_ner": tags.split()
})
with open(filename, "w", encoding="utf-8") as fout:
json.dump(doc, fout)
def get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args):
save_dir = tmp_path / "models"
args = ["--data_dir", str(tmp_path),
"--wordvec_pretrain_file", pretrain_file,
"--train_file", str(train_json),
"--eval_file", str(dev_json),
"--shorthand", "en_test",
"--max_steps", "100",
"--eval_interval", "40",
"--save_dir", str(save_dir)]
args = args + list(extra_args)
return args
def run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG):
train_json = tmp_path / "en_test.train.json"
write_temp_2tag(train_json, train_data)
dev_json = tmp_path / "en_test.dev.json"
write_temp_2tag(dev_json, EN_DEV_2TAG)
args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
return ner_tagger.main(args)
def test_basic_two_tag_training(pretrain_file, tmp_path):
trainer = run_two_tag_training(pretrain_file, tmp_path)
assert len(trainer.model.tag_clfs) == 2
assert len(trainer.model.crits) == 2
assert len(trainer.vocab['tag'].lens()) == 2
def test_two_tag_training_backprop(pretrain_file, tmp_path):
"""
Test that the training is backproping both tags
We can do this by using the "finetune" mechanism and verifying
that the output tensors are different
"""
trainer = run_two_tag_training(pretrain_file, tmp_path)
# first, need to save the final model before restarting
# (alternatively, could reload the final checkpoint)
trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune")
assert len(trainer.model.tag_clfs) == 2
assert len(new_trainer.model.tag_clfs) == 2
for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs):
assert not torch.allclose(old_clf.weight, new_clf.weight)
def test_two_tag_training_c2_backprop(pretrain_file, tmp_path):
"""
Test that the training is backproping only one tag if one column is blank
We can do this by using the "finetune" mechanism and verifying
that the output tensors are different in just the first column
"""
trainer = run_two_tag_training(pretrain_file, tmp_path)
# first, need to save the final model before restarting
# (alternatively, could reload the final checkpoint)
trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune", train_data=EN_TRAIN_2TAG_EMPTY2)
assert len(trainer.model.tag_clfs) == 2
assert len(new_trainer.model.tag_clfs) == 2
assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight)
assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight)
def test_connected_two_tag_training(pretrain_file, tmp_path):
trainer = run_two_tag_training(pretrain_file, tmp_path, "--connect_output_layers")
assert len(trainer.model.tag_clfs) == 2
assert len(trainer.model.crits) == 2
assert len(trainer.vocab['tag'].lens()) == 2
# this checks that with the connected output layers,
# the second output layer has its size increased
# by the number of tags known to the first output layer
assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1]
def run_training(pretrain_file, tmp_path, *extra_args):
train_json = tmp_path / "en_test.train.json"
write_temp_file(train_json, EN_TRAIN_BIO)
dev_json = tmp_path / "en_test.dev.json"
write_temp_file(dev_json, EN_DEV_BIO)
args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
return ner_tagger.main(args)
def test_train_model_gpu(pretrain_file, tmp_path):
"""
Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
"""
trainer = run_training(pretrain_file, tmp_path)
if not torch.cuda.is_available():
warnings.warn("Cannot check that the NER model is on the GPU, since GPU is not available")
return
model = trainer.model
device = next(model.parameters()).device
assert str(device).startswith("cuda")
def test_train_model_cpu(pretrain_file, tmp_path):
"""
Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
"""
trainer = run_training(pretrain_file, tmp_path, "--cpu")
model = trainer.model
device = next(model.parameters()).device
assert str(device).startswith("cpu")
def model_file_has_bert(filename):
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
return any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
def test_with_bert(pretrain_file, tmp_path):
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert')
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
assert not model_file_has_bert(model_file)
def test_with_bert_finetune(pretrain_file, tmp_path):
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune')
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
assert model_file_has_bert(model_file)
foo_save_filename = os.path.join(tmp_path, "foo_" + trainer.args['save_name'])
bar_save_filename = os.path.join(tmp_path, "bar_" + trainer.args['save_name'])
trainer.save(foo_save_filename)
assert model_file_has_bert(foo_save_filename)
# TODO: technically this should still work if we turn off bert finetuning when reloading
reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename)
reloaded_trainer.save(bar_save_filename)
assert model_file_has_bert(bar_save_filename)
def test_with_peft_finetune(pretrain_file, tmp_path):
# TODO: check that the peft tensors are moving when training?
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)
assert 'bert_lora' in checkpoint
assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
# test loading
reloaded_trainer = Trainer(args=trainer.args, model_file=model_file)
================================================
FILE: stanza/tests/ner/test_ner_utils.py
================================================
import pytest
from stanza.tests import *
from stanza.models.common.vocab import EMPTY
from stanza.models.ner import utils
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
WORDS = [["Unban", "Mox", "Opal"], ["Ragavan", "is", "red"], ["Urza", "Lord", "High", "Artificer", "goes", "infinite", "with", "Thopter", "Sword"]]
BIO_TAGS = [["O", "B-ART", "I-ART"], ["B-MONKEY", "O", "B-COLOR"], ["B-PER", "I-PER", "I-PER", "I-PER", "O", "O", "O", "B-WEAPON", "B-WEAPON"]]
BIO_U_TAGS = [["O", "B_ART", "I_ART"], ["B_MONKEY", "O", "B_COLOR"], ["B_PER", "I_PER", "I_PER", "I_PER", "O", "O", "O", "B_WEAPON", "B_WEAPON"]]
BIOES_TAGS = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "S-WEAPON", "S-WEAPON"]]
# note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item
BASIC_TAGS = [["O", "ART", "ART"], ["MONKEY", "O", "COLOR"], [ "PER", "PER", "PER", "PER", "O", "O", "O", "WEAPON", "WEAPON"]]
BASIC_BIOES = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "B-WEAPON", "E-WEAPON"]]
ALT_BIO = [["O", "B-MANA", "I-MANA"], ["B-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
ALT_BIOES = [["O", "B-MANA", "E-MANA"], ["S-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
NONE_BIO = [["O", "B-MANA", "I-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
NONE_BIOES = [["O", "B-MANA", "E-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
EMPTY_BIO = [["O", "B-MANA", "I-MANA"], [EMPTY, EMPTY, EMPTY], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
def test_normalize_empty_tags():
sentences = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, NONE_BIO)]
new_sentences = utils.normalize_empty_tags(sentences)
expected = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, EMPTY_BIO)]
assert new_sentences == expected
def check_reprocessed_tags(words, input_tags, expected_tags):
sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)]
retagged = utils.process_tags(sentences=sentences, scheme="bioes")
# process_tags selectively returns tuples or strings based on the input
# so we don't need to fiddle with the expected output format here
expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)]
assert retagged == expected_retagged
def test_process_tags_bio():
check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS)
# check that the alternate version is correct as well
# that way we can independently check the two layer version
check_reprocessed_tags(WORDS, ALT_BIO, ALT_BIOES)
def test_process_tags_with_none():
# if there is a block of tags with None in them, the Nones should be skipped over
check_reprocessed_tags(WORDS, NONE_BIO, NONE_BIOES)
def merge_tags(*tags):
merged_tags = [[tuple(x) for x in zip(*sentences)] # combine tags such as ("O", "O"), ("B-ART", "B-MANA"), ...
for sentences in zip(*tags)] # ... for each set of sentences
return merged_tags
def test_combined_tags_bio():
bio_tags = merge_tags(BIO_TAGS, ALT_BIO)
expected = merge_tags(BIOES_TAGS, ALT_BIOES)
check_reprocessed_tags(WORDS, bio_tags, expected)
def test_combined_tags_mixed():
bio_tags = merge_tags(BIO_TAGS, ALT_BIOES)
expected = merge_tags(BIOES_TAGS, ALT_BIOES)
check_reprocessed_tags(WORDS, bio_tags, expected)
def test_process_tags_basic():
check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES)
def test_process_tags_bioes():
"""
This one should not change, naturally
"""
check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS)
check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES)
def run_flattened(fn, tags):
return fn([x for x in y for y in tags])
def test_check_bio():
assert utils.is_bio_scheme([x for y in BIO_TAGS for x in y])
assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y])
assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y])
assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y])
def test_check_basic():
assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y])
assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y])
assert utils.is_basic_scheme([x for y in BASIC_TAGS for x in y])
assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y])
def test_underscores():
"""
Check that the methods work if the inputs are underscores instead of dashes
"""
assert not utils.is_basic_scheme([x for y in BIO_U_TAGS for x in y])
check_reprocessed_tags(WORDS, BIO_U_TAGS, BIOES_TAGS)
def test_merge_tags():
"""
Check a few versions of the tag sequence merging
"""
seq1 = [ "O", "O", "O", "B-FOO", "E-FOO", "O"]
seq2 = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
seq3 = [ "B-FOO", "E-FOO", "B-FOO", "E-FOO", "O", "O"]
seq_err = [ "O", "B-FOO", "O", "B-FOO", "E-FOO", "O"]
seq_err2 = [ "O", "B-FOO", "O", "B-FOO", "B-FOO", "O"]
seq_err3 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "O"]
seq_err4 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "I-FOO"]
result = utils.merge_tags(seq1, seq2)
expected = [ "S-FOO", "O", "O", "B-FOO", "E-FOO", "O"]
assert result == expected
result = utils.merge_tags(seq2, seq1)
expected = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
assert result == expected
result = utils.merge_tags(seq1, seq3)
expected = [ "B-FOO", "E-FOO", "O", "B-FOO", "E-FOO", "O"]
assert result == expected
with pytest.raises(ValueError):
result = utils.merge_tags(seq1, seq_err)
with pytest.raises(ValueError):
result = utils.merge_tags(seq1, seq_err2)
with pytest.raises(ValueError):
result = utils.merge_tags(seq1, seq_err3)
with pytest.raises(ValueError):
result = utils.merge_tags(seq1, seq_err4)
================================================
FILE: stanza/tests/ner/test_pay_amt_annotators.py
================================================
"""
Simple test for tracking AMT annotator work
"""
import os
import zipfile
import pytest
from stanza.tests import TEST_WORKING_DIR
from stanza.utils.ner import paying_annotators
DATA_SOURCE = os.path.join(TEST_WORKING_DIR, "in", "aws_annotations.zip")
@pytest.fixture(scope="module")
def completed_amt_job_metadata(tmp_path_factory):
assert os.path.exists(DATA_SOURCE)
unzip_path = tmp_path_factory.mktemp("amt_test")
input_path = unzip_path / "ner" / "aws_labeling_copy"
with zipfile.ZipFile(DATA_SOURCE, 'r') as zin:
zin.extractall(unzip_path)
return input_path
def test_amt_annotator_track(completed_amt_job_metadata):
workers = {
"7efc17ac-3397-4472-afe5-89184ad145d0": "Worker1",
"afce8c28-969c-4e73-a20f-622ef122f585": "Worker2",
"91f6236e-63c6-4a84-8fd6-1efbab6dedab": "Worker3",
"6f202e93-e6b6-4e1d-8f07-0484b9a9093a": "Worker4",
"2b674d33-f656-44b0-8f90-d70a1ab71ec2": "Worker5"
} # map AMT annotator subs to relevant identifier
tracked_work = paying_annotators.track_tasks(completed_amt_job_metadata, workers)
assert tracked_work == {'Worker4': 20, 'Worker5': 20, 'Worker2': 3, 'Worker3': 16}
def test_amt_annotator_track_no_map(completed_amt_job_metadata):
sub_to_count = paying_annotators.track_tasks(completed_amt_job_metadata)
assert sub_to_count == {'6f202e93-e6b6-4e1d-8f07-0484b9a9093a': 20, '2b674d33-f656-44b0-8f90-d70a1ab71ec2': 20,
'afce8c28-969c-4e73-a20f-622ef122f585': 3, '91f6236e-63c6-4a84-8fd6-1efbab6dedab': 16}
def main():
test_amt_annotator_track()
test_amt_annotator_track_no_map()
if __name__ == "__main__":
main()
print("TESTS COMPLETED!")
================================================
FILE: stanza/tests/ner/test_split_wikiner.py
================================================
"""
Runs a few tests on the split_wikiner file
"""
import os
import tempfile
import pytest
from stanza.utils.datasets.ner import split_wikiner
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# two sentences from the Italian dataset, split into many pieces
# to test the splitting functionality
FBK_SAMPLE = """
Il O
Papa O
si O
aggrava O
Le O
condizioni O
di O
Papa O
Giovanni PER
Paolo PER
II PER
si O
sono O
aggravate O
in O
il O
corso O
di O
la O
giornata O
di O
giovedì O
. O
Il O
portavoce O
Navarro PER
Valls PER
ha O
dichiarato O
che O
il O
Santo O
Padre O
in O
la O
giornata O
di O
oggi O
è O
stato O
colpito O
da O
una O
affezione O
altamente O
febbrile O
provocata O
da O
una O
infezione O
documentata O
di O
le O
vie O
urinarie O
. O
A O
il O
momento O
non O
è O
previsto O
il O
ricovero O
a O
il O
Policlinico LOC
Gemelli LOC
, O
come O
ha O
precisato O
il O
responsabile O
di O
il O
dipartimento O
di O
emergenza O
professor O
Rodolfo PER
Proietti PER
. O
"""
def test_read_sentences():
with tempfile.TemporaryDirectory() as tempdir:
raw_filename = os.path.join(tempdir, "raw.tsv")
with open(raw_filename, "w") as fout:
fout.write(FBK_SAMPLE)
sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
assert len(sentences) == 20
text = [["\t".join(word) for word in sent] for sent in sentences]
text = ["\n".join(sent) for sent in text]
text = "\n\n".join(text)
assert FBK_SAMPLE.strip() == text
def test_write_sentences():
with tempfile.TemporaryDirectory() as tempdir:
raw_filename = os.path.join(tempdir, "raw.tsv")
with open(raw_filename, "w") as fout:
fout.write(FBK_SAMPLE)
sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
copy_filename = os.path.join(tempdir, "copy.tsv")
split_wikiner.write_sentences_to_file(sentences, copy_filename)
sent2 = split_wikiner.read_sentences(raw_filename, "utf-8")
assert sent2 == sentences
def run_split_wikiner(expected_train=14, expected_dev=3, expected_test=3, **kwargs):
"""
Runs a test using various parameters to check the results of the splitting process
"""
with tempfile.TemporaryDirectory() as indir:
raw_filename = os.path.join(indir, "raw.tsv")
with open(raw_filename, "w") as fout:
fout.write(FBK_SAMPLE)
with tempfile.TemporaryDirectory() as outdir:
split_wikiner.split_wikiner(outdir, raw_filename, **kwargs)
train_file = os.path.join(outdir, "it_fbk.train.bio")
dev_file = os.path.join(outdir, "it_fbk.dev.bio")
test_file = os.path.join(outdir, "it_fbk.test.bio")
assert os.path.exists(train_file)
assert os.path.exists(dev_file)
if kwargs["test_section"]:
assert os.path.exists(test_file)
else:
assert not os.path.exists(test_file)
train_sent = split_wikiner.read_sentences(train_file, "utf-8")
dev_sent = split_wikiner.read_sentences(dev_file, "utf-8")
assert len(train_sent) == expected_train
assert len(dev_sent) == expected_dev
if kwargs["test_section"]:
test_sent = split_wikiner.read_sentences(test_file, "utf-8")
assert len(test_sent) == expected_test
else:
test_sent = []
if kwargs["shuffle"]:
orig_sents = sorted(split_wikiner.read_sentences(raw_filename, "utf-8"))
split_sents = sorted(train_sent + dev_sent + test_sent)
else:
orig_sents = split_wikiner.read_sentences(raw_filename, "utf-8")
split_sents = train_sent + dev_sent + test_sent
assert orig_sents == split_sents
def test_no_shuffle_split():
run_split_wikiner(prefix="it_fbk", shuffle=False, test_section=True)
def test_shuffle_split():
run_split_wikiner(prefix="it_fbk", shuffle=True, test_section=True)
def test_resize():
run_split_wikiner(expected_train=12, expected_dev=2, expected_test=6, train_fraction=0.6, dev_fraction=0.1, prefix="it_fbk", shuffle=True, test_section=True)
def test_no_test_split():
run_split_wikiner(expected_train=17, train_fraction=0.85, prefix="it_fbk", shuffle=False, test_section=False)
================================================
FILE: stanza/tests/ner/test_suc3.py
================================================
"""
Tests the conversion code for the SUC3 NER dataset
"""
import os
import tempfile
from zipfile import ZipFile
import pytest
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob
TEST_CONLL = """
1 Den den PN PN UTR|SIN|DEF|SUB/OBJ _ _ _ _ O _ ac01b-030:2328
2 Gud Gud PM PM NOM _ _ _ _ B myth ac01b-030:2329
3 giver giva VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2330
4 ämbetet ämbete NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2331
5 får få VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2332
6 också också AB AB _ _ _ _ O _ ac01b-030:2333
7 förståndet förstånd NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2334
8 . . MAD MAD _ _ _ _ O _ ac01b-030:2335
1 Han han PN PN UTR|SIN|DEF|SUB _ _ _ _ O _ aa01a-017:227
2 berättar berätta VB VB PRS|AKT _ _ _ _ O _ aa01a-017:228
3 anekdoten anekdot NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:229
4 som som HP HP -|-|- _ _ _ _ O _ aa01a-017:230
5 FN-medlaren FN-medlare NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:231
6 Brian Brian PM PM NOM _ _ _ _ B person aa01a-017:232
7 Urquhart Urquhart PM PM NOM _ _ _ _ I person aa01a-017:233
8 myntat mynta VB VB SUP|AKT _ _ _ _ O _ aa01a-017:234
9 : : MAD MAD _ _ _ _ O _ aa01a-017:235
"""
EXPECTED_IOB = """
Den O
Gud B-myth
giver O
ämbetet O
får O
också O
förståndet O
. O
Han O
berättar O
anekdoten O
som O
FN-medlaren O
Brian B-person
Urquhart I-person
myntat O
: O
"""
def test_read_zip():
"""
Test creating a fake zip file, then converting it to an .iob file
"""
with tempfile.TemporaryDirectory() as tempdir:
zip_name = os.path.join(tempdir, "test.zip")
in_filename = "conll"
with ZipFile(zip_name, "w") as zout:
with zout.open(in_filename, "w") as fout:
fout.write(TEST_CONLL.encode())
out_filename = os.path.join(tempdir, "iob")
num = suc_conll_to_iob.extract_from_zip(zip_name, in_filename, out_filename)
assert num == 2
with open(out_filename) as fin:
result = fin.read()
assert EXPECTED_IOB.strip() == result.strip()
def test_read_raw():
"""
Test a direct text file conversion w/o the zip file
"""
with tempfile.TemporaryDirectory() as tempdir:
in_filename = os.path.join(tempdir, "test.txt")
with open(in_filename, "w", encoding="utf-8") as fout:
fout.write(TEST_CONLL)
out_filename = os.path.join(tempdir, "iob")
with open(in_filename, encoding="utf-8") as fin, open(out_filename, "w", encoding="utf-8") as fout:
num = suc_conll_to_iob.extract(fin, fout)
assert num == 2
with open(out_filename) as fin:
result = fin.read()
assert EXPECTED_IOB.strip() == result.strip()
================================================
FILE: stanza/tests/pipeline/__init__.py
================================================
================================================
FILE: stanza/tests/pipeline/pipeline_device_tests.py
================================================
"""
Utility methods to check that all processors are on the expected device
Refactored since it can be used for multiple pipelines
"""
import warnings
import torch
def check_on_gpu(pipeline):
"""
Check that the processors are all on the GPU and that basic execution works
"""
if not torch.cuda.is_available():
warnings.warn("Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!")
return
for name, proc in pipeline.processors.items():
if proc.trainer is not None:
device = next(proc.trainer.model.parameters()).device
else:
device = next(proc._model.parameters()).device
assert str(device).startswith("cuda"), "Processor %s was not on the GPU" % name
# just check that there are no cpu/cuda tensor conflicts
# when running on the GPU
pipeline("This is a small test")
def check_on_cpu(pipeline):
"""
Check that the processors are all on the CPU and that basic execution works
"""
for name, proc in pipeline.processors.items():
if proc.trainer is not None:
device = next(proc.trainer.model.parameters()).device
else:
device = next(proc._model.parameters()).device
assert str(device).startswith("cpu"), "Processor %s was not on the CPU" % name
# just check that there are no cpu/cuda tensor conflicts
# when running on the CPU
pipeline("This is a small test")
================================================
FILE: stanza/tests/pipeline/test_arabic_pipeline.py
================================================
"""
Small test of loading the Arabic pipeline
The main goal is to check that nothing goes wrong with RtL languages,
but incidentally this would have caught a bug where the xpos tags
were split into individual pieces instead of reassembled as expected
"""
import pytest
import stanza
from stanza.tests import TEST_MODELS_DIR
pytestmark = pytest.mark.pipeline
def test_arabic_pos_pipeline():
pipe = stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'ar'})
text = "ولم يتم اعتقال احد بحسب المتحدث باسم الشرطة."
doc = pipe(text)
# the first token translates to "and not", seems common enough
# that we should be able to rely on it having a stable MWT and tag
assert len(doc.sentences) == 1
assert doc.sentences[0].tokens[0].text == "ولم"
assert doc.sentences[0].words[0].xpos == "C---------"
assert doc.sentences[0].words[1].xpos == "F---------"
================================================
FILE: stanza/tests/pipeline/test_core.py
================================================
import pytest
import shutil
import tempfile
import stanza
from stanza.tests import *
from stanza.pipeline import core
from stanza.resources.common import get_md5, load_resources_json
pytestmark = pytest.mark.pipeline
def test_pretagged():
"""
Test that the pipeline does or doesn't build if pos is left out and pretagged is specified
"""
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,pos,lemma,depparse")
with pytest.raises(core.PipelineRequirementsException):
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse")
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True)
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", pretagged=True)
# test that the module specific flag overrides the general flag
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True, pretagged=False)
def test_download_missing_ner_model():
"""
Test that the pipeline will automatically download missing models
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined", verbose=False)
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"ner": ("ontonotes_charlm")})
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
en_dir = os.path.join(test_dir, 'en')
en_dir_listing = sorted(os.listdir(en_dir))
assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
def test_download_missing_resources():
"""
Test that the pipeline will automatically download missing models
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"tokenize": "combined", "ner": "ontonotes_charlm"})
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
en_dir = os.path.join(test_dir, 'en')
en_dir_listing = sorted(os.listdir(en_dir))
assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
def test_download_resources_overwrites():
"""
Test that the DOWNLOAD_RESOURCES method overwrites an existing resources.json
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
resources_path = os.path.join(test_dir, 'resources.json')
mod_time = os.path.getmtime(resources_path)
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
new_mod_time = os.path.getmtime(resources_path)
assert mod_time != new_mod_time
def test_reuse_resources_overwrites():
"""
Test that the REUSE_RESOURCES method does *not* overwrite an existing resources.json
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
pipe = stanza.Pipeline("en",
download_method=core.DownloadMethod.REUSE_RESOURCES,
model_dir=test_dir,
processors="tokenize",
package={"tokenize": "combined"})
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
resources_path = os.path.join(test_dir, 'resources.json')
mod_time = os.path.getmtime(resources_path)
pipe = stanza.Pipeline("en",
download_method=core.DownloadMethod.REUSE_RESOURCES,
model_dir=test_dir,
processors="tokenize",
package={"tokenize": "combined"})
new_mod_time = os.path.getmtime(resources_path)
assert mod_time == new_mod_time
def test_download_not_repeated():
"""
Test that a model is only downloaded once if it already matches the expected model from the resources file
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
en_dir = os.path.join(test_dir, 'en')
en_dir_listing = sorted(os.listdir(en_dir))
assert en_dir_listing == ['mwt', 'tokenize']
tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
mod_time = os.path.getmtime(tokenize_path)
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
assert os.path.getmtime(tokenize_path) == mod_time
def test_download_none():
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
stanza.download("it", model_dir=test_dir, processors="tokenize", package="combined")
stanza.download("it", model_dir=test_dir, processors="tokenize", package="vit")
it_dir = os.path.join(test_dir, 'it')
it_dir_listing = sorted(os.listdir(it_dir))
assert sorted(it_dir_listing) == ['mwt', 'tokenize']
combined_path = os.path.join(it_dir, "tokenize", "combined.pt")
vit_path = os.path.join(it_dir, "tokenize", "vit.pt")
assert os.path.exists(combined_path)
assert os.path.exists(vit_path)
combined_md5 = get_md5(combined_path)
vit_md5 = get_md5(vit_path)
# check that the models are different
# otherwise the test is not testing anything
assert combined_md5 != vit_md5
shutil.copyfile(vit_path, combined_path)
assert get_md5(combined_path) == vit_md5
pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=None)
assert get_md5(combined_path) == vit_md5
pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
assert get_md5(combined_path) != vit_md5
def check_download_method_updates(download_method):
"""
Run a single test of creating a pipeline with a given download_method, checking that the model is updated
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
en_dir = os.path.join(test_dir, 'en')
en_dir_listing = sorted(os.listdir(en_dir))
assert en_dir_listing == ['mwt', 'tokenize']
tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
with open(tokenize_path, "w") as fout:
fout.write("Unban mox opal!")
mod_time = os.path.getmtime(tokenize_path)
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=download_method)
assert os.path.getmtime(tokenize_path) != mod_time
def test_download_fixed():
"""
Test that a model is fixed if the existing model doesn't match the md5sum
"""
for download_method in (core.DownloadMethod.REUSE_RESOURCES, core.DownloadMethod.DOWNLOAD_RESOURCES):
check_download_method_updates(download_method)
def test_download_strings():
"""
Same as the test of the download_method, but tests that the pipeline works for string download_method
"""
for download_method in ("reuse_resources", "download_resources"):
check_download_method_updates(download_method)
def test_limited_pipeline():
"""
Test loading a pipeline, but then only using a couple processors
"""
pipe = stanza.Pipeline(processors="tokenize,pos,lemma,depparse,ner", dir=TEST_MODELS_DIR)
doc = pipe("John Bauer works at Stanford")
assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
doc = pipe("John Bauer works at Stanford", processors=["tokenize","pos"])
assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
doc = pipe("John Bauer works at Stanford", processors="tokenize")
assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
doc = pipe("John Bauer works at Stanford", processors="tokenize,ner")
assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
with pytest.raises(ValueError):
# this should fail
doc = pipe("John Bauer works at Stanford", processors="tokenize,depparse")
@pytest.fixture(scope="module")
def unknown_language_name():
resources = load_resources_json(model_dir=TEST_MODELS_DIR)
name = "en"
while name in resources:
name = name + "z"
assert name != "en"
return name
def test_empty_unknown_language(unknown_language_name):
"""
Check that there is an error for trying to load an unknown language
"""
with pytest.raises(ValueError):
pipe = stanza.Pipeline(unknown_language_name, model_dir=TEST_MODELS_DIR, download_method=None)
def test_unknown_language_tokenizer(unknown_language_name):
"""
Test that loading tokenize works for an unknown language
"""
base_pipe = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
# even if we one day add MWT to English, the tokenizer by itself should still work
tokenize_processor = base_pipe.processors["tokenize"]
pipe=stanza.Pipeline(unknown_language_name,
processors="tokenize",
allow_unknown_language=True,
tokenize_model_path=tokenize_processor.config['model_path'],
download_method=None)
doc = pipe("This is a test")
words = [x.text for x in doc.sentences[0].words]
assert words == ['This', 'is', 'a', 'test']
def test_unknown_language_mwt(unknown_language_name):
"""
Test that loading tokenize & mwt works for an unknown language
"""
base_pipe = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, processors="tokenize,mwt", download_method=None)
assert len(base_pipe.processors) == 2
tokenize_processor = base_pipe.processors["tokenize"]
mwt_processor = base_pipe.processors["mwt"]
pipe=stanza.Pipeline(unknown_language_name,
model_dir=TEST_MODELS_DIR,
processors="tokenize,mwt",
allow_unknown_language=True,
tokenize_model_path=tokenize_processor.config['model_path'],
mwt_model_path=mwt_processor.config['model_path'],
download_method=None)
================================================
FILE: stanza/tests/pipeline/test_decorators.py
================================================
"""
Basic tests of the depparse processor boolean flags
"""
import pytest
import stanza
from stanza.models.common.doc import Document
from stanza.pipeline.core import PipelineRequirementsException
from stanza.pipeline.processor import Processor, ProcessorVariant, register_processor, register_processor_variant, ProcessorRegisterException
from stanza.utils.conll import CoNLL
from stanza.tests import *
pytestmark = pytest.mark.pipeline
# data for testing
EN_DOC = "This is a test sentence. This is another!"
EN_DOC_LOWERCASE_TOKENS = ''']>
]>
]>
]>
]>
]>
]>
]>
]>
]>'''
EN_DOC_LOL_TOKENS = ''']>
]>
]>
]>
]>
]>
]>
]>'''
EN_DOC_COOL_LEMMAS = ''']>
]>
]>
]>
]>
]>
]>
]>
]>
]>'''
@register_processor("lowercase")
class LowercaseProcessor(Processor):
''' Processor that lowercases all text '''
_requires = set(['tokenize'])
_provides = set(['lowercase'])
def __init__(self, config, pipeline, device):
pass
def _set_up_model(self, *args):
pass
def process(self, doc):
doc.text = doc.text.lower()
for sent in doc.sentences:
for tok in sent.tokens:
tok.text = tok.text.lower()
for word in sent.words:
word.text = word.text.lower()
return doc
def test_register_processor():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,lowercase', download_method=None)
doc = nlp(EN_DOC)
assert EN_DOC_LOWERCASE_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
def test_register_nonprocessor():
with pytest.raises(ProcessorRegisterException):
@register_processor("nonprocessor")
class NonProcessor:
pass
@register_processor_variant("tokenize", "lol")
class LOLTokenizer(ProcessorVariant):
''' An alternative tokenizer that splits text by space and replaces all tokens with LOL '''
def __init__(self, lang):
pass
def process(self, text):
sentence = [{'id': (i+1, ), 'text': 'LOL'} for i, tok in enumerate(text.split())]
return Document([sentence], text)
def test_register_processor_variant():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "lol"}, package=None, download_method=None)
doc = nlp(EN_DOC)
assert EN_DOC_LOL_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
@register_processor_variant("lemma", "cool")
class CoolLemmatizer(ProcessorVariant):
''' An alternative lemmatizer that lemmatizes every word to "cool". '''
OVERRIDE = True
def __init__(self, lang):
pass
def process(self, document):
for sentence in document.sentences:
for word in sentence.words:
word.lemma = "cool"
return document
def test_register_processor_variant_with_override():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "combined", "pos": "combined", "lemma": "cool"}, package=None, download_method=None)
doc = nlp(EN_DOC)
result = '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
assert EN_DOC_COOL_LEMMAS == result
def test_register_nonprocessor_variant():
with pytest.raises(ProcessorRegisterException):
@register_processor_variant("tokenize", "nonvariant")
class NonVariant:
pass
================================================
FILE: stanza/tests/pipeline/test_depparse.py
================================================
"""
Basic tests of the depparse processor boolean flags
"""
import gc
import pytest
import stanza
from stanza.pipeline.core import PipelineRequirementsException
from stanza.utils.conll import CoNLL
from stanza.tests import *
pytestmark = pytest.mark.pipeline
# data for testing
EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
EN_DOC_CONLLU_PRETAGGED = """
1 Barack Barack PROPN NNP Number=Sing 0 _ _ _
2 Obama Obama PROPN NNP Number=Sing 1 _ _ _
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 2 _ _ _
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 3 _ _ _
5 in in ADP IN _ 4 _ _ _
6 Hawaii Hawaii PROPN NNP Number=Sing 5 _ _ _
7 . . PUNCT . _ 6 _ _ _
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 0 _ _ _
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 _ _ _
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 2 _ _ _
4 president president PROPN NNP Number=Sing 3 _ _ _
5 in in ADP IN _ 4 _ _ _
6 2008 2008 NUM CD NumType=Card 5 _ _ _
7 . . PUNCT . _ 6 _ _ _
1 Obama Obama PROPN NNP Number=Sing 0 _ _ _
2 attended attend VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 1 _ _ _
3 Harvard Harvard PROPN NNP Number=Sing 2 _ _ _
4 . . PUNCT . _ 3 _ _ _
""".lstrip()
EN_DOC_DEPENDENCY_PARSES_GOLD = """
('Barack', 4, 'nsubj:pass')
('Obama', 1, 'flat')
('was', 4, 'aux:pass')
('born', 0, 'root')
('in', 6, 'case')
('Hawaii', 4, 'obl')
('.', 4, 'punct')
('He', 3, 'nsubj:pass')
('was', 3, 'aux:pass')
('elected', 0, 'root')
('president', 3, 'xcomp')
('in', 6, 'case')
('2008', 3, 'obl')
('.', 3, 'punct')
('Obama', 2, 'nsubj')
('attended', 0, 'root')
('Harvard', 2, 'obj')
('.', 2, 'punct')
""".strip()
@pytest.fixture(scope="module")
def en_depparse_pipeline():
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,pos,lemma,depparse')
gc.collect()
return nlp
def test_depparse(en_depparse_pipeline):
doc = en_depparse_pipeline(EN_DOC)
assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join([sent.dependencies_string() for sent in doc.sentences])
def test_depparse_with_pretagged_doc():
nlp = stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en',
'depparse_pretagged': True})
doc = CoNLL.conll2doc(input_str=EN_DOC_CONLLU_PRETAGGED)
processed_doc = nlp(doc)
assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join(
[sent.dependencies_string() for sent in processed_doc.sentences])
def test_raises_requirements_exception_if_pretagged_not_passed():
with pytest.raises(PipelineRequirementsException):
stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})
================================================
FILE: stanza/tests/pipeline/test_english_pipeline.py
================================================
"""
Basic testing of the English pipeline
"""
import pytest
import stanza
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import Document
from stanza.tests import *
from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# data for testing
EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."]
EN_DOC_TOKENS_GOLD = """
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
""".strip()
EN_DOC_WORDS_GOLD = """
""".strip()
EN_DOC_DEPENDENCY_PARSES_GOLD = """
('Barack', 4, 'nsubj:pass')
('Obama', 1, 'flat')
('was', 4, 'aux:pass')
('born', 0, 'root')
('in', 6, 'case')
('Hawaii', 4, 'obl')
('.', 4, 'punct')
('He', 3, 'nsubj:pass')
('was', 3, 'aux:pass')
('elected', 0, 'root')
('president', 3, 'xcomp')
('in', 6, 'case')
('2008', 3, 'obl')
('.', 3, 'punct')
('Obama', 2, 'nsubj')
('attended', 0, 'root')
('Harvard', 2, 'obj')
('.', 2, 'punct')
""".strip()
EN_DOC_CONLLU_GOLD = """
# text = Barack Obama was born in Hawaii.
# sent_id = 0
# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
# sentiment = 1
1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ SpaceAfter=No|start_char=25|end_char=31|ner=S-GPE
7 . . PUNCT . _ 4 punct _ SpacesAfter=\\s\\s|start_char=31|end_char=32|ner=O
# text = He was elected president in 2008.
# sent_id = 1
# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
# sentiment = 1
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=34|end_char=36|ner=O
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=37|end_char=40|ner=O
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=41|end_char=48|ner=O
4 president president NOUN NN Number=Sing 3 xcomp _ start_char=49|end_char=58|ner=O
5 in in ADP IN _ 6 case _ start_char=59|end_char=61|ner=O
6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ SpaceAfter=No|start_char=62|end_char=66|ner=S-DATE
7 . . PUNCT . _ 3 punct _ SpacesAfter=\\s\\s|start_char=66|end_char=67|ner=O
# text = Obama attended Harvard.
# sent_id = 2
# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
# sentiment = 1
1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=69|end_char=74|ner=S-PERSON
2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=75|end_char=83|ner=O
3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ SpaceAfter=No|start_char=84|end_char=91|ner=S-ORG
4 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=91|end_char=92|ner=O
""".strip()
EN_DOC_CONLLU_GOLD_MULTIDOC = """
# text = Barack Obama was born in Hawaii.
# sent_id = 0
# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
# sentiment = 1
1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ SpaceAfter=No|start_char=25|end_char=31|ner=S-GPE
7 . . PUNCT . _ 4 punct _ SpaceAfter=No|start_char=31|end_char=32|ner=O
# text = He was elected president in 2008.
# sent_id = 1
# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
# sentiment = 1
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=0|end_char=2|ner=O
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=3|end_char=6|ner=O
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=7|end_char=14|ner=O
4 president president NOUN NN Number=Sing 3 xcomp _ start_char=15|end_char=24|ner=O
5 in in ADP IN _ 6 case _ start_char=25|end_char=27|ner=O
6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ SpaceAfter=No|start_char=28|end_char=32|ner=S-DATE
7 . . PUNCT . _ 3 punct _ SpaceAfter=No|start_char=32|end_char=33|ner=O
# text = Obama attended Harvard.
# sent_id = 2
# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
# sentiment = 1
1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=5|ner=S-PERSON
2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=6|end_char=14|ner=O
3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ SpaceAfter=No|start_char=15|end_char=22|ner=S-ORG
4 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=22|end_char=23|ner=O
""".strip()
PRETOKENIZED_TEXT = "Jennifer has lovely blue antennae ."
PRETOKENIZED_PIECES = [PRETOKENIZED_TEXT.split()]
EXPECTED_TOKENIZED_ONLY_CONLLU = """
# text = Jennifer has lovely blue antennae .
# sent_id = 0
1 Jennifer _ _ _ _ 0 _ _ start_char=0|end_char=8
2 has _ _ _ _ 1 _ _ start_char=9|end_char=12
3 lovely _ _ _ _ 2 _ _ start_char=13|end_char=19
4 blue _ _ _ _ 3 _ _ start_char=20|end_char=24
5 antennae _ _ _ _ 4 _ _ start_char=25|end_char=33
6 . _ _ _ _ 5 _ _ SpaceAfter=No|start_char=34|end_char=35
""".strip()
EXPECTED_PRETOKENIZED_CONLLU = """
# text = Jennifer has lovely blue antennae .
# sent_id = 0
# constituency = (ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ lovely) (JJ blue) (NNS antennae))) (. .)))
# sentiment = 2
1 Jennifer Jennifer PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=8|ner=S-PERSON
2 has have VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ start_char=9|end_char=12|ner=O
3 lovely lovely ADJ JJ Degree=Pos 5 amod _ start_char=13|end_char=19|ner=O
4 blue blue ADJ JJ Degree=Pos 5 amod _ start_char=20|end_char=24|ner=O
5 antennae antenna NOUN NNS Number=Plur 2 obj _ start_char=25|end_char=33|ner=O
6 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=34|end_char=35|ner=O
""".strip()
class TestEnglishPipeline:
@pytest.fixture(scope="class")
def pipeline(self):
return stanza.Pipeline(dir=TEST_MODELS_DIR, download_method=None)
@pytest.fixture(scope="class")
def pretokenized_pipeline(self):
return stanza.Pipeline(dir=TEST_MODELS_DIR, tokenize_pretokenized=True, download_method=None)
@pytest.fixture(scope="class")
def tokenizer_pipeline(self):
return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
@pytest.fixture(scope="class")
def processed_doc(self, pipeline):
""" Document created by running full English pipeline on a few sentences """
return pipeline(EN_DOC)
def test_text(self, processed_doc):
assert processed_doc.text == EN_DOC
def test_conllu(self, processed_doc):
assert "{:C}".format(processed_doc) == EN_DOC_CONLLU_GOLD
def test_process_conllu(self, pipeline):
"""
Process a conllu text directly
This can use the pipeline which still uses tokenization, as
process_conllu skips the tokenize and mwt processors
"""
doc = pipeline.process_conllu(EN_DOC_CONLLU_GOLD)
result = "{:C}".format(doc)
assert result == EN_DOC_CONLLU_GOLD
def test_tokens(self, processed_doc):
assert "\n\n".join([sent.tokens_string() for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
def test_words(self, processed_doc):
assert "\n\n".join([sent.words_string() for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
def test_dependency_parse(self, processed_doc):
assert "\n\n".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \
EN_DOC_DEPENDENCY_PARSES_GOLD
def test_empty(self, pipeline):
# make sure that various models handle the degenerate empty case
pipeline("")
pipeline("--")
def test_bulk_process(self, pipeline):
""" Double check that the bulk_process method in Pipeline converts documents as expected """
# it should process strings
processed = pipeline.bulk_process(EN_DOCS)
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
# it should pass Documents through successfully
docs = [Document([], text=t) for t in EN_DOCS]
processed = pipeline.bulk_process(docs)
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
def test_empty_bulk_process(self, pipeline):
""" Previously we had a bug where an empty document list would cause a crash """
processed = pipeline.bulk_process([])
assert processed == []
def test_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline):
doc = pretokenized_pipeline(PRETOKENIZED_PIECES)
conllu = "{:C}".format(doc).strip()
assert conllu == EXPECTED_PRETOKENIZED_CONLLU
doc = tokenizer_pipeline(PRETOKENIZED_TEXT)
conllu = "{:C}".format(doc).strip()
assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU
# putting a doc with tokens into the pipeline should also work
reparsed = pretokenized_pipeline(doc)
conllu = "{:C}".format(reparsed).strip()
assert conllu == EXPECTED_PRETOKENIZED_CONLLU
def test_bulk_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline):
doc = tokenizer_pipeline(PRETOKENIZED_TEXT)
conllu = "{:C}".format(doc).strip()
assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU
docs = pretokenized_pipeline([doc, doc])
assert len(docs) == 2
for doc in docs:
conllu = "{:C}".format(doc).strip()
assert conllu == EXPECTED_PRETOKENIZED_CONLLU
def test_conll2doc_pretokenized(self, pretokenized_pipeline):
doc = CoNLL.conll2doc(input_str=EXPECTED_TOKENIZED_ONLY_CONLLU)
# this was bug from version 1.10.1 sent to us from a user
# the pretokenized tokenize_processor would try to whitespace tokenize a document
# even if the document already had sentences & words & stuff
# not only would that be wrong if the text wouldn't whitespace tokenize into the words
# (such as with punctuation and SpaceAfter=No),
# it wouldn't even work in the case of conll2doc, since the document.text wasn't set
docs = pretokenized_pipeline([doc, doc])
assert len(docs) == 2
for doc in docs:
conllu = "{:C}".format(doc).strip()
assert conllu == EXPECTED_PRETOKENIZED_CONLLU
def test_stream(self, pipeline):
""" Test the streaming interface to the Pipeline """
# Test all of the documents in one batch
# (the default batch size is significantly more than |EN_DOCS|)
processed = [doc for doc in pipeline.stream(EN_DOCS)]
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
# It should also work on an iterator rather than an iterable
processed = [doc for doc in pipeline.stream(iter(EN_DOCS))]
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
# Stream one at a time
processed = [doc for doc in pipeline.stream(EN_DOCS, batch_size=1)]
processed = ["{:C}".format(doc) for doc in processed]
assert "\n\n".join(processed) == EN_DOC_CONLLU_GOLD_MULTIDOC
@pytest.fixture(scope="class")
def processed_multidoc(self, pipeline):
""" Document created by running full English pipeline on a few sentences """
docs = [Document([], text=t) for t in EN_DOCS]
return pipeline(docs)
def test_conllu_multidoc(self, processed_multidoc):
assert "\n\n".join(["{:C}".format(doc) for doc in processed_multidoc]) == EN_DOC_CONLLU_GOLD_MULTIDOC
def test_tokens_multidoc(self, processed_multidoc):
assert "\n\n".join([sent.tokens_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
def test_words_multidoc(self, processed_multidoc):
assert "\n\n".join([sent.words_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
def test_sentence_indices_multidoc(self, processed_multidoc):
sentences = [sent for doc in processed_multidoc for sent in doc.sentences]
for sent_idx, sentence in enumerate(sentences):
assert sent_idx == sentence.index
def test_dependency_parse_multidoc(self, processed_multidoc):
assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == \
EN_DOC_DEPENDENCY_PARSES_GOLD
@pytest.fixture(scope="class")
def processed_multidoc_variant(self):
""" Document created by running full English pipeline on a few sentences """
docs = [Document([], text=t) for t in EN_DOCS]
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors={'tokenize': 'spacy'})
return nlp(docs)
def test_dependency_parse_multidoc_variant(self, processed_multidoc_variant):
assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \
EN_DOC_DEPENDENCY_PARSES_GOLD
def test_constituency_parser(self):
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency")
doc = nlp("This is a test")
assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))'
def test_on_gpu(self, pipeline):
"""
The default pipeline should have all the models on the GPU
"""
check_on_gpu(pipeline)
def test_on_cpu(self):
"""
Create a pipeline on the CPU, check that all the models on CPU
"""
pipeline = stanza.Pipeline("en", dir=TEST_MODELS_DIR, use_gpu=False)
check_on_cpu(pipeline)
================================================
FILE: stanza/tests/pipeline/test_french_pipeline.py
================================================
"""
Basic testing of French pipeline
The benefit of this test is to verify that the bulk processing works
for languages with MWT in them
"""
import pytest
import stanza
from stanza.models.common.doc import Document
from stanza.tests import *
from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
pytestmark = pytest.mark.pipeline
FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
"l'Industrie et du Numérique."
EXPECTED_RESULT = """
[
[
{
"id": 1,
"text": "Alors",
"lemma": "alors",
"upos": "ADV",
"head": 3,
"deprel": "advmod",
"start_char": 0,
"end_char": 5
},
{
"id": 2,
"text": "encore",
"lemma": "encore",
"upos": "ADV",
"head": 3,
"deprel": "advmod",
"start_char": 6,
"end_char": 12
},
{
"id": 3,
"text": "inconnu",
"lemma": "inconnu",
"upos": "ADJ",
"feats": "Gender=Masc|Number=Sing",
"head": 11,
"deprel": "advcl",
"start_char": 13,
"end_char": 20
},
{
"id": [
4,
5
],
"text": "du",
"start_char": 21,
"end_char": 23
},
{
"id": 4,
"text": "de",
"lemma": "de",
"upos": "ADP",
"head": 7,
"deprel": "case"
},
{
"id": 5,
"text": "le",
"lemma": "le",
"upos": "DET",
"feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
"head": 7,
"deprel": "det"
},
{
"id": 6,
"text": "grand",
"lemma": "grand",
"upos": "ADJ",
"feats": "Gender=Masc|Number=Sing",
"head": 7,
"deprel": "amod",
"start_char": 24,
"end_char": 29
},
{
"id": 7,
"text": "public",
"lemma": "public",
"upos": "NOUN",
"feats": "Gender=Masc|Number=Sing",
"head": 3,
"deprel": "obl:arg",
"start_char": 30,
"end_char": 36,
"misc": "SpaceAfter=No"
},
{
"id": 8,
"text": ",",
"lemma": ",",
"upos": "PUNCT",
"head": 3,
"deprel": "punct",
"start_char": 36,
"end_char": 37
},
{
"id": 9,
"text": "Emmanuel",
"lemma": "Emmanuel",
"upos": "PROPN",
"head": 11,
"deprel": "nsubj",
"start_char": 38,
"end_char": 46
},
{
"id": 10,
"text": "Macron",
"lemma": "Macron",
"upos": "PROPN",
"head": 9,
"deprel": "flat:name",
"start_char": 47,
"end_char": 53
},
{
"id": 11,
"text": "devient",
"lemma": "devenir",
"upos": "VERB",
"feats": "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
"head": 0,
"deprel": "root",
"start_char": 54,
"end_char": 61
},
{
"id": 12,
"text": "en",
"lemma": "en",
"upos": "ADP",
"head": 13,
"deprel": "case",
"start_char": 62,
"end_char": 64
},
{
"id": 13,
"text": "2014",
"lemma": "2014",
"upos": "NUM",
"feats": "Number=Plur",
"head": 11,
"deprel": "obl:mod",
"start_char": 65,
"end_char": 69
},
{
"id": 14,
"text": "ministre",
"lemma": "ministre",
"upos": "NOUN",
"feats": "Gender=Masc|Number=Sing",
"head": 11,
"deprel": "xcomp",
"start_char": 70,
"end_char": 78
},
{
"id": 15,
"text": "de",
"lemma": "de",
"upos": "ADP",
"head": 17,
"deprel": "case",
"start_char": 79,
"end_char": 81
},
{
"id": 16,
"text": "l'",
"lemma": "le",
"upos": "DET",
"feats": "Definite=Def|Number=Sing|PronType=Art",
"head": 17,
"deprel": "det",
"start_char": 82,
"end_char": 84,
"misc": "SpaceAfter=No"
},
{
"id": 17,
"text": "Économie",
"lemma": "économie",
"upos": "NOUN",
"feats": "Gender=Fem|Number=Sing",
"head": 14,
"deprel": "nmod",
"start_char": 84,
"end_char": 92,
"misc": "SpaceAfter=No"
},
{
"id": 18,
"text": ",",
"lemma": ",",
"upos": "PUNCT",
"head": 21,
"deprel": "punct",
"start_char": 92,
"end_char": 93
},
{
"id": 19,
"text": "de",
"lemma": "de",
"upos": "ADP",
"head": 21,
"deprel": "case",
"start_char": 94,
"end_char": 96
},
{
"id": 20,
"text": "l'",
"lemma": "le",
"upos": "DET",
"feats": "Definite=Def|Number=Sing|PronType=Art",
"head": 21,
"deprel": "det",
"start_char": 97,
"end_char": 99,
"misc": "SpaceAfter=No"
},
{
"id": 21,
"text": "Industrie",
"lemma": "industrie",
"upos": "NOUN",
"feats": "Gender=Fem|Number=Sing",
"head": 17,
"deprel": "conj",
"start_char": 99,
"end_char": 108
},
{
"id": 22,
"text": "et",
"lemma": "et",
"upos": "CCONJ",
"head": 25,
"deprel": "cc",
"start_char": 109,
"end_char": 111
},
{
"id": [
23,
24
],
"text": "du",
"start_char": 112,
"end_char": 114
},
{
"id": 23,
"text": "de",
"lemma": "de",
"upos": "ADP",
"head": 25,
"deprel": "case"
},
{
"id": 24,
"text": "le",
"lemma": "le",
"upos": "DET",
"feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
"head": 25,
"deprel": "det"
},
{
"id": 25,
"text": "Numérique",
"lemma": "numérique",
"upos": "NOUN",
"feats": "Gender=Masc|Number=Sing",
"head": 17,
"deprel": "conj",
"start_char": 115,
"end_char": 124,
"misc": "SpaceAfter=No"
},
{
"id": 26,
"text": ".",
"lemma": ".",
"upos": "PUNCT",
"head": 11,
"deprel": "punct",
"start_char": 124,
"end_char": 125,
"misc": "SpaceAfter=No"
}
]
]
"""
class TestFrenchPipeline:
@pytest.fixture(scope="class")
def pipeline(self):
""" Create a pipeline with French models """
pipeline = stanza.Pipeline(processors='tokenize,mwt,pos,lemma,depparse', dir=TEST_MODELS_DIR, lang='fr')
return pipeline
def test_single(self, pipeline):
doc = pipeline(FR_MWT_SENTENCE)
compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
def test_bulk(self, pipeline):
NUM_DOCS = 10
raw_text = [FR_MWT_SENTENCE] * NUM_DOCS
raw_doc = [Document([], text=doccontent) for doccontent in raw_text]
result = pipeline(raw_doc)
assert len(result) == NUM_DOCS
for doc in result:
compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
assert len(doc.sentences) == 1
assert doc.num_words == 26
assert doc.num_tokens == 24
def test_on_gpu(self, pipeline):
"""
The default pipeline should have all the models on the GPU
"""
check_on_gpu(pipeline)
def test_on_cpu(self):
"""
Create a pipeline on the CPU, check that all the models on CPU
"""
pipeline = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, use_gpu=False)
check_on_cpu(pipeline)
================================================
FILE: stanza/tests/pipeline/test_lemmatizer.py
================================================
"""
Basic testing of lemmatization
"""
import pytest
import stanza
from stanza.tests import *
from stanza.models.common.doc import TEXT, UPOS, LEMMA
pytestmark = pytest.mark.pipeline
EN_DOC = "Joe Smith was born in California."
EN_DOC_IDENTITY_GOLD = """
Joe Joe
Smith Smith
was was
born born
in in
California California
. .
""".strip()
EN_DOC_LEMMATIZER_MODEL_GOLD = """
Joe Joe
Smith Smith
was be
born bear
in in
California California
. .
""".strip()
def test_identity_lemmatizer():
nlp = stanza.Pipeline(**{'processors': 'tokenize,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_use_identity': True}, download_method=None)
doc = nlp(EN_DOC)
word_lemma_pairs = []
for w in doc.iter_words():
word_lemma_pairs += [f"{w.text} {w.lemma}"]
assert EN_DOC_IDENTITY_GOLD == "\n".join(word_lemma_pairs)
def test_full_lemmatizer():
nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, download_method=None)
doc = nlp(EN_DOC)
word_lemma_pairs = []
for w in doc.iter_words():
word_lemma_pairs += [f"{w.text} {w.lemma}"]
assert EN_DOC_LEMMATIZER_MODEL_GOLD == "\n".join(word_lemma_pairs)
def find_unknown_word(lemmatizer, base):
for i in range(10):
base = base + "z"
if base not in lemmatizer.word_dict and all(x[0] != base for x in lemmatizer.composite_dict.keys()):
return base
raise RuntimeError("wtf?")
def test_store_results():
nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, lemma_store_results=True, download_method=None)
lemmatizer = nlp.processors["lemma"]._trainer
az = find_unknown_word(lemmatizer, "a")
bz = find_unknown_word(lemmatizer, "b")
cz = find_unknown_word(lemmatizer, "c")
# try sentences with the order long, short
doc = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
stuff = doc.get([TEXT, UPOS, LEMMA])
assert len(stuff) == 12
assert stuff[3][0] == az
assert stuff[6][0] == bz
assert stuff[11][0] == cz
assert lemmatizer.composite_dict[(az, stuff[3][1])] == stuff[3][2]
assert lemmatizer.composite_dict[(bz, stuff[6][1])] == stuff[6][2]
assert lemmatizer.composite_dict[(cz, stuff[11][1])] == stuff[11][2]
doc2 = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
stuff2 = doc2.get([TEXT, UPOS, LEMMA])
assert stuff == stuff2
dz = find_unknown_word(lemmatizer, "d")
ez = find_unknown_word(lemmatizer, "e")
fz = find_unknown_word(lemmatizer, "f")
# try sentences with the order long, short
doc = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
stuff = doc.get([TEXT, UPOS, LEMMA])
assert len(stuff) == 12
assert stuff[3][0] == dz
assert stuff[8][0] == ez
assert stuff[11][0] == fz
assert lemmatizer.composite_dict[(dz, stuff[3][1])] == stuff[3][2]
assert lemmatizer.composite_dict[(ez, stuff[8][1])] == stuff[8][2]
assert lemmatizer.composite_dict[(fz, stuff[11][1])] == stuff[11][2]
doc2 = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
stuff2 = doc2.get([TEXT, UPOS, LEMMA])
assert stuff == stuff2
assert az not in lemmatizer.word_dict
def test_caseless_lemmatizer():
"""
Test that setting the lemmatizer as caseless at Pipeline time lowercases the text
"""
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
# the capital letter here should throw off the lemmatizer & it won't remove the plural
# although weirdly the current English model *does* lowercase the A
doc = nlp("Here is an Excerpt")
assert doc.sentences[0].words[-1].lemma == 'excerpt'
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None, lemma_caseless=True)
# with the model set to lowercasing, the word will be treated as if it were 'antennae'
doc = nlp("Here is an Excerpt")
assert doc.sentences[0].words[-1].lemma == 'Excerpt'
def test_latin_caseless_lemmatizer():
"""
Test the Latin caseless lemmatizer
"""
nlp = stanza.Pipeline('la', package='ittb', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
lemmatizer = nlp.processors['lemma']
assert lemmatizer.config['caseless']
doc = nlp("Quod Erat Demonstrandum")
expected_lemmas = "qui sum demonstro".split()
assert len(doc.sentences) == 1
assert len(doc.sentences[0].words) == 3
for word, expected in zip(doc.sentences[0].words, expected_lemmas):
assert word.lemma == expected
def test_contextual_lemmatizer():
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, package={"lemma": "default_accurate"}, download_method="reuse_resources")
lemmatizer = nlp.processors['lemma']._trainer
# the accurate model should have a 's classifier
assert len(lemmatizer.contextual_lemmatizers) > 0
# ideally the doc would have 'have' as the lemma for the second
# word, but maybe it's not always accurate. actually, it works
# fine at the time of this test
doc = nlp("He's added a contextual lemmatizer")
================================================
FILE: stanza/tests/pipeline/test_pipeline_constituency_processor.py
================================================
import gc
import pytest
import stanza
from stanza.models.common.foundation_cache import FoundationCache
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# data for testing
TEST_TEXT = "This is a test. Another sentence. Are these sorted?"
TEST_TOKENS = [["This", "is", "a", "test", "."], ["Another", "sentence", "."], ["Are", "these", "sorted", "?"]]
@pytest.fixture(scope="module")
def foundation_cache():
# the test suite sometimes winds up holding on to GPU memory for too long,
# resulting in an OOM error
# occasionally calling gc.collect() will help
gc.collect()
return FoundationCache()
def check_results(doc):
assert len(doc.sentences) == len(TEST_TOKENS)
for sentence, expected in zip(doc.sentences, TEST_TOKENS):
assert sentence.constituency.leaf_labels() == expected
def test_sorted_big_batch(foundation_cache):
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
doc = pipe(TEST_TEXT)
check_results(doc)
def test_comments(foundation_cache):
"""
Test that the pipeline is creating constituency comments
"""
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
doc = pipe(TEST_TEXT)
check_results(doc)
for sentence in doc.sentences:
assert any(x.startswith("# constituency = ") for x in sentence.comments)
doc.sentences[0].constituency = "asdf"
assert "# constituency = asdf" in doc.sentences[0].comments
for sentence in doc.sentences:
assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 1
def test_illegal_batch_size(foundation_cache):
stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
with pytest.raises(ValueError):
stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
def test_sorted_one_batch(foundation_cache):
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=1, foundation_cache=foundation_cache, download_method=None)
doc = pipe(TEST_TEXT)
check_results(doc)
def test_sorted_two_batch(foundation_cache):
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=2, foundation_cache=foundation_cache, download_method=None)
doc = pipe(TEST_TEXT)
check_results(doc)
def test_get_constituents(foundation_cache):
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
assert "SBAR" in pipe.processors["constituency"].get_constituents()
================================================
FILE: stanza/tests/pipeline/test_pipeline_depparse_processor.py
================================================
"""
Basic testing of part of speech tagging
"""
import pytest
import stanza
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.tests import TEST_MODELS_DIR
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
class TestClassifier:
@pytest.fixture(scope="class")
def english_depparse(self):
"""
Get a depparse_processor for English
"""
nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})
assert 'depparse' in nlp.processors
return nlp.processors['depparse']
def test_get_known_relations(self, english_depparse):
"""
Test getting the known relations from a processor.
Doesn't test that all the relations exist, since who knows what will change in the future
"""
relations = english_depparse.get_known_relations()
assert len(relations) > 5
assert 'case' in relations
for i in VOCAB_PREFIX:
assert i not in relations
================================================
FILE: stanza/tests/pipeline/test_pipeline_mwt_expander.py
================================================
"""
Basic testing of multi-word-token expansion
"""
import pytest
import stanza
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# mwt data for testing
FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
"l'Industrie et du Numérique."
FR_MWT_TOKEN_TO_WORDS_GOLD = """
token: Alors words: []
token: encore words: []
token: inconnu words: []
token: du words: [, ]
token: grand words: []
token: public words: []
token: , words: []
token: Emmanuel words: []
token: Macron words: []
token: devient words: []
token: en words: []
token: 2014 words: []
token: ministre words: []
token: de words: []
token: l' words: []
token: Économie words: []
token: , words: []
token: de words: []
token: l' words: []
token: Industrie words: []
token: et words: []
token: du words: [, ]
token: Numérique words: []
token: . words: []
""".strip()
FR_MWT_WORD_TO_TOKEN_GOLD = """
word: Alors token parent:1-Alors
word: encore token parent:2-encore
word: inconnu token parent:3-inconnu
word: de token parent:4-5-du
word: le token parent:4-5-du
word: grand token parent:6-grand
word: public token parent:7-public
word: , token parent:8-,
word: Emmanuel token parent:9-Emmanuel
word: Macron token parent:10-Macron
word: devient token parent:11-devient
word: en token parent:12-en
word: 2014 token parent:13-2014
word: ministre token parent:14-ministre
word: de token parent:15-de
word: l' token parent:16-l'
word: Économie token parent:17-Économie
word: , token parent:18-,
word: de token parent:19-de
word: l' token parent:20-l'
word: Industrie token parent:21-Industrie
word: et token parent:22-et
word: de token parent:23-24-du
word: le token parent:23-24-du
word: Numérique token parent:25-Numérique
word: . token parent:26-.
""".strip()
def test_mwt():
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='fr', download_method=None)
doc = pipeline(FR_MWT_SENTENCE)
token_to_words = "\n".join(
[f'token: {token.text.ljust(9)}\t\twords: [{", ".join([word.pretty_print() for word in token.words])}]' for sent in doc.sentences for token in sent.tokens]
).strip()
word_to_token = "\n".join(
[f'word: {word.text.ljust(9)}\t\ttoken parent:{"-".join([str(x) for x in word.parent.id])}-{word.parent.text}'
for sent in doc.sentences for word in sent.words]).strip()
assert token_to_words == FR_MWT_TOKEN_TO_WORDS_GOLD
assert word_to_token == FR_MWT_WORD_TO_TOKEN_GOLD
def test_unknown_character():
"""
The MWT processor has a mechanism to temporarily add unknown characters to the vocab
Here we check that it is properly adding the characters from a test case a user sent us
"""
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
text = "Björkängshallen's"
mwt_processor = pipeline.processors["mwt"]
trainer = mwt_processor.trainer
# verify that the test case is still valid
# (perhaps an updated MWT model will have all of these characters in the future)
assert not all(x in trainer.vocab._unit2id for x in text)
doc = pipeline(text)
batch = mwt_processor.build_batch(doc)
# the vocab used in this batch should have the missing characters
assert all(x in batch.vocab._unit2id for x in text)
def test_unknown_word():
"""
Test a word which wasn't in the MWT training data
The seq2seq model for MWT was randomly hallucinating, but with the
CharacterClassifier, it should be able to process unusual MWT
without hallucinations
"""
pipe = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
doc = pipe("I read the newspaper's report.")
assert len(doc.sentences) == 1
assert len(doc.sentences[0].tokens) == 6
assert len(doc.sentences[0].tokens[3].words) == 2
assert doc.sentences[0].tokens[3].words[0].text == 'newspaper'
# double check that this is something unknown to the model
mwt_processor = pipe.processors["mwt"]
trainer = mwt_processor.trainer
expansion = trainer.dict_expansion("newspaper's")
assert expansion is None
================================================
FILE: stanza/tests/pipeline/test_pipeline_ner_processor.py
================================================
import pytest
import stanza
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import Document
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# data for testing
EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."]
EXPECTED_ENTS = [[{
"text": "Barack Obama",
"type": "PERSON",
"start_char": 0,
"end_char": 12
}, {
"text": "Hawaii",
"type": "GPE",
"start_char": 25,
"end_char": 31
}],
[{
"text": "2008",
"type": "DATE",
"start_char": 28,
"end_char": 32
}],
[{
"text": "Obama",
"type": "PERSON",
"start_char": 0,
"end_char": 5
}, {
"text": "Harvard",
"type": "ORG",
"start_char": 15,
"end_char": 22
}]]
def check_entities_equal(doc, expected):
"""
Checks that the entities of a doc are equal to the given list of maps
"""
assert len(doc.ents) == len(expected)
for doc_entity, expected_entity in zip(doc.ents, expected):
for k in expected_entity:
assert getattr(doc_entity, k) == expected_entity[k]
class TestNERProcessor:
@pytest.fixture(scope="class")
def pipeline(self):
"""
A reusable pipeline with the NER module
"""
return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,ner")
@pytest.fixture(scope="class")
def processed_doc(self, pipeline):
""" Document created by running full English pipeline on a few sentences """
return [pipeline(text) for text in EN_DOCS]
@pytest.fixture(scope="class")
def processed_bulk(self, pipeline):
""" Document created by running full English pipeline on a few sentences """
docs = [Document([], text=t) for t in EN_DOCS]
return pipeline(docs)
def test_bulk_ents(self, processed_bulk):
assert len(processed_bulk) == len(EXPECTED_ENTS)
for doc, expected in zip(processed_bulk, EXPECTED_ENTS):
check_entities_equal(doc, expected)
def test_ents(self, processed_doc):
assert len(processed_doc) == len(EXPECTED_ENTS)
for doc, expected in zip(processed_doc, EXPECTED_ENTS):
check_entities_equal(doc, expected)
EXPECTED_MULTI_ENTS = [{
"text": "John Bauer",
"type": "PERSON",
"start_char": 0,
"end_char": 10
}, {
"text": "Stanford",
"type": "ORG",
"start_char": 20,
"end_char": 28
}, {
"text": "hip arthritis",
"type": "DISEASE",
"start_char": 37,
"end_char": 50
}, {
"text": "Chris Manning",
"type": "PERSON",
"start_char": 66,
"end_char": 79
}]
EXPECTED_MULTI_NER = [
[('O', 'B-PERSON'),
('O', 'E-PERSON'),
('O', 'O'),
('O', 'O'),
('O', 'S-ORG'),
('O', 'O'),
('O', 'O'),
('B-DISEASE', 'O'),
('E-DISEASE', 'O'),
('O', 'O')],
[('O', 'O'),
('O', 'O'),
('O', 'O'),
('O', 'B-PERSON'),
('O', 'E-PERSON'),]]
class TestMultiNERProcessor:
@pytest.fixture(scope="class")
def pipeline(self):
"""
A reusable pipeline with TWO ner models
"""
return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,ner", package={"ner": ["ncbi_disease", "ontonotes_charlm"]})
def test_multi_example(self, pipeline):
doc = pipeline("John Bauer works at Stanford and has hip arthritis. He works for Chris Manning")
check_entities_equal(doc, EXPECTED_MULTI_ENTS)
def test_multi_ner(self, pipeline):
"""
Test that multiple NER labels are correctly assigned in tuples
"""
doc = pipeline("John Bauer works at Stanford and has hip arthritis. He works for Chris Manning")
multi_ner = [[token.multi_ner for token in sentence.tokens] for sentence in doc.sentences]
assert multi_ner == EXPECTED_MULTI_NER
def test_known_tags(self, pipeline):
assert pipeline.processors["ner"].get_known_tags() == ["DISEASE"]
assert len(pipeline.processors["ner"].get_known_tags(1)) == 18
================================================
FILE: stanza/tests/pipeline/test_pipeline_pos_processor.py
================================================
"""
Basic testing of part of speech tagging
"""
import pytest
import stanza
from stanza.tests import *
pytestmark = pytest.mark.pipeline
EN_DOC = "Joe Smith was born in California."
EN_DOC_GOLD = """
]>
]>
]>
]>
]>
]>
]>
""".strip()
@pytest.fixture(scope="module")
def pos_pipeline():
return stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'en'})
def test_part_of_speech(pos_pipeline):
doc = pos_pipeline(EN_DOC)
assert EN_DOC_GOLD == '\n\n'.join([sent.tokens_string() for sent in doc.sentences])
def test_get_known_xpos(pos_pipeline):
tags = pos_pipeline.processors['pos'].get_known_xpos()
# make sure we have xpos...
assert 'DT' in tags
# ... and not upos
assert 'DET' not in tags
def test_get_known_upos(pos_pipeline):
tags = pos_pipeline.processors['pos'].get_known_upos()
# make sure we have upos...
assert 'DET' in tags
# ... and not xpos
assert 'DT' not in tags
def test_get_known_feats(pos_pipeline):
feats = pos_pipeline.processors['pos'].get_known_feats()
# I appreciate how self-referential the Abbr feat is
assert 'Abbr' in feats
assert 'Yes' in feats['Abbr']
================================================
FILE: stanza/tests/pipeline/test_pipeline_sentiment_processor.py
================================================
import gc
import pytest
import stanza
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import Document
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
# data for testing
EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."]
EN_DOC = " ".join(EN_DOCS)
EXPECTED = [0, 1, 2]
class TestSentimentPipeline:
@pytest.fixture(scope="class")
def pipeline(self):
"""
A reusable pipeline with the NER module
"""
gc.collect()
return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment")
def test_simple(self, pipeline):
results = []
for text in EN_DOCS:
doc = pipeline(text)
assert len(doc.sentences) == 1
results.append(doc.sentences[0].sentiment)
assert EXPECTED == results
def test_multiple_sentences(self, pipeline):
doc = pipeline(EN_DOC)
assert len(doc.sentences) == 3
results = [sentence.sentiment for sentence in doc.sentences]
assert EXPECTED == results
def test_empty_text(self, pipeline):
"""
Test empty text and a text which might get reduced to empty text by removing dashes
"""
doc = pipeline("")
assert len(doc.sentences) == 0
doc = pipeline("--")
assert len(doc.sentences) == 1
================================================
FILE: stanza/tests/pipeline/test_requirements.py
================================================
"""
Test the requirements functionality for processors
"""
import pytest
import stanza
from stanza.pipeline.core import PipelineRequirementsException
from stanza.pipeline.processor import ProcessorRequirementsException
from stanza.tests import *
pytestmark = pytest.mark.pipeline
def check_exception_vals(req_exception, req_exception_vals):
"""
Check the values of a ProcessorRequirementsException against a dict of expected values.
:param req_exception: the ProcessorRequirementsException to evaluate
:param req_exception_vals: expected values for the ProcessorRequirementsException
:return: None
"""
assert isinstance(req_exception, ProcessorRequirementsException)
assert req_exception.processor_type == req_exception_vals['processor_type']
assert req_exception.processors_list == req_exception_vals['processors_list']
assert req_exception.err_processor.requires == req_exception_vals['requires']
def test_missing_requirements():
"""
Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions.
:return: None
"""
# list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs
bad_config_lists = [
# missing tokenize
(
# input config
{'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'},
# 2 expected exceptions
[
{'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]),
'requires': set(['tokenize'])},
{'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'],
'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])}
]
),
# no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list
(
# input config
{'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True},
# 1 expected exception
[
{'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'],
'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])}
]
)
]
# try to build each bad config, catch exceptions, check against gold
pipeline_fails = 0
for bad_config, gold_exceptions in bad_config_lists:
try:
stanza.Pipeline(**bad_config)
except PipelineRequirementsException as e:
pipeline_fails += 1
assert isinstance(e, PipelineRequirementsException)
assert len(e.processor_req_fails) == len(gold_exceptions)
for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions):
# compare the thrown ProcessorRequirementsExceptions against gold
check_exception_vals(processor_req_e, gold_exception)
# check pipeline building failed twice
assert pipeline_fails == 2
================================================
FILE: stanza/tests/pipeline/test_tokenizer.py
================================================
"""
Basic testing of tokenization
"""
import pytest
import stanza
from stanza.tests import *
pytestmark = pytest.mark.pipeline
EN_DOC = "Joe Smith lives in California. Joe's favorite food is pizza. He enjoys going to the beach."
EN_DOC_WITH_EXTRA_WHITESPACE = "Joe Smith \n lives in\n California. Joe's favorite food \tis pizza. \t\t\tHe enjoys \t\tgoing to the beach."
EN_DOC_GOLD_TOKENS = """
]>
]>
]>
]>
]>
]>
, ]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
""".strip()
# spaCy doesn't have MWT
EN_DOC_SPACY_TOKENS = """
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
""".strip()
EN_DOC_POSTPROCESSOR_TOKENS_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], [("Joe's", True), 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', 'to', 'the', 'beach', '.']]
EN_DOC_POSTPROCESSOR_COMBINED_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], ['Joe', "'s", 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', "to the beach", '.']]
EN_DOC_POSTPROCESSOR_COMBINED_TOKENS = """
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
"""
# ensure that the entry above has spaces somewhere to test that spaces work in between tokens
EN_DOC_GOLD_NOSSPLIT_TOKENS = """
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>
]>